diff --git a/.changeset/smart-shirts-hear.md b/.changeset/smart-shirts-hear.md new file mode 100644 index 0000000000000000000000000000000000000000..d4e29b9e276fcdfba68b70c473245a1dee2141ac --- /dev/null +++ b/.changeset/smart-shirts-hear.md @@ -0,0 +1,8 @@ +--- +"@llamaindex/mistral": minor +"@llamaindex/examples": minor +--- + +Added support for function calling in mistral provider +Update model list for mistral provider +Added example for the tool call in mistral diff --git a/examples/mistral/agent.ts b/examples/mistral/agent.ts new file mode 100644 index 0000000000000000000000000000000000000000..a8e4ee649232225967fced9342a104b0db2aebc9 --- /dev/null +++ b/examples/mistral/agent.ts @@ -0,0 +1,31 @@ +import { mistral } from "@llamaindex/mistral"; +import { agent, tool } from "llamaindex"; +import { z } from "zod"; +import { WikipediaTool } from "../wiki"; + +const workflow = agent({ + tools: [ + tool({ + name: "weather", + description: "Get the weather", + parameters: z.object({ + location: z.string().describe("The location to get the weather for"), + }), + execute: ({ location }) => `The weather in ${location} is sunny`, + }), + new WikipediaTool(), + ], + llm: mistral({ + apiKey: process.env.MISTRAL_API_KEY, + model: "mistral-small-latest", + }), +}); + +async function main() { + const result = await workflow.run( + "What is the weather in New York? What's the history of New York from Wikipedia in 3 sentences?", + ); + console.log(result.data); +} + +void main(); diff --git a/examples/mistral.ts b/examples/mistral/mistral.ts similarity index 100% rename from examples/mistral.ts rename to examples/mistral/mistral.ts diff --git a/packages/providers/mistral/package.json b/packages/providers/mistral/package.json index a005343eaa10e8d5093d9699598453cb2a46c062..9f53db7a9cc140148f5ffb1f622fb038494291b8 100644 --- a/packages/providers/mistral/package.json +++ b/packages/providers/mistral/package.json @@ -27,10 +27,12 @@ }, "scripts": { "build": "bunchee", - "dev": "bunchee --watch" + "dev": "bunchee --watch", + "test": "vitest run" }, "devDependencies": { - "bunchee": "6.4.0" + "bunchee": "6.4.0", + "vitest": "^2.1.5" }, "dependencies": { "@llamaindex/core": "workspace:*", diff --git a/packages/providers/mistral/src/llm.ts b/packages/providers/mistral/src/llm.ts index 8f3792443efcc244b6841a2d9f431595f40ced86..f7d0ba77d32713b942e4b3d40915dbbfafc778e5 100644 --- a/packages/providers/mistral/src/llm.ts +++ b/packages/providers/mistral/src/llm.ts @@ -1,21 +1,51 @@ +import { wrapEventCaller } from "@llamaindex/core/decorator"; import { - BaseLLM, + ToolCallLLM, + type BaseTool, type ChatMessage, type ChatResponse, type ChatResponseChunk, type LLMChatParamsNonStreaming, type LLMChatParamsStreaming, + type PartialToolCall, + type ToolCallLLMMessageOptions, } from "@llamaindex/core/llms"; +import { extractText } from "@llamaindex/core/utils"; import { getEnv } from "@llamaindex/env"; import { type Mistral } from "@mistralai/mistralai"; -import type { ContentChunk } from "@mistralai/mistralai/models/components"; +import type { + AssistantMessage, + ChatCompletionRequest, + ChatCompletionStreamRequest, + ContentChunk, + Tool, + ToolMessage, +} from "@mistralai/mistralai/models/components"; export const ALL_AVAILABLE_MISTRAL_MODELS = { "mistral-tiny": { contextWindow: 32000 }, "mistral-small": { contextWindow: 32000 }, "mistral-medium": { contextWindow: 32000 }, + "mistral-small-latest": { contextWindow: 32000 }, + "mistral-large-latest": { contextWindow: 131000 }, + "codestral-latest": { contextWindow: 256000 }, + "pixtral-large-latest": { contextWindow: 131000 }, + "mistral-saba-latest": { contextWindow: 32000 }, + "ministral-3b-latest": { contextWindow: 131000 }, + "ministral-8b-latest": { contextWindow: 131000 }, + "mistral-embed": { contextWindow: 8000 }, + "mistral-moderation-latest": { contextWindow: 8000 }, }; +export const TOOL_CALL_MISTRAL_MODELS = [ + "mistral-small-latest", + "mistral-large-latest", + "codestral-latest", + "pixtral-large-latest", + "ministral-8b-latest", + "ministral-3b-latest", +]; + export class MistralAISession { apiKey: string; // eslint-disable-next-line @typescript-eslint/no-explicit-any @@ -46,7 +76,7 @@ export class MistralAISession { /** * MistralAI LLM implementation */ -export class MistralAI extends BaseLLM { +export class MistralAI extends ToolCallLLM<ToolCallLLMMessageOptions> { // Per completion MistralAI params model: keyof typeof ALL_AVAILABLE_MISTRAL_MODELS; temperature: number; @@ -60,7 +90,7 @@ export class MistralAI extends BaseLLM { constructor(init?: Partial<MistralAI>) { super(); - this.model = init?.model ?? "mistral-small"; + this.model = init?.model ?? "mistral-small-latest"; this.temperature = init?.temperature ?? 0.1; this.topP = init?.topP ?? 1; this.maxTokens = init?.maxTokens ?? undefined; @@ -80,8 +110,51 @@ export class MistralAI extends BaseLLM { }; } - // eslint-disable-next-line @typescript-eslint/no-explicit-any - private buildParams(messages: ChatMessage[]): any { + get supportToolCall() { + return TOOL_CALL_MISTRAL_MODELS.includes(this.metadata.model); + } + + formatMessages(messages: ChatMessage<ToolCallLLMMessageOptions>[]) { + return messages.map((message) => { + const options = message.options ?? {}; + //tool call message + if ("toolCall" in options) { + return { + role: "assistant", + content: extractText(message.content), + toolCalls: options.toolCall.map((toolCall) => { + return { + id: toolCall.id, + type: "function", + function: { + name: toolCall.name, + arguments: toolCall.input, + }, + }; + }), + } satisfies AssistantMessage; + } + + //tool result message + if ("toolResult" in options) { + return { + role: "tool", + content: extractText(message.content), + toolCallId: options.toolResult.id, + } satisfies ToolMessage; + } + + return { + role: message.role, + content: extractText(message.content), + }; + }); + } + + private buildParams( + messages: ChatMessage<ToolCallLLMMessageOptions>[], + tools?: BaseTool[], + ) { return { model: this.model, temperature: this.temperature, @@ -89,25 +162,49 @@ export class MistralAI extends BaseLLM { topP: this.topP, safeMode: this.safeMode, randomSeed: this.randomSeed, - messages, + messages: this.formatMessages(messages), + tools: tools?.map(MistralAI.toTool), + }; + } + + static toTool(tool: BaseTool): Tool { + if (!tool.metadata.parameters) { + throw new Error("Tool parameters are required"); + } + + return { + type: "function", + function: { + name: tool.metadata.name, + description: tool.metadata.description, + parameters: tool.metadata.parameters, + }, }; } chat( params: LLMChatParamsStreaming, ): Promise<AsyncIterable<ChatResponseChunk>>; - chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + chat( + params: LLMChatParamsNonStreaming<ToolCallLLMMessageOptions>, + ): Promise<ChatResponse>; async chat( params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, - ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { - const { messages, stream } = params; + ): Promise< + | ChatResponse<ToolCallLLMMessageOptions> + | AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> + > { + const { messages, stream, tools } = params; // Streaming if (stream) { - return this.streamChat(params); + return this.streamChat(messages, tools); } // Non-streaming const client = await this.session.getClient(); - const response = await client.chat.complete(this.buildParams(messages)); + const buildParams = this.buildParams(messages, tools); + const response = await client.chat.complete( + buildParams as ChatCompletionRequest, + ); if (!response || !response.choices || !response.choices[0]) { throw new Error("Unexpected response format from Mistral API"); @@ -121,28 +218,100 @@ export class MistralAI extends BaseLLM { message: { role: "assistant", content: this.extractContentAsString(content), + options: response.choices[0]!.message?.toolCalls + ? { + toolCall: response.choices[0]!.message.toolCalls.map( + (toolCall) => ({ + id: toolCall.id, + name: toolCall.function.name, + input: this.extractArgumentsAsString( + toolCall.function.arguments, + ), + }), + ), + } + : {}, }, }; } - protected async *streamChat({ - messages, - }: LLMChatParamsStreaming): AsyncIterable<ChatResponseChunk> { + @wrapEventCaller + protected async *streamChat( + messages: ChatMessage[], + tools?: BaseTool[], + ): AsyncIterable<ChatResponseChunk<ToolCallLLMMessageOptions>> { const client = await this.session.getClient(); - const chunkStream = await client.chat.stream(this.buildParams(messages)); + const buildParams = this.buildParams( + messages, + tools, + ) as ChatCompletionStreamRequest; + const chunkStream = await client.chat.stream(buildParams); + + let currentToolCall: PartialToolCall | null = null; + const toolCallMap = new Map<string, PartialToolCall>(); for await (const chunk of chunkStream) { - if (!chunk.data || !chunk.data.choices || !chunk.data.choices.length) - continue; + if (!chunk.data?.choices?.[0]?.delta) continue; const choice = chunk.data.choices[0]; - if (!choice) continue; + if (!(choice.delta.content || choice.delta.toolCalls)) continue; + + let shouldEmitToolCall: PartialToolCall | null = null; + + if (choice.delta.toolCalls?.[0]) { + const toolCall = choice.delta.toolCalls[0]; + + if (toolCall.id) { + if (currentToolCall && toolCall.id !== currentToolCall.id) { + shouldEmitToolCall = { + ...currentToolCall, + input: JSON.parse(currentToolCall.input), + }; + } + + currentToolCall = { + id: toolCall.id, + name: toolCall.function!.name!, + input: this.extractArgumentsAsString(toolCall.function!.arguments), + }; + + toolCallMap.set(toolCall.id, currentToolCall!); + } else if (currentToolCall && toolCall.function?.arguments) { + currentToolCall.input += this.extractArgumentsAsString( + toolCall.function.arguments, + ); + } + } + + const isDone: boolean = choice.finishReason !== null; + + if (isDone && currentToolCall) { + //emitting last tool call + shouldEmitToolCall = { + ...currentToolCall, + input: JSON.parse(currentToolCall.input), + }; + } yield { raw: chunk.data, delta: this.extractContentAsString(choice.delta.content), + options: shouldEmitToolCall + ? { toolCall: [shouldEmitToolCall] } + : currentToolCall + ? { toolCall: [currentToolCall] } + : {}, }; } + + toolCallMap.clear(); + } + + private extractArgumentsAsString( + // eslint-disable-next-line @typescript-eslint/no-explicit-any + args: string | { [k: string]: any } | null | undefined, + ): string { + return typeof args === "string" ? args : JSON.stringify(args) || ""; } private extractContentAsString( diff --git a/packages/providers/mistral/tests/index.test.ts b/packages/providers/mistral/tests/index.test.ts new file mode 100644 index 0000000000000000000000000000000000000000..43fdb9ea0b3e402155db5761623c43f215a89d45 --- /dev/null +++ b/packages/providers/mistral/tests/index.test.ts @@ -0,0 +1,116 @@ +import type { ChatMessage } from "@llamaindex/core/llms"; +import { setEnvs } from "@llamaindex/env"; +import { beforeAll, describe, expect, test } from "vitest"; +import { MistralAI } from "../src/index"; + +beforeAll(() => { + setEnvs({ + MISTRAL_API_KEY: "valid", + }); +}); + +describe("Message Formatting", () => { + describe("Basic Message Formatting", () => { + test("Mistral formats basic messages correctly", () => { + const mistral = new MistralAI(); + const inputMessages: ChatMessage[] = [ + { + content: "You are a helpful assistant.", + role: "assistant", + }, + { + content: "Hello?", + role: "user", + }, + ]; + const expectedOutput = [ + { + content: "You are a helpful assistant.", + role: "assistant", + }, + { + content: "Hello?", + role: "user", + }, + ]; + + expect(mistral.formatMessages(inputMessages)).toEqual(expectedOutput); + }); + + test("Mistral handles multi-turn conversation correctly", () => { + const mistral = new MistralAI(); + const inputMessages: ChatMessage[] = [ + { content: "Hi", role: "user" }, + { content: "Hello! How can I help?", role: "assistant" }, + { content: "What's the weather?", role: "user" }, + ]; + const expectedOutput = [ + { content: "Hi", role: "user" }, + { content: "Hello! How can I help?", role: "assistant" }, + { content: "What's the weather?", role: "user" }, + ]; + expect(mistral.formatMessages(inputMessages)).toEqual(expectedOutput); + }); + }); + + describe("Tool Message Formatting", () => { + const toolCallMessages: ChatMessage[] = [ + { + role: "user", + content: "What's the weather in London?", + }, + { + role: "assistant", + content: "Let me check the weather.", + options: { + toolCall: [ + { + id: "call_123", + name: "weather", + input: JSON.stringify({ location: "London" }), + }, + ], + }, + }, + { + role: "assistant", + content: "The weather in London is sunny, +20°C", + options: { + toolResult: { + id: "call_123", + }, + }, + }, + ]; + + test("Mistral formats tool calls correctly", () => { + const mistral = new MistralAI(); + const expectedOutput = [ + { + role: "user", + content: "What's the weather in London?", + }, + { + role: "assistant", + content: "Let me check the weather.", + toolCalls: [ + { + type: "function", + id: "call_123", + function: { + name: "weather", + arguments: '{"location":"London"}', + }, + }, + ], + }, + { + role: "tool", + content: "The weather in London is sunny, +20°C", + toolCallId: "call_123", + }, + ]; + expect(mistral.formatMessages(toolCallMessages)).toEqual(expectedOutput); + }); + }); +}); diff --git a/pnpm-lock.yaml b/pnpm-lock.yaml index 2b1aab4a8096cd902b90a6a57a9c0dd5d3798807..42c5e203d2fb834f0514d867ea4df8ab7bc7756a 100644 --- a/pnpm-lock.yaml +++ b/pnpm-lock.yaml @@ -1266,6 +1266,9 @@ importers: bunchee: specifier: 6.4.0 version: 6.4.0(typescript@5.7.3) + vitest: + specifier: ^2.1.5 + version: 2.1.5(@edge-runtime/vm@4.0.4)(@types/node@22.13.5)(happy-dom@15.11.7)(lightningcss@1.29.1)(msw@2.7.0(@types/node@22.13.5)(typescript@5.7.3))(terser@5.38.2) packages/providers/mixedbread: dependencies: