From 91a18e7057e942dbd1a8c53360111f49bafed20d Mon Sep 17 00:00:00 2001
From: ANKIT VARSHNEY <132201033+AVtheking@users.noreply.github.com>
Date: Sun, 16 Mar 2025 09:28:28 +0530
Subject: [PATCH] feat: add support for structured output with zod schema.
 (#1749)

---
 .changeset/mighty-eagles-wink.md             | 13 +++++
 e2e/fixtures/llm/openai.ts                   |  1 +
 examples/jsonExtract.ts                      | 51 +++++++++++++++-----
 packages/community/src/llm/bedrock/index.ts  |  1 +
 packages/core/src/llms/base.ts               |  6 ++-
 packages/core/src/llms/type.ts               |  4 ++
 packages/core/src/utils/mock.ts              |  1 +
 packages/providers/anthropic/src/llm.ts      |  1 +
 packages/providers/google/src/base.ts        |  1 +
 packages/providers/huggingface/src/llm.ts    |  1 +
 packages/providers/huggingface/src/shared.ts |  1 +
 packages/providers/mistral/src/llm.ts        |  1 +
 packages/providers/ollama/package.json       | 12 +++++
 packages/providers/ollama/src/llm.ts         | 30 +++++++++++-
 packages/providers/openai/package.json       |  3 +-
 packages/providers/openai/src/llm.ts         | 26 +++++++++-
 packages/providers/perplexity/src/llm.ts     |  1 +
 packages/providers/replicate/src/llm.ts      |  1 +
 packages/providers/vercel/src/llm.ts         |  1 +
 pnpm-lock.yaml                               |  9 ++++
 20 files changed, 148 insertions(+), 17 deletions(-)
 create mode 100644 .changeset/mighty-eagles-wink.md

diff --git a/.changeset/mighty-eagles-wink.md b/.changeset/mighty-eagles-wink.md
new file mode 100644
index 000000000..affdbe49c
--- /dev/null
+++ b/.changeset/mighty-eagles-wink.md
@@ -0,0 +1,13 @@
+---
+"@llamaindex/huggingface": minor
+"@llamaindex/anthropic": minor
+"@llamaindex/mistral": minor
+"@llamaindex/google": minor
+"@llamaindex/ollama": minor
+"@llamaindex/openai": minor
+"@llamaindex/core": minor
+"@llamaindex/examples": minor
+---
+
+Added support for structured output in the chat api of openai and ollama
+Added structured output parameter in the provider
diff --git a/e2e/fixtures/llm/openai.ts b/e2e/fixtures/llm/openai.ts
index ba35f3167..14601b873 100644
--- a/e2e/fixtures/llm/openai.ts
+++ b/e2e/fixtures/llm/openai.ts
@@ -42,6 +42,7 @@ export class OpenAI implements LLM {
       contextWindow: 2048,
       tokenizer: undefined,
       isFunctionCallingModel: true,
+      structuredOutput: false,
     };
   }
 
diff --git a/examples/jsonExtract.ts b/examples/jsonExtract.ts
index 4622177e7..6d4a5f476 100644
--- a/examples/jsonExtract.ts
+++ b/examples/jsonExtract.ts
@@ -1,4 +1,5 @@
 import { OpenAI } from "@llamaindex/openai";
+import { z } from "zod";
 
 // Example using OpenAI's chat API to extract JSON from a sales call transcript
 // using json_mode see https://platform.openai.com/docs/guides/text-generation/json-mode for more details
@@ -6,22 +7,47 @@ import { OpenAI } from "@llamaindex/openai";
 const transcript =
   "[Phone rings]\n\nJohn: Hello, this is John.\n\nSarah: Hi John, this is Sarah from XYZ Company. I'm calling to discuss our new product, the XYZ Widget, and see if it might be a good fit for your business.\n\nJohn: Hi Sarah, thanks for reaching out. I'm definitely interested in learning more about the XYZ Widget. Can you give me a quick overview of what it does?\n\nSarah: Of course! The XYZ Widget is a cutting-edge tool that helps businesses streamline their workflow and improve productivity. It's designed to automate repetitive tasks and provide real-time data analytics to help you make informed decisions.\n\nJohn: That sounds really interesting. I can see how that could benefit our team. Do you have any case studies or success stories from other companies who have used the XYZ Widget?\n\nSarah: Absolutely, we have several case studies that I can share with you. I'll send those over along with some additional information about the product. I'd also love to schedule a demo for you and your team to see the XYZ Widget in action.\n\nJohn: That would be great. I'll make sure to review the case studies and then we can set up a time for the demo. In the meantime, are there any specific action items or next steps we should take?\n\nSarah: Yes, I'll send over the information and then follow up with you to schedule the demo. In the meantime, feel free to reach out if you have any questions or need further information.\n\nJohn: Sounds good, I appreciate your help Sarah. I'm looking forward to learning more about the XYZ Widget and seeing how it can benefit our business.\n\nSarah: Thank you, John. I'll be in touch soon. Have a great day!\n\nJohn: You too, bye.";
 
+const exampleSchema = z.object({
+  summary: z.string(),
+  products: z.array(z.string()),
+  rep_name: z.string(),
+  prospect_name: z.string(),
+  action_items: z.array(z.string()),
+});
+
+const example = {
+  summary:
+    "High-level summary of the call transcript. Should not exceed 3 sentences.",
+  products: ["product 1", "product 2"],
+  rep_name: "Name of the sales rep",
+  prospect_name: "Name of the prospect",
+  action_items: ["action item 1", "action item 2"],
+};
+
 async function main() {
   const llm = new OpenAI({
-    model: "gpt-4-1106-preview",
-    additionalChatOptions: { response_format: { type: "json_object" } },
+    model: "gpt-4o",
   });
 
-  const example = {
-    summary:
-      "High-level summary of the call transcript. Should not exceed 3 sentences.",
-    products: ["product 1", "product 2"],
-    rep_name: "Name of the sales rep",
-    prospect_name: "Name of the prospect",
-    action_items: ["action item 1", "action item 2"],
-  };
-
+  //response format as zod schema
   const response = await llm.chat({
+    messages: [
+      {
+        role: "system",
+        content: `You are an expert assistant for summarizing and extracting insights from sales call transcripts.`,
+      },
+      {
+        role: "user",
+        content: `Here is the transcript: \n------\n${transcript}\n------`,
+      },
+    ],
+    responseFormat: exampleSchema,
+  });
+
+  console.log(response.message.content);
+
+  //response format as json_object
+  const response2 = await llm.chat({
     messages: [
       {
         role: "system",
@@ -34,9 +60,10 @@ async function main() {
         content: `Here is the transcript: \n------\n${transcript}\n------`,
       },
     ],
+    responseFormat: { type: "json_object" },
   });
 
-  console.log(response.message.content);
+  console.log(response2.message.content);
 }
 
 main().catch(console.error);
diff --git a/packages/community/src/llm/bedrock/index.ts b/packages/community/src/llm/bedrock/index.ts
index 5d091241f..1e491247f 100644
--- a/packages/community/src/llm/bedrock/index.ts
+++ b/packages/community/src/llm/bedrock/index.ts
@@ -381,6 +381,7 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
       maxTokens: this.maxTokens,
       contextWindow: BEDROCK_FOUNDATION_LLMS[this.model] ?? 128000,
       tokenizer: undefined,
+      structuredOutput: false,
     };
   }
 
diff --git a/packages/core/src/llms/base.ts b/packages/core/src/llms/base.ts
index 46306bfec..c1456c1cb 100644
--- a/packages/core/src/llms/base.ts
+++ b/packages/core/src/llms/base.ts
@@ -28,11 +28,12 @@ export abstract class BaseLLM<
   async complete(
     params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming,
   ): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> {
-    const { prompt, stream } = params;
+    const { prompt, stream, responseFormat } = params;
     if (stream) {
       const stream = await this.chat({
         messages: [{ content: prompt, role: "user" }],
         stream: true,
+        ...(responseFormat ? { responseFormat } : {}),
       });
       return streamConverter(stream, (chunk) => {
         return {
@@ -41,9 +42,12 @@ export abstract class BaseLLM<
         };
       });
     }
+
     const chatResponse = await this.chat({
       messages: [{ content: prompt, role: "user" }],
+      ...(responseFormat ? { responseFormat } : {}),
     });
+
     return {
       text: extractText(chatResponse.message.content),
       raw: chatResponse.raw,
diff --git a/packages/core/src/llms/type.ts b/packages/core/src/llms/type.ts
index 787c7ffa8..5f85dc86e 100644
--- a/packages/core/src/llms/type.ts
+++ b/packages/core/src/llms/type.ts
@@ -1,5 +1,6 @@
 import type { Tokenizers } from "@llamaindex/env/tokenizers";
 import type { JSONSchemaType } from "ajv";
+import { z } from "zod";
 import type { JSONObject, JSONValue } from "../global";
 
 /**
@@ -106,6 +107,7 @@ export type LLMMetadata = {
   maxTokens?: number | undefined;
   contextWindow: number;
   tokenizer: Tokenizers | undefined;
+  structuredOutput: boolean;
 };
 
 export interface LLMChatParamsBase<
@@ -115,6 +117,7 @@ export interface LLMChatParamsBase<
   messages: ChatMessage<AdditionalMessageOptions>[];
   additionalChatOptions?: AdditionalChatOptions;
   tools?: BaseTool[];
+  responseFormat?: z.ZodType | object;
 }
 
 export interface LLMChatParamsStreaming<
@@ -133,6 +136,7 @@ export interface LLMChatParamsNonStreaming<
 
 export interface LLMCompletionParamsBase {
   prompt: MessageContent;
+  responseFormat?: z.ZodType | object;
 }
 
 export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase {
diff --git a/packages/core/src/utils/mock.ts b/packages/core/src/utils/mock.ts
index 2a29e775a..bd9a14c7f 100644
--- a/packages/core/src/utils/mock.ts
+++ b/packages/core/src/utils/mock.ts
@@ -35,6 +35,7 @@ export class MockLLM extends ToolCallLLM {
       topP: 0.5,
       contextWindow: 1024,
       tokenizer: undefined,
+      structuredOutput: false,
     };
   }
 
diff --git a/packages/providers/anthropic/src/llm.ts b/packages/providers/anthropic/src/llm.ts
index 4637d289a..09e032a82 100644
--- a/packages/providers/anthropic/src/llm.ts
+++ b/packages/providers/anthropic/src/llm.ts
@@ -191,6 +191,7 @@ export class Anthropic extends ToolCallLLM<
             ].contextWindow
           : 200000,
       tokenizer: undefined,
+      structuredOutput: false,
     };
   }
 
diff --git a/packages/providers/google/src/base.ts b/packages/providers/google/src/base.ts
index 877bf37a4..ff67fb025 100644
--- a/packages/providers/google/src/base.ts
+++ b/packages/providers/google/src/base.ts
@@ -241,6 +241,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
       maxTokens: this.maxTokens,
       contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow,
       tokenizer: undefined,
+      structuredOutput: false,
     };
   }
 
diff --git a/packages/providers/huggingface/src/llm.ts b/packages/providers/huggingface/src/llm.ts
index 83befb8e4..dbd089432 100644
--- a/packages/providers/huggingface/src/llm.ts
+++ b/packages/providers/huggingface/src/llm.ts
@@ -57,6 +57,7 @@ export class HuggingFaceLLM extends BaseLLM {
       maxTokens: this.maxTokens,
       contextWindow: this.contextWindow,
       tokenizer: undefined,
+      structuredOutput: false,
     };
   }
 
diff --git a/packages/providers/huggingface/src/shared.ts b/packages/providers/huggingface/src/shared.ts
index 7225a9f77..e383ff4d3 100644
--- a/packages/providers/huggingface/src/shared.ts
+++ b/packages/providers/huggingface/src/shared.ts
@@ -123,6 +123,7 @@ export class HuggingFaceInferenceAPI extends BaseLLM {
       maxTokens: this.maxTokens,
       contextWindow: this.contextWindow,
       tokenizer: undefined,
+      structuredOutput: false,
     };
   }
 
diff --git a/packages/providers/mistral/src/llm.ts b/packages/providers/mistral/src/llm.ts
index f7d0ba77d..d5f594b04 100644
--- a/packages/providers/mistral/src/llm.ts
+++ b/packages/providers/mistral/src/llm.ts
@@ -107,6 +107,7 @@ export class MistralAI extends ToolCallLLM<ToolCallLLMMessageOptions> {
       maxTokens: this.maxTokens,
       contextWindow: ALL_AVAILABLE_MISTRAL_MODELS[this.model].contextWindow,
       tokenizer: undefined,
+      structuredOutput: false,
     };
   }
 
diff --git a/packages/providers/ollama/package.json b/packages/providers/ollama/package.json
index 2131665d1..e64442581 100644
--- a/packages/providers/ollama/package.json
+++ b/packages/providers/ollama/package.json
@@ -37,5 +37,17 @@
     "@llamaindex/env": "workspace:*",
     "ollama": "^0.5.10",
     "remeda": "^2.17.3"
+  },
+  "peerDependencies": {
+    "zod": "^3.24.2",
+    "zod-to-json-schema": "^3.23.3"
+  },
+  "peerDependenciesMeta": {
+    "zod": {
+      "optional": true
+    },
+    "zod-to-json-schema": {
+      "optional": true
+    }
   }
 }
diff --git a/packages/providers/ollama/src/llm.ts b/packages/providers/ollama/src/llm.ts
index 15bbf31cf..36cbcd995 100644
--- a/packages/providers/ollama/src/llm.ts
+++ b/packages/providers/ollama/src/llm.ts
@@ -57,6 +57,22 @@ export type OllamaParams = {
   options?: Partial<Options>;
 };
 
+async function getZod() {
+  try {
+    return await import("zod");
+  } catch (e) {
+    throw new Error("zod is required for structured output");
+  }
+}
+
+async function getZodToJsonSchema() {
+  try {
+    return await import("zod-to-json-schema");
+  } catch (e) {
+    throw new Error("zod-to-json-schema is required for structured output");
+  }
+}
+
 export class Ollama extends ToolCallLLM {
   supportToolCall: boolean = true;
   public readonly ollama: OllamaBase;
@@ -92,6 +108,7 @@ export class Ollama extends ToolCallLLM {
       maxTokens: this.options.num_ctx,
       contextWindow: num_ctx,
       tokenizer: undefined,
+      structuredOutput: true,
     };
   }
 
@@ -109,7 +126,7 @@ export class Ollama extends ToolCallLLM {
   ): Promise<
     ChatResponse<ToolCallLLMMessageOptions> | AsyncIterable<ChatResponseChunk>
   > {
-    const { messages, stream, tools } = params;
+    const { messages, stream, tools, responseFormat } = params;
     const payload: ChatRequest = {
       model: this.model,
       messages: messages.map((message) => {
@@ -130,9 +147,20 @@ export class Ollama extends ToolCallLLM {
         ...this.options,
       },
     };
+
     if (tools) {
       payload.tools = tools.map((tool) => Ollama.toTool(tool));
     }
+
+    if (responseFormat && this.metadata.structuredOutput) {
+      const [{ zodToJsonSchema }, { z }] = await Promise.all([
+        getZodToJsonSchema(),
+        getZod(),
+      ]);
+      if (responseFormat instanceof z.ZodType)
+        payload.format = zodToJsonSchema(responseFormat);
+    }
+
     if (!stream) {
       const chatResponse = await this.ollama.chat({
         ...payload,
diff --git a/packages/providers/openai/package.json b/packages/providers/openai/package.json
index b824ffba2..a60e61744 100644
--- a/packages/providers/openai/package.json
+++ b/packages/providers/openai/package.json
@@ -35,6 +35,7 @@
   "dependencies": {
     "@llamaindex/core": "workspace:*",
     "@llamaindex/env": "workspace:*",
-    "openai": "^4.86.0"
+    "openai": "^4.86.0",
+    "zod": "^3.24.2"
   }
 }
diff --git a/packages/providers/openai/src/llm.ts b/packages/providers/openai/src/llm.ts
index d24ad9a34..675737f2b 100644
--- a/packages/providers/openai/src/llm.ts
+++ b/packages/providers/openai/src/llm.ts
@@ -22,6 +22,7 @@ import type {
   ClientOptions as OpenAIClientOptions,
   OpenAI as OpenAILLM,
 } from "openai";
+import { zodResponseFormat } from "openai/helpers/zod";
 import type { ChatModel } from "openai/resources/chat/chat";
 import type {
   ChatCompletionAssistantMessageParam,
@@ -32,7 +33,12 @@ import type {
   ChatCompletionToolMessageParam,
   ChatCompletionUserMessageParam,
 } from "openai/resources/chat/completions";
-import type { ChatCompletionMessageParam } from "openai/resources/index.js";
+import type {
+  ChatCompletionMessageParam,
+  ResponseFormatJSONObject,
+  ResponseFormatJSONSchema,
+} from "openai/resources/index.js";
+import { z } from "zod";
 import {
   AzureOpenAIWithUserAgent,
   getAzureConfigFromEnv,
@@ -292,6 +298,7 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
       maxTokens: this.maxTokens,
       contextWindow,
       tokenizer: Tokenizers.CL100K_BASE,
+      structuredOutput: true,
     };
   }
 
@@ -385,7 +392,8 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
     | ChatResponse<ToolCallLLMMessageOptions>
     | AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>>
   > {
-    const { messages, stream, tools, additionalChatOptions } = params;
+    const { messages, stream, tools, responseFormat, additionalChatOptions } =
+      params;
     const baseRequestParams = <OpenAILLM.Chat.ChatCompletionCreateParams>{
       model: this.model,
       temperature: this.temperature,
@@ -408,6 +416,20 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
     if (!isTemperatureSupported(baseRequestParams.model))
       delete baseRequestParams.temperature;
 
+    //add response format for the structured output
+    if (responseFormat && this.metadata.structuredOutput) {
+      if (responseFormat instanceof z.ZodType)
+        baseRequestParams.response_format = zodResponseFormat(
+          responseFormat,
+          "response_format",
+        );
+      else {
+        baseRequestParams.response_format = responseFormat as
+          | ResponseFormatJSONObject
+          | ResponseFormatJSONSchema;
+      }
+    }
+
     // Streaming
     if (stream) {
       return this.streamChat(baseRequestParams);
diff --git a/packages/providers/perplexity/src/llm.ts b/packages/providers/perplexity/src/llm.ts
index bf3fc7e01..b9b342f4c 100644
--- a/packages/providers/perplexity/src/llm.ts
+++ b/packages/providers/perplexity/src/llm.ts
@@ -64,6 +64,7 @@ export class Perplexity extends OpenAI {
       contextWindow:
         PERPLEXITY_MODELS[this.model as PerplexityModelName]?.contextWindow,
       tokenizer: Tokenizers.CL100K_BASE,
+      structuredOutput: false,
     };
   }
 }
diff --git a/packages/providers/replicate/src/llm.ts b/packages/providers/replicate/src/llm.ts
index f75741179..f76f16718 100644
--- a/packages/providers/replicate/src/llm.ts
+++ b/packages/providers/replicate/src/llm.ts
@@ -145,6 +145,7 @@ export class ReplicateLLM extends BaseLLM {
       maxTokens: this.maxTokens,
       contextWindow: ALL_AVAILABLE_REPLICATE_MODELS[this.model].contextWindow,
       tokenizer: undefined,
+      structuredOutput: false,
     };
   }
 
diff --git a/packages/providers/vercel/src/llm.ts b/packages/providers/vercel/src/llm.ts
index 453b2fc16..8134a16cf 100644
--- a/packages/providers/vercel/src/llm.ts
+++ b/packages/providers/vercel/src/llm.ts
@@ -41,6 +41,7 @@ export class VercelLLM extends ToolCallLLM<VercelAdditionalChatOptions> {
       topP: 1,
       contextWindow: 128000,
       tokenizer: undefined,
+      structuredOutput: false,
     };
   }
 
diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml
index 42c5e203d..c828d9063 100644
--- a/pnpm-lock.yaml
+++ b/pnpm-lock.yaml
@@ -1300,6 +1300,12 @@ importers:
       remeda:
         specifier: ^2.17.3
         version: 2.20.1
+      zod:
+        specifier: ^3.24.2
+        version: 3.24.2
+      zod-to-json-schema:
+        specifier: ^3.23.3
+        version: 3.24.1(zod@3.24.2)
     devDependencies:
       bunchee:
         specifier: 6.4.0
@@ -1316,6 +1322,9 @@ importers:
       openai:
         specifier: ^4.86.0
         version: 4.86.0(ws@8.18.0(bufferutil@4.0.9))(zod@3.24.2)
+      zod:
+        specifier: ^3.24.2
+        version: 3.24.2
     devDependencies:
       bunchee:
         specifier: 6.4.0
-- 
GitLab