From 548f0687f1d5c9715dffc9de29e715d5e90c50e3 Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Thu, 4 Jan 2024 18:03:00 -0600
Subject: [PATCH] feat(core): init support for Ollama (#305)

---
 examples/ollama.ts              |  39 +++++++
 examples/tsconfig.json          |  11 ++
 packages/core/src/llm/index.ts  |   1 +
 packages/core/src/llm/ollama.ts | 200 ++++++++++++++++++++++++++++++++
 4 files changed, 251 insertions(+)
 create mode 100644 examples/ollama.ts
 create mode 100644 examples/tsconfig.json
 create mode 100644 packages/core/src/llm/ollama.ts

diff --git a/examples/ollama.ts b/examples/ollama.ts
new file mode 100644
index 000000000..9b00d11d2
--- /dev/null
+++ b/examples/ollama.ts
@@ -0,0 +1,39 @@
+import { Ollama } from "llamaindex";
+
+(async () => {
+  const llm = new Ollama({ model: "llama2", temperature: 0.75 });
+  {
+    const response = await llm.chat([
+      { content: "Tell me a joke.", role: "user" },
+    ]);
+    console.log("Response 1:", response.message.content);
+  }
+  {
+    const response = await llm.complete("How are you?");
+    console.log("Response 2:", response.message.content);
+  }
+  {
+    const response = await llm.chat(
+      [{ content: "Tell me a joke.", role: "user" }],
+      undefined,
+      true,
+    );
+    console.log("Response 3:");
+    for await (const message of response) {
+      process.stdout.write(message); // no newline
+    }
+    console.log(); // newline
+  }
+  {
+    const response = await llm.complete("How are you?", undefined, true);
+    console.log("Response 4:");
+    for await (const message of response) {
+      process.stdout.write(message); // no newline
+    }
+    console.log(); // newline
+  }
+  {
+    const embedding = await llm.getTextEmbedding("Hello world!");
+    console.log("Embedding:", embedding);
+  }
+})();
diff --git a/examples/tsconfig.json b/examples/tsconfig.json
new file mode 100644
index 000000000..a9998d1bc
--- /dev/null
+++ b/examples/tsconfig.json
@@ -0,0 +1,11 @@
+{
+  "compilerOptions": {
+    "target": "es2016",
+    "module": "commonjs",
+    "esModuleInterop": true,
+    "forceConsistentCasingInFileNames": true,
+    "strict": true,
+    "skipLibCheck": true
+  },
+  "include": ["./**/*.ts"]
+}
diff --git a/packages/core/src/llm/index.ts b/packages/core/src/llm/index.ts
index edb5d2876..5c1a9f3ed 100644
--- a/packages/core/src/llm/index.ts
+++ b/packages/core/src/llm/index.ts
@@ -1,2 +1,3 @@
 export * from "./LLM";
 export * from "./mistral";
+export { Ollama } from "./ollama";
diff --git a/packages/core/src/llm/ollama.ts b/packages/core/src/llm/ollama.ts
new file mode 100644
index 000000000..ee1708422
--- /dev/null
+++ b/packages/core/src/llm/ollama.ts
@@ -0,0 +1,200 @@
+import { ok } from "node:assert";
+import { MessageContent } from "../ChatEngine";
+import { CallbackManager, Event } from "../callbacks/CallbackManager";
+import { BaseEmbedding } from "../embeddings";
+import { ChatMessage, ChatResponse, LLM, LLMMetadata } from "./LLM";
+
+const messageAccessor = (data: any) => data.message.content;
+const completionAccessor = (data: any) => data.response;
+
+// https://github.com/jmorganca/ollama
+export class Ollama extends BaseEmbedding implements LLM {
+  readonly hasStreaming = true;
+
+  // https://ollama.ai/library
+  model: string;
+  baseURL: string = "http://127.0.0.1:11434";
+  temperature: number = 0.7;
+  topP: number = 0.9;
+  contextWindow: number = 4096;
+  requestTimeout: number = 60 * 1000; // Default is 60 seconds
+  additionalChatOptions?: Record<string, unknown>;
+  callbackManager?: CallbackManager;
+
+  constructor(
+    init: Partial<Ollama> & {
+      // model is required
+      model: string;
+    },
+  ) {
+    super();
+    this.model = init.model;
+    Object.assign(this, init);
+  }
+
+  get metadata(): LLMMetadata {
+    return {
+      model: this.model,
+      temperature: this.temperature,
+      topP: this.topP,
+      maxTokens: undefined,
+      contextWindow: this.contextWindow,
+      tokenizer: undefined,
+    };
+  }
+
+  async chat<
+    T extends boolean | undefined = undefined,
+    R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse,
+  >(
+    messages: ChatMessage[],
+    parentEvent?: Event | undefined,
+    streaming?: T,
+  ): Promise<R> {
+    const payload = {
+      model: this.model,
+      messages: messages.map((message) => ({
+        role: message.role,
+        content: message.content,
+      })),
+      stream: !!streaming,
+      options: {
+        temperature: this.temperature,
+        num_ctx: this.contextWindow,
+        top_p: this.topP,
+        ...this.additionalChatOptions,
+      },
+    };
+    const response = await fetch(`${this.baseURL}/api/chat`, {
+      body: JSON.stringify(payload),
+      method: "POST",
+      signal: AbortSignal.timeout(this.requestTimeout),
+      headers: {
+        "Content-Type": "application/json",
+      },
+    });
+    if (!streaming) {
+      const raw = await response.json();
+      const { message } = raw;
+      return {
+        message: {
+          role: "assistant",
+          content: message.content,
+        },
+        raw,
+      } satisfies ChatResponse as R;
+    } else {
+      const stream = response.body;
+      ok(stream, "stream is null");
+      ok(stream instanceof ReadableStream, "stream is not readable");
+      return this.streamChat(stream, messageAccessor, parentEvent) as R;
+    }
+  }
+
+  private async *streamChat(
+    stream: ReadableStream<Uint8Array>,
+    accessor: (data: any) => string,
+    parentEvent?: Event,
+  ): AsyncGenerator<string, void, unknown> {
+    const reader = stream.getReader();
+    while (true) {
+      const { done, value } = await reader.read();
+      if (done) {
+        return;
+      }
+      const lines = Buffer.from(value)
+        .toString("utf-8")
+        .split("\n")
+        .map((line) => line.trim());
+      for (const line of lines) {
+        if (line === "") {
+          continue;
+        }
+        const json = JSON.parse(line);
+        if (json.error) {
+          throw new Error(json.error);
+        }
+        yield accessor(json);
+      }
+    }
+  }
+
+  async complete<
+    T extends boolean | undefined = undefined,
+    R = T extends true ? AsyncGenerator<string, void, unknown> : ChatResponse,
+  >(
+    prompt: MessageContent,
+    parentEvent?: Event | undefined,
+    streaming?: T | undefined,
+  ): Promise<R> {
+    const payload = {
+      model: this.model,
+      prompt: prompt,
+      stream: !!streaming,
+      options: {
+        temperature: this.temperature,
+        num_ctx: this.contextWindow,
+        top_p: this.topP,
+        ...this.additionalChatOptions,
+      },
+    };
+    const response = await fetch(`${this.baseURL}/api/generate`, {
+      body: JSON.stringify(payload),
+      method: "POST",
+      signal: AbortSignal.timeout(this.requestTimeout),
+      headers: {
+        "Content-Type": "application/json",
+      },
+    });
+    if (!streaming) {
+      const raw = await response.json();
+      return {
+        message: {
+          role: "assistant",
+          content: raw.response,
+        },
+        raw,
+      } satisfies ChatResponse as R;
+    } else {
+      const stream = response.body;
+      ok(stream, "stream is null");
+      ok(stream instanceof ReadableStream, "stream is not readable");
+      return this.streamChat(stream, completionAccessor, parentEvent) as R;
+    }
+  }
+
+  tokens(messages: ChatMessage[]): number {
+    throw new Error("Method not implemented.");
+  }
+
+  private async getEmbedding(prompt: string): Promise<number[]> {
+    const payload = {
+      model: this.model,
+      prompt,
+      options: {
+        temperature: this.temperature,
+        num_ctx: this.contextWindow,
+        top_p: this.topP,
+        ...this.additionalChatOptions,
+      },
+    };
+    const response = await fetch(`${this.baseURL}/api/embeddings`, {
+      body: JSON.stringify(payload),
+      method: "POST",
+      signal: AbortSignal.timeout(this.requestTimeout),
+      headers: {
+        "Content-Type": "application/json",
+      },
+    });
+    const { embedding } = await response.json();
+    return embedding;
+  }
+
+  async getTextEmbedding(text: string): Promise<number[]> {
+    return this.getEmbedding(text);
+  }
+
+  async getQueryEmbedding(query: string): Promise<number[]> {
+    return this.getEmbedding(query);
+  }
+}
-- 
GitLab