From c60182a9254bbc9b2d254de5975816ea6f663df2 Mon Sep 17 00:00:00 2001 From: Huu Le <39040748+leehuwuj@users.noreply.github.com> Date: Tue, 8 Oct 2024 15:17:38 +0700 Subject: [PATCH] Add mypy checker (#346) --- .husky/pre-commit | 1 + e2e/python/resolve_dependencies.spec.ts | 238 +++++++++++------- helpers/python.ts | 2 +- .../components/engines/python/agent/engine.py | 7 +- .../engines/python/agent/tools/__init__.py | 34 ++- .../engines/python/agent/tools/artifact.py | 2 +- .../engines/python/agent/tools/duckduckgo.py | 26 +- .../components/engines/python/chat/engine.py | 2 +- .../components/loaders/python/__init__.py | 8 +- templates/components/loaders/python/db.py | 9 +- templates/components/loaders/python/web.py | 6 +- .../components/settings/python/llmhub.py | 16 +- .../components/settings/python/settings.py | 80 ++++-- .../vectordbs/python/astra/vectordb.py | 2 +- .../vectordbs/python/chroma/vectordb.py | 3 +- .../vectordbs/python/milvus/vectordb.py | 3 +- .../components/vectordbs/python/none/index.py | 2 +- .../extractor/fastapi/app/engine/generate.py | 4 +- .../extractor/fastapi/app/engine/index.py | 2 +- .../fastapi/app/api/routers/events.py | 16 +- .../fastapi/app/api/routers/models.py | 27 +- .../streaming/fastapi/app/engine/generate.py | 11 +- .../types/streaming/fastapi/pyproject.toml | 17 ++ 23 files changed, 339 insertions(+), 179 deletions(-) diff --git a/.husky/pre-commit b/.husky/pre-commit index d466e587..f9926bf4 100644 --- a/.husky/pre-commit +++ b/.husky/pre-commit @@ -1,2 +1,3 @@ pnpm format pnpm lint +uvx ruff format --check templates/ diff --git a/e2e/python/resolve_dependencies.spec.ts b/e2e/python/resolve_dependencies.spec.ts index a1b02802..f1e5ddaf 100644 --- a/e2e/python/resolve_dependencies.spec.ts +++ b/e2e/python/resolve_dependencies.spec.ts @@ -15,6 +15,8 @@ const dataSource: string = process.env.DATASOURCE ? process.env.DATASOURCE : "--example-file"; +// TODO: add support for other templates + if ( dataSource === "--example-file" // XXX: this test provides its own data source - only trigger it on one data source (usually the CI matrix will trigger multiple data sources) ) { @@ -45,14 +47,86 @@ if ( const observabilityOptions = ["llamatrace", "traceloop"]; - // Run separate tests for each observability option to reduce CI runtime - test.describe("Test resolve python dependencies with observability", () => { - // Testing with streaming template, vectorDb: none, tools: none, and dataSource: --example-file - for (const observability of observabilityOptions) { - test(`observability: ${observability}`, async () => { + test.describe("Mypy check", () => { + test.describe.configure({ retries: 0 }); + + // Test vector databases + for (const vectorDb of vectorDbs) { + test(`Mypy check for vectorDB: ${vectorDb}`, async () => { const cwd = await createTestDir(); + const { pyprojectPath } = await createAndCheckLlamaProject({ + options: { + cwd, + templateType: "streaming", + templateFramework, + dataSource: "--example-file", + vectorDb, + tools: "none", + port: 3000, + externalPort: 8000, + postInstallAction: "none", + templateUI: undefined, + appType: "--no-frontend", + llamaCloudProjectName: undefined, + llamaCloudIndexName: undefined, + observability: undefined, + }, + }); + + const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8"); + if (vectorDb !== "none") { + if (vectorDb === "pg") { + expect(pyprojectContent).toContain( + "llama-index-vector-stores-postgres", + ); + } else { + expect(pyprojectContent).toContain( + `llama-index-vector-stores-${vectorDb}`, + ); + } + } + }); + } - await createAndCheckLlamaProject({ + // Test tools + for (const tool of toolOptions) { + test(`Mypy check for tool: ${tool}`, async () => { + const cwd = await createTestDir(); + const { pyprojectPath } = await createAndCheckLlamaProject({ + options: { + cwd, + templateType: "streaming", + templateFramework, + dataSource: "--example-file", + vectorDb: "none", + tools: tool, + port: 3000, + externalPort: 8000, + postInstallAction: "none", + templateUI: undefined, + appType: "--no-frontend", + llamaCloudProjectName: undefined, + llamaCloudIndexName: undefined, + observability: undefined, + }, + }); + + const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8"); + if (tool === "wikipedia.WikipediaToolSpec") { + expect(pyprojectContent).toContain("wikipedia"); + } + if (tool === "google.GoogleSearchToolSpec") { + expect(pyprojectContent).toContain("google"); + } + }); + } + + // Test data sources + for (const dataSource of dataSources) { + const dataSourceType = dataSource.split(" ")[0]; + test(`Mypy check for data source: ${dataSourceType}`, async () => { + const cwd = await createTestDir(); + const { pyprojectPath } = await createAndCheckLlamaProject({ options: { cwd, templateType: "streaming", @@ -60,87 +134,51 @@ if ( dataSource, vectorDb: "none", tools: "none", - port: 3000, // port, not used - externalPort: 8000, // externalPort, not used - postInstallAction: "none", // postInstallAction - templateUI: undefined, // ui - appType: "--no-frontend", // appType - llamaCloudProjectName: undefined, // llamaCloudProjectName - llamaCloudIndexName: undefined, // llamaCloudIndexName - observability, + port: 3000, + externalPort: 8000, + postInstallAction: "none", + templateUI: undefined, + appType: "--no-frontend", + llamaCloudProjectName: undefined, + llamaCloudIndexName: undefined, + observability: undefined, }, }); + + const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8"); + if (dataSource.includes("--web-source")) { + expect(pyprojectContent).toContain("llama-index-readers-web"); + } + if (dataSource.includes("--db-source")) { + expect(pyprojectContent).toContain("llama-index-readers-database"); + } }); } - }); - test.describe("Test resolve python dependencies", () => { - for (const vectorDb of vectorDbs) { - for (const tool of toolOptions) { - for (const dataSource of dataSources) { - const dataSourceType = dataSource.split(" ")[0]; - const toolDescription = tool === "none" ? "no tools" : tool; - const optionDescription = `vectorDb: ${vectorDb}, ${toolDescription}, dataSource: ${dataSourceType}`; - - test(`options: ${optionDescription}`, async () => { - const cwd = await createTestDir(); - - const { pyprojectPath, projectPath } = - await createAndCheckLlamaProject({ - options: { - cwd, - templateType: "streaming", - templateFramework, - dataSource, - vectorDb, - tools: tool, - port: 3000, // port, not used - externalPort: 8000, // externalPort, not used - postInstallAction: "none", // postInstallAction - templateUI: undefined, // ui - appType: "--no-frontend", // appType - llamaCloudProjectName: undefined, // llamaCloudProjectName - llamaCloudIndexName: undefined, // llamaCloudIndexName - observability: undefined, // observability - }, - }); - - // Additional checks for specific dependencies - - // Verify that specific dependencies are in pyproject.toml - const pyprojectContent = fs.readFileSync(pyprojectPath, "utf-8"); - if (vectorDb !== "none") { - if (vectorDb === "pg") { - expect(pyprojectContent).toContain( - "llama-index-vector-stores-postgres", - ); - } else { - expect(pyprojectContent).toContain( - `llama-index-vector-stores-${vectorDb}`, - ); - } - } - if (tool !== "none") { - if (tool === "wikipedia.WikipediaToolSpec") { - expect(pyprojectContent).toContain("wikipedia"); - } - if (tool === "google.GoogleSearchToolSpec") { - expect(pyprojectContent).toContain("google"); - } - } - - // Check for data source specific dependencies - if (dataSource.includes("--web-source")) { - expect(pyprojectContent).toContain("llama-index-readers-web"); - } - if (dataSource.includes("--db-source")) { - expect(pyprojectContent).toContain( - "llama-index-readers-database ", - ); - } - }); - } - } + // Test observability options + for (const observability of observabilityOptions) { + test(`Mypy check for observability: ${observability}`, async () => { + const cwd = await createTestDir(); + + const { pyprojectPath } = await createAndCheckLlamaProject({ + options: { + cwd, + templateType: "streaming", + templateFramework, + dataSource: "--example-file", + vectorDb: "none", + tools: "none", + port: 3000, + externalPort: 8000, + postInstallAction: "none", + templateUI: undefined, + appType: "--no-frontend", + llamaCloudProjectName: undefined, + llamaCloudIndexName: undefined, + observability, + }, + }); + }); } }); } @@ -161,21 +199,39 @@ async function createAndCheckLlamaProject({ const pyprojectPath = path.join(projectPath, "pyproject.toml"); expect(fs.existsSync(pyprojectPath)).toBeTruthy(); - // Run poetry lock + const env = { + ...process.env, + POETRY_VIRTUALENVS_IN_PROJECT: "true", + }; + + // Run poetry install + try { + const { stdout: installStdout, stderr: installStderr } = await execAsync( + "poetry install", + { cwd: projectPath, env }, + ); + console.log("poetry install stdout:", installStdout); + console.error("poetry install stderr:", installStderr); + } catch (error) { + console.error("Error running poetry install:", error); + throw error; + } + + // Run poetry run mypy try { - const { stdout, stderr } = await execAsync( - "poetry config virtualenvs.in-project true && poetry lock --no-update", - { cwd: projectPath }, + const { stdout: mypyStdout, stderr: mypyStderr } = await execAsync( + "poetry run mypy .", + { cwd: projectPath, env }, ); - console.log("poetry lock stdout:", stdout); - console.error("poetry lock stderr:", stderr); + console.log("poetry run mypy stdout:", mypyStdout); + console.error("poetry run mypy stderr:", mypyStderr); } catch (error) { - console.error("Error running poetry lock:", error); + console.error("Error running mypy:", error); throw error; } - // Check if poetry.lock file was created - expect(fs.existsSync(path.join(projectPath, "poetry.lock"))).toBeTruthy(); + // If we reach this point without throwing an error, the test passes + expect(true).toBeTruthy(); return { pyprojectPath, projectPath }; } diff --git a/helpers/python.ts b/helpers/python.ts index 20d93394..a1be02d0 100644 --- a/helpers/python.ts +++ b/helpers/python.ts @@ -123,7 +123,7 @@ const getAdditionalDependencies = ( extras: ["rsa"], }); dependencies.push({ - name: "psycopg2", + name: "psycopg2-binary", version: "^2.9.9", }); break; diff --git a/templates/components/engines/python/agent/engine.py b/templates/components/engines/python/agent/engine.py index c71d3704..d1129305 100644 --- a/templates/components/engines/python/agent/engine.py +++ b/templates/components/engines/python/agent/engine.py @@ -1,17 +1,19 @@ import os +from typing import List from app.engine.index import IndexConfig, get_index from app.engine.tools import ToolFactory from llama_index.core.agent import AgentRunner from llama_index.core.callbacks import CallbackManager from llama_index.core.settings import Settings +from llama_index.core.tools import BaseTool from llama_index.core.tools.query_engine import QueryEngineTool def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs): system_prompt = os.getenv("SYSTEM_PROMPT") top_k = int(os.getenv("TOP_K", 0)) - tools = [] + tools: List[BaseTool] = [] callback_manager = CallbackManager(handlers=event_handlers or []) # Add query tool if index exists @@ -25,7 +27,8 @@ def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs): tools.append(query_engine_tool) # Add additional tools - tools += ToolFactory.from_env() + configured_tools: List[BaseTool] = ToolFactory.from_env() + tools.extend(configured_tools) return AgentRunner.from_llm( llm=Settings.llm, diff --git a/templates/components/engines/python/agent/tools/__init__.py b/templates/components/engines/python/agent/tools/__init__.py index 6b218432..f9ede661 100644 --- a/templates/components/engines/python/agent/tools/__init__.py +++ b/templates/components/engines/python/agent/tools/__init__.py @@ -1,7 +1,8 @@ import importlib import os +from typing import Dict, List, Union -import yaml +import yaml # type: ignore from llama_index.core.tools.function_tool import FunctionTool from llama_index.core.tools.tool_spec.base import BaseToolSpec @@ -17,7 +18,8 @@ class ToolFactory: ToolType.LOCAL: "app.engine.tools", } - def load_tools(tool_type: str, tool_name: str, config: dict) -> list[FunctionTool]: + @staticmethod + def load_tools(tool_type: str, tool_name: str, config: dict) -> List[FunctionTool]: source_package = ToolFactory.TOOL_SOURCE_PACKAGE_MAP[tool_type] try: if "ToolSpec" in tool_name: @@ -43,24 +45,32 @@ class ToolFactory: @staticmethod def from_env( map_result: bool = False, - ) -> list[FunctionTool] | dict[str, FunctionTool]: + ) -> Union[Dict[str, List[FunctionTool]], List[FunctionTool]]: """ Load tools from the configured file. - Params: - - use_map: if True, return map of tool name and the tool itself + + Args: + map_result: If True, return a map of tool names to their corresponding tools. + + Returns: + A dictionary of tool names to lists of FunctionTools if map_result is True, + otherwise a list of FunctionTools. """ - if map_result: - tools = {} - else: - tools = [] + tools: Union[Dict[str, List[FunctionTool]], List[FunctionTool]] = ( + {} if map_result else [] + ) + if os.path.exists("config/tools.yaml"): with open("config/tools.yaml", "r") as f: tool_configs = yaml.safe_load(f) for tool_type, config_entries in tool_configs.items(): for tool_name, config in config_entries.items(): - tool = ToolFactory.load_tools(tool_type, tool_name, config) + loaded_tools = ToolFactory.load_tools( + tool_type, tool_name, config + ) if map_result: - tools[tool_name] = tool + tools[tool_name] = loaded_tools # type: ignore else: - tools.extend(tool) + tools.extend(loaded_tools) # type: ignore + return tools diff --git a/templates/components/engines/python/agent/tools/artifact.py b/templates/components/engines/python/agent/tools/artifact.py index 9b132c64..4c877b2f 100644 --- a/templates/components/engines/python/agent/tools/artifact.py +++ b/templates/components/engines/python/agent/tools/artifact.py @@ -87,7 +87,7 @@ class CodeGeneratorTool: ChatMessage(role="user", content=user_message), ] try: - sllm = Settings.llm.as_structured_llm(output_cls=CodeArtifact) + sllm = Settings.llm.as_structured_llm(output_cls=CodeArtifact) # type: ignore response = sllm.chat(messages) data: CodeArtifact = response.raw return data.model_dump() diff --git a/templates/components/engines/python/agent/tools/duckduckgo.py b/templates/components/engines/python/agent/tools/duckduckgo.py index ec0f6332..b9a2f9ee 100644 --- a/templates/components/engines/python/agent/tools/duckduckgo.py +++ b/templates/components/engines/python/agent/tools/duckduckgo.py @@ -21,14 +21,15 @@ def duckduckgo_search( "Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`" ) - params = { - "keywords": query, - "region": region, - "max_results": max_results, - } results = [] with DDGS() as ddg: - results = list(ddg.text(**params)) + results = list( + ddg.text( + keywords=query, + region=region, + max_results=max_results, + ) + ) return results @@ -51,13 +52,14 @@ def duckduckgo_image_search( "duckduckgo_search package is required to use this function." "Please install it by running: `poetry add duckduckgo_search` or `pip install duckduckgo_search`" ) - params = { - "keywords": query, - "region": region, - "max_results": max_results, - } with DDGS() as ddg: - results = list(ddg.images(**params)) + results = list( + ddg.images( + keywords=query, + region=region, + max_results=max_results, + ) + ) return results diff --git a/templates/components/engines/python/chat/engine.py b/templates/components/engines/python/chat/engine.py index 83c73cc6..75b9fb9e 100644 --- a/templates/components/engines/python/chat/engine.py +++ b/templates/components/engines/python/chat/engine.py @@ -43,6 +43,6 @@ def get_chat_engine(filters=None, params=None, event_handlers=None, **kwargs): memory=memory, system_prompt=system_prompt, retriever=retriever, - node_postprocessors=node_postprocessors, + node_postprocessors=node_postprocessors, # type: ignore callback_manager=callback_manager, ) diff --git a/templates/components/loaders/python/__init__.py b/templates/components/loaders/python/__init__.py index 4a278a4d..fb61aeb5 100644 --- a/templates/components/loaders/python/__init__.py +++ b/templates/components/loaders/python/__init__.py @@ -1,20 +1,22 @@ import logging +from typing import Any, Dict, List -import yaml +import yaml # type: ignore from app.engine.loaders.db import DBLoaderConfig, get_db_documents from app.engine.loaders.file import FileLoaderConfig, get_file_documents from app.engine.loaders.web import WebLoaderConfig, get_web_documents +from llama_index.core import Document logger = logging.getLogger(__name__) -def load_configs(): +def load_configs() -> Dict[str, Any]: with open("config/loaders.yaml") as f: configs = yaml.safe_load(f) return configs -def get_documents(): +def get_documents() -> List[Document]: documents = [] config = load_configs() for loader_type, loader_config in config.items(): diff --git a/templates/components/loaders/python/db.py b/templates/components/loaders/python/db.py index b6e3d8f0..2d146b7c 100644 --- a/templates/components/loaders/python/db.py +++ b/templates/components/loaders/python/db.py @@ -1,5 +1,6 @@ import logging from typing import List + from pydantic import BaseModel logger = logging.getLogger(__name__) @@ -11,7 +12,13 @@ class DBLoaderConfig(BaseModel): def get_db_documents(configs: list[DBLoaderConfig]): - from llama_index.readers.database import DatabaseReader + try: + from llama_index.readers.database import DatabaseReader + except ImportError: + logger.error( + "Failed to import DatabaseReader. Make sure llama_index is installed." + ) + raise docs = [] for entry in configs: diff --git a/templates/components/loaders/python/web.py b/templates/components/loaders/python/web.py index a9bf281f..e0623ba2 100644 --- a/templates/components/loaders/python/web.py +++ b/templates/components/loaders/python/web.py @@ -1,3 +1,5 @@ +from typing import List, Optional + from pydantic import BaseModel, Field @@ -8,8 +10,8 @@ class CrawlUrl(BaseModel): class WebLoaderConfig(BaseModel): - driver_arguments: list[str] = Field(default=None) - urls: list[CrawlUrl] + driver_arguments: Optional[List[str]] = Field(default_factory=list) + urls: List[CrawlUrl] def get_web_documents(config: WebLoaderConfig): diff --git a/templates/components/settings/python/llmhub.py b/templates/components/settings/python/llmhub.py index 2c46b252..ae8f2400 100644 --- a/templates/components/settings/python/llmhub.py +++ b/templates/components/settings/python/llmhub.py @@ -1,7 +1,11 @@ -from llama_index.embeddings.openai import OpenAIEmbedding -from llama_index.core.settings import Settings -from typing import Dict +import logging import os +from typing import Dict + +from llama_index.core.settings import Settings +from llama_index.embeddings.openai import OpenAIEmbedding + +logger = logging.getLogger(__name__) DEFAULT_MODEL = "gpt-3.5-turbo" DEFAULT_EMBEDDING_MODEL = "text-embedding-3-large" @@ -50,7 +54,11 @@ def embedding_config_from_env() -> Dict: def init_llmhub(): - from llama_index.llms.openai_like import OpenAILike + try: + from llama_index.llms.openai_like import OpenAILike + except ImportError: + logger.error("Failed to import OpenAILike. Make sure llama_index is installed.") + raise llm_configs = llm_config_from_env() embedding_configs = embedding_config_from_env() diff --git a/templates/components/settings/python/settings.py b/templates/components/settings/python/settings.py index 620a4379..681974ce 100644 --- a/templates/components/settings/python/settings.py +++ b/templates/components/settings/python/settings.py @@ -33,8 +33,13 @@ def init_settings(): def init_ollama(): - from llama_index.embeddings.ollama import OllamaEmbedding - from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama + try: + from llama_index.embeddings.ollama import OllamaEmbedding + from llama_index.llms.ollama.base import DEFAULT_REQUEST_TIMEOUT, Ollama + except ImportError: + raise ImportError( + "Ollama support is not installed. Please install it with `poetry add llama-index-llms-ollama` and `poetry add llama-index-embeddings-ollama`" + ) base_url = os.getenv("OLLAMA_BASE_URL") or "http://127.0.0.1:11434" request_timeout = float( @@ -55,25 +60,29 @@ def init_openai(): from llama_index.llms.openai import OpenAI max_tokens = os.getenv("LLM_MAX_TOKENS") - config = { - "model": os.getenv("MODEL"), - "temperature": float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)), - "max_tokens": int(max_tokens) if max_tokens is not None else None, - } - Settings.llm = OpenAI(**config) + Settings.llm = OpenAI( + model=os.getenv("MODEL", "gpt-4o-mini"), + temperature=float(os.getenv("LLM_TEMPERATURE", DEFAULT_TEMPERATURE)), + max_tokens=int(max_tokens) if max_tokens is not None else None, + ) dimensions = os.getenv("EMBEDDING_DIM") - config = { - "model": os.getenv("EMBEDDING_MODEL"), - "dimensions": int(dimensions) if dimensions is not None else None, - } - Settings.embed_model = OpenAIEmbedding(**config) + Settings.embed_model = OpenAIEmbedding( + model=os.getenv("EMBEDDING_MODEL", "text-embedding-3-small"), + dimensions=int(dimensions) if dimensions is not None else None, + ) def init_azure_openai(): from llama_index.core.constants import DEFAULT_TEMPERATURE - from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding - from llama_index.llms.azure_openai import AzureOpenAI + + try: + from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding + from llama_index.llms.azure_openai import AzureOpenAI + except ImportError: + raise ImportError( + "Azure OpenAI support is not installed. Please install it with `poetry add llama-index-llms-azure-openai` and `poetry add llama-index-embeddings-azure-openai`" + ) llm_deployment = os.environ["AZURE_OPENAI_LLM_DEPLOYMENT"] embedding_deployment = os.environ["AZURE_OPENAI_EMBEDDING_DEPLOYMENT"] @@ -105,26 +114,37 @@ def init_azure_openai(): def init_fastembed(): - """ - Use Qdrant Fastembed as the local embedding provider. - """ - from llama_index.embeddings.fastembed import FastEmbedEmbedding + try: + from llama_index.embeddings.fastembed import FastEmbedEmbedding + except ImportError: + raise ImportError( + "FastEmbed support is not installed. Please install it with `poetry add llama-index-embeddings-fastembed`" + ) embed_model_map: Dict[str, str] = { # Small and multilingual "all-MiniLM-L6-v2": "sentence-transformers/all-MiniLM-L6-v2", # Large and multilingual - "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", # noqa: E501 + "paraphrase-multilingual-mpnet-base-v2": "sentence-transformers/paraphrase-multilingual-mpnet-base-v2", } + embedding_model = os.getenv("EMBEDDING_MODEL") + if embedding_model is None: + raise ValueError("EMBEDDING_MODEL environment variable is not set") + # This will download the model automatically if it is not already downloaded Settings.embed_model = FastEmbedEmbedding( - model_name=embed_model_map[os.getenv("EMBEDDING_MODEL")] + model_name=embed_model_map[embedding_model] ) def init_groq(): - from llama_index.llms.groq import Groq + try: + from llama_index.llms.groq import Groq + except ImportError: + raise ImportError( + "Groq support is not installed. Please install it with `poetry add llama-index-llms-groq`" + ) Settings.llm = Groq(model=os.getenv("MODEL")) # Groq does not provide embeddings, so we use FastEmbed instead @@ -132,7 +152,12 @@ def init_groq(): def init_anthropic(): - from llama_index.llms.anthropic import Anthropic + try: + from llama_index.llms.anthropic import Anthropic + except ImportError: + raise ImportError( + "Anthropic support is not installed. Please install it with `poetry add llama-index-llms-anthropic`" + ) model_map: Dict[str, str] = { "claude-3-opus": "claude-3-opus-20240229", @@ -148,8 +173,13 @@ def init_anthropic(): def init_gemini(): - from llama_index.embeddings.gemini import GeminiEmbedding - from llama_index.llms.gemini import Gemini + try: + from llama_index.embeddings.gemini import GeminiEmbedding + from llama_index.llms.gemini import Gemini + except ImportError: + raise ImportError( + "Gemini support is not installed. Please install it with `poetry add llama-index-llms-gemini` and `poetry add llama-index-embeddings-gemini`" + ) model_name = f"models/{os.getenv('MODEL')}" embed_model_name = f"models/{os.getenv('EMBEDDING_MODEL')}" diff --git a/templates/components/vectordbs/python/astra/vectordb.py b/templates/components/vectordbs/python/astra/vectordb.py index f84b329e..15899e05 100644 --- a/templates/components/vectordbs/python/astra/vectordb.py +++ b/templates/components/vectordbs/python/astra/vectordb.py @@ -15,6 +15,6 @@ def get_vector_store(): token=token, api_endpoint=endpoint, collection_name=collection, - embedding_dimension=int(os.getenv("EMBEDDING_DIM")), + embedding_dimension=int(os.getenv("EMBEDDING_DIM", 768)), ) return store diff --git a/templates/components/vectordbs/python/chroma/vectordb.py b/templates/components/vectordbs/python/chroma/vectordb.py index 2a71e0a2..f577408b 100644 --- a/templates/components/vectordbs/python/chroma/vectordb.py +++ b/templates/components/vectordbs/python/chroma/vectordb.py @@ -1,4 +1,5 @@ import os + from llama_index.vector_stores.chroma import ChromaVectorStore @@ -18,7 +19,7 @@ def get_vector_store(): ) store = ChromaVectorStore.from_params( host=os.getenv("CHROMA_HOST"), - port=int(os.getenv("CHROMA_PORT")), + port=os.getenv("CHROMA_PORT", "8001"), collection_name=collection_name, ) return store diff --git a/templates/components/vectordbs/python/milvus/vectordb.py b/templates/components/vectordbs/python/milvus/vectordb.py index 7da817c9..3f41940b 100644 --- a/templates/components/vectordbs/python/milvus/vectordb.py +++ b/templates/components/vectordbs/python/milvus/vectordb.py @@ -1,4 +1,5 @@ import os + from llama_index.vector_stores.milvus import MilvusVectorStore @@ -15,6 +16,6 @@ def get_vector_store(): user=os.getenv("MILVUS_USERNAME"), password=os.getenv("MILVUS_PASSWORD"), collection_name=collection, - dim=int(os.getenv("EMBEDDING_DIM")), + dim=int(os.getenv("EMBEDDING_DIM", 768)), ) return store diff --git a/templates/components/vectordbs/python/none/index.py b/templates/components/vectordbs/python/none/index.py index cd61c539..cba2610c 100644 --- a/templates/components/vectordbs/python/none/index.py +++ b/templates/components/vectordbs/python/none/index.py @@ -3,7 +3,7 @@ import os from datetime import timedelta from typing import Optional -from cachetools import TTLCache, cached +from cachetools import TTLCache, cached # type: ignore from llama_index.core.callbacks import CallbackManager from llama_index.core.indices import load_index_from_storage from llama_index.core.storage import StorageContext diff --git a/templates/types/extractor/fastapi/app/engine/generate.py b/templates/types/extractor/fastapi/app/engine/generate.py index c6f641f7..b45c21cc 100644 --- a/templates/types/extractor/fastapi/app/engine/generate.py +++ b/templates/types/extractor/fastapi/app/engine/generate.py @@ -6,7 +6,7 @@ load_dotenv() import logging import os -from llama_index.core.ingestion import IngestionPipeline +from llama_index.core.ingestion import DocstoreStrategy, IngestionPipeline from llama_index.core.node_parser import SentenceSplitter from llama_index.core.settings import Settings from llama_index.core.storage import StorageContext @@ -41,7 +41,7 @@ def run_pipeline(docstore, vector_store, documents): Settings.embed_model, ], docstore=docstore, - docstore_strategy="upserts_and_delete", + docstore_strategy=DocstoreStrategy.UPSERTS_AND_DELETE, # type: ignore vector_store=vector_store, ) diff --git a/templates/types/extractor/fastapi/app/engine/index.py b/templates/types/extractor/fastapi/app/engine/index.py index c24e39f9..5674994b 100644 --- a/templates/types/extractor/fastapi/app/engine/index.py +++ b/templates/types/extractor/fastapi/app/engine/index.py @@ -16,7 +16,7 @@ class IndexConfig(BaseModel): ) -def get_index(config: IndexConfig = None): +def get_index(config: Optional[IndexConfig] = None) -> VectorStoreIndex: if config is None: config = IndexConfig() logger.info("Connecting vector store...") diff --git a/templates/types/streaming/fastapi/app/api/routers/events.py b/templates/types/streaming/fastapi/app/api/routers/events.py index 94cc5851..d19196a7 100644 --- a/templates/types/streaming/fastapi/app/api/routers/events.py +++ b/templates/types/streaming/fastapi/app/api/routers/events.py @@ -1,13 +1,13 @@ -import json import asyncio +import json import logging -from typing import AsyncGenerator, Dict, Any, List, Optional +from typing import Any, AsyncGenerator, Dict, List, Optional + from llama_index.core.callbacks.base import BaseCallbackHandler from llama_index.core.callbacks.schema import CBEventType from llama_index.core.tools.types import ToolOutput from pydantic import BaseModel - logger = logging.getLogger(__name__) @@ -31,15 +31,20 @@ class CallbackEvent(BaseModel): return None def get_tool_message(self) -> dict | None: + if self.payload is None: + return None func_call_args = self.payload.get("function_call") if func_call_args is not None and "tool" in self.payload: tool = self.payload.get("tool") + if tool is None: + return None return { "type": "events", "data": { "title": f"Calling tool: {tool.name} with inputs: {func_call_args}", }, } + return None def _is_output_serializable(self, output: Any) -> bool: try: @@ -49,6 +54,8 @@ class CallbackEvent(BaseModel): return False def get_agent_tool_response(self) -> dict | None: + if self.payload is None: + return None response = self.payload.get("response") if response is not None: sources = response.sources @@ -74,6 +81,7 @@ class CallbackEvent(BaseModel): }, }, } + return None def to_response(self): try: @@ -114,11 +122,13 @@ class EventCallbackHandler(BaseCallbackHandler): event_type: CBEventType, payload: Optional[Dict[str, Any]] = None, event_id: str = "", + parent_id: str = "", **kwargs: Any, ) -> str: event = CallbackEvent(event_id=event_id, event_type=event_type, payload=payload) if event.to_response() is not None: self._aqueue.put_nowait(event) + return event_id def on_event_end( self, diff --git a/templates/types/streaming/fastapi/app/api/routers/models.py b/templates/types/streaming/fastapi/app/api/routers/models.py index c5c8d4a3..17c63e59 100644 --- a/templates/types/streaming/fastapi/app/api/routers/models.py +++ b/templates/types/streaming/fastapi/app/api/routers/models.py @@ -1,6 +1,6 @@ import logging import os -from typing import Any, Dict, List, Literal, Optional +from typing import Any, Dict, List, Literal, Optional, Union from llama_index.core.llms import ChatMessage, MessageRole from llama_index.core.schema import NodeWithScore @@ -62,15 +62,23 @@ class ArtifactAnnotation(BaseModel): class Annotation(BaseModel): type: str - data: AnnotationFileData | List[str] | AgentAnnotation | ArtifactAnnotation + data: Union[AnnotationFileData, List[str], AgentAnnotation, ArtifactAnnotation] - def to_content(self) -> str | None: + def to_content(self) -> Optional[str]: if self.type == "document_file": - # We only support generating context content for CSV files for now - csv_files = [file for file in self.data.files if file.filetype == "csv"] - if len(csv_files) > 0: - return "Use data from following CSV raw content\n" + "\n".join( - [f"```csv\n{csv_file.content.value}\n```" for csv_file in csv_files] + if isinstance(self.data, AnnotationFileData): + # We only support generating context content for CSV files for now + csv_files = [file for file in self.data.files if file.filetype == "csv"] + if len(csv_files) > 0: + return "Use data from following CSV raw content\n" + "\n".join( + [ + f"```csv\n{csv_file.content.value}\n```" + for csv_file in csv_files + ] + ) + else: + logger.warning( + f"Unexpected data type for document_file annotation: {type(self.data)}" ) else: logger.warning( @@ -213,6 +221,7 @@ class ChatData(BaseModel): for annotation in message.annotations: if ( annotation.type == "document_file" + and isinstance(annotation.data, AnnotationFileData) and annotation.data.files is not None ): for fi in annotation.data.files: @@ -242,7 +251,7 @@ class SourceNodes(BaseModel): ) @classmethod - def get_url_from_metadata(cls, metadata: Dict[str, Any]) -> str: + def get_url_from_metadata(cls, metadata: Dict[str, Any]) -> Optional[str]: url_prefix = os.getenv("FILESERVER_URL_PREFIX") if not url_prefix: logger.warning( diff --git a/templates/types/streaming/fastapi/app/engine/generate.py b/templates/types/streaming/fastapi/app/engine/generate.py index 1bca2e28..f27cdbd5 100644 --- a/templates/types/streaming/fastapi/app/engine/generate.py +++ b/templates/types/streaming/fastapi/app/engine/generate.py @@ -6,15 +6,16 @@ load_dotenv() import logging import os -from app.engine.loaders import get_documents -from app.engine.vectordb import get_vector_store -from app.settings import init_settings -from llama_index.core.ingestion import IngestionPipeline +from llama_index.core.ingestion import DocstoreStrategy, IngestionPipeline from llama_index.core.node_parser import SentenceSplitter from llama_index.core.settings import Settings from llama_index.core.storage import StorageContext from llama_index.core.storage.docstore import SimpleDocumentStore +from app.engine.loaders import get_documents +from app.engine.vectordb import get_vector_store +from app.settings import init_settings + logging.basicConfig(level=logging.INFO) logger = logging.getLogger() @@ -40,7 +41,7 @@ def run_pipeline(docstore, vector_store, documents): Settings.embed_model, ], docstore=docstore, - docstore_strategy="upserts_and_delete", + docstore_strategy=DocstoreStrategy.UPSERTS_AND_DELETE, # type: ignore vector_store=vector_store, ) diff --git a/templates/types/streaming/fastapi/pyproject.toml b/templates/types/streaming/fastapi/pyproject.toml index 9d857a08..05ae67d0 100644 --- a/templates/types/streaming/fastapi/pyproject.toml +++ b/templates/types/streaming/fastapi/pyproject.toml @@ -17,6 +17,23 @@ aiostream = "^0.5.2" cachetools = "^5.3.3" llama-index = "0.11.6" +[tool.poetry.group.dev.dependencies] +mypy = "^1.8.0" + [build-system] requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" + +[tool.mypy] +python_version = "3.11" +plugins = "pydantic.mypy" +exclude = [ "tests", "venv", ".venv", "output", "config" ] +check_untyped_defs = true +warn_unused_ignores = false +show_error_codes = true +namespace_packages = true +ignore_missing_imports = true +follow_imports = "silent" +implicit_optional = true +strict_optional = false +disable_error_code = ["return-value", "import-untyped", "assignment"] -- GitLab