From 2a8241328de14f55991e1a22178bc1abf64e88a8 Mon Sep 17 00:00:00 2001 From: Alex Yang <himself65@outlook.com> Date: Thu, 3 Oct 2024 17:12:33 -0700 Subject: [PATCH] fix: lazy load openai (#1294) --- .changeset/afraid-pots-raise.md | 7 + examples/agent/large_toolcall_with_gpt4o.ts | 2 +- packages/llamaindex/src/Settings.ts | 7 - .../llamaindex/src/embeddings/fireworks.ts | 2 +- .../llamaindex/src/embeddings/together.ts | 2 +- packages/llamaindex/src/index.edge.ts | 16 ++- packages/llamaindex/src/llm/deepinfra.ts | 2 +- packages/llamaindex/src/llm/deepseek.ts | 4 +- packages/llamaindex/src/llm/fireworks.ts | 2 +- packages/llamaindex/src/llm/together.ts | 2 +- packages/llamaindex/tests/init.test.ts | 7 + packages/llm/groq/src/llm.ts | 11 +- packages/llm/openai/src/azure.ts | 25 ++-- packages/llm/openai/src/embedding.ts | 81 ++++++++---- packages/llm/openai/src/index.ts | 3 - packages/llm/openai/src/llm.ts | 125 +++++++----------- 16 files changed, 148 insertions(+), 150 deletions(-) create mode 100644 .changeset/afraid-pots-raise.md create mode 100644 packages/llamaindex/tests/init.test.ts diff --git a/.changeset/afraid-pots-raise.md b/.changeset/afraid-pots-raise.md new file mode 100644 index 000000000..1fe33ecf9 --- /dev/null +++ b/.changeset/afraid-pots-raise.md @@ -0,0 +1,7 @@ +--- +"llamaindex": patch +"@llamaindex/groq": patch +"@llamaindex/openai": patch +--- + +fix(core): set `Settings.llm` to OpenAI by default and support lazy load openai diff --git a/examples/agent/large_toolcall_with_gpt4o.ts b/examples/agent/large_toolcall_with_gpt4o.ts index aa067b75b..522adf1e2 100644 --- a/examples/agent/large_toolcall_with_gpt4o.ts +++ b/examples/agent/large_toolcall_with_gpt4o.ts @@ -13,7 +13,7 @@ import { FunctionTool, OpenAI, ToolCallOptions } from "llamaindex"; } })(); -async function callLLM(init: Partial<OpenAI>) { +async function callLLM(init: { model: string }) { const csvData = "Country,Average Height (cm)\nNetherlands,156\nDenmark,158\nNorway,160"; diff --git a/packages/llamaindex/src/Settings.ts b/packages/llamaindex/src/Settings.ts index b522a6a81..3294de30b 100644 --- a/packages/llamaindex/src/Settings.ts +++ b/packages/llamaindex/src/Settings.ts @@ -2,7 +2,6 @@ import { type CallbackManager, Settings as CoreSettings, } from "@llamaindex/core/global"; -import { OpenAI } from "@llamaindex/openai"; import { PromptHelper } from "@llamaindex/core/indices"; @@ -61,12 +60,6 @@ class GlobalSettings implements Config { } get llm(): LLM { - // fixme: we might need check internal error instead of try-catch here - try { - CoreSettings.llm; - } catch (error) { - CoreSettings.llm = new OpenAI(); - } return CoreSettings.llm; } diff --git a/packages/llamaindex/src/embeddings/fireworks.ts b/packages/llamaindex/src/embeddings/fireworks.ts index 8338884c1..03676f9b2 100644 --- a/packages/llamaindex/src/embeddings/fireworks.ts +++ b/packages/llamaindex/src/embeddings/fireworks.ts @@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env"; import { OpenAIEmbedding } from "@llamaindex/openai"; export class FireworksEmbedding extends OpenAIEmbedding { - constructor(init?: Partial<OpenAIEmbedding>) { + constructor(init?: Omit<Partial<OpenAIEmbedding>, "session">) { const { apiKey = getEnv("FIREWORKS_API_KEY"), additionalSessionOptions = {}, diff --git a/packages/llamaindex/src/embeddings/together.ts b/packages/llamaindex/src/embeddings/together.ts index 1ed43fef7..f51189de0 100644 --- a/packages/llamaindex/src/embeddings/together.ts +++ b/packages/llamaindex/src/embeddings/together.ts @@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env"; import { OpenAIEmbedding } from "@llamaindex/openai"; export class TogetherEmbedding extends OpenAIEmbedding { - constructor(init?: Partial<OpenAIEmbedding>) { + constructor(init?: Omit<Partial<OpenAIEmbedding>, "session">) { const { apiKey = getEnv("TOGETHER_API_KEY"), additionalSessionOptions = {}, diff --git a/packages/llamaindex/src/index.edge.ts b/packages/llamaindex/src/index.edge.ts index 3df610225..935b27e5b 100644 --- a/packages/llamaindex/src/index.edge.ts +++ b/packages/llamaindex/src/index.edge.ts @@ -1,3 +1,15 @@ +//#region initial setup for OpenAI +import { OpenAI } from "@llamaindex/openai"; +import { Settings } from "./Settings.js"; + +try { + Settings.llm; +} catch { + Settings.llm = new OpenAI(); +} + +//#endregion + export { LlamaParseReader, type Language, @@ -28,12 +40,12 @@ export type { JSONArray, JSONObject, JSONValue, + LlamaIndexEventMaps, LLMEndEvent, LLMStartEvent, LLMStreamEvent, LLMToolCallEvent, LLMToolResultEvent, - LlamaIndexEventMaps, } from "@llamaindex/core/global"; export * from "@llamaindex/core/indices"; export * from "@llamaindex/core/llms"; @@ -61,7 +73,7 @@ export * from "./postprocessors/index.js"; export * from "./QuestionGenerator.js"; export * from "./selectors/index.js"; export * from "./ServiceContext.js"; -export { Settings } from "./Settings.js"; export * from "./storage/StorageContext.js"; export * from "./tools/index.js"; export * from "./types.js"; +export { Settings }; diff --git a/packages/llamaindex/src/llm/deepinfra.ts b/packages/llamaindex/src/llm/deepinfra.ts index c2c8bde81..e56e9469c 100644 --- a/packages/llamaindex/src/llm/deepinfra.ts +++ b/packages/llamaindex/src/llm/deepinfra.ts @@ -6,7 +6,7 @@ const DEFAULT_MODEL = "mistralai/Mixtral-8x22B-Instruct-v0.1"; const BASE_URL = "https://api.deepinfra.com/v1/openai"; export class DeepInfra extends OpenAI { - constructor(init?: Partial<OpenAI>) { + constructor(init?: Omit<Partial<OpenAI>, "session">) { const { apiKey = getEnv(ENV_VARIABLE_NAME), additionalSessionOptions = {}, diff --git a/packages/llamaindex/src/llm/deepseek.ts b/packages/llamaindex/src/llm/deepseek.ts index d8a258683..1638475fb 100644 --- a/packages/llamaindex/src/llm/deepseek.ts +++ b/packages/llamaindex/src/llm/deepseek.ts @@ -10,7 +10,9 @@ type DeepSeekModelName = keyof typeof DEEPSEEK_MODELS; const DEFAULT_MODEL: DeepSeekModelName = "deepseek-coder"; export class DeepSeekLLM extends OpenAI { - constructor(init?: Partial<OpenAI> & { model?: DeepSeekModelName }) { + constructor( + init?: Omit<Partial<OpenAI>, "session"> & { model?: DeepSeekModelName }, + ) { const { apiKey = getEnv("DEEPSEEK_API_KEY"), additionalSessionOptions = {}, diff --git a/packages/llamaindex/src/llm/fireworks.ts b/packages/llamaindex/src/llm/fireworks.ts index 3e5979f4d..4aa9cd8dc 100644 --- a/packages/llamaindex/src/llm/fireworks.ts +++ b/packages/llamaindex/src/llm/fireworks.ts @@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env"; import { OpenAI } from "@llamaindex/openai"; export class FireworksLLM extends OpenAI { - constructor(init?: Partial<OpenAI>) { + constructor(init?: Omit<Partial<OpenAI>, "session">) { const { apiKey = getEnv("FIREWORKS_API_KEY"), additionalSessionOptions = {}, diff --git a/packages/llamaindex/src/llm/together.ts b/packages/llamaindex/src/llm/together.ts index 4d314bcc0..fffaea08d 100644 --- a/packages/llamaindex/src/llm/together.ts +++ b/packages/llamaindex/src/llm/together.ts @@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env"; import { OpenAI } from "@llamaindex/openai"; export class TogetherLLM extends OpenAI { - constructor(init?: Partial<OpenAI>) { + constructor(init?: Omit<Partial<OpenAI>, "session">) { const { apiKey = getEnv("TOGETHER_API_KEY"), additionalSessionOptions = {}, diff --git a/packages/llamaindex/tests/init.test.ts b/packages/llamaindex/tests/init.test.ts new file mode 100644 index 000000000..39ea25bb1 --- /dev/null +++ b/packages/llamaindex/tests/init.test.ts @@ -0,0 +1,7 @@ +import { expect, test, vi } from "vitest"; + +test("init without error", async () => { + vi.stubEnv("OPENAI_API_KEY", undefined); + const { Settings } = await import("llamaindex"); + expect(Settings.llm).toBeDefined(); +}); diff --git a/packages/llm/groq/src/llm.ts b/packages/llm/groq/src/llm.ts index 5c058f923..bcf725245 100644 --- a/packages/llm/groq/src/llm.ts +++ b/packages/llm/groq/src/llm.ts @@ -4,7 +4,7 @@ import GroqSDK, { type ClientOptions } from "groq-sdk"; export class Groq extends OpenAI { constructor( - init?: Partial<OpenAI> & { + init?: Omit<Partial<OpenAI>, "session"> & { additionalSessionOptions?: ClientOptions; }, ) { @@ -22,9 +22,10 @@ export class Groq extends OpenAI { ...rest, }); - this.session.openai = new GroqSDK({ - apiKey, - ...init?.additionalSessionOptions, - }) as any; + this.lazySession = async () => + new GroqSDK({ + apiKey, + ...init?.additionalSessionOptions, + }) as any; } } diff --git a/packages/llm/openai/src/azure.ts b/packages/llm/openai/src/azure.ts index 7c64513c7..533594c10 100644 --- a/packages/llm/openai/src/azure.ts +++ b/packages/llm/openai/src/azure.ts @@ -2,11 +2,6 @@ import { getEnv } from "@llamaindex/env"; import type { AzureClientOptions } from "openai"; -export interface AzureOpenAIConfig extends AzureClientOptions { - /** @deprecated use "deployment" instead */ - deploymentName?: string | undefined; -} - // NOTE we're not supporting the legacy models as they're not available for new deployments // https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/legacy-models // If you have a need for them, please open an issue on GitHub @@ -85,14 +80,15 @@ const DEFAULT_API_VERSION = "2023-05-15"; //^ NOTE: this will change over time, if you want to pin it, use a specific version export function getAzureConfigFromEnv( - init?: Partial<AzureOpenAIConfig> & { model?: string }, -): AzureOpenAIConfig { + init?: Partial<AzureClientOptions> & { model?: string }, +): AzureClientOptions { const deployment = - init?.deploymentName ?? - init?.deployment ?? - getEnv("AZURE_OPENAI_DEPLOYMENT") ?? // From Azure docs - getEnv("AZURE_OPENAI_API_DEPLOYMENT_NAME") ?? // LCJS compatible - init?.model; // Fall back to model name, Python compatible + init && "deploymentName" in init && typeof init.deploymentName === "string" + ? init?.deploymentName + : (init?.deployment ?? + getEnv("AZURE_OPENAI_DEPLOYMENT") ?? // From Azure docs + getEnv("AZURE_OPENAI_API_DEPLOYMENT_NAME") ?? // LCJS compatible + init?.model); // Fall back to model name, Python compatible return { apiKey: init?.apiKey ?? @@ -110,15 +106,10 @@ export function getAzureConfigFromEnv( getEnv("OPENAI_API_VERSION") ?? // Python compatible getEnv("AZURE_OPENAI_API_VERSION") ?? // LCJS compatible DEFAULT_API_VERSION, - deploymentName: deployment, // LCJS compatible deployment, // For Azure OpenAI }; } -export function getAzureBaseUrl(config: AzureOpenAIConfig): string { - return `${config.endpoint}/openai/deployments/${config.deploymentName}`; -} - export function getAzureModel(openAIModel: string) { for (const [key, value] of Object.entries( ALL_AZURE_OPENAI_EMBEDDING_MODELS, diff --git a/packages/llm/openai/src/embedding.ts b/packages/llm/openai/src/embedding.ts index 91946eff8..3e390c6f8 100644 --- a/packages/llm/openai/src/embedding.ts +++ b/packages/llm/openai/src/embedding.ts @@ -1,14 +1,16 @@ import { BaseEmbedding } from "@llamaindex/core/embeddings"; -import { Tokenizers } from "@llamaindex/env"; -import type { ClientOptions as OpenAIClientOptions } from "openai"; -import type { AzureOpenAIConfig } from "./azure.js"; +import { getEnv, Tokenizers } from "@llamaindex/env"; +import type { + AzureClientOptions, + AzureOpenAI as AzureOpenAILLM, + ClientOptions as OpenAIClientOptions, + OpenAI as OpenAILLM, +} from "openai"; import { getAzureConfigFromEnv, getAzureModel, shouldUseAzure, } from "./azure.js"; -import type { OpenAISession } from "./llm.js"; -import { getOpenAISession } from "./llm.js"; export const ALL_OPENAI_EMBEDDING_MODELS = { "text-embedding-ada-002": { @@ -32,6 +34,8 @@ export const ALL_OPENAI_EMBEDDING_MODELS = { type ModelKeys = keyof typeof ALL_OPENAI_EMBEDDING_MODELS; +type LLMInstance = Pick<AzureOpenAILLM | OpenAILLM, "embeddings" | "apiKey">; + export class OpenAIEmbedding extends BaseEmbedding { /** embeddding model. defaults to "text-embedding-ada-002" */ model: string; @@ -51,14 +55,26 @@ export class OpenAIEmbedding extends BaseEmbedding { | Omit<Partial<OpenAIClientOptions>, "apiKey" | "maxRetries" | "timeout"> | undefined; - /** session object */ - session: OpenAISession; + // use lazy here to avoid check OPENAI_API_KEY immediately + lazySession: () => Promise<LLMInstance>; + #session: Promise<LLMInstance> | null = null; + get session() { + if (!this.#session) { + this.#session = this.lazySession(); + } + return this.#session; + } /** * OpenAI Embedding * @param init - initial parameters */ - constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) { + constructor( + init?: Omit<Partial<OpenAIEmbedding>, "lazySession"> & { + session?: LLMInstance | undefined; + azure?: AzureClientOptions; + }, + ) { super(); this.model = init?.model ?? "text-embedding-ada-002"; @@ -77,7 +93,6 @@ export class OpenAIEmbedding extends BaseEmbedding { if (key) { this.embedInfo = ALL_OPENAI_EMBEDDING_MODELS[key]; } - if (init?.azure || shouldUseAzure()) { const azureConfig = { ...getAzureConfigFromEnv({ @@ -85,26 +100,32 @@ export class OpenAIEmbedding extends BaseEmbedding { }), ...init?.azure, }; - - this.apiKey = azureConfig.apiKey; - this.session = - init?.session ?? - getOpenAISession({ - azure: true, - maxRetries: this.maxRetries, - timeout: this.timeout, - ...this.additionalSessionOptions, - ...azureConfig, - }); + this.apiKey = + init?.session?.apiKey ?? azureConfig.apiKey ?? getEnv("OPENAI_API_KEY"); + this.lazySession = async () => + import("openai").then( + async ({ AzureOpenAI }) => + init?.session ?? + new AzureOpenAI({ + maxRetries: this.maxRetries, + timeout: this.timeout!, + ...this.additionalSessionOptions, + ...azureConfig, + }), + ); } else { - this.apiKey = init?.apiKey ?? undefined; - this.session = - init?.session ?? - getOpenAISession({ - apiKey: this.apiKey, - maxRetries: this.maxRetries, - timeout: this.timeout, - ...this.additionalSessionOptions, + this.apiKey = init?.session?.apiKey ?? getEnv("OPENAI_API_KEY"); + this.lazySession = async () => + import("openai").then(({ OpenAI }) => { + return ( + init?.session ?? + new OpenAI({ + apiKey: this.apiKey, + maxRetries: this.maxRetries, + timeout: this.timeout!, + ...this.additionalSessionOptions, + }) + ); }); } } @@ -118,7 +139,9 @@ export class OpenAIEmbedding extends BaseEmbedding { // TODO: ensure this for every sub class by calling it in the base class input = this.truncateMaxTokens(input); - const { data } = await this.session.openai.embeddings.create( + const { data } = await ( + await this.session + ).embeddings.create( this.dimensions ? { model: this.model, diff --git a/packages/llm/openai/src/index.ts b/packages/llm/openai/src/index.ts index b74136de4..068984a8e 100644 --- a/packages/llm/openai/src/index.ts +++ b/packages/llm/openai/src/index.ts @@ -10,9 +10,6 @@ export { GPT4_MODELS, O1_MODELS, OpenAI, - OpenAISession, type OpenAIAdditionalChatOptions, type OpenAIAdditionalMetadata, } from "./llm"; - -export { type AzureOpenAIConfig } from "./azure"; diff --git a/packages/llm/openai/src/llm.ts b/packages/llm/openai/src/llm.ts index a10f658db..5a0457e61 100644 --- a/packages/llm/openai/src/llm.ts +++ b/packages/llm/openai/src/llm.ts @@ -1,12 +1,11 @@ import { getEnv } from "@llamaindex/env"; -import type OpenAILLM from "openai"; import type { - ClientOptions, + AzureClientOptions, + AzureOpenAI as AzureOpenAILLM, ClientOptions as OpenAIClientOptions, + OpenAI as OpenAILLM, } from "openai"; -import { AzureOpenAI, OpenAI as OrigOpenAI } from "openai"; import type { ChatModel } from "openai/resources/chat/chat"; -import { isDeepEqual } from "remeda"; import { wrapEventCaller, wrapLLMEvent } from "@llamaindex/core/decorator"; import { @@ -35,64 +34,12 @@ import type { ChatCompletionUserMessageParam, } from "openai/resources/chat/completions"; import type { ChatCompletionMessageParam } from "openai/resources/index.js"; -import type { AzureOpenAIConfig } from "./azure.js"; import { getAzureConfigFromEnv, getAzureModel, shouldUseAzure, } from "./azure.js"; -export class OpenAISession { - openai: Pick<OrigOpenAI, "chat" | "embeddings">; - - constructor(options: ClientOptions & { azure?: boolean } = {}) { - if (options.azure) { - this.openai = new AzureOpenAI(options as AzureOpenAIConfig); - } else { - if (!options.apiKey) { - options.apiKey = getEnv("OPENAI_API_KEY"); - } - - if (!options.apiKey) { - throw new Error("Set OpenAI Key in OPENAI_API_KEY env variable"); // Overriding OpenAI package's error message - } - - this.openai = new OrigOpenAI({ - ...options, - }); - } - } -} - -// I'm not 100% sure this is necessary vs. just starting a new session -// every time we make a call. They say they try to reuse connections -// so in theory this is more efficient, but we should test it in the future. -const defaultOpenAISession: { - session: OpenAISession; - options: ClientOptions; -}[] = []; - -/** - * Get a session for the OpenAI API. If one already exists with the same options, - * it will be returned. Otherwise, a new session will be created. - * @param options - * @returns - */ -export function getOpenAISession( - options: ClientOptions & { azure?: boolean } = {}, -) { - let session = defaultOpenAISession.find((session) => { - return isDeepEqual(session.options, options); - })?.session; - - if (!session) { - session = new OpenAISession(options); - defaultOpenAISession.push({ session, options }); - } - - return session; -} - export const GPT4_MODELS = { "chatgpt-4o-latest": { contextWindow: 128000, @@ -182,6 +129,8 @@ export type OpenAIAdditionalChatOptions = Omit< | "toolChoice" >; +type LLMInstance = Pick<AzureOpenAILLM | OpenAILLM, "chat" | "apiKey">; + export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { model: | ChatModel @@ -196,14 +145,24 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { apiKey?: string | undefined = undefined; maxRetries: number; timeout?: number; - session: OpenAISession; additionalSessionOptions?: | undefined | Omit<Partial<OpenAIClientOptions>, "apiKey" | "maxRetries" | "timeout">; + // use lazy here to avoid check OPENAI_API_KEY immediately + lazySession: () => Promise<LLMInstance>; + #session: Promise<LLMInstance> | null = null; + get session() { + if (!this.#session) { + this.#session = this.lazySession(); + } + return this.#session; + } + constructor( - init?: Partial<OpenAI> & { - azure?: AzureOpenAIConfig; + init?: Omit<Partial<OpenAI>, "session"> & { + session?: LLMInstance | undefined; + azure?: AzureClientOptions; }, ) { super(); @@ -216,6 +175,8 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds this.additionalChatOptions = init?.additionalChatOptions; this.additionalSessionOptions = init?.additionalSessionOptions; + this.apiKey = + init?.session?.apiKey ?? init?.apiKey ?? getEnv("OPENAI_API_KEY"); if (init?.azure || shouldUseAzure()) { const azureConfig = { @@ -225,25 +186,26 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { ...init?.azure, }; - this.apiKey = azureConfig.apiKey; - this.session = + this.lazySession = async () => init?.session ?? - getOpenAISession({ - azure: true, - maxRetries: this.maxRetries, - timeout: this.timeout, - ...this.additionalSessionOptions, - ...azureConfig, + import("openai").then(({ AzureOpenAI }) => { + return new AzureOpenAI({ + maxRetries: this.maxRetries, + timeout: this.timeout!, + ...this.additionalSessionOptions, + ...azureConfig, + }); }); } else { - this.apiKey = init?.apiKey ?? undefined; - this.session = + this.lazySession = async () => init?.session ?? - getOpenAISession({ - apiKey: this.apiKey, - maxRetries: this.maxRetries, - timeout: this.timeout, - ...this.additionalSessionOptions, + import("openai").then(({ OpenAI }) => { + return new OpenAI({ + apiKey: this.apiKey, + maxRetries: this.maxRetries, + timeout: this.timeout!, + ...this.additionalSessionOptions, + }); }); } } @@ -382,7 +344,9 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { } // Non-streaming - const response = await this.session.openai.chat.completions.create({ + const response = await ( + await this.session + ).chat.completions.create({ ...baseRequestParams, stream: false, }); @@ -414,11 +378,12 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { protected async *streamChat( baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams, ): AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> { - const stream: AsyncIterable<OpenAILLM.Chat.ChatCompletionChunk> = - await this.session.openai.chat.completions.create({ - ...baseRequestParams, - stream: true, - }); + const stream: AsyncIterable<OpenAILLM.Chat.ChatCompletionChunk> = await ( + await this.session + ).chat.completions.create({ + ...baseRequestParams, + stream: true, + }); // TODO: add callback to streamConverter and use streamConverter here // this will be used to keep track of the current tool call, make sure input are valid json object. -- GitLab