Skip to content
Snippets Groups Projects
Unverified Commit 91a18e70 authored by ANKIT VARSHNEY's avatar ANKIT VARSHNEY Committed by GitHub
Browse files

feat: add support for structured output with zod schema. (#1749)

parent d1c1f99e
No related branches found
No related tags found
No related merge requests found
Showing
with 148 additions and 17 deletions
---
"@llamaindex/huggingface": minor
"@llamaindex/anthropic": minor
"@llamaindex/mistral": minor
"@llamaindex/google": minor
"@llamaindex/ollama": minor
"@llamaindex/openai": minor
"@llamaindex/core": minor
"@llamaindex/examples": minor
---
Added support for structured output in the chat api of openai and ollama
Added structured output parameter in the provider
...@@ -42,6 +42,7 @@ export class OpenAI implements LLM { ...@@ -42,6 +42,7 @@ export class OpenAI implements LLM {
contextWindow: 2048, contextWindow: 2048,
tokenizer: undefined, tokenizer: undefined,
isFunctionCallingModel: true, isFunctionCallingModel: true,
structuredOutput: false,
}; };
} }
......
import { OpenAI } from "@llamaindex/openai"; import { OpenAI } from "@llamaindex/openai";
import { z } from "zod";
// Example using OpenAI's chat API to extract JSON from a sales call transcript // Example using OpenAI's chat API to extract JSON from a sales call transcript
// using json_mode see https://platform.openai.com/docs/guides/text-generation/json-mode for more details // using json_mode see https://platform.openai.com/docs/guides/text-generation/json-mode for more details
...@@ -6,22 +7,47 @@ import { OpenAI } from "@llamaindex/openai"; ...@@ -6,22 +7,47 @@ import { OpenAI } from "@llamaindex/openai";
const transcript = const transcript =
"[Phone rings]\n\nJohn: Hello, this is John.\n\nSarah: Hi John, this is Sarah from XYZ Company. I'm calling to discuss our new product, the XYZ Widget, and see if it might be a good fit for your business.\n\nJohn: Hi Sarah, thanks for reaching out. I'm definitely interested in learning more about the XYZ Widget. Can you give me a quick overview of what it does?\n\nSarah: Of course! The XYZ Widget is a cutting-edge tool that helps businesses streamline their workflow and improve productivity. It's designed to automate repetitive tasks and provide real-time data analytics to help you make informed decisions.\n\nJohn: That sounds really interesting. I can see how that could benefit our team. Do you have any case studies or success stories from other companies who have used the XYZ Widget?\n\nSarah: Absolutely, we have several case studies that I can share with you. I'll send those over along with some additional information about the product. I'd also love to schedule a demo for you and your team to see the XYZ Widget in action.\n\nJohn: That would be great. I'll make sure to review the case studies and then we can set up a time for the demo. In the meantime, are there any specific action items or next steps we should take?\n\nSarah: Yes, I'll send over the information and then follow up with you to schedule the demo. In the meantime, feel free to reach out if you have any questions or need further information.\n\nJohn: Sounds good, I appreciate your help Sarah. I'm looking forward to learning more about the XYZ Widget and seeing how it can benefit our business.\n\nSarah: Thank you, John. I'll be in touch soon. Have a great day!\n\nJohn: You too, bye."; "[Phone rings]\n\nJohn: Hello, this is John.\n\nSarah: Hi John, this is Sarah from XYZ Company. I'm calling to discuss our new product, the XYZ Widget, and see if it might be a good fit for your business.\n\nJohn: Hi Sarah, thanks for reaching out. I'm definitely interested in learning more about the XYZ Widget. Can you give me a quick overview of what it does?\n\nSarah: Of course! The XYZ Widget is a cutting-edge tool that helps businesses streamline their workflow and improve productivity. It's designed to automate repetitive tasks and provide real-time data analytics to help you make informed decisions.\n\nJohn: That sounds really interesting. I can see how that could benefit our team. Do you have any case studies or success stories from other companies who have used the XYZ Widget?\n\nSarah: Absolutely, we have several case studies that I can share with you. I'll send those over along with some additional information about the product. I'd also love to schedule a demo for you and your team to see the XYZ Widget in action.\n\nJohn: That would be great. I'll make sure to review the case studies and then we can set up a time for the demo. In the meantime, are there any specific action items or next steps we should take?\n\nSarah: Yes, I'll send over the information and then follow up with you to schedule the demo. In the meantime, feel free to reach out if you have any questions or need further information.\n\nJohn: Sounds good, I appreciate your help Sarah. I'm looking forward to learning more about the XYZ Widget and seeing how it can benefit our business.\n\nSarah: Thank you, John. I'll be in touch soon. Have a great day!\n\nJohn: You too, bye.";
const exampleSchema = z.object({
summary: z.string(),
products: z.array(z.string()),
rep_name: z.string(),
prospect_name: z.string(),
action_items: z.array(z.string()),
});
const example = {
summary:
"High-level summary of the call transcript. Should not exceed 3 sentences.",
products: ["product 1", "product 2"],
rep_name: "Name of the sales rep",
prospect_name: "Name of the prospect",
action_items: ["action item 1", "action item 2"],
};
async function main() { async function main() {
const llm = new OpenAI({ const llm = new OpenAI({
model: "gpt-4-1106-preview", model: "gpt-4o",
additionalChatOptions: { response_format: { type: "json_object" } },
}); });
const example = { //response format as zod schema
summary:
"High-level summary of the call transcript. Should not exceed 3 sentences.",
products: ["product 1", "product 2"],
rep_name: "Name of the sales rep",
prospect_name: "Name of the prospect",
action_items: ["action item 1", "action item 2"],
};
const response = await llm.chat({ const response = await llm.chat({
messages: [
{
role: "system",
content: `You are an expert assistant for summarizing and extracting insights from sales call transcripts.`,
},
{
role: "user",
content: `Here is the transcript: \n------\n${transcript}\n------`,
},
],
responseFormat: exampleSchema,
});
console.log(response.message.content);
//response format as json_object
const response2 = await llm.chat({
messages: [ messages: [
{ {
role: "system", role: "system",
...@@ -34,9 +60,10 @@ async function main() { ...@@ -34,9 +60,10 @@ async function main() {
content: `Here is the transcript: \n------\n${transcript}\n------`, content: `Here is the transcript: \n------\n${transcript}\n------`,
}, },
], ],
responseFormat: { type: "json_object" },
}); });
console.log(response.message.content); console.log(response2.message.content);
} }
main().catch(console.error); main().catch(console.error);
...@@ -381,6 +381,7 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { ...@@ -381,6 +381,7 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> {
maxTokens: this.maxTokens, maxTokens: this.maxTokens,
contextWindow: BEDROCK_FOUNDATION_LLMS[this.model] ?? 128000, contextWindow: BEDROCK_FOUNDATION_LLMS[this.model] ?? 128000,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false,
}; };
} }
......
...@@ -28,11 +28,12 @@ export abstract class BaseLLM< ...@@ -28,11 +28,12 @@ export abstract class BaseLLM<
async complete( async complete(
params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming,
): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> { ): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> {
const { prompt, stream } = params; const { prompt, stream, responseFormat } = params;
if (stream) { if (stream) {
const stream = await this.chat({ const stream = await this.chat({
messages: [{ content: prompt, role: "user" }], messages: [{ content: prompt, role: "user" }],
stream: true, stream: true,
...(responseFormat ? { responseFormat } : {}),
}); });
return streamConverter(stream, (chunk) => { return streamConverter(stream, (chunk) => {
return { return {
...@@ -41,9 +42,12 @@ export abstract class BaseLLM< ...@@ -41,9 +42,12 @@ export abstract class BaseLLM<
}; };
}); });
} }
const chatResponse = await this.chat({ const chatResponse = await this.chat({
messages: [{ content: prompt, role: "user" }], messages: [{ content: prompt, role: "user" }],
...(responseFormat ? { responseFormat } : {}),
}); });
return { return {
text: extractText(chatResponse.message.content), text: extractText(chatResponse.message.content),
raw: chatResponse.raw, raw: chatResponse.raw,
......
import type { Tokenizers } from "@llamaindex/env/tokenizers"; import type { Tokenizers } from "@llamaindex/env/tokenizers";
import type { JSONSchemaType } from "ajv"; import type { JSONSchemaType } from "ajv";
import { z } from "zod";
import type { JSONObject, JSONValue } from "../global"; import type { JSONObject, JSONValue } from "../global";
/** /**
...@@ -106,6 +107,7 @@ export type LLMMetadata = { ...@@ -106,6 +107,7 @@ export type LLMMetadata = {
maxTokens?: number | undefined; maxTokens?: number | undefined;
contextWindow: number; contextWindow: number;
tokenizer: Tokenizers | undefined; tokenizer: Tokenizers | undefined;
structuredOutput: boolean;
}; };
export interface LLMChatParamsBase< export interface LLMChatParamsBase<
...@@ -115,6 +117,7 @@ export interface LLMChatParamsBase< ...@@ -115,6 +117,7 @@ export interface LLMChatParamsBase<
messages: ChatMessage<AdditionalMessageOptions>[]; messages: ChatMessage<AdditionalMessageOptions>[];
additionalChatOptions?: AdditionalChatOptions; additionalChatOptions?: AdditionalChatOptions;
tools?: BaseTool[]; tools?: BaseTool[];
responseFormat?: z.ZodType | object;
} }
export interface LLMChatParamsStreaming< export interface LLMChatParamsStreaming<
...@@ -133,6 +136,7 @@ export interface LLMChatParamsNonStreaming< ...@@ -133,6 +136,7 @@ export interface LLMChatParamsNonStreaming<
export interface LLMCompletionParamsBase { export interface LLMCompletionParamsBase {
prompt: MessageContent; prompt: MessageContent;
responseFormat?: z.ZodType | object;
} }
export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase { export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase {
......
...@@ -35,6 +35,7 @@ export class MockLLM extends ToolCallLLM { ...@@ -35,6 +35,7 @@ export class MockLLM extends ToolCallLLM {
topP: 0.5, topP: 0.5,
contextWindow: 1024, contextWindow: 1024,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false,
}; };
} }
......
...@@ -191,6 +191,7 @@ export class Anthropic extends ToolCallLLM< ...@@ -191,6 +191,7 @@ export class Anthropic extends ToolCallLLM<
].contextWindow ].contextWindow
: 200000, : 200000,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false,
}; };
} }
......
...@@ -241,6 +241,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { ...@@ -241,6 +241,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> {
maxTokens: this.maxTokens, maxTokens: this.maxTokens,
contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow, contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false,
}; };
} }
......
...@@ -57,6 +57,7 @@ export class HuggingFaceLLM extends BaseLLM { ...@@ -57,6 +57,7 @@ export class HuggingFaceLLM extends BaseLLM {
maxTokens: this.maxTokens, maxTokens: this.maxTokens,
contextWindow: this.contextWindow, contextWindow: this.contextWindow,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false,
}; };
} }
......
...@@ -123,6 +123,7 @@ export class HuggingFaceInferenceAPI extends BaseLLM { ...@@ -123,6 +123,7 @@ export class HuggingFaceInferenceAPI extends BaseLLM {
maxTokens: this.maxTokens, maxTokens: this.maxTokens,
contextWindow: this.contextWindow, contextWindow: this.contextWindow,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false,
}; };
} }
......
...@@ -107,6 +107,7 @@ export class MistralAI extends ToolCallLLM<ToolCallLLMMessageOptions> { ...@@ -107,6 +107,7 @@ export class MistralAI extends ToolCallLLM<ToolCallLLMMessageOptions> {
maxTokens: this.maxTokens, maxTokens: this.maxTokens,
contextWindow: ALL_AVAILABLE_MISTRAL_MODELS[this.model].contextWindow, contextWindow: ALL_AVAILABLE_MISTRAL_MODELS[this.model].contextWindow,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false,
}; };
} }
......
...@@ -37,5 +37,17 @@ ...@@ -37,5 +37,17 @@
"@llamaindex/env": "workspace:*", "@llamaindex/env": "workspace:*",
"ollama": "^0.5.10", "ollama": "^0.5.10",
"remeda": "^2.17.3" "remeda": "^2.17.3"
},
"peerDependencies": {
"zod": "^3.24.2",
"zod-to-json-schema": "^3.23.3"
},
"peerDependenciesMeta": {
"zod": {
"optional": true
},
"zod-to-json-schema": {
"optional": true
}
} }
} }
...@@ -57,6 +57,22 @@ export type OllamaParams = { ...@@ -57,6 +57,22 @@ export type OllamaParams = {
options?: Partial<Options>; options?: Partial<Options>;
}; };
async function getZod() {
try {
return await import("zod");
} catch (e) {
throw new Error("zod is required for structured output");
}
}
async function getZodToJsonSchema() {
try {
return await import("zod-to-json-schema");
} catch (e) {
throw new Error("zod-to-json-schema is required for structured output");
}
}
export class Ollama extends ToolCallLLM { export class Ollama extends ToolCallLLM {
supportToolCall: boolean = true; supportToolCall: boolean = true;
public readonly ollama: OllamaBase; public readonly ollama: OllamaBase;
...@@ -92,6 +108,7 @@ export class Ollama extends ToolCallLLM { ...@@ -92,6 +108,7 @@ export class Ollama extends ToolCallLLM {
maxTokens: this.options.num_ctx, maxTokens: this.options.num_ctx,
contextWindow: num_ctx, contextWindow: num_ctx,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: true,
}; };
} }
...@@ -109,7 +126,7 @@ export class Ollama extends ToolCallLLM { ...@@ -109,7 +126,7 @@ export class Ollama extends ToolCallLLM {
): Promise< ): Promise<
ChatResponse<ToolCallLLMMessageOptions> | AsyncIterable<ChatResponseChunk> ChatResponse<ToolCallLLMMessageOptions> | AsyncIterable<ChatResponseChunk>
> { > {
const { messages, stream, tools } = params; const { messages, stream, tools, responseFormat } = params;
const payload: ChatRequest = { const payload: ChatRequest = {
model: this.model, model: this.model,
messages: messages.map((message) => { messages: messages.map((message) => {
...@@ -130,9 +147,20 @@ export class Ollama extends ToolCallLLM { ...@@ -130,9 +147,20 @@ export class Ollama extends ToolCallLLM {
...this.options, ...this.options,
}, },
}; };
if (tools) { if (tools) {
payload.tools = tools.map((tool) => Ollama.toTool(tool)); payload.tools = tools.map((tool) => Ollama.toTool(tool));
} }
if (responseFormat && this.metadata.structuredOutput) {
const [{ zodToJsonSchema }, { z }] = await Promise.all([
getZodToJsonSchema(),
getZod(),
]);
if (responseFormat instanceof z.ZodType)
payload.format = zodToJsonSchema(responseFormat);
}
if (!stream) { if (!stream) {
const chatResponse = await this.ollama.chat({ const chatResponse = await this.ollama.chat({
...payload, ...payload,
......
...@@ -35,6 +35,7 @@ ...@@ -35,6 +35,7 @@
"dependencies": { "dependencies": {
"@llamaindex/core": "workspace:*", "@llamaindex/core": "workspace:*",
"@llamaindex/env": "workspace:*", "@llamaindex/env": "workspace:*",
"openai": "^4.86.0" "openai": "^4.86.0",
"zod": "^3.24.2"
} }
} }
...@@ -22,6 +22,7 @@ import type { ...@@ -22,6 +22,7 @@ import type {
ClientOptions as OpenAIClientOptions, ClientOptions as OpenAIClientOptions,
OpenAI as OpenAILLM, OpenAI as OpenAILLM,
} from "openai"; } from "openai";
import { zodResponseFormat } from "openai/helpers/zod";
import type { ChatModel } from "openai/resources/chat/chat"; import type { ChatModel } from "openai/resources/chat/chat";
import type { import type {
ChatCompletionAssistantMessageParam, ChatCompletionAssistantMessageParam,
...@@ -32,7 +33,12 @@ import type { ...@@ -32,7 +33,12 @@ import type {
ChatCompletionToolMessageParam, ChatCompletionToolMessageParam,
ChatCompletionUserMessageParam, ChatCompletionUserMessageParam,
} from "openai/resources/chat/completions"; } from "openai/resources/chat/completions";
import type { ChatCompletionMessageParam } from "openai/resources/index.js"; import type {
ChatCompletionMessageParam,
ResponseFormatJSONObject,
ResponseFormatJSONSchema,
} from "openai/resources/index.js";
import { z } from "zod";
import { import {
AzureOpenAIWithUserAgent, AzureOpenAIWithUserAgent,
getAzureConfigFromEnv, getAzureConfigFromEnv,
...@@ -292,6 +298,7 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { ...@@ -292,6 +298,7 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
maxTokens: this.maxTokens, maxTokens: this.maxTokens,
contextWindow, contextWindow,
tokenizer: Tokenizers.CL100K_BASE, tokenizer: Tokenizers.CL100K_BASE,
structuredOutput: true,
}; };
} }
...@@ -385,7 +392,8 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { ...@@ -385,7 +392,8 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
| ChatResponse<ToolCallLLMMessageOptions> | ChatResponse<ToolCallLLMMessageOptions>
| AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> | AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>>
> { > {
const { messages, stream, tools, additionalChatOptions } = params; const { messages, stream, tools, responseFormat, additionalChatOptions } =
params;
const baseRequestParams = <OpenAILLM.Chat.ChatCompletionCreateParams>{ const baseRequestParams = <OpenAILLM.Chat.ChatCompletionCreateParams>{
model: this.model, model: this.model,
temperature: this.temperature, temperature: this.temperature,
...@@ -408,6 +416,20 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { ...@@ -408,6 +416,20 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> {
if (!isTemperatureSupported(baseRequestParams.model)) if (!isTemperatureSupported(baseRequestParams.model))
delete baseRequestParams.temperature; delete baseRequestParams.temperature;
//add response format for the structured output
if (responseFormat && this.metadata.structuredOutput) {
if (responseFormat instanceof z.ZodType)
baseRequestParams.response_format = zodResponseFormat(
responseFormat,
"response_format",
);
else {
baseRequestParams.response_format = responseFormat as
| ResponseFormatJSONObject
| ResponseFormatJSONSchema;
}
}
// Streaming // Streaming
if (stream) { if (stream) {
return this.streamChat(baseRequestParams); return this.streamChat(baseRequestParams);
......
...@@ -64,6 +64,7 @@ export class Perplexity extends OpenAI { ...@@ -64,6 +64,7 @@ export class Perplexity extends OpenAI {
contextWindow: contextWindow:
PERPLEXITY_MODELS[this.model as PerplexityModelName]?.contextWindow, PERPLEXITY_MODELS[this.model as PerplexityModelName]?.contextWindow,
tokenizer: Tokenizers.CL100K_BASE, tokenizer: Tokenizers.CL100K_BASE,
structuredOutput: false,
}; };
} }
} }
......
...@@ -145,6 +145,7 @@ export class ReplicateLLM extends BaseLLM { ...@@ -145,6 +145,7 @@ export class ReplicateLLM extends BaseLLM {
maxTokens: this.maxTokens, maxTokens: this.maxTokens,
contextWindow: ALL_AVAILABLE_REPLICATE_MODELS[this.model].contextWindow, contextWindow: ALL_AVAILABLE_REPLICATE_MODELS[this.model].contextWindow,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false,
}; };
} }
......
...@@ -41,6 +41,7 @@ export class VercelLLM extends ToolCallLLM<VercelAdditionalChatOptions> { ...@@ -41,6 +41,7 @@ export class VercelLLM extends ToolCallLLM<VercelAdditionalChatOptions> {
topP: 1, topP: 1,
contextWindow: 128000, contextWindow: 128000,
tokenizer: undefined, tokenizer: undefined,
structuredOutput: false,
}; };
} }
......
...@@ -1300,6 +1300,12 @@ importers: ...@@ -1300,6 +1300,12 @@ importers:
remeda: remeda:
specifier: ^2.17.3 specifier: ^2.17.3
version: 2.20.1 version: 2.20.1
zod:
specifier: ^3.24.2
version: 3.24.2
zod-to-json-schema:
specifier: ^3.23.3
version: 3.24.1(zod@3.24.2)
devDependencies: devDependencies:
bunchee: bunchee:
specifier: 6.4.0 specifier: 6.4.0
...@@ -1316,6 +1322,9 @@ importers: ...@@ -1316,6 +1322,9 @@ importers:
openai: openai:
specifier: ^4.86.0 specifier: ^4.86.0
version: 4.86.0(ws@8.18.0(bufferutil@4.0.9))(zod@3.24.2) version: 4.86.0(ws@8.18.0(bufferutil@4.0.9))(zod@3.24.2)
zod:
specifier: ^3.24.2
version: 3.24.2
devDependencies: devDependencies:
bunchee: bunchee:
specifier: 6.4.0 specifier: 6.4.0
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment