From 4fcbdf710ee9b3dc107ea32c2107d7e736bb2da7 Mon Sep 17 00:00:00 2001 From: Thuc Pham <51660321+thucpn@users.noreply.github.com> Date: Fri, 5 Apr 2024 08:33:23 +0700 Subject: [PATCH] Add tool calls for openai streaming (#682) Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de> Co-authored-by: Alex Yang <himself65@outlook.com> --- .changeset/orange-lions-remember.md | 5 + examples/toolsStream.ts | 46 +++ packages/core/src/ChatHistory.ts | 2 +- packages/core/src/QuestionGenerator.ts | 2 +- packages/core/src/ServiceContext.ts | 2 +- packages/core/src/Settings.ts | 2 +- packages/core/src/llm/LLM.ts | 294 +--------------- packages/core/src/llm/fireworks.ts | 2 +- packages/core/src/llm/groq.ts | 2 +- packages/core/src/llm/open_ai.ts | 318 +++++++++++++++++- packages/core/src/llm/together.ts | 2 +- packages/core/src/llm/types.ts | 11 + packages/core/tests/CallbackManager.test.ts | 9 +- packages/core/tests/Embedding.test.ts | 9 +- .../core/tests/MetadataExtractors.test.ts | 9 +- packages/core/tests/Selectors.test.ts | 8 +- packages/core/tests/agent/OpenAIAgent.test.ts | 8 +- .../tests/agent/runner/AgentRunner.test.ts | 10 +- .../core/tests/indices/SummaryIndex.test.ts | 8 +- .../tests/indices/VectorStoreIndex.test.ts | 8 +- .../core/tests/objects/ObjectIndex.test.ts | 8 +- packages/core/tests/utility/mockOpenAI.ts | 2 +- packages/core/tests/vitest.config.ts | 8 + packages/core/tests/vitest.setup.ts | 22 ++ 24 files changed, 425 insertions(+), 372 deletions(-) create mode 100644 .changeset/orange-lions-remember.md create mode 100644 examples/toolsStream.ts create mode 100644 packages/core/tests/vitest.config.ts create mode 100644 packages/core/tests/vitest.setup.ts diff --git a/.changeset/orange-lions-remember.md b/.changeset/orange-lions-remember.md new file mode 100644 index 000000000..a298a96fb --- /dev/null +++ b/.changeset/orange-lions-remember.md @@ -0,0 +1,5 @@ +--- +"llamaindex": patch +--- + +Support streaming for OpenAI tool calls diff --git a/examples/toolsStream.ts b/examples/toolsStream.ts new file mode 100644 index 000000000..b59114dc9 --- /dev/null +++ b/examples/toolsStream.ts @@ -0,0 +1,46 @@ +import { ChatResponseChunk, LLMChatParamsBase, OpenAI } from "llamaindex"; + +async function main() { + const llm = new OpenAI({ model: "gpt-4-turbo-preview" }); + + const args: LLMChatParamsBase = { + messages: [ + { + content: "Who was Goethe?", + role: "user", + }, + ], + tools: [ + { + type: "function", + function: { + name: "wikipedia_tool", + description: "A tool that uses a query engine to search Wikipedia.", + parameters: { + type: "object", + properties: { + query: { + type: "string", + description: "The query to search for", + }, + }, + required: ["query"], + }, + }, + }, + ], + toolChoice: "auto", + }; + + const stream = await llm.chat({ ...args, stream: true }); + let chunk: ChatResponseChunk | null = null; + for await (chunk of stream) { + process.stdout.write(chunk.delta); + } + console.log(chunk?.additionalKwargs?.toolCalls[0]); +} + +(async function () { + await main(); + console.log("Done"); +})(); diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts index d3f261515..dae76cd5d 100644 --- a/packages/core/src/ChatHistory.ts +++ b/packages/core/src/ChatHistory.ts @@ -1,7 +1,7 @@ import { globalsHelper } from "./GlobalsHelper.js"; import type { SummaryPrompt } from "./Prompt.js"; import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js"; -import { OpenAI } from "./llm/LLM.js"; +import { OpenAI } from "./llm/open_ai.js"; import type { ChatMessage, LLM, MessageType } from "./llm/types.js"; /** diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts index 4b36174e3..0e7785a9e 100644 --- a/packages/core/src/QuestionGenerator.ts +++ b/packages/core/src/QuestionGenerator.ts @@ -5,7 +5,7 @@ import type { BaseQuestionGenerator, SubQuestion, } from "./engines/query/types.js"; -import { OpenAI } from "./llm/LLM.js"; +import { OpenAI } from "./llm/open_ai.js"; import type { LLM } from "./llm/types.js"; import { PromptMixin } from "./prompts/index.js"; import type { diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts index 3b4ffbbad..48a318f02 100644 --- a/packages/core/src/ServiceContext.ts +++ b/packages/core/src/ServiceContext.ts @@ -1,7 +1,7 @@ import { PromptHelper } from "./PromptHelper.js"; import { OpenAIEmbedding } from "./embeddings/OpenAIEmbedding.js"; import type { BaseEmbedding } from "./embeddings/types.js"; -import { OpenAI } from "./llm/LLM.js"; +import { OpenAI } from "./llm/open_ai.js"; import type { LLM } from "./llm/types.js"; import { SimpleNodeParser } from "./nodeParsers/SimpleNodeParser.js"; import type { NodeParser } from "./nodeParsers/types.js"; diff --git a/packages/core/src/Settings.ts b/packages/core/src/Settings.ts index 7d29a2aa3..e2b2f4c0e 100644 --- a/packages/core/src/Settings.ts +++ b/packages/core/src/Settings.ts @@ -1,6 +1,6 @@ import { CallbackManager } from "./callbacks/CallbackManager.js"; import { OpenAIEmbedding } from "./embeddings/OpenAIEmbedding.js"; -import { OpenAI } from "./llm/LLM.js"; +import { OpenAI } from "./llm/open_ai.js"; import { PromptHelper } from "./PromptHelper.js"; import { SimpleNodeParser } from "./nodeParsers/SimpleNodeParser.js"; diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts index eb17a75d7..8b68f7466 100644 --- a/packages/core/src/llm/LLM.ts +++ b/packages/core/src/llm/LLM.ts @@ -1,27 +1,10 @@ -import type OpenAILLM from "openai"; -import type { ClientOptions as OpenAIClientOptions } from "openai"; -import { - type OpenAIStreamToken, - type StreamCallbackResponse, -} from "../callbacks/CallbackManager.js"; - -import type { ChatCompletionMessageParam } from "openai/resources/index.js"; +import { type StreamCallbackResponse } from "../callbacks/CallbackManager.js"; + import type { LLMOptions } from "portkey-ai"; -import { Tokenizers } from "../GlobalsHelper.js"; -import { wrapEventCaller } from "../internal/context/EventCaller.js"; import { getCallbackManager } from "../internal/settings/CallbackManager.js"; import type { AnthropicSession } from "./anthropic.js"; import { getAnthropicSession } from "./anthropic.js"; -import type { AzureOpenAIConfig } from "./azure.js"; -import { - getAzureBaseUrl, - getAzureConfigFromEnv, - getAzureModel, - shouldUseAzure, -} from "./azure.js"; import { BaseLLM } from "./base.js"; -import type { OpenAISession } from "./open_ai.js"; -import { getOpenAISession } from "./open_ai.js"; import type { PortkeySession } from "./portkey.js"; import { getPortkeySession } from "./portkey.js"; import { ReplicateSession } from "./replicate_ai.js"; @@ -36,279 +19,6 @@ import type { } from "./types.js"; import { wrapLLMEvent } from "./utils.js"; -export const GPT4_MODELS = { - "gpt-4": { contextWindow: 8192 }, - "gpt-4-32k": { contextWindow: 32768 }, - "gpt-4-32k-0613": { contextWindow: 32768 }, - "gpt-4-turbo-preview": { contextWindow: 128000 }, - "gpt-4-1106-preview": { contextWindow: 128000 }, - "gpt-4-0125-preview": { contextWindow: 128000 }, - "gpt-4-vision-preview": { contextWindow: 128000 }, -}; - -// NOTE we don't currently support gpt-3.5-turbo-instruct and don't plan to in the near future -export const GPT35_MODELS = { - "gpt-3.5-turbo": { contextWindow: 4096 }, - "gpt-3.5-turbo-0613": { contextWindow: 4096 }, - "gpt-3.5-turbo-16k": { contextWindow: 16384 }, - "gpt-3.5-turbo-16k-0613": { contextWindow: 16384 }, - "gpt-3.5-turbo-1106": { contextWindow: 16384 }, - "gpt-3.5-turbo-0125": { contextWindow: 16384 }, -}; - -/** - * We currently support GPT-3.5 and GPT-4 models - */ -export const ALL_AVAILABLE_OPENAI_MODELS = { - ...GPT4_MODELS, - ...GPT35_MODELS, -}; - -export const isFunctionCallingModel = (model: string): boolean => { - const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model); - const isOld = model.includes("0314") || model.includes("0301"); - return isChatModel && !isOld; -}; - -/** - * OpenAI LLM implementation - */ -export class OpenAI extends BaseLLM { - // Per completion OpenAI params - model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS | string; - temperature: number; - topP: number; - maxTokens?: number; - additionalChatOptions?: Omit< - Partial<OpenAILLM.Chat.ChatCompletionCreateParams>, - | "max_tokens" - | "messages" - | "model" - | "temperature" - | "top_p" - | "stream" - | "tools" - | "toolChoice" - >; - - // OpenAI session params - apiKey?: string = undefined; - maxRetries: number; - timeout?: number; - session: OpenAISession; - additionalSessionOptions?: Omit< - Partial<OpenAIClientOptions>, - "apiKey" | "maxRetries" | "timeout" - >; - - constructor( - init?: Partial<OpenAI> & { - azure?: AzureOpenAIConfig; - }, - ) { - super(); - this.model = init?.model ?? "gpt-3.5-turbo"; - this.temperature = init?.temperature ?? 0.1; - this.topP = init?.topP ?? 1; - this.maxTokens = init?.maxTokens ?? undefined; - - this.maxRetries = init?.maxRetries ?? 10; - this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds - this.additionalChatOptions = init?.additionalChatOptions; - this.additionalSessionOptions = init?.additionalSessionOptions; - - if (init?.azure || shouldUseAzure()) { - const azureConfig = getAzureConfigFromEnv({ - ...init?.azure, - model: getAzureModel(this.model), - }); - - if (!azureConfig.apiKey) { - throw new Error( - "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.", - ); - } - - this.apiKey = azureConfig.apiKey; - this.session = - init?.session ?? - getOpenAISession({ - azure: true, - apiKey: this.apiKey, - baseURL: getAzureBaseUrl(azureConfig), - maxRetries: this.maxRetries, - timeout: this.timeout, - defaultQuery: { "api-version": azureConfig.apiVersion }, - ...this.additionalSessionOptions, - }); - } else { - this.apiKey = init?.apiKey ?? undefined; - this.session = - init?.session ?? - getOpenAISession({ - apiKey: this.apiKey, - maxRetries: this.maxRetries, - timeout: this.timeout, - ...this.additionalSessionOptions, - }); - } - } - - get metadata() { - const contextWindow = - ALL_AVAILABLE_OPENAI_MODELS[ - this.model as keyof typeof ALL_AVAILABLE_OPENAI_MODELS - ]?.contextWindow ?? 1024; - return { - model: this.model, - temperature: this.temperature, - topP: this.topP, - maxTokens: this.maxTokens, - contextWindow, - tokenizer: Tokenizers.CL100K_BASE, - isFunctionCallingModel: isFunctionCallingModel(this.model), - }; - } - - mapMessageType( - messageType: MessageType, - ): "user" | "assistant" | "system" | "function" | "tool" { - switch (messageType) { - case "user": - return "user"; - case "assistant": - return "assistant"; - case "system": - return "system"; - case "function": - return "function"; - case "tool": - return "tool"; - default: - return "user"; - } - } - - toOpenAIMessage(messages: ChatMessage[]) { - return messages.map((message) => { - const additionalKwargs = message.additionalKwargs ?? {}; - - if (message.additionalKwargs?.toolCalls) { - additionalKwargs.tool_calls = message.additionalKwargs.toolCalls; - delete additionalKwargs.toolCalls; - } - - return { - role: this.mapMessageType(message.role), - content: message.content, - ...additionalKwargs, - }; - }); - } - - chat( - params: LLMChatParamsStreaming, - ): Promise<AsyncIterable<ChatResponseChunk>>; - chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; - @wrapEventCaller - @wrapLLMEvent - async chat( - params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, - ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { - const { messages, stream, tools, toolChoice } = params; - const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = { - model: this.model, - temperature: this.temperature, - max_tokens: this.maxTokens, - tools: tools, - tool_choice: toolChoice, - messages: this.toOpenAIMessage(messages) as ChatCompletionMessageParam[], - top_p: this.topP, - ...this.additionalChatOptions, - }; - - // Streaming - if (stream) { - return this.streamChat(params); - } - - // Non-streaming - const response = await this.session.openai.chat.completions.create({ - ...baseRequestParams, - stream: false, - }); - - const content = response.choices[0].message?.content ?? null; - - const kwargsOutput: Record<string, any> = {}; - - if (response.choices[0].message?.tool_calls) { - kwargsOutput.toolCalls = response.choices[0].message.tool_calls; - } - - return { - message: { - content, - role: response.choices[0].message.role, - additionalKwargs: kwargsOutput, - }, - }; - } - - @wrapEventCaller - protected async *streamChat({ - messages, - }: LLMChatParamsStreaming): AsyncIterable<ChatResponseChunk> { - const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = { - model: this.model, - temperature: this.temperature, - max_tokens: this.maxTokens, - messages: messages.map( - (message) => - ({ - role: this.mapMessageType(message.role), - content: message.content, - }) as ChatCompletionMessageParam, - ), - top_p: this.topP, - ...this.additionalChatOptions, - }; - - const chunk_stream: AsyncIterable<OpenAIStreamToken> = - await this.session.openai.chat.completions.create({ - ...baseRequestParams, - stream: true, - }); - - // TODO: add callback to streamConverter and use streamConverter here - //Indices - let idx_counter: number = 0; - for await (const part of chunk_stream) { - if (!part.choices.length) continue; - - //Increment - part.choices[0].index = idx_counter; - const is_done: boolean = - part.choices[0].finish_reason === "stop" ? true : false; - //onLLMStream Callback - - const stream_callback: StreamCallbackResponse = { - index: idx_counter, - isDone: is_done, - token: part, - }; - getCallbackManager().dispatchEvent("stream", stream_callback); - - idx_counter++; - - yield { - delta: part.choices[0].delta.content ?? "", - }; - } - return; - } -} - export const ALL_AVAILABLE_LLAMADEUCE_MODELS = { "Llama-2-70b-chat-old": { contextWindow: 4096, diff --git a/packages/core/src/llm/fireworks.ts b/packages/core/src/llm/fireworks.ts index 8621dd01f..f7814559b 100644 --- a/packages/core/src/llm/fireworks.ts +++ b/packages/core/src/llm/fireworks.ts @@ -1,5 +1,5 @@ import { getEnv } from "@llamaindex/env"; -import { OpenAI } from "./LLM.js"; +import { OpenAI } from "./open_ai.js"; export class FireworksLLM extends OpenAI { constructor(init?: Partial<OpenAI>) { diff --git a/packages/core/src/llm/groq.ts b/packages/core/src/llm/groq.ts index b29431749..083e305cb 100644 --- a/packages/core/src/llm/groq.ts +++ b/packages/core/src/llm/groq.ts @@ -1,5 +1,5 @@ import { getEnv } from "@llamaindex/env"; -import { OpenAI } from "./LLM.js"; +import { OpenAI } from "./open_ai.js"; export class Groq extends OpenAI { constructor(init?: Partial<OpenAI>) { diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts index 336844aaa..b9987a73f 100644 --- a/packages/core/src/llm/open_ai.ts +++ b/packages/core/src/llm/open_ai.ts @@ -1,16 +1,43 @@ import { getEnv } from "@llamaindex/env"; import _ from "lodash"; -import type { ClientOptions } from "openai"; -import OpenAI from "openai"; +import type OpenAILLM from "openai"; +import type { + ClientOptions, + ClientOptions as OpenAIClientOptions, +} from "openai"; +import { OpenAI as OrigOpenAI } from "openai"; -export class AzureOpenAI extends OpenAI { +import type { ChatCompletionMessageParam } from "openai/resources/index.js"; +import { Tokenizers } from "../GlobalsHelper.js"; +import { wrapEventCaller } from "../internal/context/EventCaller.js"; +import { getCallbackManager } from "../internal/settings/CallbackManager.js"; +import type { AzureOpenAIConfig } from "./azure.js"; +import { + getAzureBaseUrl, + getAzureConfigFromEnv, + getAzureModel, + shouldUseAzure, +} from "./azure.js"; +import { BaseLLM } from "./base.js"; +import type { + ChatMessage, + ChatResponse, + ChatResponseChunk, + LLMChatParamsNonStreaming, + LLMChatParamsStreaming, + MessageToolCall, + MessageType, +} from "./types.js"; +import { wrapLLMEvent } from "./utils.js"; + +export class AzureOpenAI extends OrigOpenAI { protected override authHeaders() { return { "api-key": this.apiKey }; } } export class OpenAISession { - openai: OpenAI; + openai: OrigOpenAI; constructor(options: ClientOptions & { azure?: boolean } = {}) { if (!options.apiKey) { @@ -24,7 +51,7 @@ export class OpenAISession { if (options.azure) { this.openai = new AzureOpenAI(options); } else { - this.openai = new OpenAI({ + this.openai = new OrigOpenAI({ ...options, // defaultHeaders: { "OpenAI-Beta": "assistants=v1" }, }); @@ -60,3 +87,284 @@ export function getOpenAISession( return session; } + +export const GPT4_MODELS = { + "gpt-4": { contextWindow: 8192 }, + "gpt-4-32k": { contextWindow: 32768 }, + "gpt-4-32k-0613": { contextWindow: 32768 }, + "gpt-4-turbo-preview": { contextWindow: 128000 }, + "gpt-4-1106-preview": { contextWindow: 128000 }, + "gpt-4-0125-preview": { contextWindow: 128000 }, + "gpt-4-vision-preview": { contextWindow: 128000 }, +}; + +// NOTE we don't currently support gpt-3.5-turbo-instruct and don't plan to in the near future +export const GPT35_MODELS = { + "gpt-3.5-turbo": { contextWindow: 4096 }, + "gpt-3.5-turbo-0613": { contextWindow: 4096 }, + "gpt-3.5-turbo-16k": { contextWindow: 16384 }, + "gpt-3.5-turbo-16k-0613": { contextWindow: 16384 }, + "gpt-3.5-turbo-1106": { contextWindow: 16384 }, + "gpt-3.5-turbo-0125": { contextWindow: 16384 }, +}; + +/** + * We currently support GPT-3.5 and GPT-4 models + */ +export const ALL_AVAILABLE_OPENAI_MODELS = { + ...GPT4_MODELS, + ...GPT35_MODELS, +}; + +export const isFunctionCallingModel = (model: string): boolean => { + const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model); + const isOld = model.includes("0314") || model.includes("0301"); + return isChatModel && !isOld; +}; + +/** + * OpenAI LLM implementation + */ +export class OpenAI extends BaseLLM { + // Per completion OpenAI params + model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS | string; + temperature: number; + topP: number; + maxTokens?: number; + additionalChatOptions?: Omit< + Partial<OpenAILLM.Chat.ChatCompletionCreateParams>, + | "max_tokens" + | "messages" + | "model" + | "temperature" + | "top_p" + | "stream" + | "tools" + | "toolChoice" + >; + + // OpenAI session params + apiKey?: string = undefined; + maxRetries: number; + timeout?: number; + session: OpenAISession; + additionalSessionOptions?: Omit< + Partial<OpenAIClientOptions>, + "apiKey" | "maxRetries" | "timeout" + >; + + constructor( + init?: Partial<OpenAI> & { + azure?: AzureOpenAIConfig; + }, + ) { + super(); + this.model = init?.model ?? "gpt-3.5-turbo"; + this.temperature = init?.temperature ?? 0.1; + this.topP = init?.topP ?? 1; + this.maxTokens = init?.maxTokens ?? undefined; + + this.maxRetries = init?.maxRetries ?? 10; + this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds + this.additionalChatOptions = init?.additionalChatOptions; + this.additionalSessionOptions = init?.additionalSessionOptions; + + if (init?.azure || shouldUseAzure()) { + const azureConfig = getAzureConfigFromEnv({ + ...init?.azure, + model: getAzureModel(this.model), + }); + + if (!azureConfig.apiKey) { + throw new Error( + "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.", + ); + } + + this.apiKey = azureConfig.apiKey; + this.session = + init?.session ?? + getOpenAISession({ + azure: true, + apiKey: this.apiKey, + baseURL: getAzureBaseUrl(azureConfig), + maxRetries: this.maxRetries, + timeout: this.timeout, + defaultQuery: { "api-version": azureConfig.apiVersion }, + ...this.additionalSessionOptions, + }); + } else { + this.apiKey = init?.apiKey ?? undefined; + this.session = + init?.session ?? + getOpenAISession({ + apiKey: this.apiKey, + maxRetries: this.maxRetries, + timeout: this.timeout, + ...this.additionalSessionOptions, + }); + } + } + + get metadata() { + const contextWindow = + ALL_AVAILABLE_OPENAI_MODELS[ + this.model as keyof typeof ALL_AVAILABLE_OPENAI_MODELS + ]?.contextWindow ?? 1024; + return { + model: this.model, + temperature: this.temperature, + topP: this.topP, + maxTokens: this.maxTokens, + contextWindow, + tokenizer: Tokenizers.CL100K_BASE, + isFunctionCallingModel: isFunctionCallingModel(this.model), + }; + } + + mapMessageType( + messageType: MessageType, + ): "user" | "assistant" | "system" | "function" | "tool" { + switch (messageType) { + case "user": + return "user"; + case "assistant": + return "assistant"; + case "system": + return "system"; + case "function": + return "function"; + case "tool": + return "tool"; + default: + return "user"; + } + } + + toOpenAIMessage(messages: ChatMessage[]) { + return messages.map((message) => { + const additionalKwargs = message.additionalKwargs ?? {}; + + if (message.additionalKwargs?.toolCalls) { + additionalKwargs.tool_calls = message.additionalKwargs.toolCalls; + delete additionalKwargs.toolCalls; + } + + return { + role: this.mapMessageType(message.role), + content: message.content, + ...additionalKwargs, + }; + }); + } + + chat( + params: LLMChatParamsStreaming, + ): Promise<AsyncIterable<ChatResponseChunk>>; + chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + @wrapEventCaller + @wrapLLMEvent + async chat( + params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, + ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { + const { messages, stream, tools, toolChoice } = params; + const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = { + model: this.model, + temperature: this.temperature, + max_tokens: this.maxTokens, + tools: tools, + tool_choice: toolChoice, + messages: this.toOpenAIMessage(messages) as ChatCompletionMessageParam[], + top_p: this.topP, + ...this.additionalChatOptions, + }; + + // Streaming + if (stream) { + return this.streamChat(baseRequestParams); + } + + // Non-streaming + const response = await this.session.openai.chat.completions.create({ + ...baseRequestParams, + stream: false, + }); + + const content = response.choices[0].message?.content ?? null; + + const kwargsOutput: Record<string, any> = {}; + + if (response.choices[0].message?.tool_calls) { + kwargsOutput.toolCalls = response.choices[0].message.tool_calls; + } + + return { + message: { + content, + role: response.choices[0].message.role, + additionalKwargs: kwargsOutput, + }, + }; + } + + @wrapEventCaller + protected async *streamChat( + baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams, + ): AsyncIterable<ChatResponseChunk> { + const stream: AsyncIterable<OpenAILLM.Chat.ChatCompletionChunk> = + await this.session.openai.chat.completions.create({ + ...baseRequestParams, + stream: true, + }); + + // TODO: add callback to streamConverter and use streamConverter here + //Indices + let idxCounter: number = 0; + const toolCalls: MessageToolCall[] = []; + for await (const part of stream) { + if (!part.choices.length) continue; + const choice = part.choices[0]; + updateToolCalls(toolCalls, choice.delta.tool_calls); + + const isDone: boolean = choice.finish_reason !== null; + + getCallbackManager().dispatchEvent("stream", { + index: idxCounter++, + isDone: isDone, + token: part, + }); + + yield { + // add tool calls to final chunk + additionalKwargs: isDone ? { toolCalls: toolCalls } : undefined, + delta: choice.delta.content ?? "", + }; + } + return; + } +} + +function updateToolCalls( + toolCalls: MessageToolCall[], + toolCallDeltas?: OpenAILLM.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall[], +) { + function augmentToolCall( + toolCall?: MessageToolCall, + toolCallDelta?: OpenAILLM.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall, + ) { + toolCall = + toolCall ?? + ({ function: { name: "", arguments: "" } } as MessageToolCall); + if (toolCallDelta?.function?.arguments) { + toolCall.function.arguments += toolCallDelta.function.arguments; + } + if (toolCallDelta?.function?.name) { + toolCall.function.name += toolCallDelta.function.name; + } + } + if (toolCallDeltas) { + toolCallDeltas?.forEach((toolCall, i) => { + augmentToolCall(toolCalls[i], toolCall); + }); + } +} diff --git a/packages/core/src/llm/together.ts b/packages/core/src/llm/together.ts index 65651cdf7..0ab3fc443 100644 --- a/packages/core/src/llm/together.ts +++ b/packages/core/src/llm/together.ts @@ -1,5 +1,5 @@ import { getEnv } from "@llamaindex/env"; -import { OpenAI } from "./LLM.js"; +import { OpenAI } from "./open_ai.js"; export class TogetherLLM extends OpenAI { constructor(init?: Partial<OpenAI>) { diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts index 8aa9548c4..1131473d9 100644 --- a/packages/core/src/llm/types.ts +++ b/packages/core/src/llm/types.ts @@ -84,6 +84,7 @@ export interface ChatResponse { export interface ChatResponseChunk { delta: string; + additionalKwargs?: Record<string, any>; } export interface CompletionResponse { @@ -139,3 +140,13 @@ export interface MessageContentDetail { * Extended type for the content of a message that allows for multi-modal messages. */ export type MessageContent = string | MessageContentDetail[]; + +interface Function { + arguments: string; + name: string; +} + +export interface MessageToolCall { + id: string; + function: Function; +} diff --git a/packages/core/tests/CallbackManager.test.ts b/packages/core/tests/CallbackManager.test.ts index 86f31b183..3e2f2749e 100644 --- a/packages/core/tests/CallbackManager.test.ts +++ b/packages/core/tests/CallbackManager.test.ts @@ -20,20 +20,13 @@ import { CallbackManager } from "llamaindex/callbacks/CallbackManager"; import { OpenAIEmbedding } from "llamaindex/embeddings/index"; import { SummaryIndex } from "llamaindex/indices/summary/index"; import { VectorStoreIndex } from "llamaindex/indices/vectorStore/index"; -import { OpenAI } from "llamaindex/llm/LLM"; +import { OpenAI } from "llamaindex/llm/open_ai"; import { ResponseSynthesizer, SimpleResponseBuilder, } from "llamaindex/synthesizers/index"; import { mockEmbeddingModel, mockLlmGeneration } from "./utility/mockOpenAI.js"; -// Mock the OpenAI getOpenAISession function during testing -vi.mock("llamaindex/llm/open_ai", () => { - return { - getOpenAISession: vi.fn().mockImplementation(() => null), - }; -}); - describe("CallbackManager: onLLMStream and onRetrieve", () => { let serviceContext: ServiceContext; let streamCallbackData: StreamCallbackResponse[] = []; diff --git a/packages/core/tests/Embedding.test.ts b/packages/core/tests/Embedding.test.ts index e0b2d2bf6..ab863ead1 100644 --- a/packages/core/tests/Embedding.test.ts +++ b/packages/core/tests/Embedding.test.ts @@ -3,16 +3,9 @@ import { SimilarityType, similarity, } from "llamaindex/embeddings/index"; -import { beforeAll, describe, expect, test, vi } from "vitest"; +import { beforeAll, describe, expect, test } from "vitest"; import { mockEmbeddingModel } from "./utility/mockOpenAI.js"; -// Mock the OpenAI getOpenAISession function during testing -vi.mock("llamaindex/llm/open_ai", () => { - return { - getOpenAISession: vi.fn().mockImplementation(() => null), - }; -}); - describe("similarity", () => { test("throws error on mismatched lengths", () => { const embedding1 = [1, 2, 3]; diff --git a/packages/core/tests/MetadataExtractors.test.ts b/packages/core/tests/MetadataExtractors.test.ts index 0ca64b372..f9337b3b0 100644 --- a/packages/core/tests/MetadataExtractors.test.ts +++ b/packages/core/tests/MetadataExtractors.test.ts @@ -8,7 +8,7 @@ import { SummaryExtractor, TitleExtractor, } from "llamaindex/extractors/index"; -import { OpenAI } from "llamaindex/llm/LLM"; +import { OpenAI } from "llamaindex/llm/open_ai"; import { SimpleNodeParser } from "llamaindex/nodeParsers/index"; import { afterAll, beforeAll, describe, expect, test, vi } from "vitest"; import { @@ -17,13 +17,6 @@ import { mockLlmGeneration, } from "./utility/mockOpenAI.js"; -// Mock the OpenAI getOpenAISession function during testing -vi.mock("llamaindex/llm/open_ai", () => { - return { - getOpenAISession: vi.fn().mockImplementation(() => null), - }; -}); - describe("[MetadataExtractor]: Extractors should populate the metadata", () => { let serviceContext: ServiceContext; diff --git a/packages/core/tests/Selectors.test.ts b/packages/core/tests/Selectors.test.ts index cbda332a5..8bf9ed18d 100644 --- a/packages/core/tests/Selectors.test.ts +++ b/packages/core/tests/Selectors.test.ts @@ -1,4 +1,4 @@ -import { describe, expect, test, vi } from "vitest"; +import { describe, expect, test } from "vitest"; // from unittest.mock import patch import { serviceContextFromDefaults } from "llamaindex/ServiceContext"; @@ -6,12 +6,6 @@ import { OpenAI } from "llamaindex/llm/index"; import { LLMSingleSelector } from "llamaindex/selectors/index"; import { mocStructuredkLlmGeneration } from "./utility/mockOpenAI.js"; -vi.mock("llamaindex/llm/open_ai", () => { - return { - getOpenAISession: vi.fn().mockImplementation(() => null), - }; -}); - describe("LLMSelector", () => { test("should be able to output a selection with a reason", async () => { const serviceContext = serviceContextFromDefaults({}); diff --git a/packages/core/tests/agent/OpenAIAgent.test.ts b/packages/core/tests/agent/OpenAIAgent.test.ts index b6006105e..8180464a3 100644 --- a/packages/core/tests/agent/OpenAIAgent.test.ts +++ b/packages/core/tests/agent/OpenAIAgent.test.ts @@ -1,7 +1,7 @@ import { OpenAIAgent } from "llamaindex/agent/index"; import { OpenAI } from "llamaindex/llm/index"; import { FunctionTool } from "llamaindex/tools/index"; -import { beforeEach, describe, expect, it, vi } from "vitest"; +import { beforeEach, describe, expect, it } from "vitest"; import { mockLlmToolCallGeneration } from "../utility/mockOpenAI.js"; // Define a function to sum two numbers @@ -24,12 +24,6 @@ const sumJSON = { required: ["a", "b"], }; -vi.mock("llamaindex/llm/open_ai", () => { - return { - getOpenAISession: vi.fn().mockImplementation(() => null), - }; -}); - describe("OpenAIAgent", () => { let openaiAgent: OpenAIAgent; diff --git a/packages/core/tests/agent/runner/AgentRunner.test.ts b/packages/core/tests/agent/runner/AgentRunner.test.ts index ab11c34c6..95e943083 100644 --- a/packages/core/tests/agent/runner/AgentRunner.test.ts +++ b/packages/core/tests/agent/runner/AgentRunner.test.ts @@ -1,19 +1,13 @@ import { OpenAIAgentWorker } from "llamaindex/agent/index"; import { AgentRunner } from "llamaindex/agent/runner/base"; -import { OpenAI } from "llamaindex/llm/LLM"; -import { beforeEach, describe, expect, it, vi } from "vitest"; +import { OpenAI } from "llamaindex/llm/open_ai"; +import { beforeEach, describe, expect, it } from "vitest"; import { DEFAULT_LLM_TEXT_OUTPUT, mockLlmGeneration, } from "../../utility/mockOpenAI.js"; -vi.mock("llamaindex/llm/open_ai", () => { - return { - getOpenAISession: vi.fn().mockImplementation(() => null), - }; -}); - describe("Agent Runner", () => { let agentRunner: AgentRunner; diff --git a/packages/core/tests/indices/SummaryIndex.test.ts b/packages/core/tests/indices/SummaryIndex.test.ts index f43df0ce7..d273e78e0 100644 --- a/packages/core/tests/indices/SummaryIndex.test.ts +++ b/packages/core/tests/indices/SummaryIndex.test.ts @@ -10,16 +10,10 @@ import { rmSync } from "node:fs"; import { mkdtemp } from "node:fs/promises"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import { afterAll, beforeAll, describe, expect, it, vi } from "vitest"; +import { afterAll, beforeAll, describe, expect, it } from "vitest"; const testDir = await mkdtemp(join(tmpdir(), "test-")); -vi.mock("llamaindex/llm/open_ai", () => { - return { - getOpenAISession: vi.fn().mockImplementation(() => null), - }; -}); - import { mockServiceContext } from "../utility/mockServiceContext.js"; describe("SummaryIndex", () => { diff --git a/packages/core/tests/indices/VectorStoreIndex.test.ts b/packages/core/tests/indices/VectorStoreIndex.test.ts index 50365b59e..1537eba40 100644 --- a/packages/core/tests/indices/VectorStoreIndex.test.ts +++ b/packages/core/tests/indices/VectorStoreIndex.test.ts @@ -9,16 +9,10 @@ import { rmSync } from "node:fs"; import { mkdtemp } from "node:fs/promises"; import { tmpdir } from "node:os"; import { join } from "node:path"; -import { afterAll, beforeAll, describe, expect, test, vi } from "vitest"; +import { afterAll, beforeAll, describe, expect, test } from "vitest"; const testDir = await mkdtemp(join(tmpdir(), "test-")); -vi.mock("llamaindex/llm/open_ai", () => { - return { - getOpenAISession: vi.fn().mockImplementation(() => null), - }; -}); - import { mockServiceContext } from "../utility/mockServiceContext.js"; describe.sequential("VectorStoreIndex", () => { diff --git a/packages/core/tests/objects/ObjectIndex.test.ts b/packages/core/tests/objects/ObjectIndex.test.ts index cd74ac261..f71fd2ef2 100644 --- a/packages/core/tests/objects/ObjectIndex.test.ts +++ b/packages/core/tests/objects/ObjectIndex.test.ts @@ -5,13 +5,7 @@ import { SimpleToolNodeMapping, VectorStoreIndex, } from "llamaindex"; -import { beforeAll, describe, expect, test, vi } from "vitest"; - -vi.mock("llamaindex/llm/open_ai", () => { - return { - getOpenAISession: vi.fn().mockImplementation(() => null), - }; -}); +import { beforeAll, describe, expect, test } from "vitest"; import { mockServiceContext } from "../utility/mockServiceContext.js"; diff --git a/packages/core/tests/utility/mockOpenAI.ts b/packages/core/tests/utility/mockOpenAI.ts index 97fd1e418..f90de391d 100644 --- a/packages/core/tests/utility/mockOpenAI.ts +++ b/packages/core/tests/utility/mockOpenAI.ts @@ -1,6 +1,6 @@ import type { CallbackManager } from "llamaindex/callbacks/CallbackManager"; import type { OpenAIEmbedding } from "llamaindex/embeddings/index"; -import type { OpenAI } from "llamaindex/llm/LLM"; +import type { OpenAI } from "llamaindex/llm/open_ai"; import type { LLMChatParamsBase } from "llamaindex/llm/types"; import { vi } from "vitest"; diff --git a/packages/core/tests/vitest.config.ts b/packages/core/tests/vitest.config.ts new file mode 100644 index 000000000..08384ff5e --- /dev/null +++ b/packages/core/tests/vitest.config.ts @@ -0,0 +1,8 @@ +import { defineConfig } from "vitest/config"; + +export default defineConfig({ + test: { + include: ["**/*.test.ts"], + setupFiles: ["./vitest.setup.ts"], + }, +}); diff --git a/packages/core/tests/vitest.setup.ts b/packages/core/tests/vitest.setup.ts new file mode 100644 index 000000000..ec1583acd --- /dev/null +++ b/packages/core/tests/vitest.setup.ts @@ -0,0 +1,22 @@ +// eslint-disable-next-line turbo/no-undeclared-env-vars +process.env.OPENAI_API_KEY = "sk-1234567890abcdef1234567890abcdef"; +const originalFetch = globalThis.fetch; + +globalThis.fetch = function fetch(...args: Parameters<typeof originalFetch>) { + let url = args[0]; + if (typeof url !== "string") { + if (url instanceof Request) { + url = url.url; + } else { + url = url.toString(); + } + } + const parsedUrl = new URL(url); + if (parsedUrl.hostname.includes("api.openai.com")) { + // todo: mock api using https://mswjs.io + throw new Error( + "Make sure to return a mock response for OpenAI API requests in your test.", + ); + } + return originalFetch(...args); +}; -- GitLab