Skip to content
Snippets Groups Projects
Code owners
Assign users and groups as approvers for specific file changes. Learn more.
openai.ts 14.52 KiB
import { getEnv } from "@llamaindex/env";
import _ from "lodash";
import type OpenAILLM from "openai";
import type {
  ClientOptions,
  ClientOptions as OpenAIClientOptions,
} from "openai";
import { AzureOpenAI, OpenAI as OrigOpenAI } from "openai";
import type { ChatModel } from "openai/resources/chat/chat";

import {
  type BaseTool,
  type ChatMessage,
  type ChatResponse,
  type ChatResponseChunk,
  type LLM,
  type LLMChatParamsNonStreaming,
  type LLMChatParamsStreaming,
  type LLMMetadata,
  type MessageType,
  type PartialToolCall,
  ToolCallLLM,
  type ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms";
import {
  extractText,
  wrapEventCaller,
  wrapLLMEvent,
} from "@llamaindex/core/utils";
import { Tokenizers } from "@llamaindex/env";
import type {
  ChatCompletionAssistantMessageParam,
  ChatCompletionMessageToolCall,
  ChatCompletionRole,
  ChatCompletionSystemMessageParam,
  ChatCompletionTool,
  ChatCompletionToolMessageParam,
  ChatCompletionUserMessageParam,
} from "openai/resources/chat/completions";
import type { ChatCompletionMessageParam } from "openai/resources/index.js";
import type { AzureOpenAIConfig } from "./azure.js";
import {
  getAzureConfigFromEnv,
  getAzureModel,
  shouldUseAzure,
} from "./azure.js";

export class OpenAISession {
  openai: Pick<OrigOpenAI, "chat" | "embeddings">;

  constructor(options: ClientOptions & { azure?: boolean } = {}) {
    if (options.azure) {
      this.openai = new AzureOpenAI(options as AzureOpenAIConfig);
    } else {
      if (!options.apiKey) {
        options.apiKey = getEnv("OPENAI_API_KEY");
      }

      if (!options.apiKey) {
        throw new Error("Set OpenAI Key in OPENAI_API_KEY env variable"); // Overriding OpenAI package's error message
      }

      this.openai = new OrigOpenAI({
        ...options,
      });
    }
  }
}

// I'm not 100% sure this is necessary vs. just starting a new session
// every time we make a call. They say they try to reuse connections
// so in theory this is more efficient, but we should test it in the future.
const defaultOpenAISession: {
  session: OpenAISession;
  options: ClientOptions;
}[] = [];

/**
 * Get a session for the OpenAI API. If one already exists with the same options,
 * it will be returned. Otherwise, a new session will be created.
 * @param options
 * @returns
 */
export function getOpenAISession(
  options: ClientOptions & { azure?: boolean } = {},
) {
  let session = defaultOpenAISession.find((session) => {
    return _.isEqual(session.options, options);
  })?.session;

  if (!session) {
    session = new OpenAISession(options);
    defaultOpenAISession.push({ session, options });
  }

  return session;
}

export const GPT4_MODELS = {
  "chatgpt-4o-latest": {
    contextWindow: 128000,
  },
  "gpt-4": { contextWindow: 8192 },
  "gpt-4-32k": { contextWindow: 32768 },
  "gpt-4-32k-0613": { contextWindow: 32768 },
  "gpt-4-turbo": { contextWindow: 128000 },
  "gpt-4-turbo-preview": { contextWindow: 128000 },
  "gpt-4-1106-preview": { contextWindow: 128000 },
  "gpt-4-0125-preview": { contextWindow: 128000 },
  "gpt-4-vision-preview": { contextWindow: 128000 },
  "gpt-4o": { contextWindow: 128000 },
  "gpt-4o-2024-05-13": { contextWindow: 128000 },
  "gpt-4o-mini": { contextWindow: 128000 },
  "gpt-4o-mini-2024-07-18": { contextWindow: 128000 },
  "gpt-4o-2024-08-06": { contextWindow: 128000 },
  "gpt-4o-2024-09-14": { contextWindow: 128000 },
  "gpt-4o-2024-10-14": { contextWindow: 128000 },
  "gpt-4-0613": { contextWindow: 128000 },
  "gpt-4-turbo-2024-04-09": { contextWindow: 128000 },
  "gpt-4-0314": { contextWindow: 128000 },
  "gpt-4-32k-0314": { contextWindow: 32768 },
};

// NOTE we don't currently support gpt-3.5-turbo-instruct and don't plan to in the near future
export const GPT35_MODELS = {
  "gpt-3.5-turbo": { contextWindow: 16385 },
  "gpt-3.5-turbo-0613": { contextWindow: 4096 },
  "gpt-3.5-turbo-16k": { contextWindow: 16385 },
  "gpt-3.5-turbo-16k-0613": { contextWindow: 16385 },
  "gpt-3.5-turbo-1106": { contextWindow: 16385 },
  "gpt-3.5-turbo-0125": { contextWindow: 16385 },
  "gpt-3.5-turbo-0301": { contextWindow: 16385 },
};

export const O1_MODELS = {
  "o1-preview": {
    contextWindow: 128000,
  },
  "o1-preview-2024-09-12": {
    contextWindow: 128000,
  },
  "o1-mini": {
    contextWindow: 128000,
  },
  "o1-mini-2024-09-12": {
    contextWindow: 128000,
  },
};

/**
 * We currently support GPT-3.5 and GPT-4 models
 */
export const ALL_AVAILABLE_OPENAI_MODELS = {
  ...GPT4_MODELS,
  ...GPT35_MODELS,
  ...O1_MODELS,
} satisfies Record<ChatModel, { contextWindow: number }>;

export function isFunctionCallingModel(llm: LLM): llm is OpenAI {
  let model: string;
  if (llm instanceof OpenAI) {
    model = llm.model;
  } else if ("model" in llm && typeof llm.model === "string") {
    model = llm.model;
  } else {
    return false;
  }
  const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model);
  const isOld = model.includes("0314") || model.includes("0301");
  const isO1 = model.startsWith("o1");
  return isChatModel && !isOld && !isO1;
}

export type OpenAIAdditionalMetadata = {};

export type OpenAIAdditionalChatOptions = Omit<
  Partial<OpenAILLM.Chat.ChatCompletionCreateParams>,
  | "max_tokens"
  | "messages"
  | "model"
  | "temperature"
  | "top_p"
  | "stream"
  | "tools"
  | "toolChoice"
>;

export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
  model:
    | ChatModel
    // string & {} is a hack to allow any string, but still give autocomplete
    | (string & {});
  temperature: number;
  topP: number;
  maxTokens?: number | undefined;
  additionalChatOptions?: OpenAIAdditionalChatOptions | undefined;

  // OpenAI session params
  apiKey?: string | undefined = undefined;
  maxRetries: number;
  timeout?: number;
  session: OpenAISession;
  additionalSessionOptions?:
    | undefined
    | Omit<Partial<OpenAIClientOptions>, "apiKey" | "maxRetries" | "timeout">;

  constructor(
    init?: Partial<OpenAI> & {
      azure?: AzureOpenAIConfig;
    },
  ) {
    super();
    this.model = init?.model ?? "gpt-4o";
    this.temperature = init?.temperature ?? 0.1;
    this.topP = init?.topP ?? 1;
    this.maxTokens = init?.maxTokens ?? undefined;

    this.maxRetries = init?.maxRetries ?? 10;
    this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
    this.additionalChatOptions = init?.additionalChatOptions;
    this.additionalSessionOptions = init?.additionalSessionOptions;

    if (init?.azure || shouldUseAzure()) {
      const azureConfig = {
        ...getAzureConfigFromEnv({
          model: getAzureModel(this.model),
        }),
        ...init?.azure,
      };

      this.apiKey = azureConfig.apiKey;
      this.session =
        init?.session ??
        getOpenAISession({
          azure: true,
          maxRetries: this.maxRetries,
          timeout: this.timeout,
          ...this.additionalSessionOptions,
          ...azureConfig,
        });
    } else {
      this.apiKey = init?.apiKey ?? undefined;
      this.session =
        init?.session ??
        getOpenAISession({
          apiKey: this.apiKey,
          maxRetries: this.maxRetries,
          timeout: this.timeout,
          ...this.additionalSessionOptions,
        });
    }
  }

  get supportToolCall() {
    return isFunctionCallingModel(this);
  }

  get metadata(): LLMMetadata & OpenAIAdditionalMetadata {
    const contextWindow =
      ALL_AVAILABLE_OPENAI_MODELS[
        this.model as keyof typeof ALL_AVAILABLE_OPENAI_MODELS
      ]?.contextWindow ?? 1024;
    return {
      model: this.model,
      temperature: this.temperature,
      topP: this.topP,
      maxTokens: this.maxTokens,
      contextWindow,
      tokenizer: Tokenizers.CL100K_BASE,
    };
  }

  static toOpenAIRole(messageType: MessageType): ChatCompletionRole {
    switch (messageType) {
      case "user":
        return "user";
      case "assistant":
        return "assistant";
      case "system":
        return "system";
      default:
        return "user";
    }
  }

  static toOpenAIMessage(
    messages: ChatMessage<ToolCallLLMMessageOptions>[],
  ): ChatCompletionMessageParam[] {
    return messages.map((message) => {
      const options = message.options ?? {};
      if ("toolResult" in options) {
        return {
          tool_call_id: options.toolResult.id,
          role: "tool",
          content: extractText(message.content),
        } satisfies ChatCompletionToolMessageParam;
      } else if ("toolCall" in options) {
        return {
          role: "assistant",
          content: extractText(message.content),
          tool_calls: options.toolCall.map((toolCall) => {
            return {
              id: toolCall.id,
              type: "function",
              function: {
                name: toolCall.name,
                arguments:
                  typeof toolCall.input === "string"
                    ? toolCall.input
                    : JSON.stringify(toolCall.input),
              },
            };
          }),
        } satisfies ChatCompletionAssistantMessageParam;
      } else if (message.role === "user") {
        return {
          role: "user",
          content: message.content,
        } satisfies ChatCompletionUserMessageParam;
      }

      const response:
        | ChatCompletionSystemMessageParam
        | ChatCompletionUserMessageParam
        | ChatCompletionMessageToolCall = {
        // fixme(alex): type assertion
        role: OpenAI.toOpenAIRole(message.role) as never,
        // fixme: should not extract text, but assert content is string
        content: extractText(message.content),
      };
      return response;
    });
  }

  chat(
    params: LLMChatParamsStreaming<
      OpenAIAdditionalChatOptions,
      ToolCallLLMMessageOptions
    >,
  ): Promise<AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>>>;
  chat(
    params: LLMChatParamsNonStreaming<
      OpenAIAdditionalChatOptions,
      ToolCallLLMMessageOptions
    >,
  ): Promise<ChatResponse<ToolCallLLMMessageOptions>>;
  @wrapEventCaller
  @wrapLLMEvent
  async chat(
    params:
      | LLMChatParamsNonStreaming<
          OpenAIAdditionalChatOptions,
          ToolCallLLMMessageOptions
        >
      | LLMChatParamsStreaming<
          OpenAIAdditionalChatOptions,
          ToolCallLLMMessageOptions
        >,
  ): Promise<
    | ChatResponse<ToolCallLLMMessageOptions>
    | AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>>
  > {
    const { messages, stream, tools, additionalChatOptions } = params;
    const baseRequestParams = <OpenAILLM.Chat.ChatCompletionCreateParams>{
      model: this.model,
      temperature: this.temperature,
      max_tokens: this.maxTokens,
      tools: tools?.map(OpenAI.toTool),
      messages: OpenAI.toOpenAIMessage(messages),
      top_p: this.topP,
      ...Object.assign({}, this.additionalChatOptions, additionalChatOptions),
    };

    if (
      Array.isArray(baseRequestParams.tools) &&
      baseRequestParams.tools.length === 0
    ) {
      // remove empty tools array to avoid OpenAI error
      delete baseRequestParams.tools;
    }

    // Streaming
    if (stream) {
      return this.streamChat(baseRequestParams);
    }

    // Non-streaming
    const response = await this.session.openai.chat.completions.create({
      ...baseRequestParams,
      stream: false,
    });

    const content = response.choices[0]!.message?.content ?? "";

    return {
      raw: response,
      message: {
        content,
        role: response.choices[0]!.message.role,
        options: response.choices[0]!.message?.tool_calls
          ? {
              toolCall: response.choices[0]!.message.tool_calls.map(
                (toolCall) => ({
                  id: toolCall.id,
                  name: toolCall.function.name,
                  input: toolCall.function.arguments,
                }),
              ),
            }
          : {},
      },
    };
  }

  // todo: this wrapper is ugly, refactor it
  @wrapEventCaller
  protected async *streamChat(
    baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams,
  ): AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> {
    const stream: AsyncIterable<OpenAILLM.Chat.ChatCompletionChunk> =
      await this.session.openai.chat.completions.create({
        ...baseRequestParams,
        stream: true,
      });

    // TODO: add callback to streamConverter and use streamConverter here
    // this will be used to keep track of the current tool call, make sure input are valid json object.
    let currentToolCall: PartialToolCall | null = null;
    const toolCallMap = new Map<string, PartialToolCall>();
    for await (const part of stream) {
      if (part.choices.length === 0) continue;
      const choice = part.choices[0]!;
      // skip parts that don't have any content
      if (!(choice.delta.content || choice.delta.tool_calls)) continue;

      let shouldEmitToolCall: PartialToolCall | null = null;
      if (
        choice.delta.tool_calls?.[0]!.id &&
        currentToolCall &&
        choice.delta.tool_calls?.[0].id !== currentToolCall.id
      ) {
        shouldEmitToolCall = {
          ...currentToolCall,
          input: JSON.parse(currentToolCall.input),
        };
      }
      if (choice.delta.tool_calls?.[0]!.id) {
        currentToolCall = {
          name: choice.delta.tool_calls[0].function!.name!,
          id: choice.delta.tool_calls[0].id,
          input: choice.delta.tool_calls[0].function!.arguments!,
        };
        toolCallMap.set(choice.delta.tool_calls[0].id, currentToolCall);
      } else {
        if (choice.delta.tool_calls?.[0]!.function?.arguments) {
          currentToolCall!.input +=
            choice.delta.tool_calls[0].function.arguments;
        }
      }

      const isDone: boolean = choice.finish_reason !== null;

      if (isDone && currentToolCall) {
        // for the last one, we need to emit the tool call
        shouldEmitToolCall = {
          ...currentToolCall,
          input: JSON.parse(currentToolCall.input),
        };
      }

      yield {
        raw: part,
        options: shouldEmitToolCall
          ? { toolCall: [shouldEmitToolCall] }
          : currentToolCall
            ? {
                toolCall: [currentToolCall],
              }
            : {},
        delta: choice.delta.content ?? "",
      };
    }
    toolCallMap.clear();
    return;
  }

  static toTool(tool: BaseTool): ChatCompletionTool {
    return {
      type: "function",
      function: tool.metadata.parameters
        ? {
            name: tool.metadata.name,
            description: tool.metadata.description,
            parameters: tool.metadata.parameters,
          }
        : {
            name: tool.metadata.name,
            description: tool.metadata.description,
          },
    };
  }
}