From afc11a3a994bad81d176d3836b71ada43cecafbd Mon Sep 17 00:00:00 2001 From: "Huu Le (Lee)" <39040748+leehuwuj@users.noreply.github.com> Date: Tue, 5 Mar 2024 16:59:51 +0700 Subject: [PATCH] feat: add embedding model option to create-llama (#608) --- create-app.ts | 2 + e2e/utils.ts | 3 ++ helpers/index.ts | 6 +++ helpers/types.ts | 1 + index.ts | 7 ++++ questions.ts | 38 ++++++++++++++++++- .../types/simple/fastapi/app/settings.py | 8 +++- .../types/streaming/fastapi/app/settings.py | 8 +++- 8 files changed, 67 insertions(+), 6 deletions(-) diff --git a/create-app.ts b/create-app.ts index 865de6bf..1b5ed303 100644 --- a/create-app.ts +++ b/create-app.ts @@ -34,6 +34,7 @@ export async function createApp({ openAiKey, llamaCloudKey, model, + embeddingModel, communityProjectPath, llamapack, vectorDb, @@ -80,6 +81,7 @@ export async function createApp({ openAiKey, llamaCloudKey, model, + embeddingModel, communityProjectPath, llamapack, vectorDb, diff --git a/e2e/utils.ts b/e2e/utils.ts index 739930c7..cab23e08 100644 --- a/e2e/utils.ts +++ b/e2e/utils.ts @@ -14,6 +14,7 @@ import { export type AppType = "--frontend" | "--no-frontend" | ""; const MODEL = "gpt-3.5-turbo"; +const EMBEDDING_MODEL = "text-embedding-ada-002"; export type CreateLlamaResult = { projectName: string; appProcess: ChildProcess; @@ -106,6 +107,8 @@ export async function runCreateLlama( vectorDb, "--model", MODEL, + "--embedding-model", + EMBEDDING_MODEL, "--open-ai-key", process.env.OPENAI_API_KEY || "testKey", appType, diff --git a/helpers/index.ts b/helpers/index.ts index 77c445e6..cb5e4372 100644 --- a/helpers/index.ts +++ b/helpers/index.ts @@ -29,6 +29,7 @@ const createEnvLocalFile = async ( llamaCloudKey?: string; vectorDb?: TemplateVectorDB; model?: string; + embeddingModel?: string; framework?: TemplateFramework; dataSource?: TemplateDataSource; }, @@ -47,6 +48,10 @@ const createEnvLocalFile = async ( content += `OPENAI_API_KEY=${opts?.openAiKey}\n`; } + if (opts?.embeddingModel) { + content += `EMBEDDING_MODEL=${opts?.embeddingModel}\n`; + } + if (opts?.llamaCloudKey) { content += `LLAMA_CLOUD_API_KEY=${opts?.llamaCloudKey}\n`; } @@ -213,6 +218,7 @@ export const installTemplate = async ( llamaCloudKey: props.llamaCloudKey, vectorDb: props.vectorDb, model: props.model, + embeddingModel: props.embeddingModel, framework: props.framework, dataSource: props.dataSource, }); diff --git a/helpers/types.ts b/helpers/types.ts index 9923e030..52ef94b9 100644 --- a/helpers/types.ts +++ b/helpers/types.ts @@ -39,6 +39,7 @@ export interface InstallTemplateArgs { llamaCloudKey?: string; forBackend?: string; model: string; + embeddingModel: string; communityProjectPath?: string; llamapack?: string; vectorDb?: TemplateVectorDB; diff --git a/index.ts b/index.ts index ce58b435..804eae6c 100644 --- a/index.ts +++ b/index.ts @@ -119,6 +119,12 @@ const program = new Commander.Command(packageJson.name) ` Select OpenAI model to use. E.g. gpt-3.5-turbo. +`, + ) + .option( + "--embedding-model <embeddingModel>", + ` + Select OpenAI embedding model to use. E.g. text-embedding-ada-002. `, ) .option( @@ -281,6 +287,7 @@ async function run(): Promise<void> { openAiKey: program.openAiKey, llamaCloudKey: program.llamaCloudKey, model: program.model, + embeddingModel: program.embeddingModel, communityProjectPath: program.communityProjectPath, llamapack: program.llamapack, vectorDb: program.vectorDb, diff --git a/questions.ts b/questions.ts index 21ea3891..194c3cce 100644 --- a/questions.ts +++ b/questions.ts @@ -69,6 +69,7 @@ const defaults: QuestionArgs = { openAiKey: "", llamaCloudKey: "", model: "gpt-3.5-turbo", + embeddingModel: "text-embedding-ada-002", communityProjectPath: "", llamapack: "", postInstallAction: "dependencies", @@ -443,6 +444,38 @@ export const askQuestions = async ( } } + if (!program.embeddingModel && program.framework === "fastapi") { + if (ciInfo.isCI) { + program.embeddingModel = getPrefOrDefault("embeddingModel"); + } else { + const { embeddingModel } = await prompts( + { + type: "select", + name: "embeddingModel", + message: "Which embedding model would you like to use?", + choices: [ + { + title: "text-embedding-ada-002", + value: "text-embedding-ada-002", + }, + { + title: "text-embedding-3-small", + value: "text-embedding-3-small", + }, + { + title: "text-embedding-3-large", + value: "text-embedding-3-large", + }, + ], + initial: 0, + }, + handlers, + ); + program.embeddingModel = embeddingModel; + preferences.embeddingModel = embeddingModel; + } + } + if (program.files) { // If user specified files option, then the program should use context engine program.engine == "context"; @@ -527,8 +560,9 @@ export const askQuestions = async ( } if ( - program.dataSource?.type === "file" || - (program.dataSource?.type === "folder" && program.framework === "fastapi") + (program.dataSource?.type === "file" || + program.dataSource?.type === "folder") && + program.framework === "fastapi" ) { if (ciInfo.isCI) { program.llamaCloudKey = getPrefOrDefault("llamaCloudKey"); diff --git a/templates/types/simple/fastapi/app/settings.py b/templates/types/simple/fastapi/app/settings.py index e221a6b4..bd49f945 100644 --- a/templates/types/simple/fastapi/app/settings.py +++ b/templates/types/simple/fastapi/app/settings.py @@ -1,10 +1,14 @@ import os from llama_index.llms.openai import OpenAI +from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.core.settings import Settings def init_settings(): - model = os.getenv("MODEL", "gpt-3.5-turbo") - Settings.llm = OpenAI(model=model) + llm_model = os.getenv("MODEL", "gpt-3.5-turbo") + embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-ada-002") + + Settings.llm = OpenAI(model=llm_model) + Settings.embed_model = OpenAIEmbedding(model=embedding_model) Settings.chunk_size = 1024 Settings.chunk_overlap = 20 diff --git a/templates/types/streaming/fastapi/app/settings.py b/templates/types/streaming/fastapi/app/settings.py index e221a6b4..bd49f945 100644 --- a/templates/types/streaming/fastapi/app/settings.py +++ b/templates/types/streaming/fastapi/app/settings.py @@ -1,10 +1,14 @@ import os from llama_index.llms.openai import OpenAI +from llama_index.embeddings.openai import OpenAIEmbedding from llama_index.core.settings import Settings def init_settings(): - model = os.getenv("MODEL", "gpt-3.5-turbo") - Settings.llm = OpenAI(model=model) + llm_model = os.getenv("MODEL", "gpt-3.5-turbo") + embedding_model = os.getenv("EMBEDDING_MODEL", "text-embedding-ada-002") + + Settings.llm = OpenAI(model=llm_model) + Settings.embed_model = OpenAIEmbedding(model=embedding_model) Settings.chunk_size = 1024 Settings.chunk_overlap = 20 -- GitLab