diff --git a/e2e/mock-module.js b/e2e/mock-module.js index 28f55f3aa6994635a3bad033496e09f83f203085..7c664a8efabce1578a6c68e73e91a802f88bf498 100644 --- a/e2e/mock-module.js +++ b/e2e/mock-module.js @@ -15,7 +15,17 @@ export async function resolve(specifier, context, nextResolve) { const targetUrl = fileURLToPath(result.url).replace(/\.js$/, ".ts"); let relativePath = relative(packageDistDir, targetUrl); // todo: make it more generic if we have more sub modules fixtures in the future - if (relativePath.startsWith("../../llm/openai")) { + if (relativePath.startsWith("../../llm/anthropic")) { + relativePath = relativePath.replace( + "../../llm/ollama/dist/index.ts", + "llm/anthropic.ts", + ); + } else if (relativePath.startsWith("../../llm/ollama")) { + relativePath = relativePath.replace( + "../../llm/ollama/dist/index.ts", + "llm/ollama.ts", + ); + } else if (relativePath.startsWith("../../llm/openai")) { relativePath = relativePath.replace( "../../llm/openai/dist/index.ts", "llm/openai.ts", diff --git a/e2e/node/snapshot/ollama.snap b/e2e/node/snapshot/ollama.snap index 6b309ce4b83d64888d2cbbdb30e7d72c2ea2fc8a..b2f4b076551a1351c59899710e35aa79ac62cdae 100644 --- a/e2e/node/snapshot/ollama.snap +++ b/e2e/node/snapshot/ollama.snap @@ -1,5 +1,37 @@ { - "llmEventStart": [], - "llmEventEnd": [], + "llmEventStart": [ + { + "id": "PRESERVE_0", + "messages": [ + { + "role": "user", + "content": "What is the weather in Paris?" + } + ] + } + ], + "llmEventEnd": [ + { + "id": "PRESERVE_0", + "response": { + "message": { + "role": "assistant", + "content": "", + "options": { + "toolCall": [ + { + "name": "getWeather", + "input": { + "city": "Paris" + }, + "id": "5d198775-5268-4552-993b-9ecb4425385b" + } + ] + } + }, + "raw": null + } + } + ], "llmEventStream": [] } \ No newline at end of file diff --git a/packages/llamaindex/src/agent/index.ts b/packages/llamaindex/src/agent/index.ts index 8d9790f26d147b7d0dac5f6e9f91456ac6c8ca00..1ec22ddb7f21ebeeeeeac85f74987897d0bbe192 100644 --- a/packages/llamaindex/src/agent/index.ts +++ b/packages/llamaindex/src/agent/index.ts @@ -1,4 +1,9 @@ export * from "@llamaindex/core/agent"; +export { + OllamaAgent, + OllamaAgentWorker, + type OllamaAgentParams, +} from "@llamaindex/ollama"; export { AnthropicAgent, AnthropicAgentWorker, @@ -16,7 +21,6 @@ export { ReActAgent, type ReACTAgentParams, } from "./react.js"; - // todo: ParallelAgent // todo: CustomAgent // todo: ReactMultiModal diff --git a/packages/providers/ollama/src/agent.ts b/packages/providers/ollama/src/agent.ts new file mode 100644 index 0000000000000000000000000000000000000000..69dd75a68c4b0f5301461c6ca7f9760590833746 --- /dev/null +++ b/packages/providers/ollama/src/agent.ts @@ -0,0 +1,33 @@ +import { + LLMAgent, + LLMAgentWorker, + type LLMAgentParams, +} from "@llamaindex/core/agent"; +import { Settings } from "@llamaindex/core/global"; +import { Ollama } from "./llm"; + +// This is likely not necessary anymore but leaving it here just incase it's in use elsewhere + +export type OllamaAgentParams = LLMAgentParams & { + model?: string; +}; + +export class OllamaAgentWorker extends LLMAgentWorker {} + +export class OllamaAgent extends LLMAgent { + constructor(params: OllamaAgentParams) { + const llm = + params.llm ?? + (Settings.llm instanceof Ollama + ? (Settings.llm as Ollama) + : !params.model + ? (() => { + throw new Error("No model provided"); + })() + : new Ollama({ model: params.model })); + super({ + ...params, + llm, + }); + } +} diff --git a/packages/providers/ollama/src/embedding.ts b/packages/providers/ollama/src/embedding.ts new file mode 100644 index 0000000000000000000000000000000000000000..6a94b11bb47c64d7232e7381944599fce1f9ce4e --- /dev/null +++ b/packages/providers/ollama/src/embedding.ts @@ -0,0 +1,29 @@ +import { BaseEmbedding } from "@llamaindex/core/embeddings"; +import { Ollama, type OllamaParams } from "./llm"; + +export class OllamaEmbedding extends BaseEmbedding { + private readonly llm: Ollama; + + constructor(params: OllamaParams) { + super(); + this.llm = new Ollama(params); + } + + private async getEmbedding(prompt: string): Promise<number[]> { + const payload = { + model: this.llm.model, + prompt, + options: { + ...this.llm.options, + }, + }; + const response = await this.llm.ollama.embeddings({ + ...payload, + }); + return response.embedding; + } + + async getTextEmbedding(text: string): Promise<number[]> { + return this.getEmbedding(text); + } +} diff --git a/packages/providers/ollama/src/index.ts b/packages/providers/ollama/src/index.ts index 8bee7b24adcc9c1d0629089a1d3b9fe746b08201..002d16f41c3ec08a3e7fd3d6dd52ab085d349120 100644 --- a/packages/providers/ollama/src/index.ts +++ b/packages/providers/ollama/src/index.ts @@ -1,224 +1,7 @@ -import { BaseEmbedding } from "@llamaindex/core/embeddings"; -import { - ToolCallLLM, - type BaseTool, - type ChatResponse, - type ChatResponseChunk, - type CompletionResponse, - type LLMChatParamsNonStreaming, - type LLMChatParamsStreaming, - type LLMCompletionParamsNonStreaming, - type LLMCompletionParamsStreaming, - type LLMMetadata, - type ToolCallLLMMessageOptions, -} from "@llamaindex/core/llms"; -import { extractText, streamConverter } from "@llamaindex/core/utils"; -import { randomUUID } from "@llamaindex/env"; -import type { ChatRequest, GenerateRequest, Tool } from "ollama"; -import { - Ollama as OllamaBase, - type Config, - type ChatResponse as OllamaChatResponse, - type GenerateResponse as OllamaGenerateResponse, - type Options, -} from "ollama/browser"; - -const messageAccessor = (part: OllamaChatResponse): ChatResponseChunk => { - return { - raw: part, - delta: part.message.content, - }; -}; - -const completionAccessor = ( - part: OllamaGenerateResponse, -): CompletionResponse => { - return { text: part.response, raw: part }; -}; - -export type OllamaParams = { - model: string; - config?: Partial<Config>; - options?: Partial<Options>; -}; - -export class Ollama extends ToolCallLLM { - supportToolCall: boolean = true; - public readonly ollama: OllamaBase; - - // https://ollama.ai/library - model: string; - - options: Partial<Omit<Options, "num_ctx" | "top_p" | "temperature">> & - Pick<Options, "num_ctx" | "top_p" | "temperature"> = { - num_ctx: 4096, - top_p: 0.9, - temperature: 0.7, - }; - - constructor(params: OllamaParams) { - super(); - this.model = params.model; - this.ollama = new OllamaBase(params.config); - if (params.options) { - this.options = { - ...this.options, - ...params.options, - }; - } - } - - get metadata(): LLMMetadata { - const { temperature, top_p, num_ctx } = this.options; - return { - model: this.model, - temperature: temperature, - topP: top_p, - maxTokens: this.options.num_ctx, - contextWindow: num_ctx, - tokenizer: undefined, - }; - } - - chat( - params: LLMChatParamsStreaming, - ): Promise<AsyncIterable<ChatResponseChunk>>; - chat( - params: LLMChatParamsNonStreaming, - ): Promise<ChatResponse<ToolCallLLMMessageOptions>>; - async chat( - params: LLMChatParamsNonStreaming | LLMChatParamsStreaming, - ): Promise< - ChatResponse<ToolCallLLMMessageOptions> | AsyncIterable<ChatResponseChunk> - > { - const { messages, stream, tools } = params; - const payload: ChatRequest = { - model: this.model, - messages: messages.map((message) => ({ - role: message.role, - content: extractText(message.content), - })), - stream: !!stream, - options: { - ...this.options, - }, - }; - if (tools) { - payload.tools = tools.map((tool) => Ollama.toTool(tool)); - } - if (!stream) { - const chatResponse = await this.ollama.chat({ - ...payload, - stream: false, - }); - if (chatResponse.message.tool_calls) { - return { - message: { - role: "assistant", - content: chatResponse.message.content, - options: { - toolCall: chatResponse.message.tool_calls.map((toolCall) => ({ - name: toolCall.function.name, - input: toolCall.function.arguments, - id: randomUUID(), - })), - }, - }, - raw: chatResponse, - }; - } - - return { - message: { - role: "assistant", - content: chatResponse.message.content, - }, - raw: chatResponse, - }; - } else { - const stream = await this.ollama.chat({ - ...payload, - stream: true, - }); - return streamConverter(stream, messageAccessor); - } - } - - complete( - params: LLMCompletionParamsStreaming, - ): Promise<AsyncIterable<CompletionResponse>>; - complete( - params: LLMCompletionParamsNonStreaming, - ): Promise<CompletionResponse>; - async complete( - params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, - ): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> { - const { prompt, stream } = params; - const payload: GenerateRequest = { - model: this.model, - prompt: extractText(prompt), - stream: !!stream, - options: { - ...this.options, - }, - }; - if (!stream) { - const response = await this.ollama.generate({ - ...payload, - stream: false, - }); - return { - text: response.response, - raw: response, - }; - } else { - const stream = await this.ollama.generate({ - ...payload, - stream: true, - }); - return streamConverter(stream, completionAccessor); - } - } - - static toTool(tool: BaseTool): Tool { - return { - type: "function", - function: { - name: tool.metadata.name, - description: tool.metadata.description, - parameters: { - type: tool.metadata.parameters?.type, - required: tool.metadata.parameters?.required, - properties: tool.metadata.parameters?.properties, - }, - }, - }; - } -} - -export class OllamaEmbedding extends BaseEmbedding { - private readonly llm: Ollama; - - constructor(params: OllamaParams) { - super(); - this.llm = new Ollama(params); - } - - private async getEmbedding(prompt: string): Promise<number[]> { - const payload = { - model: this.llm.model, - prompt, - options: { - ...this.llm.options, - }, - }; - const response = await this.llm.ollama.embeddings({ - ...payload, - }); - return response.embedding; - } - - async getTextEmbedding(text: string): Promise<number[]> { - return this.getEmbedding(text); - } -} +export { + OllamaAgent, + OllamaAgentWorker, + type OllamaAgentParams, +} from "./agent"; +export { OllamaEmbedding } from "./embedding"; +export { Ollama, type OllamaParams } from "./llm"; diff --git a/packages/providers/ollama/src/llm.ts b/packages/providers/ollama/src/llm.ts new file mode 100644 index 0000000000000000000000000000000000000000..f82da2ff32bd0785ad6a2e13f72ca3ce598c549e --- /dev/null +++ b/packages/providers/ollama/src/llm.ts @@ -0,0 +1,209 @@ +import { wrapLLMEvent } from "@llamaindex/core/decorator"; +import { + ToolCallLLM, + type BaseTool, + type ChatResponse, + type ChatResponseChunk, + type CompletionResponse, + type LLMChatParamsNonStreaming, + type LLMChatParamsStreaming, + type LLMCompletionParamsNonStreaming, + type LLMCompletionParamsStreaming, + type LLMMetadata, + type ToolCallLLMMessageOptions, +} from "@llamaindex/core/llms"; +import { extractText, streamConverter } from "@llamaindex/core/utils"; +import { randomUUID } from "@llamaindex/env"; +import type { ChatRequest, GenerateRequest, Tool } from "ollama"; +import { + Ollama as OllamaBase, + type Config, + type ChatResponse as OllamaChatResponse, + type GenerateResponse as OllamaGenerateResponse, + type Options, +} from "ollama/browser"; + +const messageAccessor = (part: OllamaChatResponse): ChatResponseChunk => { + return { + raw: part, + delta: part.message.content, + }; +}; + +const completionAccessor = ( + part: OllamaGenerateResponse, +): CompletionResponse => { + return { text: part.response, raw: part }; +}; + +export type OllamaParams = { + model: string; + config?: Partial<Config>; + options?: Partial<Options>; +}; + +export class Ollama extends ToolCallLLM { + supportToolCall: boolean = true; + public readonly ollama: OllamaBase; + + // https://ollama.ai/library + model: string; + + options: Partial<Omit<Options, "num_ctx" | "top_p" | "temperature">> & + Pick<Options, "num_ctx" | "top_p" | "temperature"> = { + num_ctx: 4096, + top_p: 0.9, + temperature: 0.7, + }; + + constructor(params: OllamaParams) { + super(); + this.model = params.model; + this.ollama = new OllamaBase(params.config); + if (params.options) { + this.options = { + ...this.options, + ...params.options, + }; + } + } + + get metadata(): LLMMetadata { + const { temperature, top_p, num_ctx } = this.options; + return { + model: this.model, + temperature: temperature, + topP: top_p, + maxTokens: this.options.num_ctx, + contextWindow: num_ctx, + tokenizer: undefined, + }; + } + + chat( + params: LLMChatParamsStreaming<ToolCallLLMMessageOptions>, + ): Promise<AsyncIterable<ChatResponseChunk>>; + chat( + params: LLMChatParamsNonStreaming<ToolCallLLMMessageOptions>, + ): Promise<ChatResponse<ToolCallLLMMessageOptions>>; + @wrapLLMEvent + async chat( + params: + | LLMChatParamsNonStreaming<object, ToolCallLLMMessageOptions> + | LLMChatParamsStreaming<object, ToolCallLLMMessageOptions>, + ): Promise< + ChatResponse<ToolCallLLMMessageOptions> | AsyncIterable<ChatResponseChunk> + > { + const { messages, stream, tools } = params; + const payload: ChatRequest = { + model: this.model, + messages: messages.map((message) => { + if (message.options && "toolResult" in message.options) { + return { + role: "tool", + content: message.options.toolResult.result, + }; + } + + return { + role: message.role, + content: extractText(message.content), + }; + }), + stream: !!stream, + options: { + ...this.options, + }, + }; + if (tools) { + payload.tools = tools.map((tool) => Ollama.toTool(tool)); + } + if (!stream) { + const chatResponse = await this.ollama.chat({ + ...payload, + stream: false, + }); + if (chatResponse.message.tool_calls) { + return { + message: { + role: "assistant", + content: chatResponse.message.content, + options: { + toolCall: chatResponse.message.tool_calls.map((toolCall) => ({ + name: toolCall.function.name, + input: toolCall.function.arguments, + id: randomUUID(), + })), + }, + }, + raw: chatResponse, + }; + } + + return { + message: { + role: "assistant", + content: chatResponse.message.content, + }, + raw: chatResponse, + }; + } else { + const stream = await this.ollama.chat({ + ...payload, + stream: true, + }); + return streamConverter(stream, messageAccessor); + } + } + + complete( + params: LLMCompletionParamsStreaming, + ): Promise<AsyncIterable<CompletionResponse>>; + complete( + params: LLMCompletionParamsNonStreaming, + ): Promise<CompletionResponse>; + async complete( + params: LLMCompletionParamsStreaming | LLMCompletionParamsNonStreaming, + ): Promise<CompletionResponse | AsyncIterable<CompletionResponse>> { + const { prompt, stream } = params; + const payload: GenerateRequest = { + model: this.model, + prompt: extractText(prompt), + stream: !!stream, + options: { + ...this.options, + }, + }; + if (!stream) { + const response = await this.ollama.generate({ + ...payload, + stream: false, + }); + return { + text: response.response, + raw: response, + }; + } else { + const stream = await this.ollama.generate({ + ...payload, + stream: true, + }); + return streamConverter(stream, completionAccessor); + } + } + + static toTool(tool: BaseTool): Tool { + return { + type: "function", + function: { + name: tool.metadata.name, + description: tool.metadata.description, + parameters: { + type: tool.metadata.parameters?.type, + required: tool.metadata.parameters?.required, + properties: tool.metadata.parameters?.properties, + }, + }, + }; + } +}