From 57e763808379743cefc4308b3ac7f456174eee7a Mon Sep 17 00:00:00 2001 From: Marcus Schiesser <mail@marcusschiesser.de> Date: Thu, 22 Aug 2024 16:30:04 +0700 Subject: [PATCH] feat: Use the retrieval defaults from LlamaCloud (#247) --- .changeset/weak-students-pay.md | 5 +++++ helpers/env-variables.ts | 1 - templates/components/engines/python/agent/engine.py | 4 ++-- templates/components/engines/python/chat/engine.py | 5 ++--- templates/components/engines/typescript/chat/chat.ts | 2 +- templates/types/extractor/fastapi/app/engine/engine.py | 4 ++-- .../multiagent/fastapi/app/agents/query_engine/agent.py | 5 ++++- 7 files changed, 16 insertions(+), 10 deletions(-) create mode 100644 .changeset/weak-students-pay.md diff --git a/.changeset/weak-students-pay.md b/.changeset/weak-students-pay.md new file mode 100644 index 00000000..e40a3f71 --- /dev/null +++ b/.changeset/weak-students-pay.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Use the retrieval defaults from LlamaCloud diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts index 222b36c6..77b60ac4 100644 --- a/helpers/env-variables.ts +++ b/helpers/env-variables.ts @@ -396,7 +396,6 @@ const getEngineEnvs = (): EnvVar[] => { name: "TOP_K", description: "The number of similar embeddings to return when retrieving documents.", - value: "3", }, { name: "STREAM_TIMEOUT", diff --git a/templates/components/engines/python/agent/engine.py b/templates/components/engines/python/agent/engine.py index 3efc1314..854757e2 100644 --- a/templates/components/engines/python/agent/engine.py +++ b/templates/components/engines/python/agent/engine.py @@ -9,14 +9,14 @@ from llama_index.core.tools.query_engine import QueryEngineTool def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") - top_k = os.getenv("TOP_K", "3") + top_k = int(os.getenv("TOP_K", 0)) tools = [] # Add query tool if index exists index = get_index() if index is not None: query_engine = index.as_query_engine( - similarity_top_k=int(top_k), filters=filters + filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {}) ) query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) tools.append(query_engine_tool) diff --git a/templates/components/engines/python/chat/engine.py b/templates/components/engines/python/chat/engine.py index b1fd361c..61fc7aad 100644 --- a/templates/components/engines/python/chat/engine.py +++ b/templates/components/engines/python/chat/engine.py @@ -9,7 +9,7 @@ from llama_index.core.chat_engine import CondensePlusContextChatEngine def get_chat_engine(filters=None, params=None): system_prompt = os.getenv("SYSTEM_PROMPT") citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None) - top_k = int(os.getenv("TOP_K", 3)) + top_k = int(os.getenv("TOP_K", 0)) node_postprocessors = [] if citation_prompt: @@ -26,8 +26,7 @@ def get_chat_engine(filters=None, params=None): ) retriever = index.as_retriever( - similarity_top_k=top_k, - filters=filters, + filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {}) ) return CondensePlusContextChatEngine.from_defaults( diff --git a/templates/components/engines/typescript/chat/chat.ts b/templates/components/engines/typescript/chat/chat.ts index e60c7970..c0841aa5 100644 --- a/templates/components/engines/typescript/chat/chat.ts +++ b/templates/components/engines/typescript/chat/chat.ts @@ -10,7 +10,7 @@ export async function createChatEngine(documentIds?: string[], params?: any) { ); } const retriever = index.asRetriever({ - similarityTopK: process.env.TOP_K ? parseInt(process.env.TOP_K) : 3, + similarityTopK: process.env.TOP_K ? parseInt(process.env.TOP_K) : undefined, filters: generateFilters(documentIds || []), }); diff --git a/templates/types/extractor/fastapi/app/engine/engine.py b/templates/types/extractor/fastapi/app/engine/engine.py index fbd92f67..07e5f5ed 100644 --- a/templates/types/extractor/fastapi/app/engine/engine.py +++ b/templates/types/extractor/fastapi/app/engine/engine.py @@ -7,7 +7,7 @@ from app.engine.index import get_index def get_query_engine(output_cls): - top_k = os.getenv("TOP_K", 3) + top_k = int(os.getenv("TOP_K", 0)) index = get_index() if index is None: @@ -21,7 +21,7 @@ def get_query_engine(output_cls): sllm = Settings.llm.as_structured_llm(output_cls) return index.as_query_engine( - similarity_top_k=int(top_k), llm=sllm, response_mode="tree_summarize", + **({"similarity_top_k": top_k} if top_k != 0 else {}), ) diff --git a/templates/types/multiagent/fastapi/app/agents/query_engine/agent.py b/templates/types/multiagent/fastapi/app/agents/query_engine/agent.py index bee1f017..4ed24e5e 100644 --- a/templates/types/multiagent/fastapi/app/agents/query_engine/agent.py +++ b/templates/types/multiagent/fastapi/app/agents/query_engine/agent.py @@ -19,7 +19,10 @@ def get_query_engine_tool() -> QueryEngineTool: index = get_index() if index is None: raise ValueError("Index not found. Please create an index first.") - query_engine = index.as_query_engine(similarity_top_k=int(os.getenv("TOP_K", 3))) + top_k = int(os.getenv("TOP_K", 0)) + query_engine = index.as_query_engine( + **({"similarity_top_k": top_k} if top_k != 0 else {}) + ) return QueryEngineTool( query_engine=query_engine, metadata=ToolMetadata( -- GitLab