Skip to content
Snippets Groups Projects
Unverified Commit 4d4cd8ac authored by Alex Yang's avatar Alex Yang Committed by GitHub
Browse files

feat: support ollama tool call (#1461)

parent 4fc001c8
No related branches found
No related tags found
No related merge requests found
---
"llamaindex": patch
"@llamaindex/ollama": patch
---
feat: support ollama tool call
Note that `OllamaEmbedding` now is not the subclass of `Ollama`.
import { OpenAI } from "./openai.js";
export class Ollama extends OpenAI {}
import { Ollama } from "@llamaindex/ollama";
import assert from "node:assert";
import { test } from "node:test";
import { getWeatherTool } from "./fixtures/tools.js";
import { mockLLMEvent } from "./utils.js";
await test("ollama", async (t) => {
await mockLLMEvent(t, "ollama");
await t.test("ollama function call", async (t) => {
const llm = new Ollama({
model: "llama3.2",
});
const chatResponse = await llm.chat({
messages: [
{
role: "user",
content: "What is the weather in Paris?",
},
],
tools: [getWeatherTool],
});
if (
chatResponse.message.options &&
"toolCall" in chatResponse.message.options
) {
assert.equal(chatResponse.message.options.toolCall.length, 1);
assert.equal(
chatResponse.message.options.toolCall[0]!.name,
getWeatherTool.metadata.name,
);
} else {
throw new Error("Expected tool calls in response");
}
});
});
{
"llmEventStart": [],
"llmEventEnd": [],
"llmEventStream": []
}
\ No newline at end of file
import { streamConverter } from "../utils";
import { extractText } from "../utils/llms";
import { extractText, streamConverter } from "../utils";
import type {
ChatResponse,
ChatResponseChunk,
......
import type { BaseEmbedding } from "@llamaindex/core/embeddings";
import { Ollama } from "@llamaindex/ollama";
/**
* OllamaEmbedding is an alias for Ollama that implements the BaseEmbedding interface.
*/
export class OllamaEmbedding extends Ollama implements BaseEmbedding {}
export { OllamaEmbedding } from "@llamaindex/ollama";
import { BaseEmbedding } from "@llamaindex/core/embeddings";
import type {
ChatResponse,
ChatResponseChunk,
CompletionResponse,
LLM,
LLMChatParamsNonStreaming,
LLMChatParamsStreaming,
LLMCompletionParamsNonStreaming,
LLMCompletionParamsStreaming,
LLMMetadata,
import {
ToolCallLLM,
type BaseTool,
type ChatResponse,
type ChatResponseChunk,
type CompletionResponse,
type LLMChatParamsNonStreaming,
type LLMChatParamsStreaming,
type LLMCompletionParamsNonStreaming,
type LLMCompletionParamsStreaming,
type LLMMetadata,
type ToolCallLLMMessageOptions,
} from "@llamaindex/core/llms";
import { extractText, streamConverter } from "@llamaindex/core/utils";
import { randomUUID } from "@llamaindex/env";
import type { ChatRequest, GenerateRequest, Tool } from "ollama";
import {
Ollama as OllamaBase,
type Config,
......@@ -38,7 +42,8 @@ export type OllamaParams = {
options?: Partial<Options>;
};
export class Ollama extends BaseEmbedding implements LLM {
export class Ollama extends ToolCallLLM {
supportToolCall: boolean = true;
public readonly ollama: OllamaBase;
// https://ollama.ai/library
......@@ -78,12 +83,16 @@ export class Ollama extends BaseEmbedding implements LLM {
chat(
params: LLMChatParamsStreaming,
): Promise<AsyncIterable<ChatResponseChunk>>;
chat(params: LLMChatParamsNonStreaming): Promise<ChatResponse>;
chat(
params: LLMChatParamsNonStreaming,
): Promise<ChatResponse<ToolCallLLMMessageOptions>>;
async chat(
params: LLMChatParamsNonStreaming | LLMChatParamsStreaming,
): Promise<ChatResponse | AsyncIterable<ChatResponseChunk>> {
const { messages, stream } = params;
const payload = {
): Promise<
ChatResponse<ToolCallLLMMessageOptions> | AsyncIterable<ChatResponseChunk>
> {
const { messages, stream, tools } = params;
const payload: ChatRequest = {
model: this.model,
messages: messages.map((message) => ({
role: message.role,
......@@ -94,11 +103,30 @@ export class Ollama extends BaseEmbedding implements LLM {
...this.options,
},
};
if (tools) {
payload.tools = tools.map((tool) => Ollama.toTool(tool));
}
if (!stream) {
const chatResponse = await this.ollama.chat({
...payload,
stream: false,
});
if (chatResponse.message.tool_calls) {
return {
message: {
role: "assistant",
content: chatResponse.message.content,
options: {
toolCall: chatResponse.message.tool_calls.map((toolCall) => ({
name: toolCall.function.name,
input: toolCall.function.arguments,
id: randomUUID(),
})),
},
},
raw: chatResponse,
};
}
return {
message: {
......@@ -126,7 +154,7 @@ export class Ollama extends BaseEmbedding implements LLM {
params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming,
): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> {
const { prompt, stream } = params;
const payload = {
const payload: GenerateRequest = {
model: this.model,
prompt: extractText(prompt),
stream: !!stream,
......@@ -152,15 +180,39 @@ export class Ollama extends BaseEmbedding implements LLM {
}
}
static toTool(tool: BaseTool): Tool {
return {
type: "function",
function: {
name: tool.metadata.name,
description: tool.metadata.description,
parameters: {
type: tool.metadata.parameters?.type,
required: tool.metadata.parameters?.required,
properties: tool.metadata.parameters?.properties,
},
},
};
}
}
export class OllamaEmbedding extends BaseEmbedding {
private readonly llm: Ollama;
constructor(params: OllamaParams) {
super();
this.llm = new Ollama(params);
}
private async getEmbedding(prompt: string): Promise<number[]> {
const payload = {
model: this.model,
model: this.llm.model,
prompt,
options: {
...this.options,
...this.llm.options,
},
};
const response = await this.ollama.embeddings({
const response = await this.llm.ollama.embeddings({
...payload,
});
return response.embedding;
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment