Skip to content
Snippets Groups Projects
Unverified Commit 2a824132 authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

fix: lazy load openai (#1294)

parent 0b20ff9f
No related branches found
No related tags found
No related merge requests found
Showing
with 148 additions and 150 deletions
---
"llamaindex": patch
"@llamaindex/groq": patch
"@llamaindex/openai": patch
---
fix(core): set `Settings.llm` to OpenAI by default and support lazy load openai
...@@ -13,7 +13,7 @@ import { FunctionTool, OpenAI, ToolCallOptions } from "llamaindex"; ...@@ -13,7 +13,7 @@ import { FunctionTool, OpenAI, ToolCallOptions } from "llamaindex";
} }
})(); })();
async function callLLM(init: Partial<OpenAI>) { async function callLLM(init: { model: string }) {
const csvData = const csvData =
"Country,Average Height (cm)\nNetherlands,156\nDenmark,158\nNorway,160"; "Country,Average Height (cm)\nNetherlands,156\nDenmark,158\nNorway,160";
......
...@@ -2,7 +2,6 @@ import { ...@@ -2,7 +2,6 @@ import {
type CallbackManager, type CallbackManager,
Settings as CoreSettings, Settings as CoreSettings,
} from "@llamaindex/core/global"; } from "@llamaindex/core/global";
import { OpenAI } from "@llamaindex/openai";
import { PromptHelper } from "@llamaindex/core/indices"; import { PromptHelper } from "@llamaindex/core/indices";
...@@ -61,12 +60,6 @@ class GlobalSettings implements Config { ...@@ -61,12 +60,6 @@ class GlobalSettings implements Config {
} }
get llm(): LLM { get llm(): LLM {
// fixme: we might need check internal error instead of try-catch here
try {
CoreSettings.llm;
} catch (error) {
CoreSettings.llm = new OpenAI();
}
return CoreSettings.llm; return CoreSettings.llm;
} }
......
...@@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env"; ...@@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env";
import { OpenAIEmbedding } from "@llamaindex/openai"; import { OpenAIEmbedding } from "@llamaindex/openai";
export class FireworksEmbedding extends OpenAIEmbedding { export class FireworksEmbedding extends OpenAIEmbedding {
constructor(init?: Partial<OpenAIEmbedding>) { constructor(init?: Omit<Partial<OpenAIEmbedding>, "session">) {
const { const {
apiKey = getEnv("FIREWORKS_API_KEY"), apiKey = getEnv("FIREWORKS_API_KEY"),
additionalSessionOptions = {}, additionalSessionOptions = {},
......
...@@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env"; ...@@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env";
import { OpenAIEmbedding } from "@llamaindex/openai"; import { OpenAIEmbedding } from "@llamaindex/openai";
export class TogetherEmbedding extends OpenAIEmbedding { export class TogetherEmbedding extends OpenAIEmbedding {
constructor(init?: Partial<OpenAIEmbedding>) { constructor(init?: Omit<Partial<OpenAIEmbedding>, "session">) {
const { const {
apiKey = getEnv("TOGETHER_API_KEY"), apiKey = getEnv("TOGETHER_API_KEY"),
additionalSessionOptions = {}, additionalSessionOptions = {},
......
//#region initial setup for OpenAI
import { OpenAI } from "@llamaindex/openai";
import { Settings } from "./Settings.js";
try {
Settings.llm;
} catch {
Settings.llm = new OpenAI();
}
//#endregion
export { export {
LlamaParseReader, LlamaParseReader,
type Language, type Language,
...@@ -28,12 +40,12 @@ export type { ...@@ -28,12 +40,12 @@ export type {
JSONArray, JSONArray,
JSONObject, JSONObject,
JSONValue, JSONValue,
LlamaIndexEventMaps,
LLMEndEvent, LLMEndEvent,
LLMStartEvent, LLMStartEvent,
LLMStreamEvent, LLMStreamEvent,
LLMToolCallEvent, LLMToolCallEvent,
LLMToolResultEvent, LLMToolResultEvent,
LlamaIndexEventMaps,
} from "@llamaindex/core/global"; } from "@llamaindex/core/global";
export * from "@llamaindex/core/indices"; export * from "@llamaindex/core/indices";
export * from "@llamaindex/core/llms"; export * from "@llamaindex/core/llms";
...@@ -61,7 +73,7 @@ export * from "./postprocessors/index.js"; ...@@ -61,7 +73,7 @@ export * from "./postprocessors/index.js";
export * from "./QuestionGenerator.js"; export * from "./QuestionGenerator.js";
export * from "./selectors/index.js"; export * from "./selectors/index.js";
export * from "./ServiceContext.js"; export * from "./ServiceContext.js";
export { Settings } from "./Settings.js";
export * from "./storage/StorageContext.js"; export * from "./storage/StorageContext.js";
export * from "./tools/index.js"; export * from "./tools/index.js";
export * from "./types.js"; export * from "./types.js";
export { Settings };
...@@ -6,7 +6,7 @@ const DEFAULT_MODEL = "mistralai/Mixtral-8x22B-Instruct-v0.1"; ...@@ -6,7 +6,7 @@ const DEFAULT_MODEL = "mistralai/Mixtral-8x22B-Instruct-v0.1";
const BASE_URL = "https://api.deepinfra.com/v1/openai"; const BASE_URL = "https://api.deepinfra.com/v1/openai";
export class DeepInfra extends OpenAI { export class DeepInfra extends OpenAI {
constructor(init?: Partial<OpenAI>) { constructor(init?: Omit<Partial<OpenAI>, "session">) {
const { const {
apiKey = getEnv(ENV_VARIABLE_NAME), apiKey = getEnv(ENV_VARIABLE_NAME),
additionalSessionOptions = {}, additionalSessionOptions = {},
......
...@@ -10,7 +10,9 @@ type DeepSeekModelName = keyof typeof DEEPSEEK_MODELS; ...@@ -10,7 +10,9 @@ type DeepSeekModelName = keyof typeof DEEPSEEK_MODELS;
const DEFAULT_MODEL: DeepSeekModelName = "deepseek-coder"; const DEFAULT_MODEL: DeepSeekModelName = "deepseek-coder";
export class DeepSeekLLM extends OpenAI { export class DeepSeekLLM extends OpenAI {
constructor(init?: Partial<OpenAI> & { model?: DeepSeekModelName }) { constructor(
init?: Omit<Partial<OpenAI>, "session"> & { model?: DeepSeekModelName },
) {
const { const {
apiKey = getEnv("DEEPSEEK_API_KEY"), apiKey = getEnv("DEEPSEEK_API_KEY"),
additionalSessionOptions = {}, additionalSessionOptions = {},
......
...@@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env"; ...@@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env";
import { OpenAI } from "@llamaindex/openai"; import { OpenAI } from "@llamaindex/openai";
export class FireworksLLM extends OpenAI { export class FireworksLLM extends OpenAI {
constructor(init?: Partial<OpenAI>) { constructor(init?: Omit<Partial<OpenAI>, "session">) {
const { const {
apiKey = getEnv("FIREWORKS_API_KEY"), apiKey = getEnv("FIREWORKS_API_KEY"),
additionalSessionOptions = {}, additionalSessionOptions = {},
......
...@@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env"; ...@@ -2,7 +2,7 @@ import { getEnv } from "@llamaindex/env";
import { OpenAI } from "@llamaindex/openai"; import { OpenAI } from "@llamaindex/openai";
export class TogetherLLM extends OpenAI { export class TogetherLLM extends OpenAI {
constructor(init?: Partial<OpenAI>) { constructor(init?: Omit<Partial<OpenAI>, "session">) {
const { const {
apiKey = getEnv("TOGETHER_API_KEY"), apiKey = getEnv("TOGETHER_API_KEY"),
additionalSessionOptions = {}, additionalSessionOptions = {},
......
import { expect, test, vi } from "vitest";
test("init without error", async () => {
vi.stubEnv("OPENAI_API_KEY", undefined);
const { Settings } = await import("llamaindex");
expect(Settings.llm).toBeDefined();
});
...@@ -4,7 +4,7 @@ import GroqSDK, { type ClientOptions } from "groq-sdk"; ...@@ -4,7 +4,7 @@ import GroqSDK, { type ClientOptions } from "groq-sdk";
export class Groq extends OpenAI { export class Groq extends OpenAI {
constructor( constructor(
init?: Partial<OpenAI> & { init?: Omit<Partial<OpenAI>, "session"> & {
additionalSessionOptions?: ClientOptions; additionalSessionOptions?: ClientOptions;
}, },
) { ) {
...@@ -22,9 +22,10 @@ export class Groq extends OpenAI { ...@@ -22,9 +22,10 @@ export class Groq extends OpenAI {
...rest, ...rest,
}); });
this.session.openai = new GroqSDK({ this.lazySession = async () =>
apiKey, new GroqSDK({
...init?.additionalSessionOptions, apiKey,
}) as any; ...init?.additionalSessionOptions,
}) as any;
} }
} }
...@@ -2,11 +2,6 @@ import { getEnv } from "@llamaindex/env"; ...@@ -2,11 +2,6 @@ import { getEnv } from "@llamaindex/env";
import type { AzureClientOptions } from "openai"; import type { AzureClientOptions } from "openai";
export interface AzureOpenAIConfig extends AzureClientOptions {
/** @deprecated use "deployment" instead */
deploymentName?: string | undefined;
}
// NOTE we're not supporting the legacy models as they're not available for new deployments // NOTE we're not supporting the legacy models as they're not available for new deployments
// https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/legacy-models // https://learn.microsoft.com/en-us/azure/ai-services/openai/concepts/legacy-models
// If you have a need for them, please open an issue on GitHub // If you have a need for them, please open an issue on GitHub
...@@ -85,14 +80,15 @@ const DEFAULT_API_VERSION = "2023-05-15"; ...@@ -85,14 +80,15 @@ const DEFAULT_API_VERSION = "2023-05-15";
//^ NOTE: this will change over time, if you want to pin it, use a specific version //^ NOTE: this will change over time, if you want to pin it, use a specific version
export function getAzureConfigFromEnv( export function getAzureConfigFromEnv(
init?: Partial<AzureOpenAIConfig> & { model?: string }, init?: Partial<AzureClientOptions> & { model?: string },
): AzureOpenAIConfig { ): AzureClientOptions {
const deployment = const deployment =
init?.deploymentName ?? init && "deploymentName" in init && typeof init.deploymentName === "string"
init?.deployment ?? ? init?.deploymentName
getEnv("AZURE_OPENAI_DEPLOYMENT") ?? // From Azure docs : (init?.deployment ??
getEnv("AZURE_OPENAI_API_DEPLOYMENT_NAME") ?? // LCJS compatible getEnv("AZURE_OPENAI_DEPLOYMENT") ?? // From Azure docs
init?.model; // Fall back to model name, Python compatible getEnv("AZURE_OPENAI_API_DEPLOYMENT_NAME") ?? // LCJS compatible
init?.model); // Fall back to model name, Python compatible
return { return {
apiKey: apiKey:
init?.apiKey ?? init?.apiKey ??
...@@ -110,15 +106,10 @@ export function getAzureConfigFromEnv( ...@@ -110,15 +106,10 @@ export function getAzureConfigFromEnv(
getEnv("OPENAI_API_VERSION") ?? // Python compatible getEnv("OPENAI_API_VERSION") ?? // Python compatible
getEnv("AZURE_OPENAI_API_VERSION") ?? // LCJS compatible getEnv("AZURE_OPENAI_API_VERSION") ?? // LCJS compatible
DEFAULT_API_VERSION, DEFAULT_API_VERSION,
deploymentName: deployment, // LCJS compatible
deployment, // For Azure OpenAI deployment, // For Azure OpenAI
}; };
} }
export function getAzureBaseUrl(config: AzureOpenAIConfig): string {
return `${config.endpoint}/openai/deployments/${config.deploymentName}`;
}
export function getAzureModel(openAIModel: string) { export function getAzureModel(openAIModel: string) {
for (const [key, value] of Object.entries( for (const [key, value] of Object.entries(
ALL_AZURE_OPENAI_EMBEDDING_MODELS, ALL_AZURE_OPENAI_EMBEDDING_MODELS,
......
import { BaseEmbedding } from "@llamaindex/core/embeddings"; import { BaseEmbedding } from "@llamaindex/core/embeddings";
import { Tokenizers } from "@llamaindex/env"; import { getEnv, Tokenizers } from "@llamaindex/env";
import type { ClientOptions as OpenAIClientOptions } from "openai"; import type {
import type { AzureOpenAIConfig } from "./azure.js"; AzureClientOptions,
AzureOpenAI as AzureOpenAILLM,
ClientOptions as OpenAIClientOptions,
OpenAI as OpenAILLM,
} from "openai";
import { import {
getAzureConfigFromEnv, getAzureConfigFromEnv,
getAzureModel, getAzureModel,
shouldUseAzure, shouldUseAzure,
} from "./azure.js"; } from "./azure.js";
import type { OpenAISession } from "./llm.js";
import { getOpenAISession } from "./llm.js";
export const ALL_OPENAI_EMBEDDING_MODELS = { export const ALL_OPENAI_EMBEDDING_MODELS = {
"text-embedding-ada-002": { "text-embedding-ada-002": {
...@@ -32,6 +34,8 @@ export const ALL_OPENAI_EMBEDDING_MODELS = { ...@@ -32,6 +34,8 @@ export const ALL_OPENAI_EMBEDDING_MODELS = {
type ModelKeys = keyof typeof ALL_OPENAI_EMBEDDING_MODELS; type ModelKeys = keyof typeof ALL_OPENAI_EMBEDDING_MODELS;
type LLMInstance = Pick<AzureOpenAILLM | OpenAILLM, "embeddings" | "apiKey">;
export class OpenAIEmbedding extends BaseEmbedding { export class OpenAIEmbedding extends BaseEmbedding {
/** embeddding model. defaults to "text-embedding-ada-002" */ /** embeddding model. defaults to "text-embedding-ada-002" */
model: string; model: string;
...@@ -51,14 +55,26 @@ export class OpenAIEmbedding extends BaseEmbedding { ...@@ -51,14 +55,26 @@ export class OpenAIEmbedding extends BaseEmbedding {
| Omit<Partial<OpenAIClientOptions>, "apiKey" | "maxRetries" | "timeout"> | Omit<Partial<OpenAIClientOptions>, "apiKey" | "maxRetries" | "timeout">
| undefined; | undefined;
/** session object */ // use lazy here to avoid check OPENAI_API_KEY immediately
session: OpenAISession; lazySession: () => Promise<LLMInstance>;
#session: Promise<LLMInstance> | null = null;
get session() {
if (!this.#session) {
this.#session = this.lazySession();
}
return this.#session;
}
/** /**
* OpenAI Embedding * OpenAI Embedding
* @param init - initial parameters * @param init - initial parameters
*/ */
constructor(init?: Partial<OpenAIEmbedding> & { azure?: AzureOpenAIConfig }) { constructor(
init?: Omit<Partial<OpenAIEmbedding>, "lazySession"> & {
session?: LLMInstance | undefined;
azure?: AzureClientOptions;
},
) {
super(); super();
this.model = init?.model ?? "text-embedding-ada-002"; this.model = init?.model ?? "text-embedding-ada-002";
...@@ -77,7 +93,6 @@ export class OpenAIEmbedding extends BaseEmbedding { ...@@ -77,7 +93,6 @@ export class OpenAIEmbedding extends BaseEmbedding {
if (key) { if (key) {
this.embedInfo = ALL_OPENAI_EMBEDDING_MODELS[key]; this.embedInfo = ALL_OPENAI_EMBEDDING_MODELS[key];
} }
if (init?.azure || shouldUseAzure()) { if (init?.azure || shouldUseAzure()) {
const azureConfig = { const azureConfig = {
...getAzureConfigFromEnv({ ...getAzureConfigFromEnv({
...@@ -85,26 +100,32 @@ export class OpenAIEmbedding extends BaseEmbedding { ...@@ -85,26 +100,32 @@ export class OpenAIEmbedding extends BaseEmbedding {
}), }),
...init?.azure, ...init?.azure,
}; };
this.apiKey =
this.apiKey = azureConfig.apiKey; init?.session?.apiKey ?? azureConfig.apiKey ?? getEnv("OPENAI_API_KEY");
this.session = this.lazySession = async () =>
init?.session ?? import("openai").then(
getOpenAISession({ async ({ AzureOpenAI }) =>
azure: true, init?.session ??
maxRetries: this.maxRetries, new AzureOpenAI({
timeout: this.timeout, maxRetries: this.maxRetries,
...this.additionalSessionOptions, timeout: this.timeout!,
...azureConfig, ...this.additionalSessionOptions,
}); ...azureConfig,
}),
);
} else { } else {
this.apiKey = init?.apiKey ?? undefined; this.apiKey = init?.session?.apiKey ?? getEnv("OPENAI_API_KEY");
this.session = this.lazySession = async () =>
init?.session ?? import("openai").then(({ OpenAI }) => {
getOpenAISession({ return (
apiKey: this.apiKey, init?.session ??
maxRetries: this.maxRetries, new OpenAI({
timeout: this.timeout, apiKey: this.apiKey,
...this.additionalSessionOptions, maxRetries: this.maxRetries,
timeout: this.timeout!,
...this.additionalSessionOptions,
})
);
}); });
} }
} }
...@@ -118,7 +139,9 @@ export class OpenAIEmbedding extends BaseEmbedding { ...@@ -118,7 +139,9 @@ export class OpenAIEmbedding extends BaseEmbedding {
// TODO: ensure this for every sub class by calling it in the base class // TODO: ensure this for every sub class by calling it in the base class
input = this.truncateMaxTokens(input); input = this.truncateMaxTokens(input);
const { data } = await this.session.openai.embeddings.create( const { data } = await (
await this.session
).embeddings.create(
this.dimensions this.dimensions
? { ? {
model: this.model, model: this.model,
......
...@@ -10,9 +10,6 @@ export { ...@@ -10,9 +10,6 @@ export {
GPT4_MODELS, GPT4_MODELS,
O1_MODELS, O1_MODELS,
OpenAI, OpenAI,
OpenAISession,
type OpenAIAdditionalChatOptions, type OpenAIAdditionalChatOptions,
type OpenAIAdditionalMetadata, type OpenAIAdditionalMetadata,
} from "./llm"; } from "./llm";
export { type AzureOpenAIConfig } from "./azure";
import { getEnv } from "@llamaindex/env"; import { getEnv } from "@llamaindex/env";
import type OpenAILLM from "openai";
import type { import type {
ClientOptions, AzureClientOptions,
AzureOpenAI as AzureOpenAILLM,
ClientOptions as OpenAIClientOptions, ClientOptions as OpenAIClientOptions,
OpenAI as OpenAILLM,
} from "openai"; } from "openai";
import { AzureOpenAI, OpenAI as OrigOpenAI } from "openai";
import type { ChatModel } from "openai/resources/chat/chat"; import type { ChatModel } from "openai/resources/chat/chat";
import { isDeepEqual } from "remeda";
import { wrapEventCaller, wrapLLMEvent } from "@llamaindex/core/decorator"; import { wrapEventCaller, wrapLLMEvent } from "@llamaindex/core/decorator";
import { import {
...@@ -35,64 +34,12 @@ import type { ...@@ -35,64 +34,12 @@ import type {
ChatCompletionUserMessageParam, ChatCompletionUserMessageParam,
} from "openai/resources/chat/completions"; } from "openai/resources/chat/completions";
import type { ChatCompletionMessageParam } from "openai/resources/index.js"; import type { ChatCompletionMessageParam } from "openai/resources/index.js";
import type { AzureOpenAIConfig } from "./azure.js";
import { import {
getAzureConfigFromEnv, getAzureConfigFromEnv,
getAzureModel, getAzureModel,
shouldUseAzure, shouldUseAzure,
} from "./azure.js"; } from "./azure.js";
export class OpenAISession {
openai: Pick<OrigOpenAI, "chat" | "embeddings">;
constructor(options: ClientOptions & { azure?: boolean } = {}) {
if (options.azure) {
this.openai = new AzureOpenAI(options as AzureOpenAIConfig);
} else {
if (!options.apiKey) {
options.apiKey = getEnv("OPENAI_API_KEY");
}
if (!options.apiKey) {
throw new Error("Set OpenAI Key in OPENAI_API_KEY env variable"); // Overriding OpenAI package's error message
}
this.openai = new OrigOpenAI({
...options,
});
}
}
}
// I'm not 100% sure this is necessary vs. just starting a new session
// every time we make a call. They say they try to reuse connections
// so in theory this is more efficient, but we should test it in the future.
const defaultOpenAISession: {
session: OpenAISession;
options: ClientOptions;
}[] = [];
/**
* Get a session for the OpenAI API. If one already exists with the same options,
* it will be returned. Otherwise, a new session will be created.
* @param options
* @returns
*/
export function getOpenAISession(
options: ClientOptions & { azure?: boolean } = {},
) {
let session = defaultOpenAISession.find((session) => {
return isDeepEqual(session.options, options);
})?.session;
if (!session) {
session = new OpenAISession(options);
defaultOpenAISession.push({ session, options });
}
return session;
}
export const GPT4_MODELS = { export const GPT4_MODELS = {
"chatgpt-4o-latest": { "chatgpt-4o-latest": {
contextWindow: 128000, contextWindow: 128000,
...@@ -182,6 +129,8 @@ export type OpenAIAdditionalChatOptions = Omit< ...@@ -182,6 +129,8 @@ export type OpenAIAdditionalChatOptions = Omit<
| "toolChoice" | "toolChoice"
>; >;
type LLMInstance = Pick<AzureOpenAILLM | OpenAILLM, "chat" | "apiKey">;
export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
model: model:
| ChatModel | ChatModel
...@@ -196,14 +145,24 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { ...@@ -196,14 +145,24 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
apiKey?: string | undefined = undefined; apiKey?: string | undefined = undefined;
maxRetries: number; maxRetries: number;
timeout?: number; timeout?: number;
session: OpenAISession;
additionalSessionOptions?: additionalSessionOptions?:
| undefined | undefined
| Omit<Partial<OpenAIClientOptions>, "apiKey" | "maxRetries" | "timeout">; | Omit<Partial<OpenAIClientOptions>, "apiKey" | "maxRetries" | "timeout">;
// use lazy here to avoid check OPENAI_API_KEY immediately
lazySession: () => Promise<LLMInstance>;
#session: Promise<LLMInstance> | null = null;
get session() {
if (!this.#session) {
this.#session = this.lazySession();
}
return this.#session;
}
constructor( constructor(
init?: Partial<OpenAI> & { init?: Omit<Partial<OpenAI>, "session"> & {
azure?: AzureOpenAIConfig; session?: LLMInstance | undefined;
azure?: AzureClientOptions;
}, },
) { ) {
super(); super();
...@@ -216,6 +175,8 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { ...@@ -216,6 +175,8 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
this.additionalChatOptions = init?.additionalChatOptions; this.additionalChatOptions = init?.additionalChatOptions;
this.additionalSessionOptions = init?.additionalSessionOptions; this.additionalSessionOptions = init?.additionalSessionOptions;
this.apiKey =
init?.session?.apiKey ?? init?.apiKey ?? getEnv("OPENAI_API_KEY");
if (init?.azure || shouldUseAzure()) { if (init?.azure || shouldUseAzure()) {
const azureConfig = { const azureConfig = {
...@@ -225,25 +186,26 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { ...@@ -225,25 +186,26 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
...init?.azure, ...init?.azure,
}; };
this.apiKey = azureConfig.apiKey; this.lazySession = async () =>
this.session =
init?.session ?? init?.session ??
getOpenAISession({ import("openai").then(({ AzureOpenAI }) => {
azure: true, return new AzureOpenAI({
maxRetries: this.maxRetries, maxRetries: this.maxRetries,
timeout: this.timeout, timeout: this.timeout!,
...this.additionalSessionOptions, ...this.additionalSessionOptions,
...azureConfig, ...azureConfig,
});
}); });
} else { } else {
this.apiKey = init?.apiKey ?? undefined; this.lazySession = async () =>
this.session =
init?.session ?? init?.session ??
getOpenAISession({ import("openai").then(({ OpenAI }) => {
apiKey: this.apiKey, return new OpenAI({
maxRetries: this.maxRetries, apiKey: this.apiKey,
timeout: this.timeout, maxRetries: this.maxRetries,
...this.additionalSessionOptions, timeout: this.timeout!,
...this.additionalSessionOptions,
});
}); });
} }
} }
...@@ -382,7 +344,9 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { ...@@ -382,7 +344,9 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
} }
// Non-streaming // Non-streaming
const response = await this.session.openai.chat.completions.create({ const response = await (
await this.session
).chat.completions.create({
...baseRequestParams, ...baseRequestParams,
stream: false, stream: false,
}); });
...@@ -414,11 +378,12 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { ...@@ -414,11 +378,12 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
protected async *streamChat( protected async *streamChat(
baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams, baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams,
): AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> { ): AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> {
const stream: AsyncIterable<OpenAILLM.Chat.ChatCompletionChunk> = const stream: AsyncIterable<OpenAILLM.Chat.ChatCompletionChunk> = await (
await this.session.openai.chat.completions.create({ await this.session
...baseRequestParams, ).chat.completions.create({
stream: true, ...baseRequestParams,
}); stream: true,
});
// TODO: add callback to streamConverter and use streamConverter here // TODO: add callback to streamConverter and use streamConverter here
// this will be used to keep track of the current tool call, make sure input are valid json object. // this will be used to keep track of the current tool call, make sure input are valid json object.
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment