From afc11a3a994bad81d176d3836b71ada43cecafbd Mon Sep 17 00:00:00 2001
From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com>
Date: Tue, 5 Mar 2024 16:59:51 +0700
Subject: [PATCH] feat: add embedding model option to create-llama (#608)

---
 create-app.ts                                 |  2 +
 e2e/utils.ts                                  |  3 ++
 helpers/index.ts                              |  6 +++
 helpers/types.ts                              |  1 +
 index.ts                                      |  7 ++++
 questions.ts                                  | 38 ++++++++++++++++++-
 .../types/simple/fastapi/app/settings.py      |  8 +++-
 .../types/streaming/fastapi/app/settings.py   |  8 +++-
 8 files changed, 67 insertions(+), 6 deletions(-)

diff --git a/create-app.ts b/create-app.ts
index 865de6bf..1b5ed303 100644
--- a/create-app.ts
+++ b/create-app.ts
@@ -34,6 +34,7 @@ export async function createApp({
   openAiKey,
   llamaCloudKey,
   model,
+  embeddingModel,
   communityProjectPath,
   llamapack,
   vectorDb,
@@ -80,6 +81,7 @@ export async function createApp({
     openAiKey,
     llamaCloudKey,
     model,
+    embeddingModel,
     communityProjectPath,
     llamapack,
     vectorDb,
diff --git a/e2e/utils.ts b/e2e/utils.ts
index 739930c7..cab23e08 100644
--- a/e2e/utils.ts
+++ b/e2e/utils.ts
@@ -14,6 +14,7 @@ import {
 
 export type AppType = "--frontend" | "--no-frontend" | "";
 const MODEL = "gpt-3.5-turbo";
+const EMBEDDING_MODEL = "text-embedding-ada-002";
 export type CreateLlamaResult = {
   projectName: string;
   appProcess: ChildProcess;
@@ -106,6 +107,8 @@ export async function runCreateLlama(
     vectorDb,
     "--model",
     MODEL,
+    "--embedding-model",
+    EMBEDDING_MODEL,
     "--open-ai-key",
     process.env.OPENAI_API_KEY || "testKey",
     appType,
diff --git a/helpers/index.ts b/helpers/index.ts
index 77c445e6..cb5e4372 100644
--- a/helpers/index.ts
+++ b/helpers/index.ts
@@ -29,6 +29,7 @@ const createEnvLocalFile = async (
     llamaCloudKey?: string;
     vectorDb?: TemplateVectorDB;
     model?: string;
+    embeddingModel?: string;
     framework?: TemplateFramework;
     dataSource?: TemplateDataSource;
   },
@@ -47,6 +48,10 @@ const createEnvLocalFile = async (
     content += `OPENAI_API_KEY=${opts?.openAiKey}\n`;
   }
 
+  if (opts?.embeddingModel) {
+    content += `EMBEDDING_MODEL=${opts?.embeddingModel}\n`;
+  }
+
   if (opts?.llamaCloudKey) {
     content += `LLAMA_CLOUD_API_KEY=${opts?.llamaCloudKey}\n`;
   }
@@ -213,6 +218,7 @@ export const installTemplate = async (
       llamaCloudKey: props.llamaCloudKey,
       vectorDb: props.vectorDb,
       model: props.model,
+      embeddingModel: props.embeddingModel,
       framework: props.framework,
       dataSource: props.dataSource,
     });
diff --git a/helpers/types.ts b/helpers/types.ts
index 9923e030..52ef94b9 100644
--- a/helpers/types.ts
+++ b/helpers/types.ts
@@ -39,6 +39,7 @@ export interface InstallTemplateArgs {
   llamaCloudKey?: string;
   forBackend?: string;
   model: string;
+  embeddingModel: string;
   communityProjectPath?: string;
   llamapack?: string;
   vectorDb?: TemplateVectorDB;
diff --git a/index.ts b/index.ts
index ce58b435..804eae6c 100644
--- a/index.ts
+++ b/index.ts
@@ -119,6 +119,12 @@ const program = new Commander.Command(packageJson.name)
     `
 
   Select OpenAI model to use. E.g. gpt-3.5-turbo.
+`,
+  )
+  .option(
+    "--embedding-model <embeddingModel>",
+    `
+  Select OpenAI embedding model to use. E.g. text-embedding-ada-002.
 `,
   )
   .option(
@@ -281,6 +287,7 @@ async function run(): Promise<void> {
     openAiKey: program.openAiKey,
     llamaCloudKey: program.llamaCloudKey,
     model: program.model,
+    embeddingModel: program.embeddingModel,
     communityProjectPath: program.communityProjectPath,
     llamapack: program.llamapack,
     vectorDb: program.vectorDb,
diff --git a/questions.ts b/questions.ts
index 21ea3891..194c3cce 100644
--- a/questions.ts
+++ b/questions.ts
@@ -69,6 +69,7 @@ const defaults: QuestionArgs = {
   openAiKey: "",
   llamaCloudKey: "",
   model: "gpt-3.5-turbo",
+  embeddingModel: "text-embedding-ada-002",
   communityProjectPath: "",
   llamapack: "",
   postInstallAction: "dependencies",
@@ -443,6 +444,38 @@ export const askQuestions = async (
     }
   }
 
+  if (!program.embeddingModel && program.framework === "fastapi") {
+    if (ciInfo.isCI) {
+      program.embeddingModel = getPrefOrDefault("embeddingModel");
+    } else {
+      const { embeddingModel } = await prompts(
+        {
+          type: "select",
+          name: "embeddingModel",
+          message: "Which embedding model would you like to use?",
+          choices: [
+            {
+              title: "text-embedding-ada-002",
+              value: "text-embedding-ada-002",
+            },
+            {
+              title: "text-embedding-3-small",
+              value: "text-embedding-3-small",
+            },
+            {
+              title: "text-embedding-3-large",
+              value: "text-embedding-3-large",
+            },
+          ],
+          initial: 0,
+        },
+        handlers,
+      );
+      program.embeddingModel = embeddingModel;
+      preferences.embeddingModel = embeddingModel;
+    }
+  }
+
   if (program.files) {
     // If user specified files option, then the program should use context engine
     program.engine == "context";
@@ -527,8 +560,9 @@ export const askQuestions = async (
   }
 
   if (
-    program.dataSource?.type === "file" ||
-    (program.dataSource?.type === "folder" && program.framework === "fastapi")
+    (program.dataSource?.type === "file" ||
+      program.dataSource?.type === "folder") &&
+    program.framework === "fastapi"
   ) {
     if (ciInfo.isCI) {
       program.llamaCloudKey = getPrefOrDefault("llamaCloudKey");
diff --git a/templates/types/simple/fastapi/app/settings.py b/templates/types/simple/fastapi/app/settings.py
index e221a6b4..bd49f945 100644
--- a/templates/types/simple/fastapi/app/settings.py
+++ b/templates/types/simple/fastapi/app/settings.py
@@ -1,10 +1,14 @@
 import os
 from llama_index.llms.openai import OpenAI
+from llama_index.embeddings.openai import OpenAIEmbedding
 from llama_index.core.settings import Settings
 
 
 def init_settings():
-    model = os.getenv("MODEL", "gpt-3.5-turbo")
-    Settings.llm = OpenAI(model=model)
+    llm_model = os.getenv("MODEL", "gpt-3.5-turbo")
+    embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-ada-002")
+
+    Settings.llm = OpenAI(model=llm_model)
+    Settings.embed_model = OpenAIEmbedding(model=embedding_model)
     Settings.chunk_size = 1024
     Settings.chunk_overlap = 20
diff --git a/templates/types/streaming/fastapi/app/settings.py b/templates/types/streaming/fastapi/app/settings.py
index e221a6b4..bd49f945 100644
--- a/templates/types/streaming/fastapi/app/settings.py
+++ b/templates/types/streaming/fastapi/app/settings.py
@@ -1,10 +1,14 @@
 import os
 from llama_index.llms.openai import OpenAI
+from llama_index.embeddings.openai import OpenAIEmbedding
 from llama_index.core.settings import Settings
 
 
 def init_settings():
-    model = os.getenv("MODEL", "gpt-3.5-turbo")
-    Settings.llm = OpenAI(model=model)
+    llm_model = os.getenv("MODEL", "gpt-3.5-turbo")
+    embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-ada-002")
+
+    Settings.llm = OpenAI(model=llm_model)
+    Settings.embed_model = OpenAIEmbedding(model=embedding_model)
     Settings.chunk_size = 1024
     Settings.chunk_overlap = 20
-- 
GitLab