From 5cd12fa90d301939cc4f02359556bdc00532d7a2 Mon Sep 17 00:00:00 2001 From: Huu Le <39040748+leehuwuj@users.noreply.github.com> Date: Thu, 29 Aug 2024 14:24:57 +0700 Subject: [PATCH] bump create-llama to 0.11 and update event handler (#260) --- .changeset/slow-papayas-camp.md | 5 + .changeset/sour-donkeys-develop.md | 5 + helpers/python.ts | 73 ++++++++++---- .../components/engines/python/agent/engine.py | 10 +- .../components/engines/python/chat/engine.py | 20 +++- .../vectordbs/python/llamacloud/index.py | 98 ++++++++++++++----- .../components/vectordbs/python/none/index.py | 35 ++++--- .../extractor/fastapi/app/engine/index.py | 20 +++- .../fastapi/app/services/extractor.py | 3 +- .../types/extractor/fastapi/pyproject.toml | 2 +- .../streaming/fastapi/app/api/routers/chat.py | 7 +- .../fastapi/app/api/services/file.py | 6 +- .../streaming/fastapi/app/engine/index.py | 20 +++- .../types/streaming/fastapi/pyproject.toml | 2 +- 14 files changed, 225 insertions(+), 81 deletions(-) create mode 100644 .changeset/slow-papayas-camp.md create mode 100644 .changeset/sour-donkeys-develop.md diff --git a/.changeset/slow-papayas-camp.md b/.changeset/slow-papayas-camp.md new file mode 100644 index 00000000..fbbf48ed --- /dev/null +++ b/.changeset/slow-papayas-camp.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Use callback manager properly diff --git a/.changeset/sour-donkeys-develop.md b/.changeset/sour-donkeys-develop.md new file mode 100644 index 00000000..6f0f4fef --- /dev/null +++ b/.changeset/sour-donkeys-develop.md @@ -0,0 +1,5 @@ +--- +"create-llama": patch +--- + +Bump create-llama version to 0.11.1 diff --git a/helpers/python.ts b/helpers/python.ts index 445f2f9c..0a2f25aa 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -12,6 +12,7 @@ import { InstallTemplateArgs, ModelConfig, TemplateDataSource, + TemplateType, TemplateVectorDB, } from "./types"; @@ -26,6 +27,7 @@ const getAdditionalDependencies = ( vectorDb?: TemplateVectorDB, dataSources?: TemplateDataSource[], tools?: Tool[], + templateType?: TemplateType, ) => { const dependencies: Dependency[] = []; @@ -128,7 +130,7 @@ const getAdditionalDependencies = ( case "llamacloud": dependencies.push({ name: "llama-index-indices-managed-llama-cloud", - version: "^0.2.7", + version: "^0.3.0", }); break; } @@ -147,77 +149,99 @@ const getAdditionalDependencies = ( case "ollama": dependencies.push({ name: "llama-index-llms-ollama", - version: "0.1.2", + version: "0.3.0", }); dependencies.push({ name: "llama-index-embeddings-ollama", - version: "0.1.2", + version: "0.3.0", }); break; case "openai": - dependencies.push({ - name: "llama-index-agent-openai", - version: "0.2.6", - }); + if (templateType !== "multiagent") { + dependencies.push({ + name: "llama-index-llms-openai", + version: "^0.2.0", + }); + dependencies.push({ + name: "llama-index-embeddings-openai", + version: "^0.2.3", + }); + dependencies.push({ + name: "llama-index-agent-openai", + version: "^0.3.0", + }); + } break; case "groq": + // Fastembed==0.2.0 does not support python3.13 at the moment + // Fixed the python version less than 3.13 + dependencies.push({ + name: "python", + version: "^3.11,<3.13", + }); dependencies.push({ name: "llama-index-llms-groq", - version: "0.1.4", + version: "0.2.0", }); dependencies.push({ name: "llama-index-embeddings-fastembed", - version: "^0.1.4", + version: "^0.2.0", }); break; case "anthropic": + // Fastembed==0.2.0 does not support python3.13 at the moment + // Fixed the python version less than 3.13 + dependencies.push({ + name: "python", + version: "^3.11,<3.13", + }); dependencies.push({ name: "llama-index-llms-anthropic", - version: "0.1.10", + version: "0.3.0", }); dependencies.push({ name: "llama-index-embeddings-fastembed", - version: "^0.1.4", + version: "^0.2.0", }); break; case "gemini": dependencies.push({ name: "llama-index-llms-gemini", - version: "0.1.10", + version: "0.3.4", }); dependencies.push({ name: "llama-index-embeddings-gemini", - version: "0.1.6", + version: "^0.2.0", }); break; case "mistral": dependencies.push({ name: "llama-index-llms-mistralai", - version: "0.1.17", + version: "0.2.1", }); dependencies.push({ name: "llama-index-embeddings-mistralai", - version: "0.1.4", + version: "0.2.0", }); break; case "azure-openai": dependencies.push({ name: "llama-index-llms-azure-openai", - version: "0.1.10", + version: "0.2.0", }); dependencies.push({ name: "llama-index-embeddings-azure-openai", - version: "0.1.11", + version: "0.2.4", }); break; case "t-systems": dependencies.push({ name: "llama-index-agent-openai", - version: "0.2.2", + version: "0.3.0", }); dependencies.push({ name: "llama-index-llms-openai-like", - version: "0.1.3", + version: "0.2.0", }); break; } @@ -227,7 +251,7 @@ const getAdditionalDependencies = ( const mergePoetryDependencies = ( dependencies: Dependency[], - existingDependencies: Record<string, Omit<Dependency, "name">>, + existingDependencies: Record<string, Omit<Dependency, "name"> | string>, ) => { for (const dependency of dependencies) { let value = existingDependencies[dependency.name] ?? {}; @@ -246,7 +270,13 @@ const mergePoetryDependencies = ( ); } - existingDependencies[dependency.name] = value; + // Serialize separately only if extras are provided + if (value.extras && value.extras.length > 0) { + existingDependencies[dependency.name] = value; + } else { + // Otherwise, serialize just the version string + existingDependencies[dependency.name] = value.version; + } } }; @@ -388,6 +418,7 @@ export const installPythonTemplate = async ({ vectorDb, dataSources, tools, + template, ); if (observability && observability !== "none") { diff --git a/templates/components/engines/python/agent/engine.py b/templates/components/engines/python/agent/engine.py index 854757e2..22a30d0e 100644 --- a/templates/components/engines/python/agent/engine.py +++ b/templates/components/engines/python/agent/engine.py @@ -1,19 +1,22 @@ import os -from app.engine.index import get_index +from app.engine.index import IndexConfig, get_index from app.engine.tools import ToolFactory from llama_index.core.agent import AgentRunner +from llama_index.core.callbacks import CallbackManager from llama_index.core.settings import Settings from llama_index.core.tools.query_engine import QueryEngineTool -def get_chat_engine(filters=None, params=None): +def get_chat_engine(filters=None, params=None, event_handlers=None): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = int(os.getenv("TOP_K", 0)) tools = [] + callback_manager = CallbackManager(handlers=event_handlers or []) # Add query tool if index exists - index = get_index() + index_config = IndexConfig(callback_manager=callback_manager, **(params or {})) + index = get_index(index_config) if index is not None: query_engine = index.as_query_engine( filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {}) @@ -28,5 +31,6 @@ def get_chat_engine(filters=None, params=None): llm=Settings.llm, tools=tools, system_prompt=system_prompt, + callback_manager=callback_manager, verbose=True, ) diff --git a/templates/components/engines/python/chat/engine.py b/templates/components/engines/python/chat/engine.py index 61fc7aad..cb7e0082 100644 --- a/templates/components/engines/python/chat/engine.py +++ b/templates/components/engines/python/chat/engine.py @@ -1,22 +1,31 @@ import os -from app.engine.index import get_index +from app.engine.index import IndexConfig, get_index from app.engine.node_postprocessors import NodeCitationProcessor from fastapi import HTTPException +from llama_index.core.callbacks import CallbackManager from llama_index.core.chat_engine import CondensePlusContextChatEngine +from llama_index.core.memory import ChatMemoryBuffer +from llama_index.core.settings import Settings -def get_chat_engine(filters=None, params=None): +def get_chat_engine(filters=None, params=None, event_handlers=None): system_prompt = os.getenv("SYSTEM_PROMPT") citation_prompt = os.getenv("SYSTEM_CITATION_PROMPT", None) top_k = int(os.getenv("TOP_K", 0)) + llm = Settings.llm + memory = ChatMemoryBuffer.from_defaults( + token_limit=llm.metadata.context_window - 256 + ) + callback_manager = CallbackManager(handlers=event_handlers or []) node_postprocessors = [] if citation_prompt: node_postprocessors = [NodeCitationProcessor()] system_prompt = f"{system_prompt}\n{citation_prompt}" - index = get_index(params) + index_config = IndexConfig(callback_manager=callback_manager, **(params or {})) + index = get_index(index_config) if index is None: raise HTTPException( status_code=500, @@ -29,8 +38,11 @@ def get_chat_engine(filters=None, params=None): filters=filters, **({"similarity_top_k": top_k} if top_k != 0 else {}) ) - return CondensePlusContextChatEngine.from_defaults( + return CondensePlusContextChatEngine( + llm=llm, + memory=memory, system_prompt=system_prompt, retriever=retriever, node_postprocessors=node_postprocessors, + callback_manager=callback_manager, ) diff --git a/templates/components/vectordbs/python/llamacloud/index.py b/templates/components/vectordbs/python/llamacloud/index.py index 0a4ba795..570f7223 100644 --- a/templates/components/vectordbs/python/llamacloud/index.py +++ b/templates/components/vectordbs/python/llamacloud/index.py @@ -1,41 +1,87 @@ import logging import os -from llama_index.indices.managed.llama_cloud import LlamaCloudIndex +from typing import Optional + +from llama_index.core.callbacks import CallbackManager from llama_index.core.ingestion.api_utils import ( get_client as llama_cloud_get_client, ) +from llama_index.indices.managed.llama_cloud import LlamaCloudIndex +from pydantic import BaseModel, Field, validator logger = logging.getLogger("uvicorn") -def get_client(): - return llama_cloud_get_client( - os.getenv("LLAMA_CLOUD_API_KEY"), - os.getenv("LLAMA_CLOUD_BASE_URL"), +class LlamaCloudConfig(BaseModel): + # Private attributes + api_key: str = Field( + default=os.getenv("LLAMA_CLOUD_API_KEY"), + exclude=True, # Exclude from the model representation + ) + base_url: Optional[str] = Field( + default=os.getenv("LLAMA_CLOUD_BASE_URL"), + exclude=True, ) + organization_id: Optional[str] = Field( + default=os.getenv("LLAMA_CLOUD_ORGANIZATION_ID"), + exclude=True, + ) + # Configuration attributes, can be set by the user + pipeline: str = Field( + description="The name of the pipeline to use", + default=os.getenv("LLAMA_CLOUD_INDEX_NAME"), + ) + project: str = Field( + description="The name of the LlamaCloud project", + default=os.getenv("LLAMA_CLOUD_PROJECT_NAME"), + ) + + # Validate and throw error if the env variables are not set before starting the app + @validator("pipeline", "project", "api_key", pre=True, always=True) + @classmethod + def validate_env_vars(cls, value): + if value is None: + raise ValueError( + "Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY" + " to your environment variables or config them in .env file" + ) + return value + def to_client_kwargs(self) -> dict: + return { + "api_key": self.api_key, + "base_url": self.base_url, + } -def get_index(params=None): - configParams = params or {} - pipelineConfig = configParams.get("llamaCloudPipeline", {}) - name = pipelineConfig.get("pipeline", os.getenv("LLAMA_CLOUD_INDEX_NAME")) - project_name = pipelineConfig.get("project", os.getenv("LLAMA_CLOUD_PROJECT_NAME")) - api_key = os.getenv("LLAMA_CLOUD_API_KEY") - base_url = os.getenv("LLAMA_CLOUD_BASE_URL") - organization_id = os.getenv("LLAMA_CLOUD_ORGANIZATION_ID") - - if name is None or project_name is None or api_key is None: - raise ValueError( - "Please set LLAMA_CLOUD_INDEX_NAME, LLAMA_CLOUD_PROJECT_NAME and LLAMA_CLOUD_API_KEY" - " to your environment variables or config them in .env file" - ) - - index = LlamaCloudIndex( - name=name, - project_name=project_name, - api_key=api_key, - base_url=base_url, - organization_id=organization_id, + +class IndexConfig(BaseModel): + llama_cloud_pipeline_config: LlamaCloudConfig = Field( + default=LlamaCloudConfig(), + alias="llamaCloudPipeline", + ) + callback_manager: Optional[CallbackManager] = Field( + default=None, ) + def to_index_kwargs(self) -> dict: + return { + "name": self.llama_cloud_pipeline_config.pipeline, + "project_name": self.llama_cloud_pipeline_config.project, + "api_key": self.llama_cloud_pipeline_config.api_key, + "base_url": self.llama_cloud_pipeline_config.base_url, + "organization_id": self.llama_cloud_pipeline_config.organization_id, + "callback_manager": self.callback_manager, + } + + +def get_index(config: IndexConfig = None): + if config is None: + config = IndexConfig() + index = LlamaCloudIndex(**config.to_index_kwargs()) + return index + + +def get_client(): + config = LlamaCloudConfig() + return llama_cloud_get_client(**config.to_client_kwargs()) diff --git a/templates/components/vectordbs/python/none/index.py b/templates/components/vectordbs/python/none/index.py index 65fd5ad5..cd61c539 100644 --- a/templates/components/vectordbs/python/none/index.py +++ b/templates/components/vectordbs/python/none/index.py @@ -1,23 +1,26 @@ -import os import logging +import os from datetime import timedelta +from typing import Optional -from cachetools import cached, TTLCache -from llama_index.core.storage import StorageContext +from cachetools import TTLCache, cached +from llama_index.core.callbacks import CallbackManager from llama_index.core.indices import load_index_from_storage +from llama_index.core.storage import StorageContext +from pydantic import BaseModel, Field logger = logging.getLogger("uvicorn") -@cached( - TTLCache(maxsize=10, ttl=timedelta(minutes=5).total_seconds()), - key=lambda *args, **kwargs: "global_storage_context", -) -def get_storage_context(persist_dir: str) -> StorageContext: - return StorageContext.from_defaults(persist_dir=persist_dir) +class IndexConfig(BaseModel): + callback_manager: Optional[CallbackManager] = Field( + default=None, + ) -def get_index(params=None): +def get_index(config: IndexConfig = None): + if config is None: + config = IndexConfig() storage_dir = os.getenv("STORAGE_DIR", "storage") # check if storage already exists if not os.path.exists(storage_dir): @@ -25,6 +28,16 @@ def get_index(params=None): # load the existing index logger.info(f"Loading index from {storage_dir}...") storage_context = get_storage_context(storage_dir) - index = load_index_from_storage(storage_context) + index = load_index_from_storage( + storage_context, callback_manager=config.callback_manager + ) logger.info(f"Finished loading index from {storage_dir}") return index + + +@cached( + TTLCache(maxsize=10, ttl=timedelta(minutes=5).total_seconds()), + key=lambda *args, **kwargs: "global_storage_context", +) +def get_storage_context(persist_dir: str) -> StorageContext: + return StorageContext.from_defaults(persist_dir=persist_dir) diff --git a/templates/types/extractor/fastapi/app/engine/index.py b/templates/types/extractor/fastapi/app/engine/index.py index e1adcb80..c24e39f9 100644 --- a/templates/types/extractor/fastapi/app/engine/index.py +++ b/templates/types/extractor/fastapi/app/engine/index.py @@ -1,17 +1,31 @@ import logging +from typing import Optional + +from llama_index.core.callbacks import CallbackManager from llama_index.core.indices import VectorStoreIndex -from app.engine.vectordb import get_vector_store +from pydantic import BaseModel, Field +from app.engine.vectordb import get_vector_store logger = logging.getLogger("uvicorn") -def get_index(params=None): +class IndexConfig(BaseModel): + callback_manager: Optional[CallbackManager] = Field( + default=None, + ) + + +def get_index(config: IndexConfig = None): + if config is None: + config = IndexConfig() logger.info("Connecting vector store...") store = get_vector_store() # Load the index from the vector store # If you are using a vector store that doesn't store text, # you must load the index from both the vector store and the document store - index = VectorStoreIndex.from_vector_store(store) + index = VectorStoreIndex.from_vector_store( + store, callback_manager=config.callback_manager + ) logger.info("Finished load index from vector store.") return index diff --git a/templates/types/extractor/fastapi/app/services/extractor.py b/templates/types/extractor/fastapi/app/services/extractor.py index c6e041b4..7ddcbfbf 100644 --- a/templates/types/extractor/fastapi/app/services/extractor.py +++ b/templates/types/extractor/fastapi/app/services/extractor.py @@ -1,4 +1,5 @@ import logging + from app.engine import get_query_engine from app.services.model import IMPORTS @@ -33,4 +34,4 @@ class ExtractorService: query_engine = get_query_engine(schema_model) response = await query_engine.aquery(query) output_data = response.response.dict() - return schema_model(**output_data).json(indent=2) + return schema_model(**output_data).model_dump_json(indent=2) diff --git a/templates/types/extractor/fastapi/pyproject.toml b/templates/types/extractor/fastapi/pyproject.toml index 5953fed0..3603c6cb 100644 --- a/templates/types/extractor/fastapi/pyproject.toml +++ b/templates/types/extractor/fastapi/pyproject.toml @@ -13,7 +13,7 @@ python = "^3.11,<4.0" fastapi = "^0.109.1" uvicorn = { extras = ["standard"], version = "^0.23.2" } python-dotenv = "^1.0.0" -llama-index = "^0.10.58" +llama-index = "^0.11.1" cachetools = "^5.3.3" reflex = "^0.5.9" diff --git a/templates/types/streaming/fastapi/app/api/routers/chat.py b/templates/types/streaming/fastapi/app/api/routers/chat.py index ace2a3b8..39894361 100644 --- a/templates/types/streaming/fastapi/app/api/routers/chat.py +++ b/templates/types/streaming/fastapi/app/api/routers/chat.py @@ -38,11 +38,10 @@ async def chat( logger.info( f"Creating chat engine with filters: {str(filters)}", ) - chat_engine = get_chat_engine(filters=filters, params=params) - event_handler = EventCallbackHandler() - chat_engine.callback_manager.handlers.append(event_handler) # type: ignore - + chat_engine = get_chat_engine( + filters=filters, params=params, event_handlers=[event_handler] + ) response = await chat_engine.astream_chat(last_message_content, messages) process_response_nodes(response.source_nodes, background_tasks) diff --git a/templates/types/streaming/fastapi/app/api/services/file.py b/templates/types/streaming/fastapi/app/api/services/file.py index 72107f8d..9441db6e 100644 --- a/templates/types/streaming/fastapi/app/api/services/file.py +++ b/templates/types/streaming/fastapi/app/api/services/file.py @@ -5,8 +5,7 @@ from io import BytesIO from pathlib import Path from typing import Any, List, Tuple - -from app.engine.index import get_index +from app.engine.index import IndexConfig, get_index from llama_index.core import VectorStoreIndex from llama_index.core.ingestion import IngestionPipeline from llama_index.core.readers.file.base import ( @@ -77,7 +76,8 @@ class PrivateFileService: file_data, extension = PrivateFileService.preprocess_base64_file(base64_content) # Add the nodes to the index and persist it - current_index = get_index(params) + index_config = IndexConfig(**params) + current_index = get_index(index_config) # Insert the documents into the index if isinstance(current_index, LlamaCloudIndex): diff --git a/templates/types/streaming/fastapi/app/engine/index.py b/templates/types/streaming/fastapi/app/engine/index.py index e1adcb80..c24e39f9 100644 --- a/templates/types/streaming/fastapi/app/engine/index.py +++ b/templates/types/streaming/fastapi/app/engine/index.py @@ -1,17 +1,31 @@ import logging +from typing import Optional + +from llama_index.core.callbacks import CallbackManager from llama_index.core.indices import VectorStoreIndex -from app.engine.vectordb import get_vector_store +from pydantic import BaseModel, Field +from app.engine.vectordb import get_vector_store logger = logging.getLogger("uvicorn") -def get_index(params=None): +class IndexConfig(BaseModel): + callback_manager: Optional[CallbackManager] = Field( + default=None, + ) + + +def get_index(config: IndexConfig = None): + if config is None: + config = IndexConfig() logger.info("Connecting vector store...") store = get_vector_store() # Load the index from the vector store # If you are using a vector store that doesn't store text, # you must load the index from both the vector store and the document store - index = VectorStoreIndex.from_vector_store(store) + index = VectorStoreIndex.from_vector_store( + store, callback_manager=config.callback_manager + ) logger.info("Finished load index from vector store.") return index diff --git a/templates/types/streaming/fastapi/pyproject.toml b/templates/types/streaming/fastapi/pyproject.toml index f622b1d8..218d6d69 100644 --- a/templates/types/streaming/fastapi/pyproject.toml +++ b/templates/types/streaming/fastapi/pyproject.toml @@ -14,7 +14,7 @@ fastapi = "^0.109.1" uvicorn = { extras = ["standard"], version = "^0.23.2" } python-dotenv = "^1.0.0" aiostream = "^0.5.2" -llama-index = "0.10.58" +llama-index = "0.11.1" cachetools = "^5.3.3" [build-system] -- GitLab