Skip to content
Snippets Groups Projects
Unverified Commit 57e76380 authored by Marcus Schiesser's avatar Marcus Schiesser Committed by GitHub
Browse files

feat: Use the retrieval defaults from LlamaCloud (#247)

parent 22ac2cae
No related branches found
No related tags found
No related merge requests found
---
"create-llama": patch
---
Use the retrieval defaults from LlamaCloud
...@@ -396,7 +396,6 @@ const getEngineEnvs = (): EnvVar[] => { ...@@ -396,7 +396,6 @@ const getEngineEnvs = (): EnvVar[] => {
name: "TOP_K", name: "TOP_K",
description: description:
"The number of similar embeddings to return when retrieving documents.", "The number of similar embeddings to return when retrieving documents.",
value: "3",
}, },
{ {
name: "STREAM_TIMEOUT", name: "STREAM_TIMEOUT",
......
...@@ -9,14 +9,14 @@ from llama_index.core.tools.query_engine import QueryEngineTool ...@@ -9,14 +9,14 @@ from llama_index.core.tools.query_engine import QueryEngineTool
def get_chat_engine(filters=None, params=None): def get_chat_engine(filters=None, params=None):
system_prompt = os.getenv("SYSTEM_PROMPT") system_prompt = os.getenv("SYSTEM_PROMPT")
top_k = os.getenv("TOP_K", "3") top_k = int(os.getenv("TOP_K", 0))
tools = [] tools = []
# Add query tool if index exists # Add query tool if index exists
index = get_index() index = get_index()
if index is not None: if index is not None:
query_engine = index.as_query_engine( query_engine = index.as_query_engine(
similarity_top_k=int(top_k), filters=filters filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {})
) )
query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine) query_engine_tool = QueryEngineTool.from_defaults(query_engine=query_engine)
tools.append(query_engine_tool) tools.append(query_engine_tool)
......
...@@ -9,7 +9,7 @@ from llama_index.core.chat_engine import CondensePlusContextChatEngine ...@@ -9,7 +9,7 @@ from llama_index.core.chat_engine import CondensePlusContextChatEngine
def get_chat_engine(filters=None, params=None): def get_chat_engine(filters=None, params=None):
system_prompt = os.getenv("SYSTEM_PROMPT") system_prompt = os.getenv("SYSTEM_PROMPT")
citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None) citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None)
top_k = int(os.getenv("TOP_K", 3)) top_k = int(os.getenv("TOP_K", 0))
node_postprocessors = [] node_postprocessors = []
if citation_prompt: if citation_prompt:
...@@ -26,8 +26,7 @@ def get_chat_engine(filters=None, params=None): ...@@ -26,8 +26,7 @@ def get_chat_engine(filters=None, params=None):
) )
retriever = index.as_retriever( retriever = index.as_retriever(
similarity_top_k=top_k, filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {})
filters=filters,
) )
return CondensePlusContextChatEngine.from_defaults( return CondensePlusContextChatEngine.from_defaults(
......
...@@ -10,7 +10,7 @@ export async function createChatEngine(documentIds?: string[], params?: any) { ...@@ -10,7 +10,7 @@ export async function createChatEngine(documentIds?: string[], params?: any) {
); );
} }
const retriever = index.asRetriever({ const retriever = index.asRetriever({
similarityTopK: process.env.TOP_K ? parseInt(process.env.TOP_K) : 3, similarityTopK: process.env.TOP_K ? parseInt(process.env.TOP_K) : undefined,
filters: generateFilters(documentIds || []), filters: generateFilters(documentIds || []),
}); });
......
...@@ -7,7 +7,7 @@ from app.engine.index import get_index ...@@ -7,7 +7,7 @@ from app.engine.index import get_index
def get_query_engine(output_cls): def get_query_engine(output_cls):
top_k = os.getenv("TOP_K", 3) top_k = int(os.getenv("TOP_K", 0))
index = get_index() index = get_index()
if index is None: if index is None:
...@@ -21,7 +21,7 @@ def get_query_engine(output_cls): ...@@ -21,7 +21,7 @@ def get_query_engine(output_cls):
sllm = Settings.llm.as_structured_llm(output_cls) sllm = Settings.llm.as_structured_llm(output_cls)
return index.as_query_engine( return index.as_query_engine(
similarity_top_k=int(top_k),
llm=sllm, llm=sllm,
response_mode="tree_summarize", response_mode="tree_summarize",
**({"similarity_top_k": top_k} if top_k != 0 else {}),
) )
...@@ -19,7 +19,10 @@ def get_query_engine_tool() -> QueryEngineTool: ...@@ -19,7 +19,10 @@ def get_query_engine_tool() -> QueryEngineTool:
index = get_index() index = get_index()
if index is None: if index is None:
raise ValueError("Index not found. Please create an index first.") raise ValueError("Index not found. Please create an index first.")
query_engine = index.as_query_engine(similarity_top_k=int(os.getenv("TOP_K", 3))) top_k = int(os.getenv("TOP_K", 0))
query_engine = index.as_query_engine(
**({"similarity_top_k": top_k} if top_k != 0 else {})
)
return QueryEngineTool( return QueryEngineTool(
query_engine=query_engine, query_engine=query_engine,
metadata=ToolMetadata( metadata=ToolMetadata(
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment