From 85de3d9a503fec159962569fb77329a43af4594e Mon Sep 17 00:00:00 2001 From: James Braza <jamesbraza@gmail.com> Date: Tue, 10 Oct 2023 11:55:21 -0400 Subject: [PATCH] More `ruff` enables (PYI, SIM, etc.) (#8038) --- benchmarks/agent/math_tasks.py | 4 +- benchmarks/embeddings/bench_embeddings.py | 2 +- experimental/cli/cli_add.py | 4 +- experimental/cli/cli_init.py | 4 +- experimental/cli/cli_query.py | 4 +- experimental/cli/configuration.py | 14 +- experimental/colbert_index/base.py | 2 +- llama_index/agent/react/formatter.py | 2 +- llama_index/callbacks/llama_debug.py | 2 +- .../callbacks/open_inference_callback.py | 2 +- llama_index/callbacks/simple_llm_handler.py | 2 +- llama_index/chat_engine/context.py | 2 +- llama_index/chat_engine/types.py | 4 +- llama_index/embeddings/adapter.py | 2 +- llama_index/evaluation/benchmarks/hotpotqa.py | 2 +- llama_index/evaluation/dataset_generation.py | 2 +- llama_index/indices/base_retriever.py | 2 +- llama_index/indices/loading.py | 2 +- llama_index/indices/managed/base.py | 2 +- llama_index/indices/managed/vectara/base.py | 2 +- .../indices/managed/vectara/retriever.py | 2 +- llama_index/indices/postprocessor/node.py | 2 +- llama_index/llm_predictor/base.py | 4 +- llama_index/llms/azure_openai.py | 2 +- llama_index/llms/portkey.py | 4 +- llama_index/llms/portkey_utils.py | 2 +- llama_index/llms/predibase.py | 2 +- .../extractors/metadata_extractors.py | 3 +- llama_index/node_parser/file/html.py | 4 +- llama_index/node_parser/file/json.py | 4 +- llama_index/node_parser/file/markdown.py | 6 +- llama_index/objects/tool_node_mapping.py | 2 +- .../query_engine/citation_query_engine.py | 2 +- .../knowledge_graph_query_engine.py | 2 +- .../query_engine/retriever_query_engine.py | 2 +- llama_index/readers/bagel.py | 20 +-- llama_index/readers/deeplake.py | 4 +- llama_index/readers/discord_reader.py | 24 ++- llama_index/readers/file/docs_reader.py | 2 +- llama_index/readers/file/flat_reader.py | 2 +- .../github_repository_reader.py | 72 ++++---- llama_index/readers/google_readers/gsheets.py | 3 +- llama_index/readers/myscale.py | 2 +- llama_index/readers/psychic.py | 2 +- llama_index/readers/redis/utils.py | 2 +- .../response_synthesizers/accumulate.py | 4 +- llama_index/selectors/llm_selectors.py | 4 +- .../storage/docstore/firestore_docstore.py | 2 +- .../index_store/firestore_indexstore.py | 2 +- .../tools/tool_spec/load_and_search/base.py | 2 +- llama_index/tts/elevenlabs.py | 2 - llama_index/utils.py | 4 +- llama_index/vector_stores/cassandra.py | 2 +- llama_index/vector_stores/cogsearch.py | 12 +- llama_index/vector_stores/dynamodb.py | 2 +- llama_index/vector_stores/elasticsearch.py | 4 +- llama_index/vector_stores/myscale.py | 2 +- llama_index/vector_stores/neo4jvector.py | 6 +- llama_index/vector_stores/postgres.py | 163 +++++++++--------- llama_index/vector_stores/redis.py | 2 +- llama_index/vector_stores/rocksetdb.py | 8 +- llama_index/vector_stores/tair.py | 2 +- llama_index/vector_stores/timescalevector.py | 2 +- pyproject.toml | 64 +------ tests/conftest.py | 2 +- tests/indices/postprocessor/test_base.py | 2 +- tests/indices/query/test_embedding_utils.py | 4 +- tests/indices/tree/test_index.py | 2 +- tests/llm_predictor/vellum/test_predictor.py | 4 +- .../vellum/test_prompt_registry.py | 6 +- tests/llms/test_localai.py | 5 +- tests/llms/test_openai_utils.py | 2 +- tests/memory/test_chat_memory_buffer.py | 2 +- tests/readers/test_file.py | 5 +- tests/text_splitter/test_code_splitter.py | 10 +- tests/text_splitter/test_sentence_splitter.py | 4 +- tests/tools/test_utils.py | 2 +- 77 files changed, 251 insertions(+), 326 deletions(-) diff --git a/benchmarks/agent/math_tasks.py b/benchmarks/agent/math_tasks.py index 76b6e09552..bf79876337 100644 --- a/benchmarks/agent/math_tasks.py +++ b/benchmarks/agent/math_tasks.py @@ -7,12 +7,12 @@ from llama_index.tools.function_tool import FunctionTool def add(a: int, b: int) -> int: - """Add two integers and returns the result integer""" + """Add two integers and returns the result integer.""" return a + b def multiply(a: int, b: int) -> int: - """Multiple two integers and returns the result integer""" + """Multiple two integers and returns the result integer.""" return a * b diff --git a/benchmarks/embeddings/bench_embeddings.py b/benchmarks/embeddings/bench_embeddings.py index b3498cd04b..1d0320ecb5 100644 --- a/benchmarks/embeddings/bench_embeddings.py +++ b/benchmarks/embeddings/bench_embeddings.py @@ -16,7 +16,7 @@ def generate_strings(num_strings: int = 100, string_length: int = 10) -> List[st offset 0: [0:string_length], [string_length:2*string_length], ... offset 1: [1:1+string_length], [1+string_length:1+2*string_length],... ... - """ + """ # noqa: D415 content = ( SimpleDirectoryReader("../../examples/paul_graham_essay/data") .load_data()[0] diff --git a/experimental/cli/cli_add.py b/experimental/cli/cli_add.py index aecb98d9ba..076c52b26d 100644 --- a/experimental/cli/cli_add.py +++ b/experimental/cli/cli_add.py @@ -7,7 +7,7 @@ from .configuration import load_index, save_index def add_cli(args: Namespace) -> None: - """Handle subcommand "add" """ + """Handle subcommand "add".""" index = load_index() for p in args.files: @@ -26,7 +26,7 @@ def add_cli(args: Namespace) -> None: def register_add_cli(subparsers: _SubParsersAction) -> None: - """Register subcommand "add" to ArgumentParser""" + """Register subcommand "add" to ArgumentParser.""" parser = subparsers.add_parser("add") parser.add_argument( "files", diff --git a/experimental/cli/cli_init.py b/experimental/cli/cli_init.py index 2d1b68fbe6..65728d31ed 100644 --- a/experimental/cli/cli_init.py +++ b/experimental/cli/cli_init.py @@ -4,13 +4,13 @@ from .configuration import load_config, save_config def init_cli(args: Namespace) -> None: - """Handle subcommand "init" """ + """Handle subcommand "init".""" config = load_config(args.directory) save_config(config, args.directory) def register_init_cli(subparsers: _SubParsersAction) -> None: - """Register subcommand "init" to ArgumentParser""" + """Register subcommand "init" to ArgumentParser.""" parser = subparsers.add_parser("init") parser.add_argument( "directory", diff --git a/experimental/cli/cli_query.py b/experimental/cli/cli_query.py index 3550f73a71..dc4543fa12 100644 --- a/experimental/cli/cli_query.py +++ b/experimental/cli/cli_query.py @@ -4,14 +4,14 @@ from .configuration import load_index def query_cli(args: Namespace) -> None: - """Handle subcommand "query" """ + """Handle subcommand "query".""" index = load_index() query_engine = index.as_query_engine() print(query_engine.query(args.query)) def register_query_cli(subparsers: _SubParsersAction) -> None: - """Register subcommand "query" to ArgumentParser""" + """Register subcommand "query" to ArgumentParser.""" parser = subparsers.add_parser("query") parser.add_argument( "query", diff --git a/experimental/cli/configuration.py b/experimental/cli/configuration.py index 395393b48b..4d2d8516cb 100644 --- a/experimental/cli/configuration.py +++ b/experimental/cli/configuration.py @@ -28,7 +28,7 @@ DEFAULT_CONFIG = { def load_config(root: str = ".") -> ConfigParser: - """Load configuration from file""" + """Load configuration from file.""" config = ConfigParser() config.read_dict(DEFAULT_CONFIG) config.read(os.path.join(root, CONFIG_FILE_NAME)) @@ -36,13 +36,13 @@ def load_config(root: str = ".") -> ConfigParser: def save_config(config: ConfigParser, root: str = ".") -> None: - """Load configuration to file""" + """Load configuration to file.""" with open(os.path.join(root, CONFIG_FILE_NAME), "w") as fd: config.write(fd) def load_index(root: str = ".") -> BaseIndex[Any]: - """Load existing index file""" + """Load existing index file.""" config = load_config(root) service_context = _load_service_context(config) @@ -69,14 +69,14 @@ def load_index(root: str = ".") -> BaseIndex[Any]: def save_index(index: BaseIndex[Any], root: str = ".") -> None: - """Save index to file""" + """Save index to file.""" config = load_config(root) persist_dir = config["store"]["persist_dir"] index.storage_context.persist(persist_dir=persist_dir) def _load_service_context(config: ConfigParser) -> ServiceContext: - """Internal function to load service context based on configuration""" + """Internal function to load service context based on configuration.""" embed_model = _load_embed_model(config) llm_predictor = _load_llm_predictor(config) return ServiceContext.from_defaults( @@ -90,7 +90,7 @@ def _load_storage_context(config: ConfigParser) -> StorageContext: def _load_llm_predictor(config: ConfigParser) -> LLMPredictor: - """Internal function to load LLM predictor based on configuration""" + """Internal function to load LLM predictor based on configuration.""" model_type = config["llm_predictor"]["type"].lower() if model_type == "default": llm = _load_llm(config["llm_predictor"]) @@ -110,7 +110,7 @@ def _load_llm(section: SectionProxy) -> LLM: def _load_embed_model(config: ConfigParser) -> BaseEmbedding: - """Internal function to load embedding model based on configuration""" + """Internal function to load embedding model based on configuration.""" model_type = config["embed_model"]["type"] if model_type == "default": return OpenAIEmbedding() diff --git a/experimental/colbert_index/base.py b/experimental/colbert_index/base.py index 487a1ba2cc..c0d3bfa143 100644 --- a/experimental/colbert_index/base.py +++ b/experimental/colbert_index/base.py @@ -145,7 +145,7 @@ class ColbertIndex(BaseIndex[IndexDict]): """ doc_ids, _, scores = self.store.search(text=query_str, k=top_k) - node_doc_ids = list(map(lambda id: self._docs_pos_to_node_id[id], doc_ids)) + node_doc_ids = [self._docs_pos_to_node_id[id] for id in doc_ids] nodes = self.docstore.get_nodes(node_doc_ids) nodes_with_score = [] diff --git a/llama_index/agent/react/formatter.py b/llama_index/agent/react/formatter.py index 97d131fe31..7680521239 100644 --- a/llama_index/agent/react/formatter.py +++ b/llama_index/agent/react/formatter.py @@ -11,7 +11,7 @@ from llama_index.tools import BaseTool def get_react_tool_descriptions(tools: Sequence[BaseTool]) -> List[str]: - """Tool""" + """Tool.""" tool_descs = [] for tool in tools: tool_desc = ( diff --git a/llama_index/callbacks/llama_debug.py b/llama_index/callbacks/llama_debug.py index 1fff689514..dd1073ec8e 100644 --- a/llama_index/callbacks/llama_debug.py +++ b/llama_index/callbacks/llama_debug.py @@ -116,7 +116,7 @@ class LlamaDebugHandler(BaseCallbackHandler): def _get_time_stats_from_event_pairs( self, event_pairs: List[List[CBEvent]] ) -> EventStats: - """Calculate time-based stats for a set of event pairs""" + """Calculate time-based stats for a set of event pairs.""" total_secs = 0.0 for event_pair in event_pairs: start_time = datetime.strptime(event_pair[0].time, TIMESTAMP_FORMAT) diff --git a/llama_index/callbacks/open_inference_callback.py b/llama_index/callbacks/open_inference_callback.py index 0e6a0ff50f..9b105ca241 100644 --- a/llama_index/callbacks/open_inference_callback.py +++ b/llama_index/callbacks/open_inference_callback.py @@ -112,7 +112,7 @@ def as_dataframe(data: Iterable[BaseDataType]) -> "DataFrame": @dataclass class TraceData: - """Trace data""" + """Trace data.""" query_data: QueryData = field(default_factory=QueryData) node_datas: List[NodeData] = field(default_factory=list) diff --git a/llama_index/callbacks/simple_llm_handler.py b/llama_index/callbacks/simple_llm_handler.py index 3a89f34faa..cc53a35960 100644 --- a/llama_index/callbacks/simple_llm_handler.py +++ b/llama_index/callbacks/simple_llm_handler.py @@ -5,7 +5,7 @@ from llama_index.callbacks.schema import CBEventType, EventPayload class SimpleLLMHandler(BaseCallbackHandler): - """Callback handler for printing llms inputs/outputs""" + """Callback handler for printing llms inputs/outputs.""" def __init__(self) -> None: super().__init__(event_starts_to_ignore=[], event_ends_to_ignore=[]) diff --git a/llama_index/chat_engine/context.py b/llama_index/chat_engine/context.py index 10ee892c04..dde6e5698d 100644 --- a/llama_index/chat_engine/context.py +++ b/llama_index/chat_engine/context.py @@ -127,7 +127,7 @@ class ContextChatEngine(BaseChatEngine): return self._context_template.format(context_str=context_str), nodes def _get_prefix_messages_with_context(self, context_str: str) -> List[ChatMessage]: - """Get the prefix messages with context""" + """Get the prefix messages with context.""" # ensure we grab the user-configured system prompt system_prompt = "" prefix_messages = self._prefix_messages diff --git a/llama_index/chat_engine/types.py b/llama_index/chat_engine/types.py index 6b96ba91cb..5b73cb5da7 100644 --- a/llama_index/chat_engine/types.py +++ b/llama_index/chat_engine/types.py @@ -18,12 +18,12 @@ logger.setLevel(logging.WARNING) def is_function(message: ChatMessage) -> bool: - """Utility for ChatMessage responses from OpenAI models""" + """Utility for ChatMessage responses from OpenAI models.""" return "function_call" in message.additional_kwargs class ChatResponseMode(str, Enum): - """Flag toggling waiting/streaming in `Agent._chat`""" + """Flag toggling waiting/streaming in `Agent._chat`.""" WAIT = "wait" STREAM = "stream" diff --git a/llama_index/embeddings/adapter.py b/llama_index/embeddings/adapter.py index e27e0e3144..a02516225e 100644 --- a/llama_index/embeddings/adapter.py +++ b/llama_index/embeddings/adapter.py @@ -47,7 +47,7 @@ class AdapterEmbeddingModel(BaseEmbedding): embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, callback_manager: Optional[CallbackManager] = None, ) -> None: - """Init params""" + """Init params.""" import torch from llama_index.embeddings.adapter_utils import BaseAdapter, LinearLayer diff --git a/llama_index/evaluation/benchmarks/hotpotqa.py b/llama_index/evaluation/benchmarks/hotpotqa.py index f09ef85629..4651fb70f7 100644 --- a/llama_index/evaluation/benchmarks/hotpotqa.py +++ b/llama_index/evaluation/benchmarks/hotpotqa.py @@ -22,7 +22,7 @@ hotpot/hotpot_dev_distractor_v1.json""" class HotpotQAEvaluator: """ - Refer to https://hotpotqa.github.io/ for more details on the dataset + Refer to https://hotpotqa.github.io/ for more details on the dataset. """ def _download_datasets(self) -> Dict[str, str]: diff --git a/llama_index/evaluation/dataset_generation.py b/llama_index/evaluation/dataset_generation.py index 299e24c159..3232f5d7fd 100644 --- a/llama_index/evaluation/dataset_generation.py +++ b/llama_index/evaluation/dataset_generation.py @@ -1,4 +1,4 @@ -"""Dataset generation from documents""" +"""Dataset generation from documents.""" from __future__ import annotations import asyncio diff --git a/llama_index/indices/base_retriever.py b/llama_index/indices/base_retriever.py index 2885698a53..3d2f52b5b9 100644 --- a/llama_index/indices/base_retriever.py +++ b/llama_index/indices/base_retriever.py @@ -47,7 +47,7 @@ class BaseRetriever(ABC): def get_service_context(self) -> Optional[ServiceContext]: """Attempts to resolve a service context. Short-circuits at self.service_context, self._service_context, - or self._index.service_context + or self._index.service_context. """ if hasattr(self, "service_context"): return self.service_context diff --git a/llama_index/indices/loading.py b/llama_index/indices/loading.py index b917f93e33..4c65946e70 100644 --- a/llama_index/indices/loading.py +++ b/llama_index/indices/loading.py @@ -50,7 +50,7 @@ def load_indices_from_storage( index_ids: Optional[Sequence[str]] = None, **kwargs: Any, ) -> List[BaseIndex]: - """Load multiple indices from storage context + """Load multiple indices from storage context. Args: storage_context (StorageContext): storage context containing diff --git a/llama_index/indices/managed/base.py b/llama_index/indices/managed/base.py index ac649407bb..928ba66d50 100644 --- a/llama_index/indices/managed/base.py +++ b/llama_index/indices/managed/base.py @@ -20,7 +20,7 @@ class BaseManagedIndex(BaseIndex[IndexDict], ABC): The managed service can index documents into a managed service. How documents are structured into nodes is a detail for the managed service, and not exposed in this interface (although could be controlled by - configuration parameters) + configuration parameters). Args: show_progress (bool): Whether to show tqdm progress bars. Defaults to False. diff --git a/llama_index/indices/managed/vectara/base.py b/llama_index/indices/managed/vectara/base.py index 21db9ccc4b..abd6c1c491 100644 --- a/llama_index/indices/managed/vectara/base.py +++ b/llama_index/indices/managed/vectara/base.py @@ -203,7 +203,7 @@ class VectaraIndex(BaseManagedIndex): ) -> Optional[str]: """Vectara provides a way to add files (binary or text) directly via our API where pre-processing and chunking occurs internally in an optimal way - This method provides a way to use that API in Llama_index + This method provides a way to use that API in Llama_index. # ruff: noqa: E501 Full API Docs: https://docs.vectara.com/docs/api-reference/indexing-apis/ diff --git a/llama_index/indices/managed/vectara/retriever.py b/llama_index/indices/managed/vectara/retriever.py index 25155d1b64..5fb08c2637 100644 --- a/llama_index/indices/managed/vectara/retriever.py +++ b/llama_index/indices/managed/vectara/retriever.py @@ -1,5 +1,5 @@ """Vectara index. -An index that that is built on top of Vectara +An index that that is built on top of Vectara. """ import json diff --git a/llama_index/indices/postprocessor/node.py b/llama_index/indices/postprocessor/node.py index 6a32ab1260..b101109736 100644 --- a/llama_index/indices/postprocessor/node.py +++ b/llama_index/indices/postprocessor/node.py @@ -362,7 +362,7 @@ class LongContextReorder(BaseNodePostprocessor): performance typically arises when crucial data is positioned at the start or conclusion of the input context. Additionally, as the input context lengthens, performance drops notably, even - in models designed for long contexts." + in models designed for long contexts.". """ @classmethod diff --git a/llama_index/llm_predictor/base.py b/llama_index/llm_predictor/base.py index 776226587c..5de8690182 100644 --- a/llama_index/llm_predictor/base.py +++ b/llama_index/llm_predictor/base.py @@ -263,7 +263,7 @@ class LLMPredictor(BaseLLMPredictor): self, formatted_prompt: str, ) -> str: - """Add system and query wrapper prompts to base prompt""" + """Add system and query wrapper prompts to base prompt.""" extended_prompt = formatted_prompt if self.system_prompt: extended_prompt = self.system_prompt + "\n\n" + extended_prompt @@ -276,7 +276,7 @@ class LLMPredictor(BaseLLMPredictor): return extended_prompt def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: - """Add system prompt to chat message list""" + """Add system prompt to chat message list.""" if self.system_prompt: messages = [ ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt), diff --git a/llama_index/llms/azure_openai.py b/llama_index/llms/azure_openai.py index 2dd0eae9c6..21bf029ab4 100644 --- a/llama_index/llms/azure_openai.py +++ b/llama_index/llms/azure_openai.py @@ -10,7 +10,7 @@ AZURE_OPENAI_API_TYPE = "azure" class AzureOpenAI(OpenAI): """ - Azure OpenAI + Azure OpenAI. To use this, you must first deploy a model on Azure OpenAI. Unlike OpenAI, you need to specify a `engine` parameter to identify diff --git a/llama_index/llms/portkey.py b/llama_index/llms/portkey.py index 4fcbcdf452..c9e444af99 100644 --- a/llama_index/llms/portkey.py +++ b/llama_index/llms/portkey.py @@ -1,5 +1,5 @@ """ -Portkey integration with Llama_index for enhanced monitoring +Portkey integration with Llama_index for enhanced monitoring. """ from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast @@ -38,7 +38,7 @@ if TYPE_CHECKING: class Portkey(CustomLLM): - """_summary_ + """_summary_. Args: LLM (_type_): _description_ diff --git a/llama_index/llms/portkey_utils.py b/llama_index/llms/portkey_utils.py index d15cb41367..b328ea402f 100644 --- a/llama_index/llms/portkey_utils.py +++ b/llama_index/llms/portkey_utils.py @@ -1,5 +1,5 @@ """ -Utility Tools for the Portkey Class +Utility Tools for the Portkey Class. This file module contains a collection of utility functions designed to enhance the functionality and usability of the Portkey class diff --git a/llama_index/llms/predibase.py b/llama_index/llms/predibase.py index e0a279653b..c70dc631d1 100644 --- a/llama_index/llms/predibase.py +++ b/llama_index/llms/predibase.py @@ -14,7 +14,7 @@ from llama_index.llms.custom import CustomLLM class PredibaseLLM(CustomLLM): - """Predibase LLM""" + """Predibase LLM.""" model_name: str = Field(description="The Predibase model to use.") predibase_api_key: str = Field(description="The Predibase API key to use.") diff --git a/llama_index/node_parser/extractors/metadata_extractors.py b/llama_index/node_parser/extractors/metadata_extractors.py index 88846e61d6..166e5b559a 100644 --- a/llama_index/node_parser/extractors/metadata_extractors.py +++ b/llama_index/node_parser/extractors/metadata_extractors.py @@ -403,7 +403,8 @@ class SummaryExtractor(MetadataFeatureExtractor): """ Summary extractor. Node-level extractor with adjacent sharing. Extracts `section_summary`, `prev_section_summary`, `next_section_summary` - metadata fields + metadata fields. + Args: llm_predictor (Optional[BaseLLMPredictor]): LLM predictor summaries (List[str]): list of summaries to extract: 'self', 'prev', 'next' diff --git a/llama_index/node_parser/file/html.py b/llama_index/node_parser/file/html.py index 60b1e387f8..c45498a1a9 100644 --- a/llama_index/node_parser/file/html.py +++ b/llama_index/node_parser/file/html.py @@ -88,7 +88,7 @@ class HTMLNodeParser(NodeParser): return all_nodes def get_nodes_from_node(self, node: BaseNode) -> List[TextNode]: - """Get nodes from document""" + """Get nodes from document.""" try: from bs4 import BeautifulSoup except ImportError: @@ -144,7 +144,7 @@ class HTMLNodeParser(NodeParser): node: BaseNode, metadata: dict, ) -> TextNode: - """Build node from single text split""" + """Build node from single text split.""" node = build_nodes_from_splits( [text_split], node, self.include_metadata, self.include_prev_next_rel )[0] diff --git a/llama_index/node_parser/file/json.py b/llama_index/node_parser/file/json.py index 207143e55c..aa8e79a4b6 100644 --- a/llama_index/node_parser/file/json.py +++ b/llama_index/node_parser/file/json.py @@ -80,7 +80,7 @@ class JSONNodeParser(NodeParser): return all_nodes def get_nodes_from_node(self, node: BaseNode) -> List[TextNode]: - """Get nodes from document""" + """Get nodes from document.""" text = node.get_content(metadata_mode=MetadataMode.NONE) try: data = json.loads(text) @@ -132,7 +132,7 @@ class JSONNodeParser(NodeParser): node: BaseNode, metadata: dict, ) -> TextNode: - """Build node from single text split""" + """Build node from single text split.""" node = build_nodes_from_splits( [text_split], node, self.include_metadata, self.include_prev_next_rel )[0] diff --git a/llama_index/node_parser/file/markdown.py b/llama_index/node_parser/file/markdown.py index c5f9920509..13ccf80f8e 100644 --- a/llama_index/node_parser/file/markdown.py +++ b/llama_index/node_parser/file/markdown.py @@ -80,7 +80,7 @@ class MarkdownNodeParser(NodeParser): return all_nodes def get_nodes_from_node(self, node: BaseNode) -> List[TextNode]: - """Get nodes from document""" + """Get nodes from document.""" text = node.get_content(metadata_mode=MetadataMode.NONE) markdown_nodes = [] lines = text.split("\n") @@ -115,7 +115,7 @@ class MarkdownNodeParser(NodeParser): def _update_metadata( self, headers_metadata: dict, new_header: str, new_header_level: int ) -> dict: - """Update the markdown headers for metadata + """Update the markdown headers for metadata. Removes all headers that are equal or less than the level of the newly found header @@ -136,7 +136,7 @@ class MarkdownNodeParser(NodeParser): node: BaseNode, metadata: dict, ) -> TextNode: - """Build node from single text split""" + """Build node from single text split.""" node = build_nodes_from_splits( [text_split], node, self.include_metadata, self.include_prev_next_rel )[0] diff --git a/llama_index/objects/tool_node_mapping.py b/llama_index/objects/tool_node_mapping.py index e17d9cfb58..18ee0c9ae5 100644 --- a/llama_index/objects/tool_node_mapping.py +++ b/llama_index/objects/tool_node_mapping.py @@ -1,4 +1,4 @@ -"""Tool mapping""" +"""Tool mapping.""" from typing import Any, Optional, Sequence diff --git a/llama_index/query_engine/citation_query_engine.py b/llama_index/query_engine/citation_query_engine.py index 7f01c049bd..52bc021a53 100644 --- a/llama_index/query_engine/citation_query_engine.py +++ b/llama_index/query_engine/citation_query_engine.py @@ -136,7 +136,7 @@ class CitationQueryEngine(BaseQueryEngine): # class-specific args **kwargs: Any, ) -> "CitationQueryEngine": - """Initialize a CitationQueryEngine object." + """Initialize a CitationQueryEngine object.". Args: index: (BastGPTIndex): index to use for querying diff --git a/llama_index/query_engine/knowledge_graph_query_engine.py b/llama_index/query_engine/knowledge_graph_query_engine.py index 53b8b9430f..8d00b2a448 100644 --- a/llama_index/query_engine/knowledge_graph_query_engine.py +++ b/llama_index/query_engine/knowledge_graph_query_engine.py @@ -1,4 +1,4 @@ -""" Knowledge Graph Query Engine""" +""" Knowledge Graph Query Engine.""" import logging from typing import Any, List, Optional, Sequence diff --git a/llama_index/query_engine/retriever_query_engine.py b/llama_index/query_engine/retriever_query_engine.py index c6bfd389c2..ad9915390e 100644 --- a/llama_index/query_engine/retriever_query_engine.py +++ b/llama_index/query_engine/retriever_query_engine.py @@ -67,7 +67,7 @@ class RetrieverQueryEngine(BaseQueryEngine): # class-specific args **kwargs: Any, ) -> "RetrieverQueryEngine": - """Initialize a RetrieverQueryEngine object." + """Initialize a RetrieverQueryEngine object.". Args: retriever (BaseRetriever): A retriever object. diff --git a/llama_index/readers/bagel.py b/llama_index/readers/bagel.py index aecb94ca65..cdf647f8c4 100644 --- a/llama_index/readers/bagel.py +++ b/llama_index/readers/bagel.py @@ -18,15 +18,8 @@ Metadatas = List[Metadata] # Metadata Query Grammar LiteralValue = Union[str, int, float] -LogicalOperator = Union[Literal["$and"], Literal["$or"]] -WhereOperator = Union[ - Literal["$gt"], - Literal["$gte"], - Literal["$lt"], - Literal["$lte"], - Literal["$ne"], - Literal["$eq"], -] +LogicalOperator = Literal["$and", "$or"] +WhereOperator = Literal["$gt", "$gte", "$lt", "$lte", "$ne", "$eq"] OperatorExpression = Dict[Union[WhereOperator, LogicalOperator], LiteralValue] Where = Dict[ @@ -47,14 +40,7 @@ OneOrMany = Union[T, List[T]] # This should ust be List[Literal["documents", "embeddings", "metadatas", "distances"]] # However, this provokes an incompatibility with the Overrides library and Python 3.7 -Include = List[ - Union[ - Literal["documents"], - Literal["embeddings"], - Literal["metadatas"], - Literal["distances"], - ] -] +Include = List[Literal["documents", "embeddings", "metadatas", "distances"]] LiteralValue = LiteralValue LogicalOperator = LogicalOperator diff --git a/llama_index/readers/deeplake.py b/llama_index/readers/deeplake.py index fa9fb3471c..00c00f43a6 100644 --- a/llama_index/readers/deeplake.py +++ b/llama_index/readers/deeplake.py @@ -30,7 +30,7 @@ def vector_search( distance_metric: distance function 'L2' for Euclidean, 'L1' for Nuclear, 'Max' l-infinity distance, 'cos' for cosine similarity, 'dot' for dot product returns: - nearest_indices: List, indices of nearest neighbors + nearest_indices: List, indices of nearest neighbors. """ # Calculate the distance between the query_vector and all data_vectors if isinstance(query_vector, list): @@ -62,7 +62,7 @@ class DeepLakeReader(BaseReader): self, token: Optional[str] = None, ): - """Initializing the deepLake reader""" + """Initializing the deepLake reader.""" import_err_msg = ( "`deeplake` package not found, please run `pip install deeplake`" ) diff --git a/llama_index/readers/discord_reader.py b/llama_index/readers/discord_reader.py index 8ea42a0f2e..0b400d2b38 100644 --- a/llama_index/readers/discord_reader.py +++ b/llama_index/readers/discord_reader.py @@ -70,20 +70,18 @@ async def read_channel( ### Wraps each message in a Document containing the text \ # as well as some useful metadata properties. - return list( - map( - lambda msg: Document( - text=msg.content, - metadata={ - "message_id": msg.id, - "username": msg.author.name, - "created_at": msg.created_at, - "edited_at": msg.edited_at, - }, - ), - messages, + return [ + Document( + text=msg.content, + metadata={ + "message_id": msg.id, + "username": msg.author.name, + "created_at": msg.created_at, + "edited_at": msg.edited_at, + }, ) - ) + for msg in messages + ] class DiscordReader(BasePydanticReader): diff --git a/llama_index/readers/file/docs_reader.py b/llama_index/readers/file/docs_reader.py index 4817086932..4203e37c54 100644 --- a/llama_index/readers/file/docs_reader.py +++ b/llama_index/readers/file/docs_reader.py @@ -71,7 +71,7 @@ class DocxReader(BaseReader): class HWPReader(BaseReader): - """Hwp Parser""" + """Hwp Parser.""" def __init__(self, *args: Any, **kwargs: Any) -> None: super().__init__(*args, **kwargs) diff --git a/llama_index/readers/file/flat_reader.py b/llama_index/readers/file/flat_reader.py index 4ba16679cf..05130cfb38 100644 --- a/llama_index/readers/file/flat_reader.py +++ b/llama_index/readers/file/flat_reader.py @@ -7,7 +7,7 @@ from llama_index.schema import Document class FlatReader(BaseReader): - """Flat reader + """Flat reader. Extract raw text from a file and save the file type in the metadata """ diff --git a/llama_index/readers/github_readers/github_repository_reader.py b/llama_index/readers/github_readers/github_repository_reader.py index ef3265e7b0..bceee82511 100644 --- a/llama_index/readers/github_readers/github_repository_reader.py +++ b/llama_index/readers/github_readers/github_repository_reader.py @@ -348,44 +348,42 @@ class GithubRepositoryReader(BaseReader): + f"as {file_extension} with " + f"{reader.__class__.__name__}", ) - with tempfile.TemporaryDirectory() as tmpdirname: - with tempfile.NamedTemporaryFile( - dir=tmpdirname, - suffix=f".{file_extension}", - mode="w+b", - delete=False, - ) as tmpfile: - print_if_verbose( - self._verbose, - "created a temporary file" - + f"{tmpfile.name} for parsing {file_path}", - ) - tmpfile.write(file_content) - tmpfile.flush() - tmpfile.close() - try: - docs = reader.load_data(pathlib.Path(tmpfile.name)) - parsed_file = "\n\n".join([doc.get_content() for doc in docs]) - except Exception as e: - print_if_verbose(self._verbose, f"error while parsing {file_path}") - logger.error( - "Error while parsing " - + f"{file_path} with " - + f"{reader.__class__.__name__}:\n{e}" - ) - parsed_file = None - finally: - os.remove(tmpfile.name) - if parsed_file is None: - return None - return Document( - text=parsed_file, - id_=tree_sha, - metadata={ - "file_path": file_path, - "file_name": tree_path, - }, + with tempfile.TemporaryDirectory() as tmpdirname, tempfile.NamedTemporaryFile( + dir=tmpdirname, + suffix=f".{file_extension}", + mode="w+b", + delete=False, + ) as tmpfile: + print_if_verbose( + self._verbose, + "created a temporary file" + f"{tmpfile.name} for parsing {file_path}", + ) + tmpfile.write(file_content) + tmpfile.flush() + tmpfile.close() + try: + docs = reader.load_data(pathlib.Path(tmpfile.name)) + parsed_file = "\n\n".join([doc.get_content() for doc in docs]) + except Exception as e: + print_if_verbose(self._verbose, f"error while parsing {file_path}") + logger.error( + "Error while parsing " + + f"{file_path} with " + + f"{reader.__class__.__name__}:\n{e}" ) + parsed_file = None + finally: + os.remove(tmpfile.name) + if parsed_file is None: + return None + return Document( + text=parsed_file, + id_=tree_sha, + metadata={ + "file_path": file_path, + "file_name": tree_path, + }, + ) if __name__ == "__main__": diff --git a/llama_index/readers/google_readers/gsheets.py b/llama_index/readers/google_readers/gsheets.py index 03a2a642a5..74a3f0d4cd 100644 --- a/llama_index/readers/google_readers/gsheets.py +++ b/llama_index/readers/google_readers/gsheets.py @@ -104,8 +104,7 @@ class GoogleSheetsReader(BasePydanticReader): .execute() ) sheet_text += ( - "\n".join(map(lambda row: "\t".join(row), response.get("values", []))) - + "\n" + "\n".join("\t".join(row) for row in response.get("values", [])) + "\n" ) return sheet_text diff --git a/llama_index/readers/myscale.py b/llama_index/readers/myscale.py index aa166307d9..fbb31db01e 100644 --- a/llama_index/readers/myscale.py +++ b/llama_index/readers/myscale.py @@ -21,7 +21,7 @@ def format_list_to_string(lst: List) -> str: class MyScaleSettings: - """MyScale Client Configuration + """MyScale Client Configuration. Attribute: table (str) : Table name to operate on. diff --git a/llama_index/readers/psychic.py b/llama_index/readers/psychic.py index 8c0edf44af..28485c75c8 100644 --- a/llama_index/readers/psychic.py +++ b/llama_index/readers/psychic.py @@ -47,7 +47,7 @@ class PsychicReader(BaseReader): def load_data( self, connector_id: Optional[str] = None, account_id: Optional[str] = None ) -> List[Document]: - """Load data from a Psychic connection + """Load data from a Psychic connection. Args: connector_id (str): The connector ID to connect to diff --git a/llama_index/readers/redis/utils.py b/llama_index/readers/redis/utils.py index 007aa6ee90..da7e1e3e12 100644 --- a/llama_index/readers/redis/utils.py +++ b/llama_index/readers/redis/utils.py @@ -70,7 +70,7 @@ def get_redis_query( sort: bool = True, filters: str = "*", ) -> "Query": - """Create a vector query for use with a SearchIndex + """Create a vector query for use with a SearchIndex. Args: return_fields (t.List[str]): A list of fields to return in the query results diff --git a/llama_index/response_synthesizers/accumulate.py b/llama_index/response_synthesizers/accumulate.py index b13b5485d1..a95a63b228 100644 --- a/llama_index/response_synthesizers/accumulate.py +++ b/llama_index/response_synthesizers/accumulate.py @@ -47,7 +47,7 @@ class Accumulate(BaseSynthesizer): separator: str = "\n---------------------\n", **response_kwargs: Any, ) -> RESPONSE_TEXT_TYPE: - """Apply the same prompt to text chunks and return async responses""" + """Apply the same prompt to text chunks and return async responses.""" if self._streaming: raise ValueError("Unable to stream in Accumulate response mode") @@ -68,7 +68,7 @@ class Accumulate(BaseSynthesizer): separator: str = "\n---------------------\n", **response_kwargs: Any, ) -> RESPONSE_TEXT_TYPE: - """Apply the same prompt to text chunks and return responses""" + """Apply the same prompt to text chunks and return responses.""" if self._streaming: raise ValueError("Unable to stream in Accumulate response mode") diff --git a/llama_index/selectors/llm_selectors.py b/llama_index/selectors/llm_selectors.py index f596cda908..1ed035b198 100644 --- a/llama_index/selectors/llm_selectors.py +++ b/llama_index/selectors/llm_selectors.py @@ -41,7 +41,7 @@ def _structured_output_to_selector_result(output: Any) -> SelectorResult: class LLMSingleSelector(BaseSelector): - """LLM single selector + """LLM single selector. LLM-based selector that chooses one out of many options. @@ -124,7 +124,7 @@ class LLMSingleSelector(BaseSelector): class LLMMultiSelector(BaseSelector): - """LLM multi selector + """LLM multi selector. LLM-based selector that chooses multiple out of many options. diff --git a/llama_index/storage/docstore/firestore_docstore.py b/llama_index/storage/docstore/firestore_docstore.py index 8f335700f9..26f71dac4b 100644 --- a/llama_index/storage/docstore/firestore_docstore.py +++ b/llama_index/storage/docstore/firestore_docstore.py @@ -34,7 +34,7 @@ class FirestoreDocumentStore(KVDocumentStore): Args: project (str): The project which the client acts on behalf of. database (str): The database name that the client targets. - namespace (str): namespace for the docstore + namespace (str): namespace for the docstore. """ firestore_kvstore = FirestoreKVStore(project=project, database=database) return cls(firestore_kvstore, namespace) diff --git a/llama_index/storage/index_store/firestore_indexstore.py b/llama_index/storage/index_store/firestore_indexstore.py index 1cea883bf2..8f777f849a 100644 --- a/llama_index/storage/index_store/firestore_indexstore.py +++ b/llama_index/storage/index_store/firestore_indexstore.py @@ -32,7 +32,7 @@ class FirestoreIndexStore(KVIndexStore): Args: project (str): The project which the client acts on behalf of. database (str): The database name that the client targets. - namespace (str): namespace for the docstore + namespace (str): namespace for the docstore. """ firestore_kvstore = FirestoreKVStore(project=project, database=database) return cls(firestore_kvstore, namespace) diff --git a/llama_index/tools/tool_spec/load_and_search/base.py b/llama_index/tools/tool_spec/load_and_search/base.py index 0ee04b533d..bde2bc6a8b 100644 --- a/llama_index/tools/tool_spec/load_and_search/base.py +++ b/llama_index/tools/tool_spec/load_and_search/base.py @@ -17,7 +17,7 @@ from llama_index.tools.utils import create_schema_from_function class LoadAndSearchToolSpec(BaseToolSpec): - """Load and Search Tool + """Load and Search Tool. This tool can be used with other tools that load large amounts of information. Compared to OndemandLoaderTool this returns two tools, diff --git a/llama_index/tts/elevenlabs.py b/llama_index/tts/elevenlabs.py index 5b8f2309a1..64aa65868b 100644 --- a/llama_index/tts/elevenlabs.py +++ b/llama_index/tts/elevenlabs.py @@ -14,8 +14,6 @@ class ElevenLabsTTS(BaseTTS): """ def __init__(self, api_key: Optional[str] = None) -> None: - """ """ - super().__init__() self.api_key = api_key diff --git a/llama_index/utils.py b/llama_index/utils.py index 59241bb6c4..c2be210b13 100644 --- a/llama_index/utils.py +++ b/llama_index/utils.py @@ -241,7 +241,7 @@ def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]: """ Args: model_name(str): the model name of the tokenizer. - For instance, fxmarty/tiny-llama-fast-tokenizer + For instance, fxmarty/tiny-llama-fast-tokenizer. """ try: from transformers import AutoTokenizer @@ -255,7 +255,7 @@ def get_transformer_tokenizer_fn(model_name: str) -> Callable[[str], List[str]]: def get_cache_dir() -> str: """Locate a platform-appropriate cache directory for llama_index, - and create it if it doesn't yet exist + and create it if it doesn't yet exist. """ # User override if "LLAMA_INDEX_CACHE_DIR" in os.environ: diff --git a/llama_index/vector_stores/cassandra.py b/llama_index/vector_stores/cassandra.py index 2722259289..6afdc7aab5 100644 --- a/llama_index/vector_stores/cassandra.py +++ b/llama_index/vector_stores/cassandra.py @@ -171,7 +171,7 @@ class CassandraVectorStore(VectorStore): @property def client(self) -> Any: - """Return the underlying cassIO vector table object""" + """Return the underlying cassIO vector table object.""" return self.vector_table @staticmethod diff --git a/llama_index/vector_stores/cogsearch.py b/llama_index/vector_stores/cogsearch.py index ac194b76c9..51e99fb327 100644 --- a/llama_index/vector_stores/cogsearch.py +++ b/llama_index/vector_stores/cogsearch.py @@ -27,7 +27,7 @@ class MetadataIndexFieldType(int, enum.Enum): """ Enumeration representing the supported types for metadata fields in an Azure Cognitive Search Index, corresponds with types supported in a flat - metadata dictionary + metadata dictionary. """ STRING = auto() # "Edm.String" @@ -38,7 +38,7 @@ class MetadataIndexFieldType(int, enum.Enum): class IndexManagement(int, enum.Enum): - """Enumeration representing the supported index management operations""" + """Enumeration representing the supported index management operations.""" NO_VALIDATION = auto() VALIDATE_INDEX = auto() @@ -85,7 +85,7 @@ class CognitiveSearchVectorStore(VectorStore): self._create_index(index_name) def _create_metadata_index_fields(self) -> List[Any]: - """Create a list of index fields for storing metadata values""" + """Create a list of index fields for storing metadata values.""" from azure.search.documents.indexes.models import SimpleField index_fields = [] @@ -113,7 +113,7 @@ class CognitiveSearchVectorStore(VectorStore): def _create_index(self, index_name: Optional[str]) -> None: """ Creates a default index based on the supplied index name, key field names and - metadata filtering keys + metadata filtering keys. """ from azure.search.documents.indexes.models import ( HnswParameters, @@ -438,7 +438,7 @@ class CognitiveSearchVectorStore(VectorStore): return ids def _create_index_document(self, node: BaseNode) -> Dict[str, Any]: - """Create Cognitive Search index document from embedding result""" + """Create Cognitive Search index document from embedding result.""" doc: Dict[str, Any] = {} doc["id"] = node.node_id doc["chunk"] = node.get_content(metadata_mode=MetadataMode.NONE) or "" @@ -478,7 +478,7 @@ class CognitiveSearchVectorStore(VectorStore): self._search_client.delete_documents(docs_to_delete) def _create_odata_filter(self, metadata_filters: MetadataFilters) -> str: - """Generate an OData filter string using supplied metadata filters""" + """Generate an OData filter string using supplied metadata filters.""" odata_filter: List[str] = [] for f in metadata_filters.filters: if not isinstance(f, ExactMatchFilter): diff --git a/llama_index/vector_stores/dynamodb.py b/llama_index/vector_stores/dynamodb.py index 8e4c8e6dfb..bd89f72447 100644 --- a/llama_index/vector_stores/dynamodb.py +++ b/llama_index/vector_stores/dynamodb.py @@ -109,7 +109,7 @@ class DynamoDBVectorStore(VectorStore): ) def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: - """Get nodes for response""" + """Get nodes for response.""" if query.filters is not None: raise ValueError( "Metadata filters not implemented for SimpleVectorStore yet." diff --git a/llama_index/vector_stores/elasticsearch.py b/llama_index/vector_stores/elasticsearch.py index dab60f7eb9..cba936152b 100644 --- a/llama_index/vector_stores/elasticsearch.py +++ b/llama_index/vector_stores/elasticsearch.py @@ -199,12 +199,12 @@ class ElasticsearchStore(VectorStore): @property def client(self) -> Any: - """Get async elasticsearch client""" + """Get async elasticsearch client.""" return self._client @staticmethod def get_user_agent() -> str: - """Get user agent for elasticsearch client""" + """Get user agent for elasticsearch client.""" import llama_index return f"llama_index-py-vs/{llama_index.__version__}" diff --git a/llama_index/vector_stores/myscale.py b/llama_index/vector_stores/myscale.py index 05e4f25c31..94022cd956 100644 --- a/llama_index/vector_stores/myscale.py +++ b/llama_index/vector_stores/myscale.py @@ -241,7 +241,7 @@ class MyScaleVectorStore(VectorStore): raise NotImplementedError("Delete not yet implemented for MyScale index.") def drop(self) -> None: - """Drop MyScale Index and table""" + """Drop MyScale Index and table.""" self._client.command( f"DROP TABLE IF EXISTS {self.config.database}.{self.config.table}" ) diff --git a/llama_index/vector_stores/neo4jvector.py b/llama_index/vector_stores/neo4jvector.py index c25a0c5fa3..be3afe357d 100644 --- a/llama_index/vector_stores/neo4jvector.py +++ b/llama_index/vector_stores/neo4jvector.py @@ -10,7 +10,7 @@ from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metad def check_if_not_null(props: List[str], values: List[Any]) -> None: - """Check if variable is not null and raise error accordingly""" + """Check if variable is not null and raise error accordingly.""" for prop, value in zip(props, values): if not value: raise ValueError(f"Parameter `{prop}` must not be None or empty string") @@ -19,12 +19,12 @@ def check_if_not_null(props: List[str], values: List[Any]) -> None: def sort_by_index_name( lst: List[Dict[str, Any]], index_name: str ) -> List[Dict[str, Any]]: - """Sort first element to match the index_name if exists""" + """Sort first element to match the index_name if exists.""" return sorted(lst, key=lambda x: x.get("index_name") != index_name) def clean_params(params: List[BaseNode]) -> List[Dict[str, Any]]: - """Convert BaseNode object to a dictionary to be imported into Neo4j""" + """Convert BaseNode object to a dictionary to be imported into Neo4j.""" clean_params = [] for record in params: text = record.get_content(metadata_mode=MetadataMode.NONE) diff --git a/llama_index/vector_stores/postgres.py b/llama_index/vector_stores/postgres.py index faa32c99d9..a603af5c13 100644 --- a/llama_index/vector_stores/postgres.py +++ b/llama_index/vector_stores/postgres.py @@ -1,6 +1,5 @@ import logging -from collections import namedtuple -from typing import Any, List, Optional, Type +from typing import Any, List, NamedTuple, Optional, Type from llama_index.bridge.pydantic import PrivateAttr from llama_index.schema import BaseNode, MetadataMode, TextNode @@ -13,9 +12,12 @@ from llama_index.vector_stores.types import ( ) from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metadata_dict -DBEmbeddingRow = namedtuple( - "DBEmbeddingRow", ["node_id", "text", "metadata", "similarity"] -) + +class DBEmbeddingRow(NamedTuple): + node_id: str # FIXME: verify this type hint + text: str + metadata: dict + similarity: float _logger = logging.getLogger(__name__) @@ -30,7 +32,7 @@ def get_data_model( embed_dim: int = 1536, ) -> Any: """ - This part create a dynamic sqlalchemy model with a new table + This part create a dynamic sqlalchemy model with a new table. """ from pgvector.sqlalchemy import Vector from sqlalchemy import Column, Computed @@ -229,18 +231,16 @@ class PGVectorStore(BasePydanticVectorStore): self._async_session = async_sessionmaker(self._async_engine) def _create_tables_if_not_exists(self) -> None: - with self._session() as session: - with session.begin(): - self._base.metadata.create_all(session.connection()) + with self._session() as session, session.begin(): + self._base.metadata.create_all(session.connection()) def _create_extension(self) -> None: import sqlalchemy - with self._session() as session: - with session.begin(): - statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector") - session.execute(statement) - session.commit() + with self._session() as session, session.begin(): + statement = sqlalchemy.text("CREATE EXTENSION IF NOT EXISTS vector") + session.execute(statement) + session.commit() def _initialize(self) -> None: if not self._is_initialized: @@ -264,25 +264,23 @@ class PGVectorStore(BasePydanticVectorStore): def add(self, nodes: List[BaseNode]) -> List[str]: self._initialize() ids = [] - with self._session() as session: - with session.begin(): - for node in nodes: - ids.append(node.node_id) - item = self._node_to_table_row(node) - session.add(item) - session.commit() + with self._session() as session, session.begin(): + for node in nodes: + ids.append(node.node_id) + item = self._node_to_table_row(node) + session.add(item) + session.commit() return ids async def async_add(self, nodes: List[BaseNode]) -> List[str]: self._initialize() ids = [] - async with self._async_session() as session: - async with session.begin(): - for node in nodes: - ids.append(node.node_id) - item = self._node_to_table_row(node) - session.add(item) - await session.commit() + async with self._async_session() as session, session.begin(): + for node in nodes: + ids.append(node.node_id) + item = self._node_to_table_row(node) + session.add(item) + await session.commit() return ids def _apply_filters_and_limit( @@ -325,20 +323,19 @@ class PGVectorStore(BasePydanticVectorStore): metadata_filters: Optional[MetadataFilters] = None, ) -> List[DBEmbeddingRow]: stmt = self._build_query(embedding, limit, metadata_filters) - with self._session() as session: - with session.begin(): - res = session.execute( - stmt, + with self._session() as session, session.begin(): + res = session.execute( + stmt, + ) + return [ + DBEmbeddingRow( + node_id=item.node_id, + text=item.text, + metadata=item.metadata_, + similarity=(1 - distance) if distance is not None else 0, ) - return [ - DBEmbeddingRow( - node_id=item.node_id, - text=item.text, - metadata=item.metadata_, - similarity=(1 - distance) if distance is not None else 0, - ) - for item, distance in res.all() - ] + for item, distance in res.all() + ] async def _aquery_with_score( self, @@ -347,18 +344,17 @@ class PGVectorStore(BasePydanticVectorStore): metadata_filters: Optional[MetadataFilters] = None, ) -> List[DBEmbeddingRow]: stmt = self._build_query(embedding, limit, metadata_filters) - async with self._async_session() as async_session: - async with async_session.begin(): - res = await async_session.execute(stmt) - return [ - DBEmbeddingRow( - node_id=item.node_id, - text=item.text, - metadata=item.metadata_, - similarity=(1 - distance) if distance is not None else 0, - ) - for item, distance in res.all() - ] + async with self._async_session() as async_session, async_session.begin(): + res = await async_session.execute(stmt) + return [ + DBEmbeddingRow( + node_id=item.node_id, + text=item.text, + metadata=item.metadata_, + similarity=(1 - distance) if distance is not None else 0, + ) + for item, distance in res.all() + ] def _build_sparse_query( self, @@ -392,18 +388,17 @@ class PGVectorStore(BasePydanticVectorStore): metadata_filters: Optional[MetadataFilters] = None, ) -> List[DBEmbeddingRow]: stmt = self._build_sparse_query(query_str, limit, metadata_filters) - async with self._async_session() as async_session: - async with async_session.begin(): - res = await async_session.execute(stmt) - return [ - DBEmbeddingRow( - node_id=item.node_id, - text=item.text, - metadata=item.metadata_, - similarity=rank, - ) - for item, rank in res.all() - ] + async with self._async_session() as async_session, async_session.begin(): + res = await async_session.execute(stmt) + return [ + DBEmbeddingRow( + node_id=item.node_id, + text=item.text, + metadata=item.metadata_, + similarity=rank, + ) + for item, rank in res.all() + ] def _sparse_query_with_rank( self, @@ -412,18 +407,17 @@ class PGVectorStore(BasePydanticVectorStore): metadata_filters: Optional[MetadataFilters] = None, ) -> List[DBEmbeddingRow]: stmt = self._build_sparse_query(query_str, limit, metadata_filters) - with self._session() as session: - with session.begin(): - res = session.execute(stmt) - return [ - DBEmbeddingRow( - node_id=item.node_id, - text=item.text, - metadata=item.metadata_, - similarity=rank, - ) - for item, rank in res.all() - ] + with self._session() as session, session.begin(): + res = session.execute(stmt) + return [ + DBEmbeddingRow( + node_id=item.node_id, + text=item.text, + metadata=item.metadata_, + similarity=rank, + ) + for item, rank in res.all() + ] async def _async_hybrid_query( self, query: VectorStoreQuery @@ -540,15 +534,14 @@ class PGVectorStore(BasePydanticVectorStore): import sqlalchemy self._initialize() - with self._session() as session: - with session.begin(): - stmt = sqlalchemy.text( - f"DELETE FROM public.data_{self.table_name} where " - f"(metadata_->>'doc_id')::text = '{ref_doc_id}' " - ) + with self._session() as session, session.begin(): + stmt = sqlalchemy.text( + f"DELETE FROM public.data_{self.table_name} where " + f"(metadata_->>'doc_id')::text = '{ref_doc_id}' " + ) - session.execute(stmt) - session.commit() + session.execute(stmt) + session.commit() def _dedup_results(results: List[DBEmbeddingRow]) -> List[DBEmbeddingRow]: diff --git a/llama_index/vector_stores/redis.py b/llama_index/vector_stores/redis.py index 0833546140..ff1c9774de 100644 --- a/llama_index/vector_stores/redis.py +++ b/llama_index/vector_stores/redis.py @@ -124,7 +124,7 @@ class RedisVectorStore(VectorStore): @property def client(self) -> "RedisType": - """Return the redis client instance""" + """Return the redis client instance.""" return self._redis_client def add(self, nodes: List[BaseNode]) -> List[str]: diff --git a/llama_index/vector_stores/rocksetdb.py b/llama_index/vector_stores/rocksetdb.py index 5b715ba773..0c3c1e596e 100644 --- a/llama_index/vector_stores/rocksetdb.py +++ b/llama_index/vector_stores/rocksetdb.py @@ -24,7 +24,7 @@ T = TypeVar("T", bound="RocksetVectorStore") def _get_rockset() -> ModuleType: """Gets the rockset module and raises an ImportError if - the rockset package hasn't been installed + the rockset package hasn't been installed. Returns: rockset module (ModuleType) @@ -124,7 +124,7 @@ class RocksetVectorStore(VectorStore): return self.rs def add(self, nodes: List[BaseNode]) -> List[str]: - """Stores vectors in the collection + """Stores vectors in the collection. Args: nodes (List[BaseNode]): List of nodes with embeddings @@ -151,7 +151,7 @@ class RocksetVectorStore(VectorStore): ] def delete(self, ref_doc_id: str, **delete_kwargs: Any) -> None: - """Deletes nodes stored in the collection by their ref_doc_id + """Deletes nodes stored in the collection by their ref_doc_id. Args: ref_doc_id (str): The ref_doc_id of the document @@ -177,7 +177,7 @@ class RocksetVectorStore(VectorStore): ) def query(self, query: VectorStoreQuery, **kwargs: Any) -> VectorStoreQueryResult: - """Gets nodes relevant to a query + """Gets nodes relevant to a query. Args: query (llama_index.vector_stores.types.VectorStoreQuery): The query diff --git a/llama_index/vector_stores/tair.py b/llama_index/vector_stores/tair.py index e611be1cac..7eb1501253 100644 --- a/llama_index/vector_stores/tair.py +++ b/llama_index/vector_stores/tair.py @@ -122,7 +122,7 @@ class TairVectorStore(VectorStore): @property def client(self) -> "Tair": - """Return the Tair client instance""" + """Return the Tair client instance.""" return self._tair_client def add(self, nodes: List[BaseNode]) -> List[str]: diff --git a/llama_index/vector_stores/timescalevector.py b/llama_index/vector_stores/timescalevector.py index bb67bda141..ec90fdfc5f 100644 --- a/llama_index/vector_stores/timescalevector.py +++ b/llama_index/vector_stores/timescalevector.py @@ -15,7 +15,7 @@ from llama_index.vector_stores.utils import metadata_dict_to_node, node_to_metad class IndexType(enum.Enum): - """Enumerator for the supported Index types""" + """Enumerator for the supported Index types.""" TIMESCALE_VECTOR = 1 PGVECTOR_IVFFLAT = 2 diff --git a/pyproject.toml b/pyproject.toml index a02d1e4f79..95eb022344 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -18,7 +18,9 @@ exclude = [ "notebooks", ] ignore = [ - "COM812", + "COM812", # Too aggressive + "D212", # Using D213 + "D417", # Too aggressive "F541", # Messes with prompts.py "TCH002", "UP006", # Messes with pydantic @@ -31,22 +33,7 @@ select = [ "B011", "B013", "B014", - "C400", - "C401", - "C402", - "C403", - "C404", - "C405", - "C406", - "C408", - "C409", - "C410", - "C411", - "C413", - "C414", - "C416", - "C418", - "C419", + "C4", "COM812", "COM819", "D201", @@ -60,23 +47,9 @@ select = [ "D213", "D214", "D215", - "D400", - "D403", - "D405", - "D406", - "D407", - "D408", - "D409", - "D410", - "D411", - "D412", - "D413", - "D416", - "E703", - "E711", - "E712", - "E713", - "E714", + "D3", + "D4", + "E7", "EXE004", "F504", "F541", @@ -102,18 +75,7 @@ select = [ "PT006", "PT02", "PTH201", - "PYI009", - "PYI010", - "PYI011", - "PYI012", - "PYI014", - "PYI015", - "PYI020", - "PYI026", - "PYI029", - "PYI032", - "PYI053", - "PYI054", + "PYI", "Q", "RET501", "RET502", @@ -128,15 +90,7 @@ select = [ "SIM103", "SIM109", "SIM118", - "SIM201", - "SIM202", - "SIM208", - "SIM211", - "SIM212", - "SIM220", - "SIM221", - "SIM222", - "SIM223", + "SIM2", "SIM300", "SIM9", "TCH005", diff --git a/tests/conftest.py b/tests/conftest.py index 54f5d47b95..cc6dac336c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -122,7 +122,7 @@ class CachedOpenAIApiKeys: openai.api_key = "sk-" + "a" * 48 # No matter what, set the environment variable back to what it was - def __exit__(self, *exc: Any) -> None: + def __exit__(self, *exc: object) -> None: os.environ["OPENAI_API_KEY"] = str(self.api_env_variable_was) os.environ["OPENAI_API_TYPE"] = str(self.api_env_type_was) openai.api_key = self.openai_api_key_was diff --git a/tests/indices/postprocessor/test_base.py b/tests/indices/postprocessor/test_base.py index 2d64821c00..955002fb51 100644 --- a/tests/indices/postprocessor/test_base.py +++ b/tests/indices/postprocessor/test_base.py @@ -25,7 +25,7 @@ from llama_index.schema import ( ) from llama_index.storage.docstore.simple_docstore import SimpleDocumentStore -spacy_installed = True if find_spec("spacy") else False +spacy_installed = bool(find_spec("spacy")) def test_forward_back_processor(tmp_path: Path) -> None: diff --git a/tests/indices/query/test_embedding_utils.py b/tests/indices/query/test_embedding_utils.py index 2dbb83e8de..646650d6f6 100644 --- a/tests/indices/query/test_embedding_utils.py +++ b/tests/indices/query/test_embedding_utils.py @@ -1,4 +1,4 @@ -""" Test embedding utility functions""" +""" Test embedding utility functions.""" import numpy as np from llama_index.indices.query.embedding_utils import ( @@ -8,7 +8,7 @@ from llama_index.indices.query.embedding_utils import ( def test_get_top_k_mmr_embeddings() -> None: - """Test Maximum Marginal Relevance""" + """Test Maximum Marginal Relevance.""" # Results score should follow from the mmr algorithm query_embedding = [5.0, 0.0, 0.0] embeddings = [[4.0, 3.0, 0.0], [3.0, 4.0, 0.0], [-4.0, 3.0, 0.0]] diff --git a/tests/indices/tree/test_index.py b/tests/indices/tree/test_index.py index 9cb651ba2d..629ca8d8d1 100644 --- a/tests/indices/tree/test_index.py +++ b/tests/indices/tree/test_index.py @@ -197,7 +197,7 @@ def test_insert( def test_twice_insert_empty( mock_service_context: ServiceContext, ) -> None: - """# test twice insert from empty (with_id)""" + """# test twice insert from empty (with_id).""" tree = TreeIndex.from_documents([], service_context=mock_service_context) # test first insert diff --git a/tests/llm_predictor/vellum/test_predictor.py b/tests/llm_predictor/vellum/test_predictor.py index cdf5f16ecd..938111fa1b 100644 --- a/tests/llm_predictor/vellum/test_predictor.py +++ b/tests/llm_predictor/vellum/test_predictor.py @@ -11,7 +11,7 @@ def test_predict__basic( vellum_predictor_factory: Callable[..., VellumPredictor], dummy_prompt: BasePromptTemplate, ) -> None: - """When the Vellum API returns expected values, so should our predictor""" + """When the Vellum API returns expected values, so should our predictor.""" vellum_client = mock_vellum_client_factory( compiled_prompt_text="What's you're favorite greeting?", completion_text="Hello, world!", @@ -29,7 +29,7 @@ def test_stream__basic( vellum_predictor_factory: Callable[..., VellumPredictor], dummy_prompt: BasePromptTemplate, ) -> None: - """When the Vellum API streams expected values, so should our predictor""" + """When the Vellum API streams expected values, so should our predictor.""" import vellum vellum_client = mock_vellum_client_factory( diff --git a/tests/llm_predictor/vellum/test_prompt_registry.py b/tests/llm_predictor/vellum/test_prompt_registry.py index 0e845bb4c2..0e3146425f 100644 --- a/tests/llm_predictor/vellum/test_prompt_registry.py +++ b/tests/llm_predictor/vellum/test_prompt_registry.py @@ -13,7 +13,7 @@ def test_from_prompt__new( mock_vellum_client_factory: Callable[..., mock.MagicMock], vellum_prompt_registry_factory: Callable[..., VellumPromptRegistry], ) -> None: - """We should register a new prompt if no deployment exists""" + """We should register a new prompt if no deployment exists.""" from vellum.core import ApiError dummy_prompt = PromptTemplate(template="What's your favorite {thing}?") @@ -32,7 +32,7 @@ def test_from_prompt__existing( mock_vellum_client_factory: Callable[..., mock.MagicMock], vellum_prompt_registry_factory: Callable[..., VellumPromptRegistry], ) -> None: - """We shouldn't register a new prompt if a deployment id or name is provided""" + """We shouldn't register a new prompt if a deployment id or name is provided.""" dummy_prompt = PromptTemplate( template="What's your favorite {thing}?", metadata={"vellum_deployment_id": "abc"}, @@ -54,7 +54,7 @@ def test_get_compiled_prompt__basic( mock_vellum_client_factory: Callable[..., mock.MagicMock], vellum_prompt_registry_factory: Callable[..., VellumPromptRegistry], ) -> None: - """Verify that we can get a compiled prompt from the registry""" + """Verify that we can get a compiled prompt from the registry.""" registered_prompt = VellumRegisteredPrompt( deployment_id="abc", deployment_name="my-deployment", diff --git a/tests/llms/test_localai.py b/tests/llms/test_localai.py index a468ad65a4..5dafe59b49 100644 --- a/tests/llms/test_localai.py +++ b/tests/llms/test_localai.py @@ -77,9 +77,8 @@ def test_forgetting_kwarg() -> None: with patch( "llama_index.llms.openai.completion_with_retry", return_value={} - ) as mock_completion: - with pytest.raises(NotImplementedError, match="/chat/completions"): - llm.complete("A long time ago in a galaxy far, far away") + ) as mock_completion, pytest.raises(NotImplementedError, match="/chat/completions"): + llm.complete("A long time ago in a galaxy far, far away") mock_completion.assert_not_called() diff --git a/tests/llms/test_openai_utils.py b/tests/llms/test_openai_utils.py index 0a9d1f8844..b61491efb7 100644 --- a/tests/llms/test_openai_utils.py +++ b/tests/llms/test_openai_utils.py @@ -57,7 +57,7 @@ def openi_message_dicts_with_function_calling() -> List[dict]: def azure_openi_message_dicts_with_function_calling() -> List[dict]: """ Taken from: - - https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling + - https://learn.microsoft.com/en-us/azure/ai-services/openai/how-to/function-calling. """ return [ { diff --git a/tests/memory/test_chat_memory_buffer.py b/tests/memory/test_chat_memory_buffer.py index c263984268..7b116880d5 100644 --- a/tests/memory/test_chat_memory_buffer.py +++ b/tests/memory/test_chat_memory_buffer.py @@ -69,7 +69,7 @@ def test_dict_save_load() -> None: def test_pickle() -> None: - """Unpickleable tiktoken tokenizer should be circumvented when pickling""" + """Unpickleable tiktoken tokenizer should be circumvented when pickling.""" memory = ChatMemoryBuffer.from_defaults() bytes_ = pickle.dumps(memory) assert isinstance(pickle.loads(bytes_), ChatMemoryBuffer) diff --git a/tests/readers/test_file.py b/tests/readers/test_file.py index 27b12b9afe..4d1bee1d01 100644 --- a/tests/readers/test_file.py +++ b/tests/readers/test_file.py @@ -343,6 +343,5 @@ def test_error_if_not_dir_or_file() -> None: SimpleDirectoryReader("not_a_dir") with pytest.raises(ValueError, match="File"): SimpleDirectoryReader(input_files=["not_a_file"]) - with TemporaryDirectory() as tmp_dir: - with pytest.raises(ValueError, match="No files"): - SimpleDirectoryReader(tmp_dir) + with TemporaryDirectory() as tmp_dir, pytest.raises(ValueError, match="No files"): + SimpleDirectoryReader(tmp_dir) diff --git a/tests/text_splitter/test_code_splitter.py b/tests/text_splitter/test_code_splitter.py index 075469c56a..e0580d9a39 100644 --- a/tests/text_splitter/test_code_splitter.py +++ b/tests/text_splitter/test_code_splitter.py @@ -5,7 +5,7 @@ from llama_index.text_splitter import CodeSplitter def test_python_code_splitter() -> None: - """Test case for code splitting using python""" + """Test case for code splitting using python.""" if "CI" in os.environ: return @@ -26,7 +26,7 @@ def baz(): def test_typescript_code_splitter() -> None: - """Test case for code splitting using typescript""" + """Test case for code splitting using typescript.""" if "CI" in os.environ: return @@ -49,7 +49,7 @@ function baz() { def test_html_code_splitter() -> None: - """Test case for code splitting using typescript""" + """Test case for code splitting using typescript.""" if "CI" in os.environ: return @@ -82,7 +82,7 @@ def test_html_code_splitter() -> None: def test_tsx_code_splitter() -> None: - """Test case for code splitting using typescript""" + """Test case for code splitting using typescript.""" if "CI" in os.environ: return @@ -120,7 +120,7 @@ export default ExampleComponent;""" def test_cpp_code_splitter() -> None: - """Test case for code splitting using typescript""" + """Test case for code splitting using typescript.""" if "CI" in os.environ: return diff --git a/tests/text_splitter/test_sentence_splitter.py b/tests/text_splitter/test_sentence_splitter.py index 82b1df214d..928dcb30ab 100644 --- a/tests/text_splitter/test_sentence_splitter.py +++ b/tests/text_splitter/test_sentence_splitter.py @@ -58,7 +58,7 @@ def test_split_with_metadata(english_text: str) -> None: def test_edge_case() -> None: - """Test case from: https://github.com/jerryjliu/llama_index/issues/7287""" + """Test case from: https://github.com/jerryjliu/llama_index/issues/7287.""" text = "\n\nMarch 2020\n\nL&D Metric (Org) - 2.92%\n\n| Training Name | Category | Duration (hrs) | Invitees | Attendance | Target Training Hours | Actual Training Hours | Adoption % |\n| ---------------------------------------------------------------------------------------------------------------------- | --------------- | -------------- | -------- | ---------- | --------------------- | --------------------- | ---------- |\n| Overview of Data Analytics | Technical | 1 | 23 | 10 | 23 | 10 | 43.5 |\n| Sales & Learning Best Practices - Introduction to OTT Platforms | Technical | 0.5 | 16 | 12 | 8 | 6 | 75 |\n| Leading Through OKRs | Lifeskill | 1 | 1 | 1 | 1 | 1 | 100 |\n| COVID: Lockdown Awareness Session | Lifeskill | 2 | 1 | 1 | 2 | 2 | 100 |\n| Navgati Interview | Lifeskill | 2 | 6 | 6 | 12 | 12 | 100 |\n| leadership Summit | Leadership | 18 | 42 | 42 | 756 | 756 | 100 |\n| AWS - AI/ML - Online Conference | Project Related | 15 | 2 | 2 | 30 | 30 | 100 |\n" splitter = SentenceSplitter(tokenizer=tiktoken.get_encoding("gpt2").encode) chunks = splitter.split_text(text) @@ -109,7 +109,7 @@ def test_split_texts_multiple() -> None: def test_split_texts_with_metadata(english_text: str) -> None: - """Test case for a list of texts with metadata""" + """Test case for a list of texts with metadata.""" chunk_size = 100 metadata_str = "word " * 50 tokenizer = tiktoken.get_encoding("cl100k_base") diff --git a/tests/tools/test_utils.py b/tests/tools/test_utils.py index 6f5cc1ae36..f727d8400b 100644 --- a/tests/tools/test_utils.py +++ b/tests/tools/test_utils.py @@ -23,7 +23,7 @@ def test_create_schema_from_function() -> None: assert schema["properties"]["a"]["type"] == "boolean" def test_fn2(x: int = 1) -> None: - """Optional input""" + """Optional input.""" SchemaCls = create_schema_from_function("test_schema", test_fn2) schema = SchemaCls.schema() -- GitLab