From edd0f662346bce989935e4cde0d4e18dc9e5126c Mon Sep 17 00:00:00 2001
From: Alex Yang <himself65@outlook.com>
Date: Fri, 12 Jan 2024 18:41:48 -0600
Subject: [PATCH] feat: support Together AI (#373)

---
 examples/together-ai.ts                       | 30 +++++++++++++++++++
 .../core/src/embeddings/OpenAIEmbedding.ts    |  2 +-
 packages/core/src/embeddings/index.ts         |  1 +
 packages/core/src/embeddings/together.ts      | 16 ++++++++++
 packages/core/src/llm/LLM.ts                  |  8 +++--
 packages/core/src/llm/index.ts                |  1 +
 packages/core/src/llm/together.ts             | 14 +++++++++
 packages/eslint-config-custom/index.js        |  1 +
 8 files changed, 70 insertions(+), 3 deletions(-)
 create mode 100644 examples/together-ai.ts
 create mode 100644 packages/core/src/embeddings/together.ts
 create mode 100644 packages/core/src/llm/together.ts

diff --git a/examples/together-ai.ts b/examples/together-ai.ts
new file mode 100644
index 000000000..257ce466f
--- /dev/null
+++ b/examples/together-ai.ts
@@ -0,0 +1,30 @@
+import { TogetherEmbedding, TogetherLLM } from "llamaindex";
+
+// process.env.TOGETHER_API_KEY is required
+const together = new TogetherLLM({
+  model: "mistralai/Mixtral-8x7B-Instruct-v0.1",
+});
+
+(async () => {
+  const generator = await together.chat(
+    [
+      {
+        role: "system",
+        content: "You are an AI assistant",
+      },
+      {
+        role: "user",
+        content: "Tell me about San Francisco",
+      },
+    ],
+    undefined,
+    true,
+  );
+  console.log("Chatting with Together AI...");
+  for await (const message of generator) {
+    process.stdout.write(message);
+  }
+  const embedding = new TogetherEmbedding();
+  const vector = await embedding.getTextEmbedding("Hello world!");
+  console.log("vector:", vector);
+})();
diff --git a/packages/core/src/embeddings/OpenAIEmbedding.ts b/packages/core/src/embeddings/OpenAIEmbedding.ts
index 106c6cbff..6bbbfba3a 100644
--- a/packages/core/src/embeddings/OpenAIEmbedding.ts
+++ b/packages/core/src/embeddings/OpenAIEmbedding.ts
@@ -14,7 +14,7 @@ export enum OpenAIEmbeddingModelType {
 }
 
 export class OpenAIEmbedding extends BaseEmbedding {
-  model: OpenAIEmbeddingModelType;
+  model: OpenAIEmbeddingModelType | string;
 
   // OpenAI session params
   apiKey?: string = undefined;
diff --git a/packages/core/src/embeddings/index.ts b/packages/core/src/embeddings/index.ts
index 32d6535bd..80a788f58 100644
--- a/packages/core/src/embeddings/index.ts
+++ b/packages/core/src/embeddings/index.ts
@@ -3,5 +3,6 @@ export * from "./HuggingFaceEmbedding";
 export * from "./MistralAIEmbedding";
 export * from "./MultiModalEmbedding";
 export * from "./OpenAIEmbedding";
+export { TogetherEmbedding } from "./together";
 export * from "./types";
 export * from "./utils";
diff --git a/packages/core/src/embeddings/together.ts b/packages/core/src/embeddings/together.ts
new file mode 100644
index 000000000..dde47c30c
--- /dev/null
+++ b/packages/core/src/embeddings/together.ts
@@ -0,0 +1,16 @@
+import { OpenAIEmbedding } from "./OpenAIEmbedding";
+
+export class TogetherEmbedding extends OpenAIEmbedding {
+  override model: string;
+  constructor(init?: Partial<OpenAIEmbedding>) {
+    super({
+      apiKey: process.env.TOGETHER_API_KEY,
+      ...init,
+      additionalSessionOptions: {
+        ...init?.additionalSessionOptions,
+        baseURL: "https://api.together.xyz/v1",
+      },
+    });
+    this.model = init?.model ?? "togethercomputer/m2-bert-80M-32k-retrieval";
+  }
+}
diff --git a/packages/core/src/llm/LLM.ts b/packages/core/src/llm/LLM.ts
index 06a0857a6..a90e938d7 100644
--- a/packages/core/src/llm/LLM.ts
+++ b/packages/core/src/llm/LLM.ts
@@ -129,7 +129,7 @@ export class OpenAI implements LLM {
   hasStreaming: boolean = true;
 
   // Per completion OpenAI params
-  model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS;
+  model: keyof typeof ALL_AVAILABLE_OPENAI_MODELS | string;
   temperature: number;
   topP: number;
   maxTokens?: number;
@@ -205,12 +205,16 @@ export class OpenAI implements LLM {
   }
 
   get metadata() {
+    const contextWindow =
+      ALL_AVAILABLE_OPENAI_MODELS[
+        this.model as keyof typeof ALL_AVAILABLE_OPENAI_MODELS
+      ]?.contextWindow ?? 1024;
     return {
       model: this.model,
       temperature: this.temperature,
       topP: this.topP,
       maxTokens: this.maxTokens,
-      contextWindow: ALL_AVAILABLE_OPENAI_MODELS[this.model].contextWindow,
+      contextWindow,
       tokenizer: Tokenizers.CL100K_BASE,
     };
   }
diff --git a/packages/core/src/llm/index.ts b/packages/core/src/llm/index.ts
index 5c1a9f3ed..74e0b91d9 100644
--- a/packages/core/src/llm/index.ts
+++ b/packages/core/src/llm/index.ts
@@ -1,3 +1,4 @@
 export * from "./LLM";
 export * from "./mistral";
 export { Ollama } from "./ollama";
+export { TogetherLLM } from "./together";
diff --git a/packages/core/src/llm/together.ts b/packages/core/src/llm/together.ts
new file mode 100644
index 000000000..f972faf7d
--- /dev/null
+++ b/packages/core/src/llm/together.ts
@@ -0,0 +1,14 @@
+import { OpenAI } from "./LLM";
+
+export class TogetherLLM extends OpenAI {
+  constructor(init?: Partial<OpenAI>) {
+    super({
+      ...init,
+      apiKey: process.env.TOGETHER_API_KEY,
+      additionalSessionOptions: {
+        ...init?.additionalSessionOptions,
+        baseURL: "https://api.together.xyz/v1",
+      },
+    });
+  }
+}
diff --git a/packages/eslint-config-custom/index.js b/packages/eslint-config-custom/index.js
index 4383def9e..ff53536b5 100644
--- a/packages/eslint-config-custom/index.js
+++ b/packages/eslint-config-custom/index.js
@@ -10,6 +10,7 @@ module.exports = {
           "REPLICATE_API_TOKEN",
           "ANTHROPIC_API_KEY",
           "ASSEMBLYAI_API_KEY",
+          "TOGETHER_API_KEY",
 
           "ASTRA_DB_APPLICATION_TOKEN",
           "ASTRA_DB_ENDPOINT",
-- 
GitLab