diff --git a/.changeset/calm-eggs-type.md b/.changeset/calm-eggs-type.md new file mode 100644 index 0000000000000000000000000000000000000000..4ac61d8da39347e49a0b10c616024f4e03da2dcb --- /dev/null +++ b/.changeset/calm-eggs-type.md @@ -0,0 +1,7 @@ +--- +"@llamaindex/readers": patch +"@llamaindex/core": patch +"@llamaindex/doc": patch +--- + +Expose more content to fix the issue with unavailable documentation links, and adjust the documentation based on the latest code. diff --git a/.changeset/mighty-eagles-wink.md b/.changeset/mighty-eagles-wink.md new file mode 100644 index 0000000000000000000000000000000000000000..affdbe49c74cc6ed5214a61fe09ae7334262daa8 --- /dev/null +++ b/.changeset/mighty-eagles-wink.md @@ -0,0 +1,13 @@ +--- +"@llamaindex/huggingface": minor +"@llamaindex/anthropic": minor +"@llamaindex/mistral": minor +"@llamaindex/google": minor +"@llamaindex/ollama": minor +"@llamaindex/openai": minor +"@llamaindex/core": minor +"@llamaindex/examples": minor +--- + +Added support for structured output in the chat api of openai and ollama +Added structured output parameter in the provider diff --git a/.changeset/smart-shirts-hear.md b/.changeset/smart-shirts-hear.md new file mode 100644 index 0000000000000000000000000000000000000000..d4e29b9e276fcdfba68b70c473245a1dee2141ac --- /dev/null +++ b/.changeset/smart-shirts-hear.md @@ -0,0 +1,8 @@ +--- +"@llamaindex/mistral": minor +"@llamaindex/examples": minor +--- + +Added support for function calling in mistral provider +Update model list for mistral provider +Added example for the tool call in mistral diff --git a/apps/next/scripts/validate-links.mts b/apps/next/scripts/validate-links.mts index cafaee980fad5621bece6f9a6cdecc3cc452010e..86a1393fe93aaa2e432c661091d618022c5c087c 100644 --- a/apps/next/scripts/validate-links.mts +++ b/apps/next/scripts/validate-links.mts @@ -162,7 +162,12 @@ async function validateLinks(): Promise<LinkValidationResult[]> { const invalidLinks = links.filter(({ link }) => { // Check if the link exists in valid routes // First normalize the link (remove any query string or hash) - const normalizedLink = link.split("#")[0].split("?")[0]; + const baseLink = link.split("?")[0].split("#")[0]; + // Remove the trailing slash if present. + // This works with links like "api/interfaces/MetadataFilter#operator" and "api/interfaces/MetadataFilter/#operator". + const normalizedLink = baseLink.endsWith("/") + ? baseLink.slice(0, -1) + : baseLink; // Remove llamaindex/ prefix if it exists as it's the root of the docs let routePath = normalizedLink; @@ -192,8 +197,7 @@ async function main() { try { // Check for invalid internal links - const validationResults: LinkValidationResult[] = []; - await validateLinks(); + const validationResults: LinkValidationResult[] = await validateLinks(); // Check for relative links const relativeLinksResults = await findRelativeLinks(); diff --git a/apps/next/src/content/docs/llamaindex/modules/data_loaders/index.mdx b/apps/next/src/content/docs/llamaindex/modules/data_loaders/index.mdx index 295540506e9055666f79001e47a870625d613f43..188ddba65b3b4fa35410c5a9e822f23bbcff058a 100644 --- a/apps/next/src/content/docs/llamaindex/modules/data_loaders/index.mdx +++ b/apps/next/src/content/docs/llamaindex/modules/data_loaders/index.mdx @@ -35,7 +35,7 @@ Currently, the following readers are mapped to specific file types: - [TextFileReader](/docs/api/classes/TextFileReader): `.txt` - [PDFReader](/docs/api/classes/PDFReader): `.pdf` -- [PapaCSVReader](/docs/api/classes/PapaCSVReader): `.csv` +- [CSVReader](/docs/api/classes/CSVReader): `.csv` - [MarkdownReader](/docs/api/classes/MarkdownReader): `.md` - [DocxReader](/docs/api/classes/DocxReader): `.docx` - [HTMLReader](/docs/api/classes/HTMLReader): `.htm`, `.html` diff --git a/apps/next/src/content/docs/llamaindex/modules/data_stores/chat_stores/index.mdx b/apps/next/src/content/docs/llamaindex/modules/data_stores/chat_stores/index.mdx index cadcc2ad4170a5b8efca631e072af6dcacb0951e..4fe5e8f3ebad4476d62a14ae4f916346313d92bc 100644 --- a/apps/next/src/content/docs/llamaindex/modules/data_stores/chat_stores/index.mdx +++ b/apps/next/src/content/docs/llamaindex/modules/data_stores/chat_stores/index.mdx @@ -12,5 +12,5 @@ Check the [LlamaIndexTS Github](https://github.com/run-llama/LlamaIndexTS) for t ## API Reference -- [BaseChatStore](/docs/api/interfaces/BaseChatStore) +- [BaseChatStore](/docs/api/classes/BaseChatStore) diff --git a/apps/next/src/content/docs/llamaindex/modules/evaluation/correctness.mdx b/apps/next/src/content/docs/llamaindex/modules/evaluation/correctness.mdx index 50cb3c856ad3b37a4858c67fe57b4e2ded5ad6c5..c1189dfddb56f14cd7ff685b9851c3e50f278071 100644 --- a/apps/next/src/content/docs/llamaindex/modules/evaluation/correctness.mdx +++ b/apps/next/src/content/docs/llamaindex/modules/evaluation/correctness.mdx @@ -74,4 +74,4 @@ the response is not correct with a score of 2.5 ## API Reference -- [CorrectnessEvaluator](/docs/api/classes/CorrectnessEvaluator) +- [CorrectnessEvaluator](/docs/api/classes/CorrectnessEvaluator) \ No newline at end of file diff --git a/apps/next/src/content/docs/llamaindex/modules/prompt/index.mdx b/apps/next/src/content/docs/llamaindex/modules/prompt/index.mdx index f26d0dd3cc7fcd390758efb83cdbf405ef92d2c6..d53ff387cf00e03262aa37d799a3fd01f7a94fec 100644 --- a/apps/next/src/content/docs/llamaindex/modules/prompt/index.mdx +++ b/apps/next/src/content/docs/llamaindex/modules/prompt/index.mdx @@ -28,14 +28,21 @@ Answer:`; ### 1. Customizing the default prompt on initialization -The first method is to create a new instance of `ResponseSynthesizer` (or the module you would like to update the prompt) and pass the custom prompt to the `responseBuilder` parameter. Then, pass the instance to the `asQueryEngine` method of the index. +The first method is to create a new instance of a Response Synthesizer (or the module you would like to update the prompt) by using the getResponseSynthesizer function. Instead of passing the custom prompt to the deprecated responseBuilder parameter, call getResponseSynthesizer with the mode as the first argument and supply the new prompt via the options parameter. ```ts -// Create an instance of response synthesizer +// Create an instance of Response Synthesizer + +// Deprecated usage: const responseSynthesizer = new ResponseSynthesizer({ responseBuilder: new CompactAndRefine(undefined, newTextQaPrompt), }); +// Current usage: +const responseSynthesizer = getResponseSynthesizer('compact', { + textQATemplate: newTextQaPrompt +}) + // Create index const index = await VectorStoreIndex.fromDocuments([document]); @@ -75,5 +82,5 @@ const response = await queryEngine.query({ ## API Reference -- [ResponseSynthesizer](/docs/api/classes/ResponseSynthesizer) +- [Response Synthesizer](/docs/llamaindex/modules/response_synthesizer) - [CompactAndRefine](/docs/api/classes/CompactAndRefine) diff --git a/apps/next/src/content/docs/llamaindex/modules/response_synthesizer.mdx b/apps/next/src/content/docs/llamaindex/modules/response_synthesizer.mdx index bda0d53bfdb70fb5f2886256a66c0f5551f93b2b..8e94f5bd7f253e1c0cd661198108bc242d45e509 100644 --- a/apps/next/src/content/docs/llamaindex/modules/response_synthesizer.mdx +++ b/apps/next/src/content/docs/llamaindex/modules/response_synthesizer.mdx @@ -1,5 +1,5 @@ --- -title: ResponseSynthesizer +title: Response Synthesizer --- The ResponseSynthesizer is responsible for sending the query, nodes, and prompt templates to the LLM to generate a response. There are a few key modes for generating a response: @@ -12,15 +12,17 @@ The ResponseSynthesizer is responsible for sending the query, nodes, and prompt multiple compact prompts. The same as `refine`, but should result in less LLM calls. - `TreeSummarize`: Given a set of text chunks and the query, recursively construct a tree and return the root node as the response. Good for summarization purposes. -- `SimpleResponseBuilder`: Given a set of text chunks and the query, apply the query to each text - chunk while accumulating the responses into an array. Returns a concatenated string of all - responses. Good for when you need to run the same query separately against each text - chunk. +- `MultiModal`: Combines textual inputs with additional modality-specific metadata to generate an integrated response. + It leverages a text QA template to build a prompt that incorporates various input types and produces either streaming or complete responses. + This approach is ideal for use cases where enriching the answer with multi-modal context (such as images, audio, or other data) + can enhance the output quality. ```typescript -import { NodeWithScore, TextNode, ResponseSynthesizer } from "llamaindex"; +import { NodeWithScore, TextNode, getResponseSynthesizer, responseModeSchema } from "llamaindex"; -const responseSynthesizer = new ResponseSynthesizer(); +// you can also use responseModeSchema.Enum.refine, responseModeSchema.Enum.tree_summarize, responseModeSchema.Enum.multi_modal +// or you can use the CompactAndRefine, Refine, TreeSummarize, or MultiModal classes directly +const responseSynthesizer = getResponseSynthesizer(responseModeSchema.Enum.compact); const nodesWithScore: NodeWithScore[] = [ { @@ -55,8 +57,9 @@ for await (const chunk of stream) { ## API Reference -- [ResponseSynthesizer](/docs/api/classes/ResponseSynthesizer) +- [getResponseSynthesizer](/docs/api/functions/getResponseSynthesizer) +- [responseModeSchema](/docs/api/variables/responseModeSchema) - [Refine](/docs/api/classes/Refine) - [CompactAndRefine](/docs/api/classes/CompactAndRefine) - [TreeSummarize](/docs/api/classes/TreeSummarize) -- [SimpleResponseBuilder](/docs/api/classes/SimpleResponseBuilder) +- [MultiModal](/docs/api/classes/MultiModal) diff --git a/apps/next/typedoc.json b/apps/next/typedoc.json index a545e44cd683978a62f47ecb66ba674877e1b30f..5b79b46f71ca4dff3265e8ebe3c7ef23f8965795 100644 --- a/apps/next/typedoc.json +++ b/apps/next/typedoc.json @@ -1,8 +1,13 @@ { "plugin": ["typedoc-plugin-markdown", "typedoc-plugin-merge-modules"], - "entryPoints": ["../../packages/**/src/index.ts"], + "entryPoints": [ + "../../packages/{,**/}index.ts", + "../../packages/readers/src/*.ts", + "../../packages/cloud/src/{reader,utils}.ts" + ], "exclude": [ "../../packages/autotool/**/src/index.ts", + "../../packages/cloud/src/client/index.ts", "**/node_modules/**", "**/dist/**", "**/test/**", diff --git a/e2e/fixtures/llm/openai.ts b/e2e/fixtures/llm/openai.ts index ba35f31670a84c59c386eb0a12c0c69288da34e7..14601b8739ad078916dd121714764af5181dff02 100644 --- a/e2e/fixtures/llm/openai.ts +++ b/e2e/fixtures/llm/openai.ts @@ -42,6 +42,7 @@ export class OpenAI implements LLM { contextWindow: 2048, tokenizer: undefined, isFunctionCallingModel: true, + structuredOutput: false, }; } diff --git a/examples/jsonExtract.ts b/examples/jsonExtract.ts index 4622177e7d19b9d7a06d1f8c4977c34b427a237f..6d4a5f476a11726b73b7374b09c52d3bfcee0f5f 100644 --- a/examples/jsonExtract.ts +++ b/examples/jsonExtract.ts @@ -1,4 +1,5 @@ import { OpenAI } from "@llamaindex/openai"; +import { z } from "zod"; // Example using OpenAI's chat API to extract JSON from a sales call transcript // using json_mode see https://platform.openai.com/docs/guides/text-generation/json-mode for more details @@ -6,22 +7,47 @@ import { OpenAI } from "@llamaindex/openai"; const transcript = "[Phone rings]\n\nJohn: Hello, this is John.\n\nSarah: Hi John, this is Sarah from XYZ Company. I'm calling to discuss our new product, the XYZ Widget, and see if it might be a good fit for your business.\n\nJohn: Hi Sarah, thanks for reaching out. I'm definitely interested in learning more about the XYZ Widget. Can you give me a quick overview of what it does?\n\nSarah: Of course! The XYZ Widget is a cutting-edge tool that helps businesses streamline their workflow and improve productivity. It's designed to automate repetitive tasks and provide real-time data analytics to help you make informed decisions.\n\nJohn: That sounds really interesting. I can see how that could benefit our team. Do you have any case studies or success stories from other companies who have used the XYZ Widget?\n\nSarah: Absolutely, we have several case studies that I can share with you. I'll send those over along with some additional information about the product. I'd also love to schedule a demo for you and your team to see the XYZ Widget in action.\n\nJohn: That would be great. I'll make sure to review the case studies and then we can set up a time for the demo. In the meantime, are there any specific action items or next steps we should take?\n\nSarah: Yes, I'll send over the information and then follow up with you to schedule the demo. In the meantime, feel free to reach out if you have any questions or need further information.\n\nJohn: Sounds good, I appreciate your help Sarah. I'm looking forward to learning more about the XYZ Widget and seeing how it can benefit our business.\n\nSarah: Thank you, John. I'll be in touch soon. Have a great day!\n\nJohn: You too, bye."; +const exampleSchema = z.object({ + summary: z.string(), + products: z.array(z.string()), + rep_name: z.string(), + prospect_name: z.string(), + action_items: z.array(z.string()), +}); + +const example = { + summary: + "High-level summary of the call transcript. Should not exceed 3 sentences.", + products: ["product 1", "product 2"], + rep_name: "Name of the sales rep", + prospect_name: "Name of the prospect", + action_items: ["action item 1", "action item 2"], +}; + async function main() { const llm = new OpenAI({ - model: "gpt-4-1106-preview", - additionalChatOptions: { response_format: { type: "json_object" } }, + model: "gpt-4o", }); - const example = { - summary: - "High-level summary of the call transcript. Should not exceed 3 sentences.", - products: ["product 1", "product 2"], - rep_name: "Name of the sales rep", - prospect_name: "Name of the prospect", - action_items: ["action item 1", "action item 2"], - }; - + //response format as zod schema const response = await llm.chat({ + messages: [ + { + role: "system", + content: `You are an expert assistant for summarizing and extracting insights from sales call transcripts.`, + }, + { + role: "user", + content: `Here is the transcript: \n------\n${transcript}\n------`, + }, + ], + responseFormat: exampleSchema, + }); + + console.log(response.message.content); + + //response format as json_object + const response2 = await llm.chat({ messages: [ { role: "system", @@ -34,9 +60,10 @@ async function main() { content: `Here is the transcript: \n------\n${transcript}\n------`, }, ], + responseFormat: { type: "json_object" }, }); - console.log(response.message.content); + console.log(response2.message.content); } main().catch(console.error); diff --git a/examples/mistral/agent.ts b/examples/mistral/agent.ts new file mode 100644 index 0000000000000000000000000000000000000000..a8e4ee649232225967fced9342a104b0db2aebc9 --- /dev/null +++ b/examples/mistral/agent.ts @@ -0,0 +1,31 @@ +import { mistral } from "@llamaindex/mistral"; +import { agent, tool } from "llamaindex"; +import { z } from "zod"; +import { WikipediaTool } from "../wiki"; + +const workflow = agent({ + tools: [ + tool({ + name: "weather", + description: "Get the weather", + parameters: z.object({ + location: z.string().describe("The location to get the weather for"), + }), + execute: ({ location }) => `The weather in ${location} is sunny`, + }), + new WikipediaTool(), + ], + llm: mistral({ + apiKey: process.env.MISTRAL_API_KEY, + model: "mistral-small-latest", + }), +}); + +async function main() { + const result = await workflow.run( + "What is the weather in New York? What's the history of New York from Wikipedia in 3 sentences?", + ); + console.log(result.data); +} + +void main(); diff --git a/examples/mistral.ts b/examples/mistral/mistral.ts similarity index 100% rename from examples/mistral.ts rename to examples/mistral/mistral.ts diff --git a/packages/community/src/llm/bedrock/index.ts b/packages/community/src/llm/bedrock/index.ts index 5d091241fa1432594d3286d2a4a07c00a9fb27ef..1e491247f797093a64cba8e4670de4a949a5a068 100644 --- a/packages/community/src/llm/bedrock/index.ts +++ b/packages/community/src/llm/bedrock/index.ts @@ -381,6 +381,7 @@ export class Bedrock extends ToolCallLLM<BedrockAdditionalChatOptions> { maxTokens: this.maxTokens, contextWindow: BEDROCK_FOUNDATION_LLMS[this.model] ?? 128000, tokenizer: undefined, + structuredOutput: false, }; } diff --git a/packages/core/src/llms/base.ts b/packages/core/src/llms/base.ts index 46306bfecec1b5e39115bc929b27e56fdf89e4e3..c1456c1cbc699b4af17076ff11e879a8ddf0de15 100644 --- a/packages/core/src/llms/base.ts +++ b/packages/core/src/llms/base.ts @@ -28,11 +28,12 @@ export abstract class BaseLLM< async complete( params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, ): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> { - const { prompt, stream } = params; + const { prompt, stream, responseFormat } = params; if (stream) { const stream = await this.chat({ messages: [{ content: prompt, role: "user" }], stream: true, + ...(responseFormat ? { responseFormat } : {}), }); return streamConverter(stream, (chunk) => { return { @@ -41,9 +42,12 @@ export abstract class BaseLLM< }; }); } + const chatResponse = await this.chat({ messages: [{ content: prompt, role: "user" }], + ...(responseFormat ? { responseFormat } : {}), }); + return { text: extractText(chatResponse.message.content), raw: chatResponse.raw, diff --git a/packages/core/src/llms/type.ts b/packages/core/src/llms/type.ts index 787c7ffa8fdcb1573a5d687607d63512b11cc79f..5f85dc86e6063c28d6833b8886da01a8c838ce85 100644 --- a/packages/core/src/llms/type.ts +++ b/packages/core/src/llms/type.ts @@ -1,5 +1,6 @@ import type { Tokenizers } from "@llamaindex/env/tokenizers"; import type { JSONSchemaType } from "ajv"; +import { z } from "zod"; import type { JSONObject, JSONValue } from "../global"; /** @@ -106,6 +107,7 @@ export type LLMMetadata = { maxTokens?: number | undefined; contextWindow: number; tokenizer: Tokenizers | undefined; + structuredOutput: boolean; }; export interface LLMChatParamsBase< @@ -115,6 +117,7 @@ export interface LLMChatParamsBase< messages: ChatMessage<AdditionalMessageOptions>[]; additionalChatOptions?: AdditionalChatOptions; tools?: BaseTool[]; + responseFormat?: z.ZodType | object; } export interface LLMChatParamsStreaming< @@ -133,6 +136,7 @@ export interface LLMChatParamsNonStreaming< export interface LLMCompletionParamsBase { prompt: MessageContent; + responseFormat?: z.ZodType | object; } export interface LLMCompletionParamsStreaming extends LLMCompletionParamsBase { diff --git a/packages/core/src/response-synthesizers/factory.ts b/packages/core/src/response-synthesizers/factory.ts index 0e6ae17537ed3c794c6ac368e59a2159b0f8fabd..7f9af5f4ea40ead5430b09d8af02adc6b8a8c1ff 100644 --- a/packages/core/src/response-synthesizers/factory.ts +++ b/packages/core/src/response-synthesizers/factory.ts @@ -23,7 +23,7 @@ import { } from "./base-synthesizer"; import { createMessageContent } from "./utils"; -const responseModeSchema = z.enum([ +export const responseModeSchema = z.enum([ "refine", "compact", "tree_summarize", @@ -35,7 +35,7 @@ export type ResponseMode = z.infer<typeof responseModeSchema>; /** * A response builder that uses the query to ask the LLM generate a better response using multiple text chunks. */ -class Refine extends BaseSynthesizer { +export class Refine extends BaseSynthesizer { textQATemplate: TextQAPrompt; refineTemplate: RefinePrompt; @@ -213,7 +213,7 @@ class Refine extends BaseSynthesizer { /** * CompactAndRefine is a slight variation of Refine that first compacts the text chunks into the smallest possible number of chunks. */ -class CompactAndRefine extends Refine { +export class CompactAndRefine extends Refine { async getResponse( query: MessageContent, nodes: NodeWithScore[], @@ -267,7 +267,7 @@ class CompactAndRefine extends Refine { /** * TreeSummarize repacks the text chunks into the smallest possible number of chunks and then summarizes them, then recursively does so until there's one chunk left. */ -class TreeSummarize extends BaseSynthesizer { +export class TreeSummarize extends BaseSynthesizer { summaryTemplate: TreeSummarizePrompt; constructor( @@ -370,7 +370,7 @@ class TreeSummarize extends BaseSynthesizer { } } -class MultiModal extends BaseSynthesizer { +export class MultiModal extends BaseSynthesizer { metadataMode: MetadataMode; textQATemplate: TextQAPrompt; diff --git a/packages/core/src/response-synthesizers/index.ts b/packages/core/src/response-synthesizers/index.ts index a782d514f18ba7b7fd0f6410a499345153c3bbd5..907958b237786f7f19c295aaf8898599b55a691c 100644 --- a/packages/core/src/response-synthesizers/index.ts +++ b/packages/core/src/response-synthesizers/index.ts @@ -2,7 +2,15 @@ export { BaseSynthesizer, type BaseSynthesizerOptions, } from "./base-synthesizer"; -export { getResponseSynthesizer, type ResponseMode } from "./factory"; +export { + CompactAndRefine, + MultiModal, + Refine, + TreeSummarize, + getResponseSynthesizer, + responseModeSchema, + type ResponseMode, +} from "./factory"; export type { SynthesizeEndEvent, SynthesizeQuery, diff --git a/packages/core/src/utils/mock.ts b/packages/core/src/utils/mock.ts index 2a29e775ac27fc5d6b776549c8d64621d3caabbf..bd9a14c7f973c7d7c9fcfaac1663845dad9aaf20 100644 --- a/packages/core/src/utils/mock.ts +++ b/packages/core/src/utils/mock.ts @@ -35,6 +35,7 @@ export class MockLLM extends ToolCallLLM { topP: 0.5, contextWindow: 1024, tokenizer: undefined, + structuredOutput: false, }; } diff --git a/packages/providers/anthropic/src/llm.ts b/packages/providers/anthropic/src/llm.ts index 4637d289a5c207e1ff81e63348a2bacb8a8d3966..09e032a82ebcca8e1110a41f30555cf57e59d117 100644 --- a/packages/providers/anthropic/src/llm.ts +++ b/packages/providers/anthropic/src/llm.ts @@ -191,6 +191,7 @@ export class Anthropic extends ToolCallLLM< ].contextWindow : 200000, tokenizer: undefined, + structuredOutput: false, }; } diff --git a/packages/providers/google/src/base.ts b/packages/providers/google/src/base.ts index 877bf37a4897b9005387f08bf47776359d01af95..ff67fb0251bb1b71d739d9a3851d65e19cb649d4 100644 --- a/packages/providers/google/src/base.ts +++ b/packages/providers/google/src/base.ts @@ -241,6 +241,7 @@ export class Gemini extends ToolCallLLM<GeminiAdditionalChatOptions> { maxTokens: this.maxTokens, contextWindow: GEMINI_MODEL_INFO_MAP[this.model].contextWindow, tokenizer: undefined, + structuredOutput: false, }; } diff --git a/packages/providers/huggingface/src/llm.ts b/packages/providers/huggingface/src/llm.ts index 83befb8e485555ca3a5e6da6ec46f344ed0d022a..dbd089432876b282efadd7d556bb6019e0e451ef 100644 --- a/packages/providers/huggingface/src/llm.ts +++ b/packages/providers/huggingface/src/llm.ts @@ -57,6 +57,7 @@ export class HuggingFaceLLM extends BaseLLM { maxTokens: this.maxTokens, contextWindow: this.contextWindow, tokenizer: undefined, + structuredOutput: false, }; } diff --git a/packages/providers/huggingface/src/shared.ts b/packages/providers/huggingface/src/shared.ts index 7225a9f77ca1326253e11836b2307f6946960546..e383ff4d3fb8b8b105604ec5cb0777dca93d64e9 100644 --- a/packages/providers/huggingface/src/shared.ts +++ b/packages/providers/huggingface/src/shared.ts @@ -123,6 +123,7 @@ export class HuggingFaceInferenceAPI extends BaseLLM { maxTokens: this.maxTokens, contextWindow: this.contextWindow, tokenizer: undefined, + structuredOutput: false, }; } diff --git a/packages/providers/mistral/package.json b/packages/providers/mistral/package.json index a005343eaa10e8d5093d9699598453cb2a46c062..9f53db7a9cc140148f5ffb1f622fb038494291b8 100644 --- a/packages/providers/mistral/package.json +++ b/packages/providers/mistral/package.json @@ -27,10 +27,12 @@ }, "scripts": { "build": "bunchee", - "dev": "bunchee --watch" + "dev": "bunchee --watch", + "test": "vitest run" }, "devDependencies": { - "bunchee": "6.4.0" + "bunchee": "6.4.0", + "vitest": "^2.1.5" }, "dependencies": { "@llamaindex/core": "workspace:*", diff --git a/packages/providers/mistral/src/llm.ts b/packages/providers/mistral/src/llm.ts index 8f3792443efcc244b6841a2d9f431595f40ced86..d5f594b04d2fac777c0c49f3886e60169ba9a4c6 100644 --- a/packages/providers/mistral/src/llm.ts +++ b/packages/providers/mistral/src/llm.ts @@ -1,21 +1,51 @@ +import { wrapEventCaller } from "@llamaindex/core/decorator"; import { - BaseLLM, + ToolCallLLM, + type BaseTool, type ChatMessage, type ChatResponse, type ChatResponseChunk, type LLMChatParamsNonStreaming, type LLMChatParamsStreaming, + type PartialToolCall, + type ToolCallLLMMessageOptions, } from "@llamaindex/core/llms"; +import { extractText } from "@llamaindex/core/utils"; import { getEnv } from "@llamaindex/env"; import { type Mistral } from "@mistralai/mistralai"; -import type { ContentChunk } from "@mistralai/mistralai/models/components"; +import type { + AssistantMessage, + ChatCompletionRequest, + ChatCompletionStreamRequest, + ContentChunk, + Tool, + ToolMessage, +} from "@mistralai/mistralai/models/components"; export const ALL_AVAILABLE_MISTRAL_MODELS = { "mistral-tiny": { contextWindow: 32000 }, "mistral-small": { contextWindow: 32000 }, "mistral-medium": { contextWindow: 32000 }, + "mistral-small-latest": { contextWindow: 32000 }, + "mistral-large-latest": { contextWindow: 131000 }, + "codestral-latest": { contextWindow: 256000 }, + "pixtral-large-latest": { contextWindow: 131000 }, + "mistral-saba-latest": { contextWindow: 32000 }, + "ministral-3b-latest": { contextWindow: 131000 }, + "ministral-8b-latest": { contextWindow: 131000 }, + "mistral-embed": { contextWindow: 8000 }, + "mistral-moderation-latest": { contextWindow: 8000 }, }; +export const TOOL_CALL_MISTRAL_MODELS = [ + "mistral-small-latest", + "mistral-large-latest", + "codestral-latest", + "pixtral-large-latest", + "ministral-8b-latest", + "ministral-3b-latest", +]; + export class MistralAISession { apiKey: string; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -46,7 +76,7 @@ export class MistralAISession { /** * MistralAI LLM implementation */ -export class MistralAI extends BaseLLM { +export class MistralAI extends ToolCallLLM<ToolCallLLMMessageOptions> { // Per completion MistralAI params model: keyof typeof ALL_AVAILABLE_MISTRAL_MODELS; temperature: number; @@ -60,7 +90,7 @@ export class MistralAI extends BaseLLM { constructor(init?: Partial<MistralAI>) { super(); - this.model = init?.model ?? "mistral-small"; + this.model = init?.model ?? "mistral-small-latest"; this.temperature = init?.temperature ?? 0.1; this.topP = init?.topP ?? 1; this.maxTokens = init?.maxTokens ?? undefined; @@ -77,11 +107,55 @@ export class MistralAI extends BaseLLM { maxTokens: this.maxTokens, contextWindow: ALL_AVAILABLE_MISTRAL_MODELS[this.model].contextWindow, tokenizer: undefined, + structuredOutput: false, }; } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - private buildParams(messages: ChatMessage[]): any { + get supportToolCall() { + return TOOL_CALL_MISTRAL_MODELS.includes(this.metadata.model); + } + + formatMessages(messages: ChatMessage<ToolCallLLMMessageOptions>[]) { + return messages.map((message) => { + const options = message.options ?? {}; + //tool call message + if ("toolCall" in options) { + return { + role: "assistant", + content: extractText(message.content), + toolCalls: options.toolCall.map((toolCall) => { + return { + id: toolCall.id, + type: "function", + function: { + name: toolCall.name, + arguments: toolCall.input, + }, + }; + }), + } satisfies AssistantMessage; + } + + //tool result message + if ("toolResult" in options) { + return { + role: "tool", + content: extractText(message.content), + toolCallId: options.toolResult.id, + } satisfies ToolMessage; + } + + return { + role: message.role, + content: extractText(message.content), + }; + }); + } + + private buildParams( + messages: ChatMessage<ToolCallLLMMessageOptions>[], + tools?: BaseTool[], + ) { return { model: this.model, temperature: this.temperature, @@ -89,25 +163,49 @@ export class MistralAI extends BaseLLM { topP: this.topP, safeMode: this.safeMode, randomSeed: this.randomSeed, - messages, + messages: this.formatMessages(messages), + tools: tools?.map(MistralAI.toTool), + }; + } + + static toTool(tool: BaseTool): Tool { + if (!tool.metadata.parameters) { + throw new Error("Tool parameters are required"); + } + + return { + type: "function", + function: { + name: tool.metadata.name, + description: tool.metadata.description, + parameters: tool.metadata.parameters, + }, }; } chat( params: LLMChatParamsStreaming, ): Promise<AsyncIterable<ChatResponseChunk>>; - chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + chat( + params: LLMChatParamsNonStreaming<ToolCallLLMMessageOptions>, + ): Promise<ChatResponse>; async chat( params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, - ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { - const { messages, stream } = params; + ): Promise< + | ChatResponse<ToolCallLLMMessageOptions> + | AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> + > { + const { messages, stream, tools } = params; // Streaming if (stream) { - return this.streamChat(params); + return this.streamChat(messages, tools); } // Non-streaming const client = await this.session.getClient(); - const response = await client.chat.complete(this.buildParams(messages)); + const buildParams = this.buildParams(messages, tools); + const response = await client.chat.complete( + buildParams as ChatCompletionRequest, + ); if (!response || !response.choices || !response.choices[0]) { throw new Error("Unexpected response format from Mistral API"); @@ -121,28 +219,100 @@ export class MistralAI extends BaseLLM { message: { role: "assistant", content: this.extractContentAsString(content), + options: response.choices[0]!.message?.toolCalls + ? { + toolCall: response.choices[0]!.message.toolCalls.map( + (toolCall) => ({ + id: toolCall.id, + name: toolCall.function.name, + input: this.extractArgumentsAsString( + toolCall.function.arguments, + ), + }), + ), + } + : {}, }, }; } - protected async *streamChat({ - messages, - }: LLMChatParamsStreaming): AsyncIterable<ChatResponseChunk> { + @wrapEventCaller + protected async *streamChat( + messages: ChatMessage[], + tools?: BaseTool[], + ): AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> { const client = await this.session.getClient(); - const chunkStream = await client.chat.stream(this.buildParams(messages)); + const buildParams = this.buildParams( + messages, + tools, + ) as ChatCompletionStreamRequest; + const chunkStream = await client.chat.stream(buildParams); + + let currentToolCall: PartialToolCall | null = null; + const toolCallMap = new Map<string, PartialToolCall>(); for await (const chunk of chunkStream) { - if (!chunk.data || !chunk.data.choices || !chunk.data.choices.length) - continue; + if (!chunk.data?.choices?.[0]?.delta) continue; const choice = chunk.data.choices[0]; - if (!choice) continue; + if (!(choice.delta.content || choice.delta.toolCalls)) continue; + + let shouldEmitToolCall: PartialToolCall | null = null; + + if (choice.delta.toolCalls?.[0]) { + const toolCall = choice.delta.toolCalls[0]; + + if (toolCall.id) { + if (currentToolCall && toolCall.id !== currentToolCall.id) { + shouldEmitToolCall = { + ...currentToolCall, + input: JSON.parse(currentToolCall.input), + }; + } + + currentToolCall = { + id: toolCall.id, + name: toolCall.function!.name!, + input: this.extractArgumentsAsString(toolCall.function!.arguments), + }; + + toolCallMap.set(toolCall.id, currentToolCall!); + } else if (currentToolCall && toolCall.function?.arguments) { + currentToolCall.input += this.extractArgumentsAsString( + toolCall.function.arguments, + ); + } + } + + const isDone: boolean = choice.finishReason !== null; + + if (isDone && currentToolCall) { + //emitting last tool call + shouldEmitToolCall = { + ...currentToolCall, + input: JSON.parse(currentToolCall.input), + }; + } yield { raw: chunk.data, delta: this.extractContentAsString(choice.delta.content), + options: shouldEmitToolCall + ? { toolCall: [shouldEmitToolCall] } + : currentToolCall + ? { toolCall: [currentToolCall] } + : {}, }; } + + toolCallMap.clear(); + } + + private extractArgumentsAsString( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + args: string | { [k: string]: any } | null | undefined, + ): string { + return typeof args === "string" ? args : JSON.stringify(args) || ""; } private extractContentAsString( diff --git a/packages/providers/mistral/tests/index.test.ts b/packages/providers/mistral/tests/index.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..43fdb9ea0b3e402155db5761623c43f215a89d45 --- /dev/null +++ b/packages/providers/mistral/tests/index.test.ts @@ -0,0 +1,116 @@ +import type { ChatMessage } from "@llamaindex/core/llms"; +import { setEnvs } from "@llamaindex/env"; +import { beforeAll, describe, expect, test } from "vitest"; +import { MistralAI } from "../src/index"; + +beforeAll(() => { + setEnvs({ + MISTRAL_API_KEY: "valid", + }); +}); + +describe("Message Formatting", () => { + describe("Basic Message Formatting", () => { + test("Mistral formats basic messages correctly", () => { + const mistral = new MistralAI(); + const inputMessages: ChatMessage[] = [ + { + content: "You are a helpful assistant.", + role: "assistant", + }, + { + content: "Hello?", + role: "user", + }, + ]; + const expectedOutput = [ + { + content: "You are a helpful assistant.", + role: "assistant", + }, + { + content: "Hello?", + role: "user", + }, + ]; + + expect(mistral.formatMessages(inputMessages)).toEqual(expectedOutput); + }); + + test("Mistral handles multi-turn conversation correctly", () => { + const mistral = new MistralAI(); + const inputMessages: ChatMessage[] = [ + { content: "Hi", role: "user" }, + { content: "Hello! How can I help?", role: "assistant" }, + { content: "What's the weather?", role: "user" }, + ]; + const expectedOutput = [ + { content: "Hi", role: "user" }, + { content: "Hello! How can I help?", role: "assistant" }, + { content: "What's the weather?", role: "user" }, + ]; + expect(mistral.formatMessages(inputMessages)).toEqual(expectedOutput); + }); + }); + + describe("Tool Message Formatting", () => { + const toolCallMessages: ChatMessage[] = [ + { + role: "user", + content: "What's the weather in London?", + }, + { + role: "assistant", + content: "Let me check the weather.", + options: { + toolCall: [ + { + id: "call_123", + name: "weather", + input: JSON.stringify({ location: "London" }), + }, + ], + }, + }, + { + role: "assistant", + content: "The weather in London is sunny, +20°C", + options: { + toolResult: { + id: "call_123", + }, + }, + }, + ]; + + test("Mistral formats tool calls correctly", () => { + const mistral = new MistralAI(); + const expectedOutput = [ + { + role: "user", + content: "What's the weather in London?", + }, + { + role: "assistant", + content: "Let me check the weather.", + toolCalls: [ + { + type: "function", + id: "call_123", + function: { + name: "weather", + arguments: '{"location":"London"}', + }, + }, + ], + }, + { + role: "tool", + content: "The weather in London is sunny, +20°C", + toolCallId: "call_123", + }, + ]; + expect(mistral.formatMessages(toolCallMessages)).toEqual(expectedOutput); + }); + }); +}); diff --git a/packages/providers/ollama/package.json b/packages/providers/ollama/package.json index 2131665d194c6ae24da4babf62ed70ba91418241..e6444258100ef24d8151bc0de46938c65ca2b287 100644 --- a/packages/providers/ollama/package.json +++ b/packages/providers/ollama/package.json @@ -37,5 +37,17 @@ "@llamaindex/env": "workspace:*", "ollama": "^0.5.10", "remeda": "^2.17.3" + }, + "peerDependencies": { + "zod": "^3.24.2", + "zod-to-json-schema": "^3.23.3" + }, + "peerDependenciesMeta": { + "zod": { + "optional": true + }, + "zod-to-json-schema": { + "optional": true + } } } diff --git a/packages/providers/ollama/src/llm.ts b/packages/providers/ollama/src/llm.ts index 15bbf31cf9aaf3135c82e61a7cd4f3ca1b692749..36cbcd995ec5e8b90c47f83ad7a668ef22a417bd 100644 --- a/packages/providers/ollama/src/llm.ts +++ b/packages/providers/ollama/src/llm.ts @@ -57,6 +57,22 @@ export type OllamaParams = { options?: Partial<Options>; }; +async function getZod() { + try { + return await import("zod"); + } catch (e) { + throw new Error("zod is required for structured output"); + } +} + +async function getZodToJsonSchema() { + try { + return await import("zod-to-json-schema"); + } catch (e) { + throw new Error("zod-to-json-schema is required for structured output"); + } +} + export class Ollama extends ToolCallLLM { supportToolCall: boolean = true; public readonly ollama: OllamaBase; @@ -92,6 +108,7 @@ export class Ollama extends ToolCallLLM { maxTokens: this.options.num_ctx, contextWindow: num_ctx, tokenizer: undefined, + structuredOutput: true, }; } @@ -109,7 +126,7 @@ export class Ollama extends ToolCallLLM { ): Promise< ChatResponse<ToolCallLLMMessageOptions> | AsyncIterable<ChatResponseChunk> > { - const { messages, stream, tools } = params; + const { messages, stream, tools, responseFormat } = params; const payload: ChatRequest = { model: this.model, messages: messages.map((message) => { @@ -130,9 +147,20 @@ export class Ollama extends ToolCallLLM { ...this.options, }, }; + if (tools) { payload.tools = tools.map((tool) => Ollama.toTool(tool)); } + + if (responseFormat && this.metadata.structuredOutput) { + const [{ zodToJsonSchema }, { z }] = await Promise.all([ + getZodToJsonSchema(), + getZod(), + ]); + if (responseFormat instanceof z.ZodType) + payload.format = zodToJsonSchema(responseFormat); + } + if (!stream) { const chatResponse = await this.ollama.chat({ ...payload, diff --git a/packages/providers/openai/package.json b/packages/providers/openai/package.json index b824ffba2c0bb2fc729fd114816a23b85aed07b6..a60e61744580adde0c261f457e1b4b73c60a84b5 100644 --- a/packages/providers/openai/package.json +++ b/packages/providers/openai/package.json @@ -35,6 +35,7 @@ "dependencies": { "@llamaindex/core": "workspace:*", "@llamaindex/env": "workspace:*", - "openai": "^4.86.0" + "openai": "^4.86.0", + "zod": "^3.24.2" } } diff --git a/packages/providers/openai/src/llm.ts b/packages/providers/openai/src/llm.ts index d24ad9a34a964795eddae57477e835309e2f6a7b..675737f2be6802ec2cf0f44e2997809820b9f545 100644 --- a/packages/providers/openai/src/llm.ts +++ b/packages/providers/openai/src/llm.ts @@ -22,6 +22,7 @@ import type { ClientOptions as OpenAIClientOptions, OpenAI as OpenAILLM, } from "openai"; +import { zodResponseFormat } from "openai/helpers/zod"; import type { ChatModel } from "openai/resources/chat/chat"; import type { ChatCompletionAssistantMessageParam, @@ -32,7 +33,12 @@ import type { ChatCompletionToolMessageParam, ChatCompletionUserMessageParam, } from "openai/resources/chat/completions"; -import type { ChatCompletionMessageParam } from "openai/resources/index.js"; +import type { + ChatCompletionMessageParam, + ResponseFormatJSONObject, + ResponseFormatJSONSchema, +} from "openai/resources/index.js"; +import { z } from "zod"; import { AzureOpenAIWithUserAgent, getAzureConfigFromEnv, @@ -292,6 +298,7 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { maxTokens: this.maxTokens, contextWindow, tokenizer: Tokenizers.CL100K_BASE, + structuredOutput: true, }; } @@ -385,7 +392,8 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { | ChatResponse<ToolCallLLMMessageOptions> | AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> > { - const { messages, stream, tools, additionalChatOptions } = params; + const { messages, stream, tools, responseFormat, additionalChatOptions } = + params; const baseRequestParams = <OpenAILLM.Chat.ChatCompletionCreateParams>{ model: this.model, temperature: this.temperature, @@ -408,6 +416,20 @@ export class OpenAI extends ToolCallLLM<OpenAIAdditionalChatOptions> { if (!isTemperatureSupported(baseRequestParams.model)) delete baseRequestParams.temperature; + //add response format for the structured output + if (responseFormat && this.metadata.structuredOutput) { + if (responseFormat instanceof z.ZodType) + baseRequestParams.response_format = zodResponseFormat( + responseFormat, + "response_format", + ); + else { + baseRequestParams.response_format = responseFormat as + | ResponseFormatJSONObject + | ResponseFormatJSONSchema; + } + } + // Streaming if (stream) { return this.streamChat(baseRequestParams); diff --git a/packages/providers/perplexity/src/llm.ts b/packages/providers/perplexity/src/llm.ts index bf3fc7e01b8e2ded5810cded54629e84b8719ea5..b9b342f4c8cd7d38e84374462a8ae5de1350fc30 100644 --- a/packages/providers/perplexity/src/llm.ts +++ b/packages/providers/perplexity/src/llm.ts @@ -64,6 +64,7 @@ export class Perplexity extends OpenAI { contextWindow: PERPLEXITY_MODELS[this.model as PerplexityModelName]?.contextWindow, tokenizer: Tokenizers.CL100K_BASE, + structuredOutput: false, }; } } diff --git a/packages/providers/replicate/src/llm.ts b/packages/providers/replicate/src/llm.ts index f757411799df129bbb8110ae63acff6a45bb93b6..f76f1671842703d4f2ecf0b8bdea5250e71c3c35 100644 --- a/packages/providers/replicate/src/llm.ts +++ b/packages/providers/replicate/src/llm.ts @@ -145,6 +145,7 @@ export class ReplicateLLM extends BaseLLM { maxTokens: this.maxTokens, contextWindow: ALL_AVAILABLE_REPLICATE_MODELS[this.model].contextWindow, tokenizer: undefined, + structuredOutput: false, }; } diff --git a/packages/providers/vercel/src/llm.ts b/packages/providers/vercel/src/llm.ts index 453b2fc1691c70de2f9a3f7b064af205029670fb..8134a16cf39f39a12e70b774670e784456ea6ba1 100644 --- a/packages/providers/vercel/src/llm.ts +++ b/packages/providers/vercel/src/llm.ts @@ -41,6 +41,7 @@ export class VercelLLM extends ToolCallLLM<VercelAdditionalChatOptions> { topP: 1, contextWindow: 128000, tokenizer: undefined, + structuredOutput: false, }; } diff --git a/packages/readers/package.json b/packages/readers/package.json index 9a7d6338bca7d67b47ce6e85caaecc53adddb922..80953af9c912346478d0ab392d68e88a7f41461d 100644 --- a/packages/readers/package.json +++ b/packages/readers/package.json @@ -230,7 +230,6 @@ "mammoth": "^1.7.2", "mongodb": "^6.7.0", "notion-md-crawler": "^1.0.0", - "papaparse": "^5.4.1", "unpdf": "^0.12.1" } } diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 66f99453de50e1e1570a9f0756853ca66e395933..5c0e6f580f68cc6da66bb0142bcbdfacf936156b 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1269,6 +1269,9 @@ importers: bunchee: specifier: 6.4.0 version: 6.4.0(typescript@5.7.3) + vitest: + specifier: ^2.1.5 + version: 2.1.5(@edge-runtime/vm@4.0.4)(@types/node@22.13.5)(happy-dom@15.11.7)(lightningcss@1.29.1)(msw@2.7.0(@types/node@22.13.5)(typescript@5.7.3))(terser@5.38.2) packages/providers/mixedbread: dependencies: @@ -1300,6 +1303,12 @@ importers: remeda: specifier: ^2.17.3 version: 2.20.1 + zod: + specifier: ^3.24.2 + version: 3.24.2 + zod-to-json-schema: + specifier: ^3.23.3 + version: 3.24.1(zod@3.24.2) devDependencies: bunchee: specifier: 6.4.0 @@ -1316,6 +1325,9 @@ importers: openai: specifier: ^4.86.0 version: 4.86.0(ws@8.18.0(bufferutil@4.0.9))(zod@3.24.2) + zod: + specifier: ^3.24.2 + version: 3.24.2 devDependencies: bunchee: specifier: 6.4.0 @@ -1686,9 +1698,6 @@ importers: notion-md-crawler: specifier: ^1.0.0 version: 1.0.1 - papaparse: - specifier: ^5.4.1 - version: 5.5.2 unpdf: specifier: ^0.12.1 version: 0.12.1 @@ -9838,9 +9847,6 @@ packages: pako@1.0.11: resolution: {integrity: sha512-4hLB8Py4zZce5s4yd9XzopqwVv/yGNhV1Bl8NTmCq1763HeK2+EwVTv+leGeL13Dnh2wfbqowVPXCIO0z4taYw==} - papaparse@5.5.2: - resolution: {integrity: sha512-PZXg8UuAc4PcVwLosEEDYjPyfWnTEhOrUfdv+3Bx+NuAb+5NhDmXzg5fHWmdCh1mP5p7JAZfFr3IMQfcntNAdA==} - parent-module@1.0.1: resolution: {integrity: sha512-GQ2EWRpQV8/o+Aw8YqtfZZPfNRWZYkbidE9k5rpl/hC3vtHHBfGm2Ifi6qWV+coDGkrUKZAxE3Lot5kcsRlh+g==} engines: {node: '>=6'} @@ -22412,8 +22418,6 @@ snapshots: pako@1.0.11: {} - papaparse@5.5.2: {} - parent-module@1.0.1: dependencies: callsites: 3.1.0