From 57e763808379743cefc4308b3ac7f456174eee7a Mon Sep 17 00:00:00 2001
From: Marcus Schiesser <mail@marcusschiesser.de>
Date: Thu, 22 Aug 2024 16:30:04 +0700
Subject: [PATCH] feat: Use the retrieval defaults from LlamaCloud (#247)

---
 .changeset/weak-students-pay.md                              | 5 +++++
 helpers/env-variables.ts                                     | 1 -
 templates/components/engines/python/agent/engine.py          | 4 ++--
 templates/components/engines/python/chat/engine.py           | 5 ++---
 templates/components/engines/typescript/chat/chat.ts         | 2 +-
 templates/types/extractor/fastapi/app/engine/engine.py       | 4 ++--
 .../multiagent/fastapi/app/agents/query_engine/agent.py      | 5 ++++-
 7 files changed, 16 insertions(+), 10 deletions(-)
 create mode 100644 .changeset/weak-students-pay.md

diff --git a/.changeset/weak-students-pay.md b/.changeset/weak-students-pay.md
new file mode 100644
index 00000000..e40a3f71
--- /dev/null
+++ b/.changeset/weak-students-pay.md
@@ -0,0 +1,5 @@
+---
+"create-llama": patch
+---
+
+Use the retrieval defaults from LlamaCloud
diff --git a/helpers/env-variables.ts b/helpers/env-variables.ts
index 222b36c6..77b60ac4 100644
--- a/helpers/env-variables.ts
+++ b/helpers/env-variables.ts
@@ -396,7 +396,6 @@ const getEngineEnvs = (): EnvVar[] => {
       name: "TOP_K",
       description:
         "The number of similar embeddings to return when retrieving documents.",
-      value: "3",
     },
     {
       name: "STREAM_TIMEOUT",
diff --git a/templates/components/engines/python/agent/engine.py b/templates/components/engines/python/agent/engine.py
index 3efc1314..854757e2 100644
--- a/templates/components/engines/python/agent/engine.py
+++ b/templates/components/engines/python/agent/engine.py
@@ -9,14 +9,14 @@ from llama_index.core.tools.query_engine import QueryEngineTool
 
 def get_chat_engine(filters=None, params=None):
     system_prompt = os.getenv("SYSTEM_PROMPT")
-    top_k = os.getenv("TOP_K", "3")
+    top_k = int(os.getenv("TOP_K", 0))
     tools = []
 
     # Add query tool if index exists
     index = get_index()
     if index is not None:
         query_engine = index.as_query_engine(
-            similarity_top_k=int(top_k), filters=filters
+            filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {})
         )
         query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
         tools.append(query_engine_tool)
diff --git a/templates/components/engines/python/chat/engine.py b/templates/components/engines/python/chat/engine.py
index b1fd361c..61fc7aad 100644
--- a/templates/components/engines/python/chat/engine.py
+++ b/templates/components/engines/python/chat/engine.py
@@ -9,7 +9,7 @@ from llama_index.core.chat_engine import CondensePlusContextChatEngine
 def get_chat_engine(filters=None, params=None):
     system_prompt = os.getenv("SYSTEM_PROMPT")
     citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None)
-    top_k = int(os.getenv("TOP_K", 3))
+    top_k = int(os.getenv("TOP_K", 0))
 
     node_postprocessors = []
     if citation_prompt:
@@ -26,8 +26,7 @@ def get_chat_engine(filters=None, params=None):
         )
 
     retriever = index.as_retriever(
-        similarity_top_k=top_k,
-        filters=filters,
+        filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {})
     )
 
     return CondensePlusContextChatEngine.from_defaults(
diff --git a/templates/components/engines/typescript/chat/chat.ts b/templates/components/engines/typescript/chat/chat.ts
index e60c7970..c0841aa5 100644
--- a/templates/components/engines/typescript/chat/chat.ts
+++ b/templates/components/engines/typescript/chat/chat.ts
@@ -10,7 +10,7 @@ export async function createChatEngine(documentIds?: string[], params?: any) {
     );
   }
   const retriever = index.asRetriever({
-    similarityTopK: process.env.TOP_K ? parseInt(process.env.TOP_K) : 3,
+    similarityTopK: process.env.TOP_K ? parseInt(process.env.TOP_K) : undefined,
     filters: generateFilters(documentIds || []),
   });
 
diff --git a/templates/types/extractor/fastapi/app/engine/engine.py b/templates/types/extractor/fastapi/app/engine/engine.py
index fbd92f67..07e5f5ed 100644
--- a/templates/types/extractor/fastapi/app/engine/engine.py
+++ b/templates/types/extractor/fastapi/app/engine/engine.py
@@ -7,7 +7,7 @@ from app.engine.index import get_index
 
 
 def get_query_engine(output_cls):
-    top_k = os.getenv("TOP_K", 3)
+    top_k = int(os.getenv("TOP_K", 0))
 
     index = get_index()
     if index is None:
@@ -21,7 +21,7 @@ def get_query_engine(output_cls):
     sllm = Settings.llm.as_structured_llm(output_cls)
 
     return index.as_query_engine(
-        similarity_top_k=int(top_k),
         llm=sllm,
         response_mode="tree_summarize",
+        **({"similarity_top_k": top_k} if top_k != 0 else {}),
     )
diff --git a/templates/types/multiagent/fastapi/app/agents/query_engine/agent.py b/templates/types/multiagent/fastapi/app/agents/query_engine/agent.py
index bee1f017..4ed24e5e 100644
--- a/templates/types/multiagent/fastapi/app/agents/query_engine/agent.py
+++ b/templates/types/multiagent/fastapi/app/agents/query_engine/agent.py
@@ -19,7 +19,10 @@ def get_query_engine_tool() -> QueryEngineTool:
     index = get_index()
     if index is None:
         raise ValueError("Index not found. Please create an index first.")
-    query_engine = index.as_query_engine(similarity_top_k=int(os.getenv("TOP_K", 3)))
+    top_k = int(os.getenv("TOP_K", 0))
+    query_engine = index.as_query_engine(
+        **({"similarity_top_k": top_k} if top_k != 0 else {})
+    )
     return QueryEngineTool(
         query_engine=query_engine,
         metadata=ToolMetadata(
-- 
GitLab