From 4fcbdf710ee9b3dc107ea32c2107d7e736bb2da7 Mon Sep 17 00:00:00 2001
From: Thuc Pham <51660321+thucpn@users.noreply.github.com>
Date: Fri, 5 Apr 2024 08:33:23 +0700
Subject: [PATCH] Add tool calls for openai streaming (#682)

Co-authored-by: Marcus Schiesser <mail@marcusschiesser.de>
Co-authored-by: Alex Yang <himself65@outlook.com>
---
 .changeset/orange-lions-remember.md           |   5 +
 examples/toolsStream.ts                       |  46 +++
 packages/core/src/ChatHistory.ts              |   2 +-
 packages/core/src/QuestionGenerator.ts        |   2 +-
 packages/core/src/ServiceContext.ts           |   2 +-
 packages/core/src/Settings.ts                 |   2 +-
 packages/core/src/llm/LLM.ts                  | 294 +---------------
 packages/core/src/llm/fireworks.ts            |   2 +-
 packages/core/src/llm/groq.ts                 |   2 +-
 packages/core/src/llm/open_ai.ts              | 318 +++++++++++++++++-
 packages/core/src/llm/together.ts             |   2 +-
 packages/core/src/llm/types.ts                |  11 +
 packages/core/tests/CallbackManager.test.ts   |   9 +-
 packages/core/tests/Embedding.test.ts         |   9 +-
 .../core/tests/MetadataExtractors.test.ts     |   9 +-
 packages/core/tests/Selectors.test.ts         |   8 +-
 packages/core/tests/agent/OpenAIAgent.test.ts |   8 +-
 .../tests/agent/runner/AgentRunner.test.ts    |  10 +-
 .../core/tests/indices/SummaryIndex.test.ts   |   8 +-
 .../tests/indices/VectorStoreIndex.test.ts    |   8 +-
 .../core/tests/objects/ObjectIndex.test.ts    |   8 +-
 packages/core/tests/utility/mockOpenAI.ts     |   2 +-
 packages/core/tests/vitest.config.ts          |   8 +
 packages/core/tests/vitest.setup.ts           |  22 ++
 24 files changed, 425 insertions(+), 372 deletions(-)
 create mode 100644 .changeset/orange-lions-remember.md
 create mode 100644 examples/toolsStream.ts
 create mode 100644 packages/core/tests/vitest.config.ts
 create mode 100644 packages/core/tests/vitest.setup.ts

diff --git a/.changeset/orange-lions-remember.md b/.changeset/orange-lions-remember.md
new file mode 100644
index 000000000..a298a96fb
--- /dev/null
+++ b/.changeset/orange-lions-remember.md
@@ -0,0 +1,5 @@
+---
+"llamaindex": patch
+---
+
+Support streaming for OpenAI tool calls
diff --git a/examples/toolsStream.ts b/examples/toolsStream.ts
new file mode 100644
index 000000000..b59114dc9
--- /dev/null
+++ b/examples/toolsStream.ts
@@ -0,0 +1,46 @@
+import { ChatResponseChunk, LLMChatParamsBase, OpenAI } from "llamaindex";
+
+async function main() {
+  const llm = new OpenAI({ model: "gpt-4-turbo-preview" });
+
+  const args: LLMChatParamsBase = {
+    messages: [
+      {
+        content: "Who was Goethe?",
+        role: "user",
+      },
+    ],
+    tools: [
+      {
+        type: "function",
+        function: {
+          name: "wikipedia_tool",
+          description: "A tool that uses a query engine to search Wikipedia.",
+          parameters: {
+            type: "object",
+            properties: {
+              query: {
+                type: "string",
+                description: "The query to search for",
+              },
+            },
+            required: ["query"],
+          },
+        },
+      },
+    ],
+    toolChoice: "auto",
+  };
+
+  const stream = await llm.chat({ ...args, stream: true });
+  let chunk: ChatResponseChunk | null = null;
+  for await (chunk of stream) {
+    process.stdout.write(chunk.delta);
+  }
+  console.log(chunk?.additionalKwargs?.toolCalls[0]);
+}
+
+(async function () {
+  await main();
+  console.log("Done");
+})();
diff --git a/packages/core/src/ChatHistory.ts b/packages/core/src/ChatHistory.ts
index d3f261515..dae76cd5d 100644
--- a/packages/core/src/ChatHistory.ts
+++ b/packages/core/src/ChatHistory.ts
@@ -1,7 +1,7 @@
 import { globalsHelper } from "./GlobalsHelper.js";
 import type { SummaryPrompt } from "./Prompt.js";
 import { defaultSummaryPrompt, messagesToHistoryStr } from "./Prompt.js";
-import { OpenAI } from "./llm/LLM.js";
+import { OpenAI } from "./llm/open_ai.js";
 import type { ChatMessage, LLM, MessageType } from "./llm/types.js";
 
 /**
diff --git a/packages/core/src/QuestionGenerator.ts b/packages/core/src/QuestionGenerator.ts
index 4b36174e3..0e7785a9e 100644
--- a/packages/core/src/QuestionGenerator.ts
+++ b/packages/core/src/QuestionGenerator.ts
@@ -5,7 +5,7 @@ import type {
   BaseQuestionGenerator,
   SubQuestion,
 } from "./engines/query/types.js";
-import { OpenAI } from "./llm/LLM.js";
+import { OpenAI } from "./llm/open_ai.js";
 import type { LLM } from "./llm/types.js";
 import { PromptMixin } from "./prompts/index.js";
 import type {
diff --git a/packages/core/src/ServiceContext.ts b/packages/core/src/ServiceContext.ts
index 3b4ffbbad..48a318f02 100644
--- a/packages/core/src/ServiceContext.ts
+++ b/packages/core/src/ServiceContext.ts
@@ -1,7 +1,7 @@
 import { PromptHelper } from "./PromptHelper.js";
 import { OpenAIEmbedding } from "./embeddings/OpenAIEmbedding.js";
 import type { BaseEmbedding } from "./embeddings/types.js";
-import { OpenAI } from "./llm/LLM.js";
+import { OpenAI } from "./llm/open_ai.js";
 import type { LLM } from "./llm/types.js";
 import { SimpleNodeParser } from "./nodeParsers/SimpleNodeParser.js";
 import type { NodeParser } from "./nodeParsers/types.js";
diff --git a/packages/core/src/Settings.ts b/packages/core/src/Settings.ts
index 7d29a2aa3..e2b2f4c0e 100644
--- a/packages/core/src/Settings.ts
+++ b/packages/core/src/Settings.ts
@@ -1,6 +1,6 @@
 import { CallbackManager } from "./callbacks/CallbackManager.js";
 import { OpenAIEmbedding } from "./embeddings/OpenAIEmbedding.js";
-import { OpenAI } from "./llm/LLM.js";
+import { OpenAI } from "./llm/open_ai.js";
 
 import { PromptHelper } from "./PromptHelper.js";
 import { SimpleNodeParser } from "./nodeParsers/SimpleNodeParser.js";
diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts
index eb17a75d7..8b68f7466 100644
--- a/packages/core/src/llm/LLM.ts
+++ b/packages/core/src/llm/LLM.ts
@@ -1,27 +1,10 @@
-import type OpenAILLM from "openai";
-import type { ClientOptions as OpenAIClientOptions } from "openai";
-import {
-  type OpenAIStreamToken,
-  type StreamCallbackResponse,
-} from "../callbacks/CallbackManager.js";
-
-import type { ChatCompletionMessageParam } from "openai/resources/index.js";
+import { type StreamCallbackResponse } from "../callbacks/CallbackManager.js";
+
 import type { LLMOptions } from "portkey-ai";
-import { Tokenizers } from "../GlobalsHelper.js";
-import { wrapEventCaller } from "../internal/context/EventCaller.js";
 import { getCallbackManager } from "../internal/settings/CallbackManager.js";
 import type { AnthropicSession } from "./anthropic.js";
 import { getAnthropicSession } from "./anthropic.js";
-import type { AzureOpenAIConfig } from "./azure.js";
-import {
-  getAzureBaseUrl,
-  getAzureConfigFromEnv,
-  getAzureModel,
-  shouldUseAzure,
-} from "./azure.js";
 import { BaseLLM } from "./base.js";
-import type { OpenAISession } from "./open_ai.js";
-import { getOpenAISession } from "./open_ai.js";
 import type { PortkeySession } from "./portkey.js";
 import { getPortkeySession } from "./portkey.js";
 import { ReplicateSession } from "./replicate_ai.js";
@@ -36,279 +19,6 @@ import type {
 } from "./types.js";
 import { wrapLLMEvent } from "./utils.js";
 
-export const GPT4_MODELS = {
-  "gpt-4": { contextWindow: 8192 },
-  "gpt-4-32k": { contextWindow: 32768 },
-  "gpt-4-32k-0613": { contextWindow: 32768 },
-  "gpt-4-turbo-preview": { contextWindow: 128000 },
-  "gpt-4-1106-preview": { contextWindow: 128000 },
-  "gpt-4-0125-preview": { contextWindow: 128000 },
-  "gpt-4-vision-preview": { contextWindow: 128000 },
-};
-
-// NOTE we don't currently support gpt-3.5-turbo-instruct and don't plan to in the near future
-export const GPT35_MODELS = {
-  "gpt-3.5-turbo": { contextWindow: 4096 },
-  "gpt-3.5-turbo-0613": { contextWindow: 4096 },
-  "gpt-3.5-turbo-16k": { contextWindow: 16384 },
-  "gpt-3.5-turbo-16k-0613": { contextWindow: 16384 },
-  "gpt-3.5-turbo-1106": { contextWindow: 16384 },
-  "gpt-3.5-turbo-0125": { contextWindow: 16384 },
-};
-
-/**
- * We currently support GPT-3.5 and GPT-4 models
- */
-export const ALL_AVAILABLE_OPENAI_MODELS = {
-  ...GPT4_MODELS,
-  ...GPT35_MODELS,
-};
-
-export const isFunctionCallingModel = (model: string): boolean => {
-  const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model);
-  const isOld = model.includes("0314") || model.includes("0301");
-  return isChatModel && !isOld;
-};
-
-/**
- * OpenAI LLM implementation
- */
-export class OpenAI extends BaseLLM {
-  // Per completion OpenAI params
-  model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS | string;
-  temperature: number;
-  topP: number;
-  maxTokens?: number;
-  additionalChatOptions?: Omit<
-    Partial<OpenAILLM.Chat.ChatCompletionCreateParams>,
-    | "max_tokens"
-    | "messages"
-    | "model"
-    | "temperature"
-    | "top_p"
-    | "stream"
-    | "tools"
-    | "toolChoice"
-  >;
-
-  // OpenAI session params
-  apiKey?: string = undefined;
-  maxRetries: number;
-  timeout?: number;
-  session: OpenAISession;
-  additionalSessionOptions?: Omit<
-    Partial<OpenAIClientOptions>,
-    "apiKey" | "maxRetries" | "timeout"
-  >;
-
-  constructor(
-    init?: Partial<OpenAI> & {
-      azure?: AzureOpenAIConfig;
-    },
-  ) {
-    super();
-    this.model = init?.model ?? "gpt-3.5-turbo";
-    this.temperature = init?.temperature ?? 0.1;
-    this.topP = init?.topP ?? 1;
-    this.maxTokens = init?.maxTokens ?? undefined;
-
-    this.maxRetries = init?.maxRetries ?? 10;
-    this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
-    this.additionalChatOptions = init?.additionalChatOptions;
-    this.additionalSessionOptions = init?.additionalSessionOptions;
-
-    if (init?.azure || shouldUseAzure()) {
-      const azureConfig = getAzureConfigFromEnv({
-        ...init?.azure,
-        model: getAzureModel(this.model),
-      });
-
-      if (!azureConfig.apiKey) {
-        throw new Error(
-          "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.",
-        );
-      }
-
-      this.apiKey = azureConfig.apiKey;
-      this.session =
-        init?.session ??
-        getOpenAISession({
-          azure: true,
-          apiKey: this.apiKey,
-          baseURL: getAzureBaseUrl(azureConfig),
-          maxRetries: this.maxRetries,
-          timeout: this.timeout,
-          defaultQuery: { "api-version": azureConfig.apiVersion },
-          ...this.additionalSessionOptions,
-        });
-    } else {
-      this.apiKey = init?.apiKey ?? undefined;
-      this.session =
-        init?.session ??
-        getOpenAISession({
-          apiKey: this.apiKey,
-          maxRetries: this.maxRetries,
-          timeout: this.timeout,
-          ...this.additionalSessionOptions,
-        });
-    }
-  }
-
-  get metadata() {
-    const contextWindow =
-      ALL_AVAILABLE_OPENAI_MODELS[
-        this.model as keyof typeof ALL_AVAILABLE_OPENAI_MODELS
-      ]?.contextWindow ?? 1024;
-    return {
-      model: this.model,
-      temperature: this.temperature,
-      topP: this.topP,
-      maxTokens: this.maxTokens,
-      contextWindow,
-      tokenizer: Tokenizers.CL100K_BASE,
-      isFunctionCallingModel: isFunctionCallingModel(this.model),
-    };
-  }
-
-  mapMessageType(
-    messageType: MessageType,
-  ): "user" | "assistant" | "system" | "function" | "tool" {
-    switch (messageType) {
-      case "user":
-        return "user";
-      case "assistant":
-        return "assistant";
-      case "system":
-        return "system";
-      case "function":
-        return "function";
-      case "tool":
-        return "tool";
-      default:
-        return "user";
-    }
-  }
-
-  toOpenAIMessage(messages: ChatMessage[]) {
-    return messages.map((message) => {
-      const additionalKwargs = message.additionalKwargs ?? {};
-
-      if (message.additionalKwargs?.toolCalls) {
-        additionalKwargs.tool_calls = message.additionalKwargs.toolCalls;
-        delete additionalKwargs.toolCalls;
-      }
-
-      return {
-        role: this.mapMessageType(message.role),
-        content: message.content,
-        ...additionalKwargs,
-      };
-    });
-  }
-
-  chat(
-    params: LLMChatParamsStreaming,
-  ): Promise<AsyncIterable<ChatResponseChunk>>;
-  chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
-  @wrapEventCaller
-  @wrapLLMEvent
-  async chat(
-    params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
-  ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
-    const { messages, stream, tools, toolChoice } = params;
-    const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = {
-      model: this.model,
-      temperature: this.temperature,
-      max_tokens: this.maxTokens,
-      tools: tools,
-      tool_choice: toolChoice,
-      messages: this.toOpenAIMessage(messages) as ChatCompletionMessageParam[],
-      top_p: this.topP,
-      ...this.additionalChatOptions,
-    };
-
-    // Streaming
-    if (stream) {
-      return this.streamChat(params);
-    }
-
-    // Non-streaming
-    const response = await this.session.openai.chat.completions.create({
-      ...baseRequestParams,
-      stream: false,
-    });
-
-    const content = response.choices[0].message?.content ?? null;
-
-    const kwargsOutput: Record<string, any> = {};
-
-    if (response.choices[0].message?.tool_calls) {
-      kwargsOutput.toolCalls = response.choices[0].message.tool_calls;
-    }
-
-    return {
-      message: {
-        content,
-        role: response.choices[0].message.role,
-        additionalKwargs: kwargsOutput,
-      },
-    };
-  }
-
-  @wrapEventCaller
-  protected async *streamChat({
-    messages,
-  }: LLMChatParamsStreaming): AsyncIterable<ChatResponseChunk> {
-    const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = {
-      model: this.model,
-      temperature: this.temperature,
-      max_tokens: this.maxTokens,
-      messages: messages.map(
-        (message) =>
-          ({
-            role: this.mapMessageType(message.role),
-            content: message.content,
-          }) as ChatCompletionMessageParam,
-      ),
-      top_p: this.topP,
-      ...this.additionalChatOptions,
-    };
-
-    const chunk_stream: AsyncIterable<OpenAIStreamToken> =
-      await this.session.openai.chat.completions.create({
-        ...baseRequestParams,
-        stream: true,
-      });
-
-    // TODO: add callback to streamConverter and use streamConverter here
-    //Indices
-    let idx_counter: number = 0;
-    for await (const part of chunk_stream) {
-      if (!part.choices.length) continue;
-
-      //Increment
-      part.choices[0].index = idx_counter;
-      const is_done: boolean =
-        part.choices[0].finish_reason === "stop" ? true : false;
-      //onLLMStream Callback
-
-      const stream_callback: StreamCallbackResponse = {
-        index: idx_counter,
-        isDone: is_done,
-        token: part,
-      };
-      getCallbackManager().dispatchEvent("stream", stream_callback);
-
-      idx_counter++;
-
-      yield {
-        delta: part.choices[0].delta.content ?? "",
-      };
-    }
-    return;
-  }
-}
-
 export const ALL_AVAILABLE_LLAMADEUCE_MODELS = {
   "Llama-2-70b-chat-old": {
     contextWindow: 4096,
diff --git a/packages/core/src/llm/fireworks.ts b/packages/core/src/llm/fireworks.ts
index 8621dd01f..f7814559b 100644
--- a/packages/core/src/llm/fireworks.ts
+++ b/packages/core/src/llm/fireworks.ts
@@ -1,5 +1,5 @@
 import { getEnv } from "@llamaindex/env";
-import { OpenAI } from "./LLM.js";
+import { OpenAI } from "./open_ai.js";
 
 export class FireworksLLM extends OpenAI {
   constructor(init?: Partial<OpenAI>) {
diff --git a/packages/core/src/llm/groq.ts b/packages/core/src/llm/groq.ts
index b29431749..083e305cb 100644
--- a/packages/core/src/llm/groq.ts
+++ b/packages/core/src/llm/groq.ts
@@ -1,5 +1,5 @@
 import { getEnv } from "@llamaindex/env";
-import { OpenAI } from "./LLM.js";
+import { OpenAI } from "./open_ai.js";
 
 export class Groq extends OpenAI {
   constructor(init?: Partial<OpenAI>) {
diff --git a/packages/core/src/llm/open_ai.ts b/packages/core/src/llm/open_ai.ts
index 336844aaa..b9987a73f 100644
--- a/packages/core/src/llm/open_ai.ts
+++ b/packages/core/src/llm/open_ai.ts
@@ -1,16 +1,43 @@
 import { getEnv } from "@llamaindex/env";
 import _ from "lodash";
-import type { ClientOptions } from "openai";
-import OpenAI from "openai";
+import type OpenAILLM from "openai";
+import type {
+  ClientOptions,
+  ClientOptions as OpenAIClientOptions,
+} from "openai";
+import { OpenAI as OrigOpenAI } from "openai";
 
-export class AzureOpenAI extends OpenAI {
+import type { ChatCompletionMessageParam } from "openai/resources/index.js";
+import { Tokenizers } from "../GlobalsHelper.js";
+import { wrapEventCaller } from "../internal/context/EventCaller.js";
+import { getCallbackManager } from "../internal/settings/CallbackManager.js";
+import type { AzureOpenAIConfig } from "./azure.js";
+import {
+  getAzureBaseUrl,
+  getAzureConfigFromEnv,
+  getAzureModel,
+  shouldUseAzure,
+} from "./azure.js";
+import { BaseLLM } from "./base.js";
+import type {
+  ChatMessage,
+  ChatResponse,
+  ChatResponseChunk,
+  LLMChatParamsNonStreaming,
+  LLMChatParamsStreaming,
+  MessageToolCall,
+  MessageType,
+} from "./types.js";
+import { wrapLLMEvent } from "./utils.js";
+
+export class AzureOpenAI extends OrigOpenAI {
   protected override authHeaders() {
     return { "api-key": this.apiKey };
   }
 }
 
 export class OpenAISession {
-  openai: OpenAI;
+  openai: OrigOpenAI;
 
   constructor(options: ClientOptions & { azure?: boolean } = {}) {
     if (!options.apiKey) {
@@ -24,7 +51,7 @@ export class OpenAISession {
     if (options.azure) {
       this.openai = new AzureOpenAI(options);
     } else {
-      this.openai = new OpenAI({
+      this.openai = new OrigOpenAI({
         ...options,
         // defaultHeaders: { "OpenAI-Beta": "assistants=v1" },
       });
@@ -60,3 +87,284 @@ export function getOpenAISession(
 
   return session;
 }
+
+export const GPT4_MODELS = {
+  "gpt-4": { contextWindow: 8192 },
+  "gpt-4-32k": { contextWindow: 32768 },
+  "gpt-4-32k-0613": { contextWindow: 32768 },
+  "gpt-4-turbo-preview": { contextWindow: 128000 },
+  "gpt-4-1106-preview": { contextWindow: 128000 },
+  "gpt-4-0125-preview": { contextWindow: 128000 },
+  "gpt-4-vision-preview": { contextWindow: 128000 },
+};
+
+// NOTE we don't currently support gpt-3.5-turbo-instruct and don't plan to in the near future
+export const GPT35_MODELS = {
+  "gpt-3.5-turbo": { contextWindow: 4096 },
+  "gpt-3.5-turbo-0613": { contextWindow: 4096 },
+  "gpt-3.5-turbo-16k": { contextWindow: 16384 },
+  "gpt-3.5-turbo-16k-0613": { contextWindow: 16384 },
+  "gpt-3.5-turbo-1106": { contextWindow: 16384 },
+  "gpt-3.5-turbo-0125": { contextWindow: 16384 },
+};
+
+/**
+ * We currently support GPT-3.5 and GPT-4 models
+ */
+export const ALL_AVAILABLE_OPENAI_MODELS = {
+  ...GPT4_MODELS,
+  ...GPT35_MODELS,
+};
+
+export const isFunctionCallingModel = (model: string): boolean => {
+  const isChatModel = Object.keys(ALL_AVAILABLE_OPENAI_MODELS).includes(model);
+  const isOld = model.includes("0314") || model.includes("0301");
+  return isChatModel && !isOld;
+};
+
+/**
+ * OpenAI LLM implementation
+ */
+export class OpenAI extends BaseLLM {
+  // Per completion OpenAI params
+  model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS | string;
+  temperature: number;
+  topP: number;
+  maxTokens?: number;
+  additionalChatOptions?: Omit<
+    Partial<OpenAILLM.Chat.ChatCompletionCreateParams>,
+    | "max_tokens"
+    | "messages"
+    | "model"
+    | "temperature"
+    | "top_p"
+    | "stream"
+    | "tools"
+    | "toolChoice"
+  >;
+
+  // OpenAI session params
+  apiKey?: string = undefined;
+  maxRetries: number;
+  timeout?: number;
+  session: OpenAISession;
+  additionalSessionOptions?: Omit<
+    Partial<OpenAIClientOptions>,
+    "apiKey" | "maxRetries" | "timeout"
+  >;
+
+  constructor(
+    init?: Partial<OpenAI> & {
+      azure?: AzureOpenAIConfig;
+    },
+  ) {
+    super();
+    this.model = init?.model ?? "gpt-3.5-turbo";
+    this.temperature = init?.temperature ?? 0.1;
+    this.topP = init?.topP ?? 1;
+    this.maxTokens = init?.maxTokens ?? undefined;
+
+    this.maxRetries = init?.maxRetries ?? 10;
+    this.timeout = init?.timeout ?? 60 * 1000; // Default is 60 seconds
+    this.additionalChatOptions = init?.additionalChatOptions;
+    this.additionalSessionOptions = init?.additionalSessionOptions;
+
+    if (init?.azure || shouldUseAzure()) {
+      const azureConfig = getAzureConfigFromEnv({
+        ...init?.azure,
+        model: getAzureModel(this.model),
+      });
+
+      if (!azureConfig.apiKey) {
+        throw new Error(
+          "Azure API key is required for OpenAI Azure models. Please set the AZURE_OPENAI_KEY environment variable.",
+        );
+      }
+
+      this.apiKey = azureConfig.apiKey;
+      this.session =
+        init?.session ??
+        getOpenAISession({
+          azure: true,
+          apiKey: this.apiKey,
+          baseURL: getAzureBaseUrl(azureConfig),
+          maxRetries: this.maxRetries,
+          timeout: this.timeout,
+          defaultQuery: { "api-version": azureConfig.apiVersion },
+          ...this.additionalSessionOptions,
+        });
+    } else {
+      this.apiKey = init?.apiKey ?? undefined;
+      this.session =
+        init?.session ??
+        getOpenAISession({
+          apiKey: this.apiKey,
+          maxRetries: this.maxRetries,
+          timeout: this.timeout,
+          ...this.additionalSessionOptions,
+        });
+    }
+  }
+
+  get metadata() {
+    const contextWindow =
+      ALL_AVAILABLE_OPENAI_MODELS[
+        this.model as keyof typeof ALL_AVAILABLE_OPENAI_MODELS
+      ]?.contextWindow ?? 1024;
+    return {
+      model: this.model,
+      temperature: this.temperature,
+      topP: this.topP,
+      maxTokens: this.maxTokens,
+      contextWindow,
+      tokenizer: Tokenizers.CL100K_BASE,
+      isFunctionCallingModel: isFunctionCallingModel(this.model),
+    };
+  }
+
+  mapMessageType(
+    messageType: MessageType,
+  ): "user" | "assistant" | "system" | "function" | "tool" {
+    switch (messageType) {
+      case "user":
+        return "user";
+      case "assistant":
+        return "assistant";
+      case "system":
+        return "system";
+      case "function":
+        return "function";
+      case "tool":
+        return "tool";
+      default:
+        return "user";
+    }
+  }
+
+  toOpenAIMessage(messages: ChatMessage[]) {
+    return messages.map((message) => {
+      const additionalKwargs = message.additionalKwargs ?? {};
+
+      if (message.additionalKwargs?.toolCalls) {
+        additionalKwargs.tool_calls = message.additionalKwargs.toolCalls;
+        delete additionalKwargs.toolCalls;
+      }
+
+      return {
+        role: this.mapMessageType(message.role),
+        content: message.content,
+        ...additionalKwargs,
+      };
+    });
+  }
+
+  chat(
+    params: LLMChatParamsStreaming,
+  ): Promise<AsyncIterable<ChatResponseChunk>>;
+  chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
+  @wrapEventCaller
+  @wrapLLMEvent
+  async chat(
+    params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
+  ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
+    const { messages, stream, tools, toolChoice } = params;
+    const baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams = {
+      model: this.model,
+      temperature: this.temperature,
+      max_tokens: this.maxTokens,
+      tools: tools,
+      tool_choice: toolChoice,
+      messages: this.toOpenAIMessage(messages) as ChatCompletionMessageParam[],
+      top_p: this.topP,
+      ...this.additionalChatOptions,
+    };
+
+    // Streaming
+    if (stream) {
+      return this.streamChat(baseRequestParams);
+    }
+
+    // Non-streaming
+    const response = await this.session.openai.chat.completions.create({
+      ...baseRequestParams,
+      stream: false,
+    });
+
+    const content = response.choices[0].message?.content ?? null;
+
+    const kwargsOutput: Record<string, any> = {};
+
+    if (response.choices[0].message?.tool_calls) {
+      kwargsOutput.toolCalls = response.choices[0].message.tool_calls;
+    }
+
+    return {
+      message: {
+        content,
+        role: response.choices[0].message.role,
+        additionalKwargs: kwargsOutput,
+      },
+    };
+  }
+
+  @wrapEventCaller
+  protected async *streamChat(
+    baseRequestParams: OpenAILLM.Chat.ChatCompletionCreateParams,
+  ): AsyncIterable<ChatResponseChunk> {
+    const stream: AsyncIterable<OpenAILLM.Chat.ChatCompletionChunk> =
+      await this.session.openai.chat.completions.create({
+        ...baseRequestParams,
+        stream: true,
+      });
+
+    // TODO: add callback to streamConverter and use streamConverter here
+    //Indices
+    let idxCounter: number = 0;
+    const toolCalls: MessageToolCall[] = [];
+    for await (const part of stream) {
+      if (!part.choices.length) continue;
+      const choice = part.choices[0];
+      updateToolCalls(toolCalls, choice.delta.tool_calls);
+
+      const isDone: boolean = choice.finish_reason !== null;
+
+      getCallbackManager().dispatchEvent("stream", {
+        index: idxCounter++,
+        isDone: isDone,
+        token: part,
+      });
+
+      yield {
+        // add tool calls to final chunk
+        additionalKwargs: isDone ? { toolCalls: toolCalls } : undefined,
+        delta: choice.delta.content ?? "",
+      };
+    }
+    return;
+  }
+}
+
+function updateToolCalls(
+  toolCalls: MessageToolCall[],
+  toolCallDeltas?: OpenAILLM.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall[],
+) {
+  function augmentToolCall(
+    toolCall?: MessageToolCall,
+    toolCallDelta?: OpenAILLM.Chat.Completions.ChatCompletionChunk.Choice.Delta.ToolCall,
+  ) {
+    toolCall =
+      toolCall ??
+      ({ function: { name: "", arguments: "" } } as MessageToolCall);
+    if (toolCallDelta?.function?.arguments) {
+      toolCall.function.arguments += toolCallDelta.function.arguments;
+    }
+    if (toolCallDelta?.function?.name) {
+      toolCall.function.name += toolCallDelta.function.name;
+    }
+  }
+  if (toolCallDeltas) {
+    toolCallDeltas?.forEach((toolCall, i) => {
+      augmentToolCall(toolCalls[i], toolCall);
+    });
+  }
+}
diff --git a/packages/core/src/llm/together.ts b/packages/core/src/llm/together.ts
index 65651cdf7..0ab3fc443 100644
--- a/packages/core/src/llm/together.ts
+++ b/packages/core/src/llm/together.ts
@@ -1,5 +1,5 @@
 import { getEnv } from "@llamaindex/env";
-import { OpenAI } from "./LLM.js";
+import { OpenAI } from "./open_ai.js";
 
 export class TogetherLLM extends OpenAI {
   constructor(init?: Partial<OpenAI>) {
diff --git a/packages/core/src/llm/types.ts b/packages/core/src/llm/types.ts
index 8aa9548c4..1131473d9 100644
--- a/packages/core/src/llm/types.ts
+++ b/packages/core/src/llm/types.ts
@@ -84,6 +84,7 @@ export interface ChatResponse {
 
 export interface ChatResponseChunk {
   delta: string;
+  additionalKwargs?: Record<string, any>;
 }
 
 export interface CompletionResponse {
@@ -139,3 +140,13 @@ export interface MessageContentDetail {
  * Extended type for the content of a message that allows for multi-modal messages.
  */
 export type MessageContent = string | MessageContentDetail[];
+
+interface Function {
+  arguments: string;
+  name: string;
+}
+
+export interface MessageToolCall {
+  id: string;
+  function: Function;
+}
diff --git a/packages/core/tests/CallbackManager.test.ts b/packages/core/tests/CallbackManager.test.ts
index 86f31b183..3e2f2749e 100644
--- a/packages/core/tests/CallbackManager.test.ts
+++ b/packages/core/tests/CallbackManager.test.ts
@@ -20,20 +20,13 @@ import { CallbackManager } from "llamaindex/callbacks/CallbackManager";
 import { OpenAIEmbedding } from "llamaindex/embeddings/index";
 import { SummaryIndex } from "llamaindex/indices/summary/index";
 import { VectorStoreIndex } from "llamaindex/indices/vectorStore/index";
-import { OpenAI } from "llamaindex/llm/LLM";
+import { OpenAI } from "llamaindex/llm/open_ai";
 import {
   ResponseSynthesizer,
   SimpleResponseBuilder,
 } from "llamaindex/synthesizers/index";
 import { mockEmbeddingModel, mockLlmGeneration } from "./utility/mockOpenAI.js";
 
-// Mock the OpenAI getOpenAISession function during testing
-vi.mock("llamaindex/llm/open_ai", () => {
-  return {
-    getOpenAISession: vi.fn().mockImplementation(() => null),
-  };
-});
-
 describe("CallbackManager: onLLMStream and onRetrieve", () => {
   let serviceContext: ServiceContext;
   let streamCallbackData: StreamCallbackResponse[] = [];
diff --git a/packages/core/tests/Embedding.test.ts b/packages/core/tests/Embedding.test.ts
index e0b2d2bf6..ab863ead1 100644
--- a/packages/core/tests/Embedding.test.ts
+++ b/packages/core/tests/Embedding.test.ts
@@ -3,16 +3,9 @@ import {
   SimilarityType,
   similarity,
 } from "llamaindex/embeddings/index";
-import { beforeAll, describe, expect, test, vi } from "vitest";
+import { beforeAll, describe, expect, test } from "vitest";
 import { mockEmbeddingModel } from "./utility/mockOpenAI.js";
 
-// Mock the OpenAI getOpenAISession function during testing
-vi.mock("llamaindex/llm/open_ai", () => {
-  return {
-    getOpenAISession: vi.fn().mockImplementation(() => null),
-  };
-});
-
 describe("similarity", () => {
   test("throws error on mismatched lengths", () => {
     const embedding1 = [1, 2, 3];
diff --git a/packages/core/tests/MetadataExtractors.test.ts b/packages/core/tests/MetadataExtractors.test.ts
index 0ca64b372..f9337b3b0 100644
--- a/packages/core/tests/MetadataExtractors.test.ts
+++ b/packages/core/tests/MetadataExtractors.test.ts
@@ -8,7 +8,7 @@ import {
   SummaryExtractor,
   TitleExtractor,
 } from "llamaindex/extractors/index";
-import { OpenAI } from "llamaindex/llm/LLM";
+import { OpenAI } from "llamaindex/llm/open_ai";
 import { SimpleNodeParser } from "llamaindex/nodeParsers/index";
 import { afterAll, beforeAll, describe, expect, test, vi } from "vitest";
 import {
@@ -17,13 +17,6 @@ import {
   mockLlmGeneration,
 } from "./utility/mockOpenAI.js";
 
-// Mock the OpenAI getOpenAISession function during testing
-vi.mock("llamaindex/llm/open_ai", () => {
-  return {
-    getOpenAISession: vi.fn().mockImplementation(() => null),
-  };
-});
-
 describe("[MetadataExtractor]: Extractors should populate the metadata", () => {
   let serviceContext: ServiceContext;
 
diff --git a/packages/core/tests/Selectors.test.ts b/packages/core/tests/Selectors.test.ts
index cbda332a5..8bf9ed18d 100644
--- a/packages/core/tests/Selectors.test.ts
+++ b/packages/core/tests/Selectors.test.ts
@@ -1,4 +1,4 @@
-import { describe, expect, test, vi } from "vitest";
+import { describe, expect, test } from "vitest";
 // from unittest.mock import patch
 
 import { serviceContextFromDefaults } from "llamaindex/ServiceContext";
@@ -6,12 +6,6 @@ import { OpenAI } from "llamaindex/llm/index";
 import { LLMSingleSelector } from "llamaindex/selectors/index";
 import { mocStructuredkLlmGeneration } from "./utility/mockOpenAI.js";
 
-vi.mock("llamaindex/llm/open_ai", () => {
-  return {
-    getOpenAISession: vi.fn().mockImplementation(() => null),
-  };
-});
-
 describe("LLMSelector", () => {
   test("should be able to output a selection with a reason", async () => {
     const serviceContext = serviceContextFromDefaults({});
diff --git a/packages/core/tests/agent/OpenAIAgent.test.ts b/packages/core/tests/agent/OpenAIAgent.test.ts
index b6006105e..8180464a3 100644
--- a/packages/core/tests/agent/OpenAIAgent.test.ts
+++ b/packages/core/tests/agent/OpenAIAgent.test.ts
@@ -1,7 +1,7 @@
 import { OpenAIAgent } from "llamaindex/agent/index";
 import { OpenAI } from "llamaindex/llm/index";
 import { FunctionTool } from "llamaindex/tools/index";
-import { beforeEach, describe, expect, it, vi } from "vitest";
+import { beforeEach, describe, expect, it } from "vitest";
 import { mockLlmToolCallGeneration } from "../utility/mockOpenAI.js";
 
 // Define a function to sum two numbers
@@ -24,12 +24,6 @@ const sumJSON = {
   required: ["a", "b"],
 };
 
-vi.mock("llamaindex/llm/open_ai", () => {
-  return {
-    getOpenAISession: vi.fn().mockImplementation(() => null),
-  };
-});
-
 describe("OpenAIAgent", () => {
   let openaiAgent: OpenAIAgent;
 
diff --git a/packages/core/tests/agent/runner/AgentRunner.test.ts b/packages/core/tests/agent/runner/AgentRunner.test.ts
index ab11c34c6..95e943083 100644
--- a/packages/core/tests/agent/runner/AgentRunner.test.ts
+++ b/packages/core/tests/agent/runner/AgentRunner.test.ts
@@ -1,19 +1,13 @@
 import { OpenAIAgentWorker } from "llamaindex/agent/index";
 import { AgentRunner } from "llamaindex/agent/runner/base";
-import { OpenAI } from "llamaindex/llm/LLM";
-import { beforeEach, describe, expect, it, vi } from "vitest";
+import { OpenAI } from "llamaindex/llm/open_ai";
+import { beforeEach, describe, expect, it } from "vitest";
 
 import {
   DEFAULT_LLM_TEXT_OUTPUT,
   mockLlmGeneration,
 } from "../../utility/mockOpenAI.js";
 
-vi.mock("llamaindex/llm/open_ai", () => {
-  return {
-    getOpenAISession: vi.fn().mockImplementation(() => null),
-  };
-});
-
 describe("Agent Runner", () => {
   let agentRunner: AgentRunner;
 
diff --git a/packages/core/tests/indices/SummaryIndex.test.ts b/packages/core/tests/indices/SummaryIndex.test.ts
index f43df0ce7..d273e78e0 100644
--- a/packages/core/tests/indices/SummaryIndex.test.ts
+++ b/packages/core/tests/indices/SummaryIndex.test.ts
@@ -10,16 +10,10 @@ import { rmSync } from "node:fs";
 import { mkdtemp } from "node:fs/promises";
 import { tmpdir } from "node:os";
 import { join } from "node:path";
-import { afterAll, beforeAll, describe, expect, it, vi } from "vitest";
+import { afterAll, beforeAll, describe, expect, it } from "vitest";
 
 const testDir = await mkdtemp(join(tmpdir(), "test-"));
 
-vi.mock("llamaindex/llm/open_ai", () => {
-  return {
-    getOpenAISession: vi.fn().mockImplementation(() => null),
-  };
-});
-
 import { mockServiceContext } from "../utility/mockServiceContext.js";
 
 describe("SummaryIndex", () => {
diff --git a/packages/core/tests/indices/VectorStoreIndex.test.ts b/packages/core/tests/indices/VectorStoreIndex.test.ts
index 50365b59e..1537eba40 100644
--- a/packages/core/tests/indices/VectorStoreIndex.test.ts
+++ b/packages/core/tests/indices/VectorStoreIndex.test.ts
@@ -9,16 +9,10 @@ import { rmSync } from "node:fs";
 import { mkdtemp } from "node:fs/promises";
 import { tmpdir } from "node:os";
 import { join } from "node:path";
-import { afterAll, beforeAll, describe, expect, test, vi } from "vitest";
+import { afterAll, beforeAll, describe, expect, test } from "vitest";
 
 const testDir = await mkdtemp(join(tmpdir(), "test-"));
 
-vi.mock("llamaindex/llm/open_ai", () => {
-  return {
-    getOpenAISession: vi.fn().mockImplementation(() => null),
-  };
-});
-
 import { mockServiceContext } from "../utility/mockServiceContext.js";
 
 describe.sequential("VectorStoreIndex", () => {
diff --git a/packages/core/tests/objects/ObjectIndex.test.ts b/packages/core/tests/objects/ObjectIndex.test.ts
index cd74ac261..f71fd2ef2 100644
--- a/packages/core/tests/objects/ObjectIndex.test.ts
+++ b/packages/core/tests/objects/ObjectIndex.test.ts
@@ -5,13 +5,7 @@ import {
   SimpleToolNodeMapping,
   VectorStoreIndex,
 } from "llamaindex";
-import { beforeAll, describe, expect, test, vi } from "vitest";
-
-vi.mock("llamaindex/llm/open_ai", () => {
-  return {
-    getOpenAISession: vi.fn().mockImplementation(() => null),
-  };
-});
+import { beforeAll, describe, expect, test } from "vitest";
 
 import { mockServiceContext } from "../utility/mockServiceContext.js";
 
diff --git a/packages/core/tests/utility/mockOpenAI.ts b/packages/core/tests/utility/mockOpenAI.ts
index 97fd1e418..f90de391d 100644
--- a/packages/core/tests/utility/mockOpenAI.ts
+++ b/packages/core/tests/utility/mockOpenAI.ts
@@ -1,6 +1,6 @@
 import type { CallbackManager } from "llamaindex/callbacks/CallbackManager";
 import type { OpenAIEmbedding } from "llamaindex/embeddings/index";
-import type { OpenAI } from "llamaindex/llm/LLM";
+import type { OpenAI } from "llamaindex/llm/open_ai";
 import type { LLMChatParamsBase } from "llamaindex/llm/types";
 import { vi } from "vitest";
 
diff --git a/packages/core/tests/vitest.config.ts b/packages/core/tests/vitest.config.ts
new file mode 100644
index 000000000..08384ff5e
--- /dev/null
+++ b/packages/core/tests/vitest.config.ts
@@ -0,0 +1,8 @@
+import { defineConfig } from "vitest/config";
+
+export default defineConfig({
+  test: {
+    include: ["**/*.test.ts"],
+    setupFiles: ["./vitest.setup.ts"],
+  },
+});
diff --git a/packages/core/tests/vitest.setup.ts b/packages/core/tests/vitest.setup.ts
new file mode 100644
index 000000000..ec1583acd
--- /dev/null
+++ b/packages/core/tests/vitest.setup.ts
@@ -0,0 +1,22 @@
+// eslint-disable-next-line turbo/no-undeclared-env-vars
+process.env.OPENAI_API_KEY = "sk-1234567890abcdef1234567890abcdef";
+const originalFetch = globalThis.fetch;
+
+globalThis.fetch = function fetch(...args: Parameters<typeof originalFetch>) {
+  let url = args[0];
+  if (typeof url !== "string") {
+    if (url instanceof Request) {
+      url = url.url;
+    } else {
+      url = url.toString();
+    }
+  }
+  const parsedUrl = new URL(url);
+  if (parsedUrl.hostname.includes("api.openai.com")) {
+    // todo: mock api using https://mswjs.io
+    throw new Error(
+      "Make sure to return a mock response for OpenAI API requests in your test.",
+    );
+  }
+  return originalFetch(...args);
+};
-- 
GitLab