diff --git a/.changeset/short-candles-drive.md b/.changeset/short-candles-drive.md new file mode 100644 index 0000000000000000000000000000000000000000..ebd25584ed19c053251f78537b1918d7bac179ff --- /dev/null +++ b/.changeset/short-candles-drive.md @@ -0,0 +1,8 @@ +--- +"llamaindex": patch +"@llamaindex/ollama": patch +--- + +feat: support ollama tool call + +Note that `OllamaEmbedding` now is not the subclass of `Ollama`. diff --git a/e2e/fixtures/llm/ollama.ts b/e2e/fixtures/llm/ollama.ts new file mode 100644 index 0000000000000000000000000000000000000000..b16fea13337cc861b63d45c7fbca2d2c598e8080 --- /dev/null +++ b/e2e/fixtures/llm/ollama.ts @@ -0,0 +1,3 @@ +import { OpenAI } from "./openai.js"; + +export class Ollama extends OpenAI {} diff --git a/e2e/node/ollama.e2e.ts b/e2e/node/ollama.e2e.ts new file mode 100644 index 0000000000000000000000000000000000000000..9aadcdb4cf81c965f1eb78849d0209e5a21124c5 --- /dev/null +++ b/e2e/node/ollama.e2e.ts @@ -0,0 +1,35 @@ +import { Ollama } from "@llamaindex/ollama"; +import assert from "node:assert"; +import { test } from "node:test"; +import { getWeatherTool } from "./fixtures/tools.js"; +import { mockLLMEvent } from "./utils.js"; + +await test("ollama", async (t) => { + await mockLLMEvent(t, "ollama"); + await t.test("ollama function call", async (t) => { + const llm = new Ollama({ + model: "llama3.2", + }); + const chatResponse = await llm.chat({ + messages: [ + { + role: "user", + content: "What is the weather in Paris?", + }, + ], + tools: [getWeatherTool], + }); + if ( + chatResponse.message.options && + "toolCall" in chatResponse.message.options + ) { + assert.equal(chatResponse.message.options.toolCall.length, 1); + assert.equal( + chatResponse.message.options.toolCall[0]!.name, + getWeatherTool.metadata.name, + ); + } else { + throw new Error("Expected tool calls in response"); + } + }); +}); diff --git a/e2e/node/snapshot/ollama.snap b/e2e/node/snapshot/ollama.snap new file mode 100644 index 0000000000000000000000000000000000000000..6b309ce4b83d64888d2cbbdb30e7d72c2ea2fc8a --- /dev/null +++ b/e2e/node/snapshot/ollama.snap @@ -0,0 +1,5 @@ +{ + "llmEventStart": [], + "llmEventEnd": [], + "llmEventStream": [] +} \ No newline at end of file diff --git a/packages/core/src/llms/base.ts b/packages/core/src/llms/base.ts index b04defc65d8415ede985380e33e99c00ae350c13..89b9ecd82ca08dc0bc6c4f596b595cf2826d704a 100644 --- a/packages/core/src/llms/base.ts +++ b/packages/core/src/llms/base.ts @@ -1,5 +1,4 @@ -import { streamConverter } from "../utils"; -import { extractText } from "../utils/llms"; +import { extractText, streamConverter } from "../utils"; import type { ChatResponse, ChatResponseChunk, diff --git a/packages/llamaindex/src/embeddings/OllamaEmbedding.ts b/packages/llamaindex/src/embeddings/OllamaEmbedding.ts index 2bd40a48eeaa17652ef2d5a46e15fc3e0aa9b2b1..8b1bff9bc3dfe3e1131c172785f200c8dbd55c6e 100644 --- a/packages/llamaindex/src/embeddings/OllamaEmbedding.ts +++ b/packages/llamaindex/src/embeddings/OllamaEmbedding.ts @@ -1,7 +1 @@ -import type { BaseEmbedding } from "@llamaindex/core/embeddings"; -import { Ollama } from "@llamaindex/ollama"; - -/** - * OllamaEmbedding is an alias for Ollama that implements the BaseEmbedding interface. - */ -export class OllamaEmbedding extends Ollama implements BaseEmbedding {} +export { OllamaEmbedding } from "@llamaindex/ollama"; diff --git a/packages/providers/ollama/src/index.ts b/packages/providers/ollama/src/index.ts index 1bdcd4b81af422e56ac4fe09ce13a1fc1faead57..8bee7b24adcc9c1d0629089a1d3b9fe746b08201 100644 --- a/packages/providers/ollama/src/index.ts +++ b/packages/providers/ollama/src/index.ts @@ -1,16 +1,20 @@ import { BaseEmbedding } from "@llamaindex/core/embeddings"; -import type { - ChatResponse, - ChatResponseChunk, - CompletionResponse, - LLM, - LLMChatParamsNonStreaming, - LLMChatParamsStreaming, - LLMCompletionParamsNonStreaming, - LLMCompletionParamsStreaming, - LLMMetadata, +import { + ToolCallLLM, + type BaseTool, + type ChatResponse, + type ChatResponseChunk, + type CompletionResponse, + type LLMChatParamsNonStreaming, + type LLMChatParamsStreaming, + type LLMCompletionParamsNonStreaming, + type LLMCompletionParamsStreaming, + type LLMMetadata, + type ToolCallLLMMessageOptions, } from "@llamaindex/core/llms"; import { extractText, streamConverter } from "@llamaindex/core/utils"; +import { randomUUID } from "@llamaindex/env"; +import type { ChatRequest, GenerateRequest, Tool } from "ollama"; import { Ollama as OllamaBase, type Config, @@ -38,7 +42,8 @@ export type OllamaParams = { options?: Partial<Options>; }; -export class Ollama extends BaseEmbedding implements LLM { +export class Ollama extends ToolCallLLM { + supportToolCall: boolean = true; public readonly ollama: OllamaBase; // https://ollama.ai/library @@ -78,12 +83,16 @@ export class Ollama extends BaseEmbedding implements LLM { chat( params: LLMChatParamsStreaming, ): Promise<AsyncIterable<ChatResponseChunk>>; - chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>; + chat( + params: LLMChatParamsNonStreaming, + ): Promise<ChatResponse<ToolCallLLMMessageOptions>>; async chat( params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, - ): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> { - const { messages, stream } = params; - const payload = { + ): Promise< + ChatResponse<ToolCallLLMMessageOptions> | AsyncIterable<ChatResponseChunk> + > { + const { messages, stream, tools } = params; + const payload: ChatRequest = { model: this.model, messages: messages.map((message) => ({ role: message.role, @@ -94,11 +103,30 @@ export class Ollama extends BaseEmbedding implements LLM { ...this.options, }, }; + if (tools) { + payload.tools = tools.map((tool) => Ollama.toTool(tool)); + } if (!stream) { const chatResponse = await this.ollama.chat({ ...payload, stream: false, }); + if (chatResponse.message.tool_calls) { + return { + message: { + role: "assistant", + content: chatResponse.message.content, + options: { + toolCall: chatResponse.message.tool_calls.map((toolCall) => ({ + name: toolCall.function.name, + input: toolCall.function.arguments, + id: randomUUID(), + })), + }, + }, + raw: chatResponse, + }; + } return { message: { @@ -126,7 +154,7 @@ export class Ollama extends BaseEmbedding implements LLM { params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, ): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> { const { prompt, stream } = params; - const payload = { + const payload: GenerateRequest = { model: this.model, prompt: extractText(prompt), stream: !!stream, @@ -152,15 +180,39 @@ export class Ollama extends BaseEmbedding implements LLM { } } + static toTool(tool: BaseTool): Tool { + return { + type: "function", + function: { + name: tool.metadata.name, + description: tool.metadata.description, + parameters: { + type: tool.metadata.parameters?.type, + required: tool.metadata.parameters?.required, + properties: tool.metadata.parameters?.properties, + }, + }, + }; + } +} + +export class OllamaEmbedding extends BaseEmbedding { + private readonly llm: Ollama; + + constructor(params: OllamaParams) { + super(); + this.llm = new Ollama(params); + } + private async getEmbedding(prompt: string): Promise<number[]> { const payload = { - model: this.model, + model: this.llm.model, prompt, options: { - ...this.options, + ...this.llm.options, }, }; - const response = await this.ollama.embeddings({ + const response = await this.llm.ollama.embeddings({ ...payload, }); return response.embedding;