diff --git a/.gitmodules b/.gitmodules new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/CHANGELOG.md b/CHANGELOG.md index eee2a842b65066cd1a22f4242b9c1191a2cb7edf..e0238f97ef1c50a448377f58409c5c22ab5ce35a 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -1,5 +1,18 @@ # ChangeLog +## [0.9.0] - 2023-11-15 + +### New Features / Breaking Changes / Deprecations + +- New `IngestionPipline` concept for ingesting and transforming data +- Data ingestion and transforms are now automatically cached +- Updated interface for node parsing/text splitting/metadata extraction modules +- Changes to the default tokenizer, as well as customizing the tokenizer +- Packaging/Installation changes with PyPi (reduced bloat, new install options) +- More predictable and consistent import paths +- Plus, in beta: MultiModal RAG Modules for handling text and images! +- Find more details at: https://pretty-sodium-5e0.notion.site/Alpha-Preview-LlamaIndex-v0-9-8f815bfdd4c346c1a696e013fccefe5e + ## [0.8.69.post1] - 2023-11-13 ### Bug Fixes / Nits diff --git a/README.md b/README.md index 89df1f3030314a58958facf0c3a9bbbaf0c0eb86..22f679e564be5fdc106bd16310d04d0cc56974fb 100644 --- a/README.md +++ b/README.md @@ -95,6 +95,14 @@ llm = Replicate( additional_kwargs={"top_p": 1, "max_new_tokens": 300}, ) +# set tokenizer to match LLM +from llama_index import set_global_tokenizer +from transformers import AutoTokenizer + +set_global_tokenizer( + AutoTokenizer.from_pretrained("NousResearch/Llama-2-7b-chat-hf").encode +) + from llama_index.embeddings import HuggingFaceEmbedding from llama_index import ServiceContext diff --git a/benchmarks/struct_indices/spider/generate_sql.py b/benchmarks/struct_indices/spider/generate_sql.py index 399253c358bbf8dff4c1add91541b48195bc31cc..0c8b9d0d01df1929f10a0bda7ce5a9107711c4bb 100644 --- a/benchmarks/struct_indices/spider/generate_sql.py +++ b/benchmarks/struct_indices/spider/generate_sql.py @@ -9,7 +9,8 @@ from typing import Any, cast from sqlalchemy import create_engine, text from tqdm import tqdm -from llama_index import LLMPredictor, SQLDatabase, SQLStructStoreIndex +from llama_index import LLMPredictor, SQLDatabase +from llama_index.indices import SQLStructStoreIndex from llama_index.llms.openai import OpenAI logging.getLogger("root").setLevel(logging.WARNING) diff --git a/benchmarks/struct_indices/spider/spider_utils.py b/benchmarks/struct_indices/spider/spider_utils.py index 7c90611fc14347062d2aa2cb3c0c7381af9ec9bc..4ab08aeedec00e57587ebc447cfdb41166df07b7 100644 --- a/benchmarks/struct_indices/spider/spider_utils.py +++ b/benchmarks/struct_indices/spider/spider_utils.py @@ -6,7 +6,8 @@ from typing import Dict, Tuple from sqlalchemy import create_engine, text -from llama_index import LLMPredictor, SQLDatabase, SQLStructStoreIndex +from llama_index import LLMPredictor, SQLDatabase +from llama_index.indices import SQLStructStoreIndex from llama_index.llms.openai import OpenAI diff --git a/docs/api_reference/service_context/node_parser.rst b/docs/api_reference/service_context/node_parser.rst index 598af8cabcf49ce006d9c4280f5a6c56c626ebcd..f005686dba85cec7625c9393361c7b7635b4389a 100644 --- a/docs/api_reference/service_context/node_parser.rst +++ b/docs/api_reference/service_context/node_parser.rst @@ -5,8 +5,6 @@ Node Parser :members: :inherited-members: -.. autopydantic_model:: llama_index.node_parser.extractors.metadata_extractors.MetadataExtractor - .. autopydantic_model:: llama_index.node_parser.extractors.metadata_extractors.SummaryExtractor .. autopydantic_model:: llama_index.node_parser.extractors.metadata_extractors.QuestionsAnsweredExtractor @@ -17,4 +15,4 @@ Node Parser .. autopydantic_model:: llama_index.node_parser.extractors.metadata_extractors.EntityExtractor -.. autopydantic_model:: llama_index.node_parser.extractors.metadata_extractors.MetadataFeatureExtractor +.. autopydantic_model:: llama_index.node_parser.extractors.metadata_extractors.BaseExtractor diff --git a/docs/examples/agent/multi_document_agents-v1.ipynb b/docs/examples/agent/multi_document_agents-v1.ipynb index 21fb189433734647c4634ff2007353dcd15098a4..40c470abe518c29290ebc950bf52f7f72d7bed7a 100644 --- a/docs/examples/agent/multi_document_agents-v1.ipynb +++ b/docs/examples/agent/multi_document_agents-v1.ipynb @@ -264,7 +264,7 @@ "from llama_index.agent import OpenAIAgent\n", "from llama_index import load_index_from_storage, StorageContext\n", "from llama_index.tools import QueryEngineTool, ToolMetadata\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "import os\n", "from tqdm.notebook import tqdm\n", "import pickle\n", @@ -341,7 +341,7 @@ "\n", "\n", "async def build_agents(docs):\n", - " node_parser = SimpleNodeParser.from_defaults()\n", + " node_parser = SentenceSplitter()\n", "\n", " # Build agents dictionary\n", " agents_dict = {}\n", @@ -446,7 +446,7 @@ " ObjectRetriever,\n", ")\n", "from llama_index.retrievers import BaseRetriever\n", - "from llama_index.indices.postprocessor import CohereRerank\n", + "from llama_index.postprocessor import CohereRerank\n", "from llama_index.tools import QueryPlanTool\n", "from llama_index.query_engine import SubQuestionQueryEngine\n", "from llama_index.llms import OpenAI\n", diff --git a/docs/examples/agent/multi_document_agents.ipynb b/docs/examples/agent/multi_document_agents.ipynb index 66e7f5b10bace3f3d5576ccc6ac9f8b49bc1bf47..be544ddb23ac9dcf783606a372656462b93fedb2 100644 --- a/docs/examples/agent/multi_document_agents.ipynb +++ b/docs/examples/agent/multi_document_agents.ipynb @@ -213,10 +213,10 @@ "source": [ "from llama_index.agent import OpenAIAgent\n", "from llama_index import load_index_from_storage, StorageContext\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "import os\n", "\n", - "node_parser = SimpleNodeParser.from_defaults()\n", + "node_parser = SentenceSplitter()\n", "\n", "# Build agents dictionary\n", "agents = {}\n", diff --git a/docs/examples/agent/openai_agent_query_cookbook.ipynb b/docs/examples/agent/openai_agent_query_cookbook.ipynb index 6465110ca604d5efe34d3ac5844834098cbc8946..8b50df6f54c6fb7eccdf3cb8f6dcc8b28e4c8c6d 100644 --- a/docs/examples/agent/openai_agent_query_cookbook.ipynb +++ b/docs/examples/agent/openai_agent_query_cookbook.ipynb @@ -677,19 +677,17 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", "from llama_index import ServiceContext\n", "from llama_index.storage import StorageContext\n", "from llama_index.vector_stores import PineconeVectorStore\n", - "from llama_index.text_splitter import TokenTextSplitter\n", + "from llama_index.node_parser import TokenTextSplitter\n", "from llama_index.llms import OpenAI\n", "\n", "# define node parser and LLM\n", "chunk_size = 1024\n", "llm = OpenAI(temperature=0, model=\"gpt-4\")\n", "service_context = ServiceContext.from_defaults(chunk_size=chunk_size, llm=llm)\n", - "text_splitter = TokenTextSplitter(chunk_size=chunk_size)\n", - "node_parser = SimpleNodeParser.from_defaults(text_splitter=text_splitter)\n", + "node_parser = TokenTextSplitter(chunk_size=chunk_size)\n", "\n", "# define pinecone vector index\n", "vector_store = PineconeVectorStore(\n", diff --git a/docs/examples/callbacks/OpenInferenceCallback.ipynb b/docs/examples/callbacks/OpenInferenceCallback.ipynb index 7403cf1e2a2e13b363ef96f7eda12e7453289fee..f6ff52b07a0429b50d250b8a12132381c1075407 100644 --- a/docs/examples/callbacks/OpenInferenceCallback.ipynb +++ b/docs/examples/callbacks/OpenInferenceCallback.ipynb @@ -514,7 +514,7 @@ } ], "source": [ - "parser = SimpleNodeParser.from_defaults()\n", + "parser = SentenceSplitter()\n", "nodes = parser.get_nodes_from_documents(documents)\n", "print(nodes[0].text)" ] diff --git a/docs/examples/docstore/DocstoreDemo.ipynb b/docs/examples/docstore/DocstoreDemo.ipynb index bfd7d15a116da630ef1d3d347c1859b53dae4df3..788c2a397048691cd1309d6f6969fa47d9ede2c1 100644 --- a/docs/examples/docstore/DocstoreDemo.ipynb +++ b/docs/examples/docstore/DocstoreDemo.ipynb @@ -120,9 +120,9 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", - "nodes = SimpleNodeParser.from_defaults().get_nodes_from_documents(documents)" + "nodes = SentenceSplitter().get_nodes_from_documents(documents)" ] }, { diff --git a/docs/examples/docstore/DynamoDBDocstoreDemo.ipynb b/docs/examples/docstore/DynamoDBDocstoreDemo.ipynb index 96259a71d19c8693aa27443189482b7af63b6555..e7f895c00594e0aae42aa2de508ab1a9c6432ac0 100644 --- a/docs/examples/docstore/DynamoDBDocstoreDemo.ipynb +++ b/docs/examples/docstore/DynamoDBDocstoreDemo.ipynb @@ -125,9 +125,9 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", - "nodes = SimpleNodeParser().get_nodes_from_documents(documents)" + "nodes = SentenceSplitter().get_nodes_from_documents(documents)" ] }, { diff --git a/docs/examples/docstore/FirestoreDemo.ipynb b/docs/examples/docstore/FirestoreDemo.ipynb index f7a62e488713036d6b8b86849a57ebd86e61a5dd..f2c9a9343aa87d0f90c0697043a93d6da177313f 100644 --- a/docs/examples/docstore/FirestoreDemo.ipynb +++ b/docs/examples/docstore/FirestoreDemo.ipynb @@ -112,9 +112,9 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", - "nodes = SimpleNodeParser.from_defaults().get_nodes_from_documents(documents)" + "nodes = SentenceSplitter().get_nodes_from_documents(documents)" ] }, { diff --git a/docs/examples/docstore/MongoDocstoreDemo.ipynb b/docs/examples/docstore/MongoDocstoreDemo.ipynb index 52feb99c5268d311cf334522c3dba623404710e2..2f45117a545e54a147cb2ad557291e6c40f4e682 100644 --- a/docs/examples/docstore/MongoDocstoreDemo.ipynb +++ b/docs/examples/docstore/MongoDocstoreDemo.ipynb @@ -127,9 +127,9 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", - "nodes = SimpleNodeParser.from_defaults().get_nodes_from_documents(documents)" + "nodes = SentenceSplitter().get_nodes_from_documents(documents)" ] }, { diff --git a/docs/examples/docstore/RedisDocstoreIndexStoreDemo.ipynb b/docs/examples/docstore/RedisDocstoreIndexStoreDemo.ipynb index 0141208048db7cf044624b471b44954aa147fcb5..3ed734ba028dfae732cf3c02f3bc892fc76b2778 100644 --- a/docs/examples/docstore/RedisDocstoreIndexStoreDemo.ipynb +++ b/docs/examples/docstore/RedisDocstoreIndexStoreDemo.ipynb @@ -155,9 +155,9 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", - "nodes = SimpleNodeParser.from_defaults().get_nodes_from_documents(documents)" + "nodes = SentenceSplitter().get_nodes_from_documents(documents)" ] }, { diff --git a/docs/examples/evaluation/HotpotQADistractor.ipynb b/docs/examples/evaluation/HotpotQADistractor.ipynb index afa9ce612b05d1951e2f2d2511b17577b6381f9b..a62f910db0d55eda41fd1192bc4817d6331a6e57 100644 --- a/docs/examples/evaluation/HotpotQADistractor.ipynb +++ b/docs/examples/evaluation/HotpotQADistractor.ipynb @@ -171,7 +171,7 @@ } ], "source": [ - "from llama_index.indices.postprocessor import SentenceTransformerRerank\n", + "from llama_index.postprocessor import SentenceTransformerRerank\n", "\n", "rerank = SentenceTransformerRerank(top_n=3)\n", "\n", diff --git a/docs/examples/evaluation/retrieval/retriever_eval.ipynb b/docs/examples/evaluation/retrieval/retriever_eval.ipynb index 559f25aa291479c6f66035d9bd171e139cffbc4a..bec863dfa7393603abb60e5d316ed4589fe177fa 100644 --- a/docs/examples/evaluation/retrieval/retriever_eval.ipynb +++ b/docs/examples/evaluation/retrieval/retriever_eval.ipynb @@ -54,7 +54,7 @@ "source": [ "from llama_index.evaluation import generate_question_context_pairs\n", "from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.llms import OpenAI" ] }, @@ -95,7 +95,7 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults(chunk_size=512)\n", + "node_parser = SentenceSplitter(chunk_size=512)\n", "nodes = node_parser.get_nodes_from_documents(documents)" ] }, diff --git a/docs/examples/finetuning/cross_encoder_finetuning/cross_encoder_finetuning.ipynb b/docs/examples/finetuning/cross_encoder_finetuning/cross_encoder_finetuning.ipynb index 9f68b7c3daf59e3e1148fbff0ed54b5be37c20f2..1fdfdaf2ed7c0fbffbd812d38ce0637d63dc44d1 100644 --- a/docs/examples/finetuning/cross_encoder_finetuning/cross_encoder_finetuning.ipynb +++ b/docs/examples/finetuning/cross_encoder_finetuning/cross_encoder_finetuning.ipynb @@ -1040,7 +1040,7 @@ "# We evaluate by calculating hits for each (question, context) pair,\n", "# we retrieve top-k documents with the question, and\n", "# it’s a hit if the results contain the context\n", - "from llama_index.indices.postprocessor import SentenceTransformerRerank\n", + "from llama_index.postprocessor import SentenceTransformerRerank\n", "from llama_index import (\n", " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", @@ -1716,7 +1716,7 @@ } ], "source": [ - "from llama_index.indices.postprocessor import SentenceTransformerRerank\n", + "from llama_index.postprocessor import SentenceTransformerRerank\n", "from llama_index import (\n", " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", @@ -1905,7 +1905,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.indices.postprocessor import SentenceTransformerRerank\n", + "from llama_index.postprocessor import SentenceTransformerRerank\n", "from llama_index import (\n", " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", diff --git a/docs/examples/finetuning/embeddings/finetune_embedding.ipynb b/docs/examples/finetuning/embeddings/finetune_embedding.ipynb index b5678970101ccca472ebd76a318b67fffc204970..6b95bc00275d1a3b0c5d4b356d66cfb7dd2ee541 100644 --- a/docs/examples/finetuning/embeddings/finetune_embedding.ipynb +++ b/docs/examples/finetuning/embeddings/finetune_embedding.ipynb @@ -44,7 +44,7 @@ "import json\n", "\n", "from llama_index import SimpleDirectoryReader\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.schema import MetadataMode" ] }, @@ -99,7 +99,7 @@ " if verbose:\n", " print(f\"Loaded {len(docs)} docs\")\n", "\n", - " parser = SimpleNodeParser.from_defaults()\n", + " parser = SentenceSplitter()\n", " nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n", "\n", " if verbose:\n", diff --git a/docs/examples/finetuning/embeddings/finetune_embedding_adapter.ipynb b/docs/examples/finetuning/embeddings/finetune_embedding_adapter.ipynb index a79843af5bd0c55cb7e751e880f56624903db20c..2f6558712e4d2861faa3c35febb3734538fbbbca 100644 --- a/docs/examples/finetuning/embeddings/finetune_embedding_adapter.ipynb +++ b/docs/examples/finetuning/embeddings/finetune_embedding_adapter.ipynb @@ -47,7 +47,7 @@ "import json\n", "\n", "from llama_index import SimpleDirectoryReader\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.schema import MetadataMode" ] }, @@ -102,7 +102,7 @@ " if verbose:\n", " print(f\"Loaded {len(docs)} docs\")\n", "\n", - " parser = SimpleNodeParser.from_defaults()\n", + " parser = SentenceSplitter()\n", " nodes = parser.get_nodes_from_documents(docs, show_progress=verbose)\n", "\n", " if verbose:\n", diff --git a/docs/examples/finetuning/knowledge/finetune_knowledge.ipynb b/docs/examples/finetuning/knowledge/finetune_knowledge.ipynb index 55965fbc228f8f61f2753f77d4af8b6a981f1647..f467525a32824a67261cb7a3bccf6a4721b5c265 100644 --- a/docs/examples/finetuning/knowledge/finetune_knowledge.ipynb +++ b/docs/examples/finetuning/knowledge/finetune_knowledge.ipynb @@ -177,7 +177,7 @@ "outputs": [], "source": [ "from llama_index.evaluation import DatasetGenerator\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", "# try evaluation modules\n", "from llama_index.evaluation import RelevancyEvaluator, FaithfulnessEvaluator\n", @@ -191,7 +191,7 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults()\n", + "node_parser = SentenceSplitter()\n", "nodes = node_parser.get_nodes_from_documents(docs)" ] }, diff --git a/docs/examples/finetuning/knowledge/finetune_retrieval_aug.ipynb b/docs/examples/finetuning/knowledge/finetune_retrieval_aug.ipynb index 67bdd5c6f3980e2d645c19939d5c82752ab0ab3f..dacba8d7a4eec3db5cb98fa6df82bfb248c5a1a1 100644 --- a/docs/examples/finetuning/knowledge/finetune_retrieval_aug.ipynb +++ b/docs/examples/finetuning/knowledge/finetune_retrieval_aug.ipynb @@ -159,7 +159,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index import VectorStoreIndex" ] }, @@ -170,7 +170,7 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults()\n", + "node_parser = SentenceSplitter()\n", "nodes = node_parser.get_nodes_from_documents(docs)" ] }, diff --git a/docs/examples/finetuning/openai_fine_tuning_functions.ipynb b/docs/examples/finetuning/openai_fine_tuning_functions.ipynb index 419449db27c5c63d63dc259d122172fce15af88d..879527fe0e62726368637c1b100d16f60b6330f6 100644 --- a/docs/examples/finetuning/openai_fine_tuning_functions.ipynb +++ b/docs/examples/finetuning/openai_fine_tuning_functions.ipynb @@ -460,7 +460,7 @@ "source": [ "from llama_hub.file.pymu_pdf.base import PyMuPDFReader\n", "from llama_index import Document, ServiceContext\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from pathlib import Path" ] }, @@ -494,7 +494,7 @@ "outputs": [], "source": [ "chunk_size = 1024\n", - "node_parser = SimpleNodeParser.from_defaults(chunk_size=chunk_size)\n", + "node_parser = SentenceSplitter(chunk_size=chunk_size)\n", "nodes = node_parser.get_nodes_from_documents(docs)" ] }, diff --git a/docs/examples/index_structs/knowledge_graph/KnowledgeGraphDemo.ipynb b/docs/examples/index_structs/knowledge_graph/KnowledgeGraphDemo.ipynb index e3932abbeb5c6ab94bb0d6421bd243f5d5a01172..bb238dec752d906e9e8f22c8a960ac2f6e2cc515 100644 --- a/docs/examples/index_structs/knowledge_graph/KnowledgeGraphDemo.ipynb +++ b/docs/examples/index_structs/knowledge_graph/KnowledgeGraphDemo.ipynb @@ -431,7 +431,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser" + "from llama_index.node_parser import SentenceSplitter" ] }, { @@ -441,7 +441,7 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults()" + "node_parser = SentenceSplitter()" ] }, { diff --git a/docs/examples/index_structs/knowledge_graph/KuzuGraphDemo.ipynb b/docs/examples/index_structs/knowledge_graph/KuzuGraphDemo.ipynb index 721455d6992fcb6003a809458803b664d45fbdbc..a8e18dea3b4dcc83bd3bace0b8891eedebee93dc 100644 --- a/docs/examples/index_structs/knowledge_graph/KuzuGraphDemo.ipynb +++ b/docs/examples/index_structs/knowledge_graph/KuzuGraphDemo.ipynb @@ -549,7 +549,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser" + "from llama_index.node_parser import SentenceSplitter" ] }, { @@ -559,7 +559,7 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults()" + "node_parser = SentenceSplitter()" ] }, { diff --git a/docs/examples/index_structs/knowledge_graph/NebulaGraphKGIndexDemo.ipynb b/docs/examples/index_structs/knowledge_graph/NebulaGraphKGIndexDemo.ipynb index 4441284317600f3002fa36845afe1964dbe23589..68fa6426c8fe23e0a287547cff472b0df6608213 100644 --- a/docs/examples/index_structs/knowledge_graph/NebulaGraphKGIndexDemo.ipynb +++ b/docs/examples/index_structs/knowledge_graph/NebulaGraphKGIndexDemo.ipynb @@ -984,7 +984,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser" + "from llama_index.node_parser import SentenceSplitter" ] }, { @@ -994,7 +994,7 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults()" + "node_parser = SentenceSplitter()" ] }, { diff --git a/docs/examples/index_structs/knowledge_graph/Neo4jKGIndexDemo.ipynb b/docs/examples/index_structs/knowledge_graph/Neo4jKGIndexDemo.ipynb index 9cc37e0b8ab1bee02de7384c56caf5686d0b0a0f..82272d9a40158b3bd47ff04d00e5370d91d60722 100644 --- a/docs/examples/index_structs/knowledge_graph/Neo4jKGIndexDemo.ipynb +++ b/docs/examples/index_structs/knowledge_graph/Neo4jKGIndexDemo.ipynb @@ -470,7 +470,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser" + "from llama_index.node_parser import SentenceSplitter" ] }, { @@ -480,7 +480,7 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults()" + "node_parser = SentenceSplitter()" ] }, { diff --git a/docs/examples/llm/huggingface.ipynb b/docs/examples/llm/huggingface.ipynb index a7bf9ad3f708bd845caed46c95c7e42cafee7519..5b433e3a3ce44584ece81834b3998a9c9f531655 100644 --- a/docs/examples/llm/huggingface.ipynb +++ b/docs/examples/llm/huggingface.ipynb @@ -157,6 +157,29 @@ "print(completion_response)" ] }, + { + "cell_type": "markdown", + "id": "dda1be10", + "metadata": {}, + "source": [ + "If you are modifying the LLM, you should also change the global tokenizer to match!" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "12e0f3c0", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index import set_global_tokenizer\n", + "from transformers import AutoTokenizer\n", + "\n", + "set_global_tokenizer(\n", + " AutoTokenizer.from_pretrained(\"HuggingFaceH4/zephyr-7b-alpha\").encode\n", + ")" + ] + }, { "cell_type": "markdown", "id": "3fa723d6-4308-4d94-9609-8c51ce8184c3", diff --git a/docs/examples/llm/llama_2_llama_cpp.ipynb b/docs/examples/llm/llama_2_llama_cpp.ipynb index e56fb7273fc49fd899e7aa9315d0083cfeaa0119..c342bdc726c7135e18462836d1b80b27b4c7f349 100644 --- a/docs/examples/llm/llama_2_llama_cpp.ipynb +++ b/docs/examples/llm/llama_2_llama_cpp.ipynb @@ -348,7 +348,24 @@ "source": [ "## Query engine set up with LlamaCPP\n", "\n", - "We can simply pass in the `LlamaCPP` LLM abstraction to the `LlamaIndex` query engine as usual:" + "We can simply pass in the `LlamaCPP` LLM abstraction to the `LlamaIndex` query engine as usual.\n", + "\n", + "But first, let's change the global tokenizer to match our LLM." + ] + }, + { + "cell_type": "code", + "execution_count": null, + "id": "d8ff0c0b", + "metadata": {}, + "outputs": [], + "source": [ + "from llama_index import set_global_tokenizer\n", + "from transformers import AutoTokenizer\n", + "\n", + "set_global_tokenizer(\n", + " AutoTokenizer.from_pretrained(\"NousResearch/Llama-2-7b-chat-hf\").encode\n", + ")" ] }, { diff --git a/docs/examples/low_level/evaluation.ipynb b/docs/examples/low_level/evaluation.ipynb index 31e0f513882dd26e6f23d2c73c4a51614c24e125..bd94324723722c2f98ecc7f272fed1c8084794a1 100644 --- a/docs/examples/low_level/evaluation.ipynb +++ b/docs/examples/low_level/evaluation.ipynb @@ -91,7 +91,7 @@ "outputs": [], "source": [ "from llama_index import VectorStoreIndex, ServiceContext\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.llms import OpenAI" ] }, @@ -103,7 +103,7 @@ "outputs": [], "source": [ "llm = OpenAI(model=\"gpt-4\")\n", - "node_parser = SimpleNodeParser.from_defaults(chunk_size=1024)\n", + "node_parser = SentenceSplitter(chunk_size=1024)\n", "service_context = ServiceContext.from_defaults(llm=llm)" ] }, diff --git a/docs/examples/low_level/ingestion.ipynb b/docs/examples/low_level/ingestion.ipynb index d35129d0598ab80222dbd5f93fb047e608a5faaa..5893a571b1f947cce4c077a488d771a70e87950f 100644 --- a/docs/examples/low_level/ingestion.ipynb +++ b/docs/examples/low_level/ingestion.ipynb @@ -357,7 +357,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.text_splitter import SentenceSplitter" + "from llama_index.node_parser import SentenceSplitter" ] }, { @@ -367,7 +367,7 @@ "metadata": {}, "outputs": [], "source": [ - "text_splitter = SentenceSplitter(\n", + "text_parser = SentenceSplitter(\n", " chunk_size=1024,\n", " # separator=\" \",\n", ")" @@ -401,7 +401,7 @@ "\n", "We inject metadata from the document into each node.\n", "\n", - "This essentially replicates logic in our `SimpleNodeParser`." + "This essentially replicates logic in our `SentenceSplitter`." ] }, { @@ -471,22 +471,19 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser.extractors import (\n", - " MetadataExtractor,\n", + "from llama_index.extractors import (\n", " QuestionsAnsweredExtractor,\n", " TitleExtractor,\n", ")\n", + "from llama_index.ingestion import IngestionPipeline\n", "from llama_index.llms import OpenAI\n", "\n", "llm = OpenAI(model=\"gpt-3.5-turbo\")\n", "\n", - "metadata_extractor = MetadataExtractor(\n", - " extractors=[\n", - " TitleExtractor(nodes=5, llm=llm),\n", - " QuestionsAnsweredExtractor(questions=3, llm=llm),\n", - " ],\n", - " in_place=False,\n", - ")" + "extractors = [\n", + " TitleExtractor(nodes=5, llm=llm),\n", + " QuestionsAnsweredExtractor(questions=3, llm=llm),\n", + "]" ] }, { @@ -496,7 +493,10 @@ "metadata": {}, "outputs": [], "source": [ - "nodes = metadata_extractor.process_nodes(nodes)" + "pipeline = IngestionPipeline(\n", + " transformations=extractors,\n", + ")\n", + "nodes = pipeline.run(nodes=nodes, in_place=False)" ] }, { diff --git a/docs/examples/low_level/oss_ingestion_retrieval.ipynb b/docs/examples/low_level/oss_ingestion_retrieval.ipynb index e398da900dac4d756f070b4b2de3456ac13f1aee..1ce6c1bcb07b5358840bc560479cb1f4eee105f4 100644 --- a/docs/examples/low_level/oss_ingestion_retrieval.ipynb +++ b/docs/examples/low_level/oss_ingestion_retrieval.ipynb @@ -297,7 +297,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.text_splitter import SentenceSplitter" + "from llama_index.node_parser.text import SentenceSplitter" ] }, { @@ -307,7 +307,7 @@ "metadata": {}, "outputs": [], "source": [ - "text_splitter = SentenceSplitter(\n", + "text_parser = SentenceSplitter(\n", " chunk_size=1024,\n", " # separator=\" \",\n", ")" @@ -324,7 +324,7 @@ "# maintain relationship with source doc index, to help inject doc metadata in (3)\n", "doc_idxs = []\n", "for doc_idx, doc in enumerate(documents):\n", - " cur_text_chunks = text_splitter.split_text(doc.text)\n", + " cur_text_chunks = text_parser.split_text(doc.text)\n", " text_chunks.extend(cur_text_chunks)\n", " doc_idxs.extend([doc_idx] * len(cur_text_chunks))" ] diff --git a/docs/examples/low_level/vector_store.ipynb b/docs/examples/low_level/vector_store.ipynb index e6e49c64ec3ef17f38fb371e74f03955580cafef..e1ebbe8bdb6c9edb6ef6846258679895b4d9fed8 100644 --- a/docs/examples/low_level/vector_store.ipynb +++ b/docs/examples/low_level/vector_store.ipynb @@ -91,9 +91,9 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", - "node_parser = SimpleNodeParser.from_defaults(chunk_size=256)\n", + "node_parser = SentenceSplitter(chunk_size=256)\n", "nodes = node_parser.get_nodes_from_documents(documents)" ] }, diff --git a/docs/examples/metadata_extraction/EntityExtractionClimate.ipynb b/docs/examples/metadata_extraction/EntityExtractionClimate.ipynb index 28d19105b46ac606d2c78c3dfe45fe8edb02bc10..78648d6f178368c8ff7d59a176b51061ef5694bc 100644 --- a/docs/examples/metadata_extraction/EntityExtractionClimate.ipynb +++ b/docs/examples/metadata_extraction/EntityExtractionClimate.ipynb @@ -83,11 +83,8 @@ } ], "source": [ - "from llama_index.node_parser.extractors.metadata_extractors import (\n", - " MetadataExtractor,\n", - " EntityExtractor,\n", - ")\n", - "from llama_index.node_parser.simple import SimpleNodeParser\n", + "from llama_index.extractors.metadata_extractors import EntityExtractor\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", "entity_extractor = EntityExtractor(\n", " prediction_threshold=0.5,\n", @@ -95,11 +92,9 @@ " device=\"cpu\", # set to \"cuda\" if you have a GPU\n", ")\n", "\n", - "metadata_extractor = MetadataExtractor(extractors=[entity_extractor])\n", + "node_parser = SentenceSplitter()\n", "\n", - "node_parser = SimpleNodeParser.from_defaults(\n", - " metadata_extractor=metadata_extractor\n", - ")" + "transformations = [node_parser, entity_extractor]" ] }, { @@ -179,6 +174,8 @@ } ], "source": [ + "from llama_index.ingestion import IngestionPipeline\n", + "\n", "import random\n", "\n", "random.seed(42)\n", @@ -186,7 +183,9 @@ "# 100 documents takes about 5 minutes on CPU\n", "documents = random.sample(documents, 100)\n", "\n", - "nodes = node_parser.get_nodes_from_documents(documents)" + "pipline = IngestionPipeline(transformations=transformations)\n", + "\n", + "nodes = pipline.run(documents=documents)" ] }, { diff --git a/docs/examples/metadata_extraction/MarvinMetadataExtractorDemo.ipynb b/docs/examples/metadata_extraction/MarvinMetadataExtractorDemo.ipynb index 1517306f864d43e1de77e71fd5b45d9ece1b2573..3c13c744516e7781fac906c9d4c7d02cbee787da 100644 --- a/docs/examples/metadata_extraction/MarvinMetadataExtractorDemo.ipynb +++ b/docs/examples/metadata_extraction/MarvinMetadataExtractorDemo.ipynb @@ -38,12 +38,8 @@ "from llama_index import SimpleDirectoryReader\n", "from llama_index.indices.service_context import ServiceContext\n", "from llama_index.llms import OpenAI\n", - "from llama_index.node_parser import SimpleNodeParser\n", - "from llama_index.node_parser.extractors import (\n", - " MetadataExtractor,\n", - ")\n", - "from llama_index.text_splitter import TokenTextSplitter\n", - "from llama_index.node_parser.extractors.marvin_metadata_extractor import (\n", + "from llama_index.node_parser import TokenTextSplitter\n", + "from llama_index.extractors.marvin_metadata_extractor import (\n", " MarvinMetadataExtractor,\n", ")" ] @@ -112,7 +108,7 @@ "# construct text splitter to split texts into chunks for processing\n", "# this takes a while to process, you can increase processing time by using larger chunk_size\n", "# file size is a factor too of course\n", - "text_splitter = TokenTextSplitter(\n", + "node_parser = TokenTextSplitter(\n", " separator=\" \", chunk_size=512, chunk_overlap=128\n", ")\n", "\n", @@ -122,22 +118,16 @@ "set_global_service_context(service_context)\n", "\n", "# create metadata extractor\n", - "metadata_extractor = MetadataExtractor(\n", - " extractors=[\n", - " MarvinMetadataExtractor(\n", - " marvin_model=SportsSupplement, llm_model_string=llm_model\n", - " ), # let's extract custom entities for each node.\n", - " ],\n", - ")\n", - "\n", - "# create node parser to parse nodes from document\n", - "node_parser = SimpleNodeParser(\n", - " text_splitter=text_splitter,\n", - " metadata_extractor=metadata_extractor,\n", - ")\n", + "metadata_extractor = MarvinMetadataExtractor(\n", + " marvin_model=SportsSupplement, llm_model_string=llm_model\n", + ") # let's extract custom entities for each node.\n", "\n", "# use node_parser to get nodes from the documents\n", - "nodes = node_parser.get_nodes_from_documents(documents)" + "from llama_index.ingestion import IngestionPipeline\n", + "\n", + "pipeline = IngestionPipeline(transformations=[node_parser, metadata_extractor])\n", + "\n", + "nodes = pipeline.run(documents=documents, show_progress=True)" ] }, { diff --git a/docs/examples/metadata_extraction/MetadataExtractionSEC.ipynb b/docs/examples/metadata_extraction/MetadataExtractionSEC.ipynb index 8e2b6aac0e0adda8a363763a8dea9e9516a326e7..15ee39c1a0dda240dcde18e1d1cac2ffef1f5252 100644 --- a/docs/examples/metadata_extraction/MetadataExtractionSEC.ipynb +++ b/docs/examples/metadata_extraction/MetadataExtractionSEC.ipynb @@ -20,7 +20,7 @@ "\n", "To combat this, we use LLMs to extract certain contextual information relevant to the document to better help the retrieval and language models disambiguate similar-looking passages.\n", "\n", - "We do this through our brand-new `MetadataExtractor` modules." + "We do this through our brand-new `Metadata Extractor` modules." ] }, { @@ -90,7 +90,7 @@ "We create a node parser that extracts the document title and hypothetical question embeddings relevant to the document chunk.\n", "\n", "We also show how to instantiate the `SummaryExtractor` and `KeywordExtractor`, as well as how to create your own custom extractor \n", - "based on the `MetadataFeatureExtractor` base class" + "based on the `BaseExtractor` base class" ] }, { @@ -100,15 +100,13 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", - "from llama_index.node_parser.extractors import (\n", - " MetadataExtractor,\n", + "from llama_index.extractors import (\n", " SummaryExtractor,\n", " QuestionsAnsweredExtractor,\n", " TitleExtractor,\n", " KeywordExtractor,\n", " EntityExtractor,\n", - " MetadataFeatureExtractor,\n", + " BaseExtractor,\n", ")\n", "from llama_index.text_splitter import TokenTextSplitter\n", "\n", @@ -117,7 +115,7 @@ ")\n", "\n", "\n", - "class CustomExtractor(MetadataFeatureExtractor):\n", + "class CustomExtractor(BaseExtractor):\n", " def extract(self, nodes):\n", " metadata_list = [\n", " {\n", @@ -132,21 +130,16 @@ " return metadata_list\n", "\n", "\n", - "metadata_extractor = MetadataExtractor(\n", - " extractors=[\n", - " TitleExtractor(nodes=5, llm=llm),\n", - " QuestionsAnsweredExtractor(questions=3, llm=llm),\n", - " # EntityExtractor(prediction_threshold=0.5),\n", - " # SummaryExtractor(summaries=[\"prev\", \"self\"], llm=llm),\n", - " # KeywordExtractor(keywords=10, llm=llm),\n", - " # CustomExtractor()\n", - " ],\n", - ")\n", + "extractors = [\n", + " TitleExtractor(nodes=5, llm=llm),\n", + " QuestionsAnsweredExtractor(questions=3, llm=llm),\n", + " # EntityExtractor(prediction_threshold=0.5),\n", + " # SummaryExtractor(summaries=[\"prev\", \"self\"], llm=llm),\n", + " # KeywordExtractor(keywords=10, llm=llm),\n", + " # CustomExtractor()\n", + "]\n", "\n", - "node_parser = SimpleNodeParser.from_defaults(\n", - " text_splitter=text_splitter,\n", - " metadata_extractor=metadata_extractor,\n", - ")" + "transformations = [text_splitter] + extractors" ] }, { @@ -201,7 +194,11 @@ "metadata": {}, "outputs": [], "source": [ - "uber_nodes = node_parser.get_nodes_from_documents(uber_docs)" + "from llama_index.ingestion import IngestionPipeline\n", + "\n", + "pipeline = IngestionPipeline(transformations=transformations)\n", + "\n", + "uber_nodes = pipeline.run(documents=uber_docs)" ] }, { @@ -251,7 +248,11 @@ "metadata": {}, "outputs": [], "source": [ - "lyft_nodes = node_parser.get_nodes_from_documents(lyft_docs)" + "from llama_index.ingestion import IngestionPipeline\n", + "\n", + "pipeline = IngestionPipeline(transformations=transformations)\n", + "\n", + "lyft_nodes = pipeline.run(documents=lyft_docs)" ] }, { @@ -298,7 +299,7 @@ "from llama_index.question_gen.prompts import DEFAULT_SUB_QUESTION_PROMPT_TMPL\n", "\n", "service_context = ServiceContext.from_defaults(\n", - " llm=llm, node_parser=node_parser\n", + " llm=llm, text_splitter=text_splitter\n", ")\n", "question_gen = LLMQuestionGenerator.from_defaults(\n", " service_context=service_context,\n", diff --git a/docs/examples/metadata_extraction/MetadataExtraction_LLMSurvey.ipynb b/docs/examples/metadata_extraction/MetadataExtraction_LLMSurvey.ipynb index 19d0249e625694166f119ae612e6548051cac146..9f49dc916d54851c529645cc0ea6c8000bd9ba69 100644 --- a/docs/examples/metadata_extraction/MetadataExtraction_LLMSurvey.ipynb +++ b/docs/examples/metadata_extraction/MetadataExtraction_LLMSurvey.ipynb @@ -84,7 +84,7 @@ "metadata": {}, "outputs": [], "source": [ - "os.environ[\"OPENAI_API_KEY\"] = \"YOUR_API_KEY_HERE\"\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n", "openai.api_key = os.environ[\"OPENAI_API_KEY\"]" ] }, @@ -137,38 +137,29 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", - "from llama_index.node_parser.extractors import (\n", - " MetadataExtractor,\n", + "from llama_index.node_parser import TokenTextSplitter\n", + "from llama_index.extractors import (\n", " SummaryExtractor,\n", " QuestionsAnsweredExtractor,\n", ")\n", - "from llama_index.text_splitter import TokenTextSplitter\n", "\n", - "text_splitter = TokenTextSplitter(\n", + "node_parser = TokenTextSplitter(\n", " separator=\" \", chunk_size=256, chunk_overlap=128\n", ")\n", "\n", "\n", - "metadata_extractor_1 = MetadataExtractor(\n", - " extractors=[\n", - " QuestionsAnsweredExtractor(questions=3, llm=llm),\n", - " ],\n", - " in_place=False,\n", - ")\n", - "\n", - "metadata_extractor = MetadataExtractor(\n", - " extractors=[\n", - " SummaryExtractor(summaries=[\"prev\", \"self\", \"next\"], llm=llm),\n", - " QuestionsAnsweredExtractor(questions=3, llm=llm),\n", - " ],\n", - " in_place=False,\n", - ")\n", + "extractors_1 = [\n", + " QuestionsAnsweredExtractor(\n", + " questions=3, llm=llm, metadata_mode=MetadataMode.EMBED\n", + " ),\n", + "]\n", "\n", - "node_parser = SimpleNodeParser.from_defaults(\n", - " text_splitter=text_splitter,\n", - " # metadata_extractor=metadata_extractor,\n", - ")" + "extractors_2 = [\n", + " SummaryExtractor(summaries=[\"prev\", \"self\", \"next\"], llm=llm),\n", + " QuestionsAnsweredExtractor(\n", + " questions=3, llm=llm, metadata_mode=MetadataMode.EMBED\n", + " ),\n", + "]" ] }, { @@ -292,7 +283,21 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "fc9d6aa373674dd79d293a55d9eec319", + "model_id": "e828522a65bb4304bdaeae041a2eee31", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Parsing documents into nodes: 0%| | 0/8 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "f8d5a4483c4e49f5b26afc869b3f39dd", "version_major": 2, "version_minor": 0 }, @@ -305,8 +310,12 @@ } ], "source": [ - "# process nodes with metadata extractor\n", - "nodes_1 = metadata_extractor_1.process_nodes(nodes)" + "from llama_index.ingestion import IngestionPipeline\n", + "\n", + "# process nodes with metadata extractors\n", + "pipeline = IngestionPipeline(transformations=[node_parser, *extractors_1])\n", + "\n", + "nodes_1 = pipeline.run(nodes=nodes, in_place=False, show_progress=True)" ] }, { @@ -320,9 +329,9 @@ "output_type": "stream", "text": [ "[Excerpt from document]\n", - "questions_this_excerpt_can_answer: 1. What is the correlation between conventional metrics like BLEU and ROUGE and human judgments of fluency and adequacy in natural language processing tasks?\n", - "2. How well do metrics like BLEU and ROUGE perform in tasks that require creativity and diversity?\n", - "3. Why are exact match metrics like BLEU and ROUGE not suitable for tasks like abstractive summarization or dialogue?\n", + "questions_this_excerpt_can_answer: 1. What is the correlation between conventional metrics like BLEU and ROUGE and human judgments in evaluating fluency and adequacy in natural language processing tasks?\n", + "2. How do conventional metrics like BLEU and ROUGE perform in tasks that require creativity and diversity?\n", + "3. Why are exact match metrics like BLEU and ROUGE not suitable for tasks like abstractive summarization or dialogue in natural language processing?\n", "Excerpt:\n", "-----\n", "is to measure the distance that words would\n", @@ -361,7 +370,21 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "4a05842ddcb541df9355b5b025768bed", + "model_id": "f74cdd7524064d42bb4c250db190b0f6", + "version_major": 2, + "version_minor": 0 + }, + "text/plain": [ + "Parsing documents into nodes: 0%| | 0/8 [00:00<?, ?it/s]" + ] + }, + "metadata": {}, + "output_type": "display_data" + }, + { + "data": { + "application/vnd.jupyter.widget-view+json": { + "model_id": "ec9a4e2ac9424fa4b425b08facf6adfe", "version_major": 2, "version_minor": 0 }, @@ -375,7 +398,7 @@ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "6200878d19b04a42926e798bcbaec509", + "model_id": "09dcebe7f8ee437d88792ddf1eecce14", "version_major": 2, "version_minor": 0 }, @@ -391,7 +414,9 @@ "# 2nd pass: run summaries, and then metadata extractor\n", "\n", "# process nodes with metadata extractor\n", - "nodes_2 = metadata_extractor.process_nodes(nodes)" + "pipeline = IngestionPipeline(transformations=[node_parser, *extractors_2])\n", + "\n", + "nodes_2 = pipeline.run(nodes=nodes, in_place=False, show_progress=True)" ] }, { @@ -415,10 +440,10 @@ "[Excerpt from document]\n", "prev_section_summary: The section discusses the comparison between BERTScore and MoverScore, two metrics used to evaluate the quality of text generation models. MoverScore is described as a metric that measures the effort required to transform one text sequence into another by mapping semantically related words. The section also highlights the limitations of conventional benchmarks and metrics, such as poor correlation with human judgments and low correlation with tasks requiring creativity.\n", "next_section_summary: The section discusses the limitations of current evaluation metrics in natural language processing tasks. It highlights three main issues: lack of creativity and diversity in metrics, poor adaptability to different tasks, and poor reproducibility. The section mentions specific metrics like BLEU and ROUGE, and also references studies that have reported high variance in metric scores.\n", - "section_summary: The section discusses the limitations of conventional benchmarks and metrics used to measure the distance between word sequences. It highlights two main issues: poor correlation with human judgments and poor adaptability to different tasks. The metrics like BLEU and ROUGE have been found to have low correlation with human evaluations of fluency and adequacy, as well as tasks requiring creativity and diversity. Additionally, these metrics are not suitable for tasks like abstractive summarization or dialogue due to their reliance on n-gram overlap.\n", + "section_summary: The section discusses the limitations of conventional benchmarks and metrics used to measure the distance between word sequences. It highlights two main issues: the poor correlation between these metrics and human judgments, and their limited adaptability to different tasks. The section mentions specific metrics like BLEU and ROUGE, which have been found to have low correlation with human evaluations of fluency, adequacy, creativity, and diversity. It also points out that metrics based on n-gram overlap, such as BLEU and ROUGE, are not suitable for tasks like abstractive summarization or dialogue.\n", "questions_this_excerpt_can_answer: 1. What are the limitations of conventional benchmarks and metrics in measuring the distance between word sequences?\n", - "2. How do metrics like BLEU and ROUGE correlate with human judgments of fluency and adequacy?\n", - "3. Why are metrics like BLEU and ROUGE not suitable for tasks like abstractive summarization or dialogue?\n", + "2. How do metrics like BLEU and ROUGE correlate with human judgments in terms of fluency, adequacy, creativity, and diversity?\n", + "3. Why are metrics based on n-gram overlap, such as BLEU and ROUGE, not suitable for tasks like abstractive summarization or dialogue?\n", "Excerpt:\n", "-----\n", "is to measure the distance that words would\n", @@ -525,17 +550,7 @@ "execution_count": null, "id": "bd729fb8-1e00-4cd0-9505-a86a7daa89d0", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[34m\u001b[1mwandb\u001b[0m: Logged trace tree to W&B.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Logged trace tree to W&B.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Logged trace tree to W&B.\n" - ] - } - ], + "outputs": [], "source": [ "# try out different query engines\n", "\n", @@ -583,17 +598,7 @@ "execution_count": null, "id": "1e1e448d-632c-42a0-ad60-4a315491945f", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[34m\u001b[1mwandb\u001b[0m: Logged trace tree to W&B.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Logged trace tree to W&B.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Logged trace tree to W&B.\n" - ] - } - ], + "outputs": [], "source": [ "# query_str = \"In the original RAG paper, can you describe the two main approaches for generation and compare them?\"\n", "query_str = (\n", @@ -691,17 +696,7 @@ "execution_count": null, "id": "5afe4f69-b676-43fd-bc15-39e18e94801f", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[34m\u001b[1mwandb\u001b[0m: Logged trace tree to W&B.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Logged trace tree to W&B.\n", - "\u001b[34m\u001b[1mwandb\u001b[0m: Logged trace tree to W&B.\n" - ] - } - ], + "outputs": [], "source": [ "# query_str = \"What are some reproducibility issues with the ROUGE metric? Give some details related to benchmarks and also describe other ROUGE issues. \"\n", "query_str = (\n", @@ -759,7 +754,7 @@ { "data": { "text/plain": [ - "{}" + "{'questions_this_excerpt_can_answer': '1. What is the advantage of using BERTScore over simpler metrics like BLEU and ROUGE?\\n2. How does MoverScore differ from BERTScore in terms of token matching?\\n3. What tasks have shown better correlation with BERTScore, such as image captioning and machine translation?'}" ] }, "execution_count": null, @@ -774,9 +769,9 @@ ], "metadata": { "kernelspec": { - "display_name": "llama_index_v2", + "display_name": "llama-index", "language": "python", - "name": "llama_index_v2" + "name": "llama-index" }, "language_info": { "codemirror_mode": { diff --git a/docs/examples/metadata_extraction/PydanticExtractor.ipynb b/docs/examples/metadata_extraction/PydanticExtractor.ipynb index cde6faea8047e4e6b73257983194e7e267b76ff9..c61469ff70702b0ee8a8b39a1443d7e18fe3e1c8 100644 --- a/docs/examples/metadata_extraction/PydanticExtractor.ipynb +++ b/docs/examples/metadata_extraction/PydanticExtractor.ipynb @@ -124,10 +124,7 @@ "outputs": [], "source": [ "from llama_index.program.openai_program import OpenAIPydanticProgram\n", - "from llama_index.node_parser.extractors import (\n", - " PydanticProgramExtractor,\n", - " MetadataExtractor,\n", - ")\n", + "from llama_index.extractors import PydanticProgramExtractor\n", "\n", "EXTRACT_TEMPLATE_STR = \"\"\"\\\n", "Here is the content of the section:\n", @@ -145,9 +142,7 @@ "\n", "program_extractor = PydanticProgramExtractor(\n", " program=openai_program, input_key=\"input\", show_progress=True\n", - ")\n", - "\n", - "metadata_extractor = MetadataExtractor(extractors=[program_extractor])" + ")" ] }, { @@ -170,7 +165,7 @@ "# load in blog\n", "\n", "from llama_hub.web.simple_web.base import SimpleWebPageReader\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", "reader = SimpleWebPageReader(html_to_text=True)\n", "docs = reader.load_data(urls=[\"https://eugeneyan.com/writing/llm-patterns/\"])" @@ -183,8 +178,13 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults(chunk_size=1024)\n", - "orig_nodes = node_parser.get_nodes_from_documents(docs)" + "from llama_index.ingestion import IngestionPipeline\n", + "\n", + "node_parser = SentenceSplitter(chunk_size=1024)\n", + "\n", + "pipline = IngestionPipeline(transformations=[node_parser, program_extractor])\n", + "\n", + "orig_nodes = pipline.run(documents=docs)" ] }, { @@ -312,7 +312,7 @@ } ], "source": [ - "new_nodes = metadata_extractor.process_nodes(orig_nodes)" + "new_nodes = program_extractor.process_nodes(orig_nodes)" ] }, { diff --git a/docs/examples/node_postprocessor/CohereRerank.ipynb b/docs/examples/node_postprocessor/CohereRerank.ipynb index 87dbb08a0f27491b5c54cde947e2c3124ec81f08..1fcacfe2eff98afa4118f4e8409214a3ced2e8ad 100644 --- a/docs/examples/node_postprocessor/CohereRerank.ipynb +++ b/docs/examples/node_postprocessor/CohereRerank.ipynb @@ -99,7 +99,7 @@ "outputs": [], "source": [ "import os\n", - "from llama_index.indices.postprocessor.cohere_rerank import CohereRerank\n", + "from llama_index.postprocessor.cohere_rerank import CohereRerank\n", "\n", "\n", "api_key = os.environ[\"COHERE_API_KEY\"]\n", diff --git a/docs/examples/node_postprocessor/FileNodeProcessors.ipynb b/docs/examples/node_postprocessor/FileNodeProcessors.ipynb index 45ce4adf087e8c2d64c493cc3528470a8881640e..3185c5d69134af1d9559ad0a34b7c56d76e4dea2 100644 --- a/docs/examples/node_postprocessor/FileNodeProcessors.ipynb +++ b/docs/examples/node_postprocessor/FileNodeProcessors.ipynb @@ -16,9 +16,9 @@ "source": [ "# File Based Node Parsers\n", "\n", - "The combination of the `SimpleFileNodeParser` and `FlatReader` are designed to allow opening a variety of file types and automatically selecting the best NodeParser to process the files. The `FlatReader` loads the file in a raw text format and attaches the file information to the metadata, then the `SimpleFileNodeParser` maps file types to node parsers in `node_parser/file`, selecting the best node parser for the job.\n", + "The `SimpleFileNodeParser` and `FlatReader` are designed to allow opening a variety of file types and automatically selecting the best `NodeParser` to process the files. The `FlatReader` loads the file in a raw text format and attaches the file information to the metadata, then the `SimpleFileNodeParser` maps file types to node parsers in `node_parser/file`, selecting the best node parser for the job.\n", "\n", - "The `SimpleFileNodeParser` does not perform token based chunking of the text, and one of the other node parsers, in particular ones that accept an instance of a `TextSplitter`, can be chained to further split the content. \n", + "The `SimpleFileNodeParser` does not perform token based chunking of the text, and is intended to be used in combination with a token node parser.\n", "\n", "Let's look at an example of using the `FlatReader` and `SimpleFileNodeParser` to load content. For the README file I will be using the LlamaIndex README and the HTML file is the Stack Overflow landing page, however any README and HTML file will work." ] @@ -60,7 +60,7 @@ } ], "source": [ - "from llama_index.node_parser.simple_file import SimpleFileNodeParser\n", + "from llama_index.node_parser.file import SimpleFileNodeParser\n", "from llama_index.readers.file.flat_reader import FlatReader\n", "from pathlib import Path" ] @@ -419,15 +419,13 @@ } ], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", - "from llama_index.text_splitter import SentenceSplitter\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", "# For clarity in the demo, make small splits without overlap\n", - "splitter = SentenceSplitter(chunk_size=200, chunk_overlap=0)\n", - "splitting_parser = SimpleNodeParser(text_splitter=splitter)\n", + "splitting_parser = SentenceSplitter(chunk_size=200, chunk_overlap=0)\n", "\n", - "html_chunked_nodes = splitting_parser.get_nodes_from_documents(html_nodes)\n", - "md_chunked_nodes = splitting_parser.get_nodes_from_documents(md_nodes)\n", + "html_chunked_nodes = splitting_parser(html_nodes)\n", + "md_chunked_nodes = splitting_parser(md_nodes)\n", "print(f\"\\n\\nHTML parsed nodes: {len(html_nodes)}\")\n", "print(html_nodes[0].text)\n", "\n", @@ -466,9 +464,17 @@ } ], "source": [ - "md_chunked_nodes = splitting_parser.get_nodes_from_documents(\n", - " parser.get_nodes_from_documents(reader.load_data(Path(\"./README.md\")))\n", + "from llama_index.ingestion import IngestionPipeline\n", + "\n", + "pipeline = IngestionPipeline(\n", + " documents=reader.load_data(Path(\"./README.md\")),\n", + " transformations=[\n", + " SimpleFileNodeParser(),\n", + " SentenceSplitter(chunk_size=200, chunk_overlap=0),\n", + " ],\n", ")\n", + "\n", + "md_chunked_nodes = pipeline.run()\n", "print(md_chunked_nodes)" ] } diff --git a/docs/examples/node_postprocessor/LLMReranker-Gatsby.ipynb b/docs/examples/node_postprocessor/LLMReranker-Gatsby.ipynb index 34d74da07293250e0de45b23c2c719f20488ed5b..7e9e6c61304a9e6185768d99a6e84144df0910af 100644 --- a/docs/examples/node_postprocessor/LLMReranker-Gatsby.ipynb +++ b/docs/examples/node_postprocessor/LLMReranker-Gatsby.ipynb @@ -43,7 +43,7 @@ " ServiceContext,\n", " LLMPredictor,\n", ")\n", - "from llama_index.indices.postprocessor import LLMRerank\n", + "from llama_index.postprocessor import LLMRerank\n", "from llama_index.llms import OpenAI\n", "from IPython.display import Markdown, display" ] @@ -148,7 +148,7 @@ ], "source": [ "from llama_index.retrievers import VectorIndexRetriever\n", - "from llama_index.indices.query.schema import QueryBundle\n", + "from llama_index.schema import QueryBundle\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "\n", diff --git a/docs/examples/node_postprocessor/LLMReranker-Lyft-10k.ipynb b/docs/examples/node_postprocessor/LLMReranker-Lyft-10k.ipynb index c93f255ad92092ec9eb65c970a9e7b29a6a63a20..00aec9e230a25eb0fec8cf5222a45bec03e4a3c4 100644 --- a/docs/examples/node_postprocessor/LLMReranker-Lyft-10k.ipynb +++ b/docs/examples/node_postprocessor/LLMReranker-Lyft-10k.ipynb @@ -52,7 +52,7 @@ " ServiceContext,\n", " LLMPredictor,\n", ")\n", - "from llama_index.indices.postprocessor import LLMRerank\n", + "from llama_index.postprocessor import LLMRerank\n", "\n", "from llama_index.llms import OpenAI\n", "from IPython.display import Markdown, display" @@ -171,7 +171,7 @@ ], "source": [ "from llama_index.retrievers import VectorIndexRetriever\n", - "from llama_index.indices.query.schema import QueryBundle\n", + "from llama_index.schema import QueryBundle\n", "import pandas as pd\n", "from IPython.display import display, HTML\n", "from copy import deepcopy\n", diff --git a/docs/examples/node_postprocessor/LongContextReorder.ipynb b/docs/examples/node_postprocessor/LongContextReorder.ipynb index ff1191547a19fedb9b2db28625d92476d870d195..6d76dfc59201f63033ccd59194137da4a3a13b07 100644 --- a/docs/examples/node_postprocessor/LongContextReorder.ipynb +++ b/docs/examples/node_postprocessor/LongContextReorder.ipynb @@ -133,7 +133,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.indices.postprocessor import LongContextReorder\n", + "from llama_index.postprocessor import LongContextReorder\n", "\n", "reorder = LongContextReorder()\n", "\n", diff --git a/docs/examples/node_postprocessor/LongLLMLingua.ipynb b/docs/examples/node_postprocessor/LongLLMLingua.ipynb index 0f7f6076b5313b16af67fc0a22934890fa98d130..2ea9e40b10596bbeefb7382a23cbae7f959750c2 100644 --- a/docs/examples/node_postprocessor/LongLLMLingua.ipynb +++ b/docs/examples/node_postprocessor/LongLLMLingua.ipynb @@ -175,7 +175,7 @@ "source": [ "from llama_index.query_engine import RetrieverQueryEngine\n", "from llama_index.response_synthesizers import CompactAndRefine\n", - "from llama_index.indices.postprocessor import LongLLMLinguaPostprocessor\n", + "from llama_index.postprocessor import LongLLMLinguaPostprocessor\n", "\n", "node_postprocessor = LongLLMLinguaPostprocessor(\n", " instruction_str=\"Given the context, please answer the final question\",\n", @@ -224,7 +224,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.indices.query.schema import QueryBundle\n", + "from llama_index.schema import QueryBundle\n", "\n", "# outline steps in RetrieverQueryEngine for clarity:\n", "# postprocess (compress), synthesize\n", diff --git a/docs/examples/node_postprocessor/MetadataReplacementDemo.ipynb b/docs/examples/node_postprocessor/MetadataReplacementDemo.ipynb index cc3fc230e48f1d3d68af3ae5de525091a5bbad7d..52f6fb1848e11e45f251ebd250b754d5e866ae9d 100644 --- a/docs/examples/node_postprocessor/MetadataReplacementDemo.ipynb +++ b/docs/examples/node_postprocessor/MetadataReplacementDemo.ipynb @@ -88,7 +88,10 @@ "from llama_index import ServiceContext, set_global_service_context\n", "from llama_index.llms import OpenAI\n", "from llama_index.embeddings import OpenAIEmbedding, HuggingFaceEmbedding\n", - "from llama_index.node_parser import SentenceWindowNodeParser, SimpleNodeParser\n", + "from llama_index.node_parser import (\n", + " SentenceWindowNodeParser,\n", + ")\n", + "from llama_index.text_splitter import SentenceSplitter\n", "\n", "# create the sentence window node parser w/ default settings\n", "node_parser = SentenceWindowNodeParser.from_defaults(\n", @@ -96,7 +99,9 @@ " window_metadata_key=\"window\",\n", " original_text_metadata_key=\"original_text\",\n", ")\n", - "simple_node_parser = SimpleNodeParser.from_defaults()\n", + "\n", + "# base node parser is a sentence splitter\n", + "text_splitter = SentenceSplitter()\n", "\n", "llm = OpenAI(model=\"gpt-3.5-turbo\", temperature=0.1)\n", "embed_model = HuggingFaceEmbedding(\n", @@ -187,7 +192,7 @@ "metadata": {}, "outputs": [], "source": [ - "base_nodes = simple_node_parser.get_nodes_from_documents(documents)" + "base_nodes = text_splitter.get_nodes_from_documents(documents)" ] }, { @@ -244,7 +249,7 @@ } ], "source": [ - "from llama_index.indices.postprocessor import MetadataReplacementPostProcessor\n", + "from llama_index.postprocessor import MetadataReplacementPostProcessor\n", "\n", "query_engine = sentence_index.as_query_engine(\n", " similarity_top_k=2,\n", diff --git a/docs/examples/node_postprocessor/OptimizerDemo.ipynb b/docs/examples/node_postprocessor/OptimizerDemo.ipynb index 74934346e004b111815d4f7e5c31ebc40151ed1d..e22f7b97f25e44a98b19744f580f66b3556d5210 100644 --- a/docs/examples/node_postprocessor/OptimizerDemo.ipynb +++ b/docs/examples/node_postprocessor/OptimizerDemo.ipynb @@ -184,7 +184,7 @@ "source": [ "import time\n", "from llama_index import VectorStoreIndex\n", - "from llama_index.indices.postprocessor import SentenceEmbeddingOptimizer\n", + "from llama_index.postprocessor import SentenceEmbeddingOptimizer\n", "\n", "print(\"Without optimization\")\n", "start_time = time.time()\n", diff --git a/docs/examples/node_postprocessor/PII.ipynb b/docs/examples/node_postprocessor/PII.ipynb index f9118ce83b003ad74d437459302113e46ab63662..658bc21245d3d8dbaad93d1d7c337f2162d133ec 100644 --- a/docs/examples/node_postprocessor/PII.ipynb +++ b/docs/examples/node_postprocessor/PII.ipynb @@ -69,7 +69,7 @@ "logging.basicConfig(stream=sys.stdout, level=logging.INFO)\n", "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))\n", "\n", - "from llama_index.indices.postprocessor import (\n", + "from llama_index.postprocessor import (\n", " PIINodePostprocessor,\n", " NERPIINodePostprocessor,\n", ")\n", diff --git a/docs/examples/node_postprocessor/PrevNextPostprocessorDemo.ipynb b/docs/examples/node_postprocessor/PrevNextPostprocessorDemo.ipynb index 59d805ddb4e846a5fec4bae21f2e8bc9e1bed414..7886e137d92731f0ecfde8822e867ca8aae43a04 100644 --- a/docs/examples/node_postprocessor/PrevNextPostprocessorDemo.ipynb +++ b/docs/examples/node_postprocessor/PrevNextPostprocessorDemo.ipynb @@ -46,11 +46,11 @@ "outputs": [], "source": [ "from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext\n", - "from llama_index.indices.postprocessor import (\n", + "from llama_index.postprocessor import (\n", " PrevNextNodePostprocessor,\n", " AutoPrevNextNodePostprocessor,\n", ")\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.storage.docstore import SimpleDocumentStore" ] }, diff --git a/docs/examples/node_postprocessor/RecencyPostprocessorDemo.ipynb b/docs/examples/node_postprocessor/RecencyPostprocessorDemo.ipynb index 8b13fe6df585a1104cff2f7e137ff28725de84b4..e4ab3ef526d8d7660ce81093b73d2de046eed349 100644 --- a/docs/examples/node_postprocessor/RecencyPostprocessorDemo.ipynb +++ b/docs/examples/node_postprocessor/RecencyPostprocessorDemo.ipynb @@ -28,11 +28,11 @@ ], "source": [ "from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext\n", - "from llama_index.indices.postprocessor import (\n", + "from llama_index.postprocessor import (\n", " FixedRecencyPostprocessor,\n", " EmbeddingRecencyPostprocessor,\n", ")\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.text_splitter import SentenceSplitter\n", "from llama_index.storage.docstore import SimpleDocumentStore\n", "from llama_index.response.notebook_utils import display_response" ] @@ -87,10 +87,11 @@ ").load_data()\n", "\n", "# define service context (wrapper container around current classes)\n", - "service_context = ServiceContext.from_defaults(chunk_size=512)\n", + "text_splitter = SentenceSplitter(chunk_size=512)\n", + "service_context = ServiceContext.from_defaults(text_splitter=text_splitter)\n", "\n", - "# use node parser in service context to parse into nodes\n", - "nodes = service_context.node_parser.get_nodes_from_documents(documents)\n", + "# use node parser to parse into nodes\n", + "nodes = text_splitter.get_nodes_from_documents(documents)\n", "\n", "# add to docstore\n", "docstore = SimpleDocumentStore()\n", diff --git a/docs/examples/node_postprocessor/SentenceTransformerRerank.ipynb b/docs/examples/node_postprocessor/SentenceTransformerRerank.ipynb index 9deacf36f994b7970a3fd1b453e4c4f5319804b1..cd10c28c971bde680313beea5ede65cfcf357026 100644 --- a/docs/examples/node_postprocessor/SentenceTransformerRerank.ipynb +++ b/docs/examples/node_postprocessor/SentenceTransformerRerank.ipynb @@ -119,7 +119,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.indices.postprocessor import SentenceTransformerRerank\n", + "from llama_index.postprocessor import SentenceTransformerRerank\n", "\n", "rerank = SentenceTransformerRerank(\n", " model=\"cross-encoder/ms-marco-MiniLM-L-2-v2\", top_n=3\n", diff --git a/docs/examples/node_postprocessor/TimeWeightedPostprocessorDemo.ipynb b/docs/examples/node_postprocessor/TimeWeightedPostprocessorDemo.ipynb index 4cdf065559452a353e937c11e62719952e0b9b30..9a9b7cab3336f4d15622f3c83e71ad24cc515ea4 100644 --- a/docs/examples/node_postprocessor/TimeWeightedPostprocessorDemo.ipynb +++ b/docs/examples/node_postprocessor/TimeWeightedPostprocessorDemo.ipynb @@ -27,10 +27,10 @@ ], "source": [ "from llama_index import VectorStoreIndex, SimpleDirectoryReader, ServiceContext\n", - "from llama_index.indices.postprocessor import (\n", + "from llama_index.postprocessor import (\n", " TimeWeightedPostprocessor,\n", ")\n", - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.text_splitter import SentenceSplitter\n", "from llama_index.storage.docstore import SimpleDocumentStore\n", "from llama_index.response.notebook_utils import display_response\n", "from datetime import datetime, timedelta" @@ -83,13 +83,13 @@ "\n", "\n", "# define service context (wrapper container around current classes)\n", - "service_context = ServiceContext.from_defaults(chunk_size=512)\n", - "node_parser = service_context.node_parser\n", + "text_splitter = SentenceSplitter(chunk_size=512)\n", + "service_context = ServiceContext.from_defaults(text_splitter=text_splitter)\n", "\n", "# use node parser in service context to parse docs into nodes\n", - "nodes1 = node_parser.get_nodes_from_documents([doc1])\n", - "nodes2 = node_parser.get_nodes_from_documents([doc2])\n", - "nodes3 = node_parser.get_nodes_from_documents([doc3])\n", + "nodes1 = text_splitter.get_nodes_from_documents([doc1])\n", + "nodes2 = text_splitter.get_nodes_from_documents([doc2])\n", + "nodes3 = text_splitter.get_nodes_from_documents([doc3])\n", "\n", "\n", "# fetch the modified chunk from each document, set metadata\n", diff --git a/docs/examples/prompts/prompt_mixin.ipynb b/docs/examples/prompts/prompt_mixin.ipynb index 772fb7081905ed4b51c73af803f4c14027c75939..27f19e28ef42968600aa2663a298879b5e980d69 100644 --- a/docs/examples/prompts/prompt_mixin.ipynb +++ b/docs/examples/prompts/prompt_mixin.ipynb @@ -581,7 +581,7 @@ ")\n", "from llama_index.selectors import LLMMultiSelector\n", "from llama_index.evaluation import FaithfulnessEvaluator, DatasetGenerator\n", - "from llama_index.indices.postprocessor import LLMRerank" + "from llama_index.postprocessor import LLMRerank" ] }, { diff --git a/docs/examples/prompts/prompts_rag.ipynb b/docs/examples/prompts/prompts_rag.ipynb index c9fc50c9db39b16c0702f30c68587089b46772e7..94f9dea25ea142759dd2f1dd1a4775817e1c868f 100644 --- a/docs/examples/prompts/prompts_rag.ipynb +++ b/docs/examples/prompts/prompts_rag.ipynb @@ -828,12 +828,12 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.indices.postprocessor import (\n", + "from llama_index.postprocessor import (\n", " NERPIINodePostprocessor,\n", " SentenceEmbeddingOptimizer,\n", ")\n", "from llama_index import ServiceContext\n", - "from llama_index.indices.query.schema import QueryBundle\n", + "from llama_index.schema import QueryBundle\n", "from llama_index.schema import NodeWithScore, TextNode" ] }, diff --git a/docs/examples/query_engine/SQLAutoVectorQueryEngine.ipynb b/docs/examples/query_engine/SQLAutoVectorQueryEngine.ipynb index d0027bdd6d43b5b28465dc58e07889d757b870dd..33c692a01dcc80a9b3bca895d5f7544f91b68be1 100644 --- a/docs/examples/query_engine/SQLAutoVectorQueryEngine.ipynb +++ b/docs/examples/query_engine/SQLAutoVectorQueryEngine.ipynb @@ -170,19 +170,17 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser.simple import SimpleNodeParser\n", "from llama_index import ServiceContext, LLMPredictor\n", "from llama_index.storage import StorageContext\n", "from llama_index.vector_stores import PineconeVectorStore\n", - "from llama_index.text_splitter import TokenTextSplitter\n", + "from llama_index.node_parser import TokenTextSplitter\n", "from llama_index.llms import OpenAI\n", "\n", "# define node parser and LLM\n", "chunk_size = 1024\n", "llm = OpenAI(temperature=0, model=\"gpt-4\", streaming=True)\n", "service_context = ServiceContext.from_defaults(chunk_size=chunk_size, llm=llm)\n", - "text_splitter = TokenTextSplitter(chunk_size=chunk_size)\n", - "node_parser = SimpleNodeParser.from_defaults(text_splitter=text_splitter)\n", + "node_parser = TokenTextSplitter(chunk_size=chunk_size)\n", "\n", "# define pinecone vector index\n", "vector_store = PineconeVectorStore(\n", diff --git a/docs/examples/query_engine/SQLJoinQueryEngine.ipynb b/docs/examples/query_engine/SQLJoinQueryEngine.ipynb index 6cfaea78b5d07a464f5a6451d7f91d36f72409df..07d19dc598f879496022143e546d6eeb2bfd52bc 100644 --- a/docs/examples/query_engine/SQLJoinQueryEngine.ipynb +++ b/docs/examples/query_engine/SQLJoinQueryEngine.ipynb @@ -137,19 +137,17 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser.simple import SimpleNodeParser\n", "from llama_index import ServiceContext, LLMPredictor\n", "from llama_index.storage import StorageContext\n", "from llama_index.vector_stores import PineconeVectorStore\n", - "from llama_index.text_splitter import TokenTextSplitter\n", + "from llama_index.node_parser import TokenTextSplitter\n", "from llama_index.llms import OpenAI\n", "\n", "# define node parser and LLM\n", "chunk_size = 1024\n", "llm = OpenAI(temperature=0, model=\"gpt-4\", streaming=True)\n", "service_context = ServiceContext.from_defaults(chunk_size=chunk_size, llm=llm)\n", - "text_splitter = TokenTextSplitter(chunk_size=chunk_size)\n", - "node_parser = SimpleNodeParser.from_defaults(text_splitter=text_splitter)\n", + "node_parser = TokenTextSplitter(chunk_size=chunk_size)\n", "\n", "# # define pinecone vector index\n", "# vector_store = PineconeVectorStore(pinecone_index=pinecone_index, namespace='wiki_cities')\n", diff --git a/docs/examples/query_engine/pgvector_sql_query_engine.ipynb b/docs/examples/query_engine/pgvector_sql_query_engine.ipynb index 94885f0c2378a6ac0c7c9d023a033e30eff03633..8bb261556ea1f24a38418f679248b6d3836c6d59 100644 --- a/docs/examples/query_engine/pgvector_sql_query_engine.ipynb +++ b/docs/examples/query_engine/pgvector_sql_query_engine.ipynb @@ -98,9 +98,9 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", - "node_parser = SimpleNodeParser.from_defaults()\n", + "node_parser = SentenceSplitter()\n", "nodes = node_parser.get_nodes_from_documents(docs)" ] }, diff --git a/docs/examples/query_engine/sec_tables/tesla_10q_table.ipynb b/docs/examples/query_engine/sec_tables/tesla_10q_table.ipynb index 96215100aea62051f9237d68dde73b517d3d06ad..92f628559b2c064fb925561d5f1435a5fc025132 100644 --- a/docs/examples/query_engine/sec_tables/tesla_10q_table.ipynb +++ b/docs/examples/query_engine/sec_tables/tesla_10q_table.ipynb @@ -314,7 +314,7 @@ "from llama_index.retrievers import RecursiveRetriever\n", "\n", "recursive_retriever = RecursiveRetriever(\n", - " \"vector\",\n", + " \"vector\",SentenceSplitter\n", " retriever_dict={\"vector\": vector_retriever},\n", " node_dict=node_mappings_2021,\n", " verbose=True,\n", diff --git a/docs/examples/retrievers/auto_merging_retriever.ipynb b/docs/examples/retrievers/auto_merging_retriever.ipynb index 4028fce785553eb7e44040c5f2f457f7c4b9a4e0..5b950f06f9d85161a01bffc4af5f759a2af5f41c 100644 --- a/docs/examples/retrievers/auto_merging_retriever.ipynb +++ b/docs/examples/retrievers/auto_merging_retriever.ipynb @@ -143,7 +143,10 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import HierarchicalNodeParser, SimpleNodeParser" + "from llama_index.node_parser import (\n", + " HierarchicalNodeParser,\n", + " SentenceSplitter,\n", + ")" ] }, { diff --git a/docs/examples/retrievers/bm25_retriever.ipynb b/docs/examples/retrievers/bm25_retriever.ipynb index 063752fae33341bd7abbfb8c76ecd0dea63ee092..1a80efd4c0ca2986e10117eb109b8a0f081c762d 100644 --- a/docs/examples/retrievers/bm25_retriever.ipynb +++ b/docs/examples/retrievers/bm25_retriever.ipynb @@ -596,7 +596,7 @@ } ], "source": [ - "from llama_index.indices.postprocessor import SentenceTransformerRerank\n", + "from llama_index.postprocessor import SentenceTransformerRerank\n", "\n", "reranker = SentenceTransformerRerank(top_n=4, model=\"BAAI/bge-reranker-base\")" ] diff --git a/docs/examples/retrievers/ensemble_retrieval.ipynb b/docs/examples/retrievers/ensemble_retrieval.ipynb index 1579d25407e98a71de0349fc8a579636d083b3e3..34e43b845e8e4ec458d32fe95e1ad46fbed2dc6b 100644 --- a/docs/examples/retrievers/ensemble_retrieval.ipynb +++ b/docs/examples/retrievers/ensemble_retrieval.ipynb @@ -362,7 +362,7 @@ "outputs": [], "source": [ "# define reranker\n", - "from llama_index.indices.postprocessor import (\n", + "from llama_index.postprocessor import (\n", " LLMRerank,\n", " SentenceTransformerRerank,\n", " CohereRerank,\n", diff --git a/docs/examples/retrievers/recurisve_retriever_nodes_braintrust.ipynb b/docs/examples/retrievers/recurisve_retriever_nodes_braintrust.ipynb index a9d946d332e9dd9c7b5759c282984abf7673b249..7673559e904ef64a907b3ab3fc4a57e09660b536 100644 --- a/docs/examples/retrievers/recurisve_retriever_nodes_braintrust.ipynb +++ b/docs/examples/retrievers/recurisve_retriever_nodes_braintrust.ipynb @@ -128,7 +128,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.schema import IndexNode" ] }, @@ -139,7 +139,7 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults(chunk_size=1024)" + "node_parser = SentenceSplitter(chunk_size=1024)" ] }, { @@ -260,9 +260,7 @@ "outputs": [], "source": [ "sub_chunk_sizes = [128, 256, 512]\n", - "sub_node_parsers = [\n", - " SimpleNodeParser.from_defaults(chunk_size=c) for c in sub_chunk_sizes\n", - "]\n", + "sub_node_parsers = [SentenceSplitter(chunk_size=c) for c in sub_chunk_sizes]\n", "\n", "all_nodes = []\n", "\n", @@ -386,12 +384,11 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.schema import IndexNode\n", - "from llama_index.node_parser.extractors import (\n", + "from llama_index.extractors import (\n", " SummaryExtractor,\n", " QuestionsAnsweredExtractor,\n", - " MetadataExtractor,\n", ")" ] }, @@ -402,12 +399,10 @@ "metadata": {}, "outputs": [], "source": [ - "metadata_extractor = MetadataExtractor(\n", - " extractors=[\n", - " SummaryExtractor(summaries=[\"self\"], show_progress=True),\n", - " QuestionsAnsweredExtractor(questions=5, show_progress=True),\n", - " ],\n", - ")" + "extractors = [\n", + " SummaryExtractor(summaries=[\"self\"], show_progress=True),\n", + " QuestionsAnsweredExtractor(questions=5, show_progress=True),\n", + "]" ] }, { @@ -418,7 +413,9 @@ "outputs": [], "source": [ "# run metadata extractor across base nodes, get back dictionaries\n", - "metadata_dicts = metadata_extractor.extract(base_nodes)" + "metadata_dicts = []\n", + "for extractor in extractors:\n", + " metadata_dicts.extend(extractor.extract(base_nodes))" ] }, { diff --git a/docs/examples/retrievers/recursive_retriever_nodes.ipynb b/docs/examples/retrievers/recursive_retriever_nodes.ipynb index e81becd7ba2143be3843d9299b0110c130c49bc8..b1419746ec471483a91bb1744712b2a75ec08c2e 100644 --- a/docs/examples/retrievers/recursive_retriever_nodes.ipynb +++ b/docs/examples/retrievers/recursive_retriever_nodes.ipynb @@ -133,7 +133,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.schema import IndexNode" ] }, @@ -144,7 +144,7 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults(chunk_size=1024)" + "node_parser = SentenceSplitter(chunk_size=1024)" ] }, { @@ -265,9 +265,7 @@ "outputs": [], "source": [ "sub_chunk_sizes = [128, 256, 512]\n", - "sub_node_parsers = [\n", - " SimpleNodeParser.from_defaults(chunk_size=c) for c in sub_chunk_sizes\n", - "]\n", + "sub_node_parsers = [SentenceSplitter(chunk_size=c) for c in sub_chunk_sizes]\n", "\n", "all_nodes = []\n", "for base_node in base_nodes:\n", @@ -390,12 +388,11 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.node_parser import SimpleNodeParser\n", + "from llama_index.node_parser import SentenceSplitter\n", "from llama_index.schema import IndexNode\n", - "from llama_index.node_parser.extractors import (\n", + "from llama_index.extractors import (\n", " SummaryExtractor,\n", " QuestionsAnsweredExtractor,\n", - " MetadataExtractor,\n", ")" ] }, @@ -406,12 +403,10 @@ "metadata": {}, "outputs": [], "source": [ - "metadata_extractor = MetadataExtractor(\n", - " extractors=[\n", - " SummaryExtractor(summaries=[\"self\"], show_progress=True),\n", - " QuestionsAnsweredExtractor(questions=5, show_progress=True),\n", - " ],\n", - ")" + "extractors = [\n", + " SummaryExtractor(summaries=[\"self\"], show_progress=True),\n", + " QuestionsAnsweredExtractor(questions=5, show_progress=True),\n", + "]" ] }, { @@ -422,7 +417,9 @@ "outputs": [], "source": [ "# run metadata extractor across base nodes, get back dictionaries\n", - "metadata_dicts = metadata_extractor.extract(base_nodes)" + "metadata_dicts = []\n", + "for extractor in extractors:\n", + " metadata_dicts.extend(extractor.extract(base_nodes))" ] }, { diff --git a/docs/examples/vector_stores/PineconeIndexDemo-0.6.0.ipynb b/docs/examples/vector_stores/PineconeIndexDemo-0.6.0.ipynb index bf43128bca84d3c4a70849a596ca543d493b6fec..85e9ba59dc77c4f74ef4d47dfab01631bcfef05d 100644 --- a/docs/examples/vector_stores/PineconeIndexDemo-0.6.0.ipynb +++ b/docs/examples/vector_stores/PineconeIndexDemo-0.6.0.ipynb @@ -600,7 +600,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.indices.postprocessor.node import (\n", + "from llama_index.postprocessor.node import (\n", " AutoPrevNextNodePostprocessor,\n", ")\n", "\n", diff --git a/docs/examples/vector_stores/SimpleIndexDemo.ipynb b/docs/examples/vector_stores/SimpleIndexDemo.ipynb index d80dc983adeda7d6a499a3e9150458c098044075..3f2b5a921a1af4521d17264b5398c0a668241256 100644 --- a/docs/examples/vector_stores/SimpleIndexDemo.ipynb +++ b/docs/examples/vector_stores/SimpleIndexDemo.ipynb @@ -372,7 +372,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.indices.query.schema import QueryBundle" + "from llama_index.schema import QueryBundle" ] }, { diff --git a/docs/examples/vector_stores/TypesenseDemo.ipynb b/docs/examples/vector_stores/TypesenseDemo.ipynb index 3261118570184f50c111d3b99c4c51bcc88fdb0d..4852dbe371b1c929140af54439355468e5dcbfc3 100644 --- a/docs/examples/vector_stores/TypesenseDemo.ipynb +++ b/docs/examples/vector_stores/TypesenseDemo.ipynb @@ -133,7 +133,7 @@ } ], "source": [ - "from llama_index.indices.query.schema import QueryBundle\n", + "from llama_index.schema import QueryBundle\n", "from llama_index.embeddings import OpenAIEmbedding\n", "\n", "# By default, typesense vector store uses vector search. You need to provide the embedding yourself.\n", diff --git a/docs/module_guides/indexing/metadata_extraction.md b/docs/module_guides/indexing/metadata_extraction.md index 7c6a99c7b7f1eb51df0fe186ac37dabbea74e878..cba5338bb5497d3612e0da9ea24b3d0cfe3c807c 100644 --- a/docs/module_guides/indexing/metadata_extraction.md +++ b/docs/module_guides/indexing/metadata_extraction.md @@ -15,9 +15,8 @@ First, we define a metadata extractor that takes in a list of feature extractors We then feed this to the node parser, which will add the additional metadata to each node. ```python -from llama_index.node_parser import SimpleNodeParser -from llama_index.node_parser.extractors import ( - MetadataExtractor, +from llama_index.node_parser import SentenceSplitter +from llama_index.extractors import ( SummaryExtractor, QuestionsAnsweredExtractor, TitleExtractor, @@ -25,19 +24,24 @@ from llama_index.node_parser.extractors import ( EntityExtractor, ) -metadata_extractor = MetadataExtractor( - extractors=[ - TitleExtractor(nodes=5), - QuestionsAnsweredExtractor(questions=3), - SummaryExtractor(summaries=["prev", "self"]), - KeywordExtractor(keywords=10), - EntityExtractor(prediction_threshold=0.5), - ], -) +transformations = [ + SentenceSplitter(), + TitleExtractor(nodes=5), + QuestionsAnsweredExtractor(questions=3), + SummaryExtractor(summaries=["prev", "self"]), + KeywordExtractor(keywords=10), + EntityExtractor(prediction_threshold=0.5), +] +``` -node_parser = SimpleNodeParser.from_defaults( - metadata_extractor=metadata_extractor, -) +Then, we can run our transformations on input documents or nodes: + +```python +from llama_index.ingestion import IngestionPipline + +pipeline = IngestionPipline(transformations=transformations) + +nodes = pipeline.run(documents=documents) ``` Here is an sample of extracted metadata: @@ -57,10 +61,10 @@ Here is an sample of extracted metadata: If the provided extractors do not fit your needs, you can also define a custom extractor like so: ```python -from llama_index.node_parser.extractors import MetadataFeatureExtractor +from llama_index.extractors import BaseExtractor -class CustomExtractor(MetadataFeatureExtractor): +class CustomExtractor(BaseExtractor): def extract(self, nodes) -> List[Dict]: metadata_list = [ { @@ -74,3 +78,18 @@ class CustomExtractor(MetadataFeatureExtractor): ``` In a more advanced example, it can also make use of an `llm` to extract features from the node content and the existing metadata. Refer to the [source code of the provided metadata extractors](https://github.com/jerryjliu/llama_index/blob/main/llama_index/node_parser/extractors/metadata_extractors.py) for more details. + +## Modules + +Below you will find guides and tutorials for various metadata extractors. + +```{toctree} +--- +maxdepth: 1 +--- +/examples/metadata_extraction/MetadataExtractionSEC.ipynb +/examples/metadata_extraction/MetadataExtraction_LLMSurvey.ipynb +/examples/metadata_extraction/EntityExtractionClimate.ipynb +/examples/metadata_extraction/MarvinMetadataExtractorDemo.ipynb +/examples/metadata_extraction/PydanticExtractor.ipynb +``` diff --git a/docs/module_guides/indexing/usage_pattern.md b/docs/module_guides/indexing/usage_pattern.md index 570b7b49831a420df0e2637e78f89092dad4ebc9..450057f8a5c32553ac5a3d5c4367e9d8aa01ed97 100644 --- a/docs/module_guides/indexing/usage_pattern.md +++ b/docs/module_guides/indexing/usage_pattern.md @@ -52,9 +52,9 @@ The steps are: 1. Configure a node parser ```python -from llama_index.node_parser import SimpleNodeParser +from llama_index.node_parser import SentenceSplitter -parser = SimpleNodeParser.from_defaults( +parser = SentenceSplitter( chunk_size=512, include_extra_info=False, include_prev_next_rel=False, diff --git a/docs/module_guides/loading/documents_and_nodes/root.md b/docs/module_guides/loading/documents_and_nodes/root.md index 9f69c1c4a27deece9623f7de448b1309474d1638..011af021cb2aef07637be2789f100f3c05d3d7c3 100644 --- a/docs/module_guides/loading/documents_and_nodes/root.md +++ b/docs/module_guides/loading/documents_and_nodes/root.md @@ -34,15 +34,28 @@ index = VectorStoreIndex.from_documents(documents) #### Nodes ```python -from llama_index.node_parser import SimpleNodeParser +from llama_index.node_parser import SentenceSplitter # load documents ... # parse nodes -parser = SimpleNodeParser.from_defaults() +parser = SentenceSplitter() nodes = parser.get_nodes_from_documents(documents) # build index index = VectorStoreIndex(nodes) ``` + +### Document/Node Usage + +Take a look at our in-depth guides for more details on how to use Documents/Nodes. + +```{toctree} +--- +maxdepth: 1 +--- +usage_documents.md +usage_nodes.md +/core_modules/data_modules/transformations/root.md +``` diff --git a/docs/module_guides/loading/documents_and_nodes/usage_documents.md b/docs/module_guides/loading/documents_and_nodes/usage_documents.md index 28ac753489c923b7e7ec59e9866413569b3dad83..2e2f989e304814ea66fb64c17837d65ce021d3c5 100644 --- a/docs/module_guides/loading/documents_and_nodes/usage_documents.md +++ b/docs/module_guides/loading/documents_and_nodes/usage_documents.md @@ -181,5 +181,5 @@ Take a look here! --- maxdepth: 1 --- -/examples/metadata_extraction/MetadataExtractionSEC.ipynb +/core_modules/data_modules/transformations/metadata_extractor_usage_pattern.md ``` diff --git a/docs/module_guides/loading/documents_and_nodes/usage_metadata_extractor.md b/docs/module_guides/loading/documents_and_nodes/usage_metadata_extractor.md index 17372550df69432ad36d3e70f9949abab59458f2..907d7e388df4daa847bf8906d7599d91d9fb6286 100644 --- a/docs/module_guides/loading/documents_and_nodes/usage_metadata_extractor.md +++ b/docs/module_guides/loading/documents_and_nodes/usage_metadata_extractor.md @@ -1,6 +1,6 @@ -# Automated Metadata Extraction for Nodes +# Metadata Extraction Usage Pattern -You can use LLMs to automate metadata extraction with our `MetadataExtractor` modules. +You can use LLMs to automate metadata extraction with our `Metadata Extractor` modules. Our metadata extractor modules include the following "feature extractors": @@ -9,11 +9,10 @@ Our metadata extractor modules include the following "feature extractors": - `TitleExtractor` - extracts a title over the context of each Node - `EntityExtractor` - extracts entities (i.e. names of places, people, things) mentioned in the content of each Node -You can use these feature extractors within our overall `MetadataExtractor` class. Then you can plug in the `MetadataExtractor` into our node parser: +Then you can chain the `Metadata Extractor`s with our node parser: ```python -from llama_index.node_parser.extractors import ( - MetadataExtractor, +from llama_index.extractors import ( TitleExtractor, QuestionsAnsweredExtractor, ) @@ -22,19 +21,31 @@ from llama_index.text_splitter import TokenTextSplitter text_splitter = TokenTextSplitter( separator=" ", chunk_size=512, chunk_overlap=128 ) -metadata_extractor = MetadataExtractor( - extractors=[ - TitleExtractor(nodes=5), - QuestionsAnsweredExtractor(questions=3), - ], +title_extractor = TitleExtractor(nodes=5) +qa_extractor = QuestionsAnsweredExtractor(questions=3) + +# assume documents are defined -> extract nodes +from llama_index.ingestion import IngestionPipeline + +pipeline = IngestionPipeline( + transformations=[text_splitter, title_extractor, qa_extractor] ) -node_parser = SimpleNodeParser.from_defaults( - text_splitter=text_splitter, - metadata_extractor=metadata_extractor, +nodes = pipeline.run( + documents=documents, + in_place=True, + show_progress=True, +) +``` + +or insert into the service context: + +```python +from llama_index import ServiceContext + +service_context = ServiceContext.from_defaults( + transformations=[text_splitter, title_extractor, qa_extractor] ) -# assume documents are defined -> extract nodes -nodes = node_parser.get_nodes_from_documents(documents) ``` ```{toctree} diff --git a/docs/module_guides/loading/documents_and_nodes/usage_nodes.md b/docs/module_guides/loading/documents_and_nodes/usage_nodes.md index 643267b291bc87276e985d84ae51261342730573..51c04f03d98fb183d59f3d3a096cb6363aac3fcc 100644 --- a/docs/module_guides/loading/documents_and_nodes/usage_nodes.md +++ b/docs/module_guides/loading/documents_and_nodes/usage_nodes.md @@ -8,9 +8,9 @@ Nodes are a first-class citizen in LlamaIndex. You can choose to define Nodes an For instance, you can do ```python -from llama_index.node_parser import SimpleNodeParser +from llama_index.node_parser import SentenceSplitter -parser = SimpleNodeParser.from_defaults() +parser = SentenceSplitter() nodes = parser.get_nodes_from_documents(documents) ``` diff --git a/docs/module_guides/loading/ingestion_pipeline/root.md b/docs/module_guides/loading/ingestion_pipeline/root.md new file mode 100644 index 0000000000000000000000000000000000000000..2b94f573cce6a7ede04c6eb996577d6a0d1d3b51 --- /dev/null +++ b/docs/module_guides/loading/ingestion_pipeline/root.md @@ -0,0 +1,153 @@ +# Ingestion Pipeline + +An `IngestionPipeline` uses a concept of `Transformations` that are applied to input data. These `Transformations` are applied to your input data, and the resulting nodes are either returned or inserted into a vector database (if given). Each node+transformation pair is cached, so that subsequent runs (if the cache is persisted) with the same node+transformation combination can use the cached result and save you time. + +## Usage Pattern + +At it's most basic level, you can quickly instantiate an `IngestionPipeline` like so: + +```python +from llama_index import Document +from llama_index.embeddings import OpenAIEmbedding +from llama_index.text_splitter import SentenceSplitter +from llama_index.extractors import TitleExtractor +from llama_index.ingestion import IngestionPipeline, IngestionCache + +# create the pipeline with transformations +pipeline = IngestionPipeline( + transformations=[ + SentenceSplitter(chunk_size=25, chunk_overlap=0), + TitleExtractor(), + OpenAIEmbedding(), + ] +) + +# run the pipeline +nodes = pipeline.run(documents=[Document.example()]) +``` + +## Connecting to Vector Databases + +When running an ingestion pipeline, you can also chose to automatically insert the resulting nodes into a remote vector store. + +Then, you can construct an index from that vector store later on. + +```python +from llama_index import Document +from llama_index.embeddings import OpenAIEmbedding +from llama_index.text_splitter import SentenceSplitter +from llama_index.extractors import TitleExtractor +from llama_index.ingestion import IngestionPipeline +from llama_index.vector_stores.qdrant import QdrantVectorStore + +import qdrant_client + +client = qdrant_client.QdrantClient(location=":memory:") +vector_store = QdrantVectorStore(client=client, collection_name="test_store") + +pipeline = IngestionPipeline( + transformations=[ + SentenceSplitter(chunk_size=25, chunk_overlap=0), + TitleExtractor(), + OpenAIEmbedding(), + ], + vector_store=vector_store, +) + +# Ingest directly into a vector db +pipeline.run(documents=[Document.example()]) + +# Create your index +from llama_index import VectorStoreIndex + +index = VectorStoreIndex.from_vector_store(vector_store) +``` + +## Caching + +In an `IngestionPipeline`, each node + transformation combination is hashed and cached. This saves time on subsequent runs that use the same data. + +The following sections describe some basic usage around caching. + +### Local Cache Management + +Once you have a pipeline, you may want to store and load the cache. + +```python +# save and load +pipeline.cache.persist("./test_cache.json") +new_cache = IngestionCache.from_persist_path("./test_cache.json") + +new_pipeline = IngestionPipeline( + transformations=[ + SentenceSplitter(chunk_size=25, chunk_overlap=0), + TitleExtractor(), + ], + cache=new_cache, +) + +# will run instantly due to the cache +nodes = pipeline.run(documents=[Document.example()]) +``` + +If the cache becomes too large, you can clear it + +```python +# delete all context of the cache +cache.clear() +``` + +### Remote Cache Management + +We support multiple remote storage backends for caches + +- `RedisCache` +- `MongoDBCache` +- `FirestoreCache` + +Here as an example using the `RedisCache`: + +```python +from llama_index import Document +from llama_index.embeddings import OpenAIEmbedding +from llama_index.text_splitter import SentenceSplitter +from llama_index.extractors import TitleExtractor +from llama_index.ingestion import IngestionPipeline, IngestionCache +from llama_index.ingestion.cache import RedisCache + + +pipeline = IngestionPipeline( + transformations=[ + SentenceSplitter(chunk_size=25, chunk_overlap=0), + TitleExtractor(), + OpenAIEmbedding(), + ], + cache=IngestionCache( + cache=RedisCache( + redis_uri="redis://127.0.0.1:6379", collection="test_cache" + ) + ), +) + +# Ingest directly into a vector db +nodes = pipeline.run(documents=[Document.example()]) +``` + +Here, no persist step is needed, since everything is cached as you go in the specified remote collection. + +## Async Support + +The `IngestionPipeline` also has support for async operation + +```python +nodes = await pipeline.arun(documents=documents) +``` + +## Modules + +```{toctree} +--- +maxdepth: 2 +--- +transformations.md +``` diff --git a/docs/module_guides/loading/ingestion_pipeline/transformations.md b/docs/module_guides/loading/ingestion_pipeline/transformations.md new file mode 100644 index 0000000000000000000000000000000000000000..7591a693182b36147de688274c1aa93900eec6b2 --- /dev/null +++ b/docs/module_guides/loading/ingestion_pipeline/transformations.md @@ -0,0 +1,93 @@ +# Transformations + +A transformation is something that takes a list of nodes as an input, and returns a list of nodes. Each component that implements the `Transformation` base class has both a synchronous `__call__()` definition and an async `acall()` definition. + +Current;y, the following components are `Transformation` objects: + +- `TextSplitter` +- `NodeParser` +- `MetadataExtractor` +- `Embeddings`model + +## Usage Pattern + +While transformations are best used with with an [`IngestionPipeline`](./root.md), they can also be used directly. + +```python +from llama_index.text_splitter import SentenceSplitter +from llama_index.extractors import TitleExtractor + +node_parser = SentenceSplitter(chunk_size=512) +extractor = TitleExtractor() + +# use transforms directly +nodes = node_parser(documents) + +# or use a transformation in async +nodes = await extractor.acall(nodes) +``` + +## Combining with ServiceContext + +Transformations can be passed into a service context, and will be used when calling `from_documents()` or `insert()` on an index. + +```python +from llama_index import ServiceContext, VectorStoreIndex +from llama_index.extractors import ( + TitleExtractor, + QuestionsAnsweredExtractor, +) +from llama_index.ingestion import IngestionPipeline +from llama_index.text_splitter import TokenTextSplitter + +transformations = [ + TokenTextSplitter(chunk_size=512, chunk_overlap=128), + TitleExtractor(nodes=5), + QuestionsAnsweredExtractor(questions=3), +] + +service_context = ServiceContext.from_defaults( + transformations=[text_splitter, title_extractor, qa_extractor] +) + +index = VectorStoreIndex.from_documents( + documents, service_context=service_context +) +``` + +## Custom Transformations + +You can implement any transformation yourself by implementing the base class. + +The following custom transformation will remove any special characters or punctutaion in text. + +```python +import re +from llama_index import Document +from llama_index.embeddings import OpenAIEmbedding +from llama_index.text_splitter import SentenceSplitter +from llama_index.ingestion import IngestionPipeline +from llama_index.schema import TransformComponent + + +class TextCleaner(TransformComponent): + def __call__(self, nodes, **kwargs): + for node in nodes: + node.text = re.sub(r"[^0-9A-Za-z ]", "", node.text) + return nodes +``` + +These can then be used directly or in any `IngestionPipeline`. + +```python +# use in a pipeline +pipeline = IngestionPipeline( + transformations=[ + SentenceSplitter(chunk_size=25, chunk_overlap=0), + TextCleaner(), + OpenAIEmbedding(), + ], +) + +nodes = pipeline.run(documents=[Document.example()]) +``` diff --git a/docs/module_guides/loading/node_parsers/modules.md b/docs/module_guides/loading/node_parsers/modules.md new file mode 100644 index 0000000000000000000000000000000000000000..f078bf1b52520b74fe945cf650efc57769811827 --- /dev/null +++ b/docs/module_guides/loading/node_parsers/modules.md @@ -0,0 +1,161 @@ +# Node Parser Modules + +## File-Based Node Parsers + +There are several file-based node parsers, that will create nodes based on the type of content that is being parsed (JSON, Markdown, etc.) + +The simplest flow is to combine the `FlatFileReader` with the `SimpleFileNodeParser` to automatically use the best node parser for each type of content. Then, you may want to chain the file-based node parser with a text-based node parser to account for the actual length of the text. + +### SimpleFileNodeParser + +```python +from llama_index.node_parser.file import SimpleFileNodeParser +from llama_index.readers.file.flat_reader import FlatReader + +md_docs = FlatReader().load_data("./test.md") + +parser = SimpleFileNodeParser() +md_nodes = parser.get_nodes_from_documents(md_docs) +``` + +### HTMLNodeParser + +This node parser uses `beautifulsoup` to parse raw HTML. + +By default, it will parse a select subset of HTML tags, but you can override this. + +The default tags are: `["p", "h1", "h2", "h3", "h4", "h5", "h6", "li", "b", "i", "u", "section"]` + +```python +from llama_index.node_parser import HTMLNodeParser + +parser = HTMLNodeParser(tags=["p", "h1"]) # optional list of tags +nodes = parser.get_nodes_from_documents(html_docs) +``` + +### JSONNodeParser + +The `JSONNodeParser` parses raw JSON. + +```python +from llama_index import JSONNodeParser + +parser = JSONNodeParser() + +nodes = parser.get_nodes_from_documents(json_docs) +``` + +### MarkdownNodeParser + +The `MarkdownNodeParser` parses raw markdown text. + +```python +from llama_index import MarkdownNodeParser + +parser = MarkdownNodeParser() + +nodes = parser.get_nodes_from_documents(markdown_docs) +``` + +## Text-Based Node Parsers + +### CodeSplitter + +Splits raw code-text based on the language it is written in. + +Check the full list of [supported languages here](https://github.com/grantjenks/py-tree-sitter-languages#license). + +```python +from llama_index.node_parser import CodeSplitter + +splitter = CodeSplitter( + language="python", + chunk_lines=40, # lines per chunk + chunk_lines_overlap=15, # lines overlap between chunks + max_chars=1500, # max chars per chunk +) +nodes = splitter.get_nodes_from_documents(documents) +``` + +### LangchainNodeParser + +You can also wrap any existing text splitter from langchain with a node parser. + +```python +from langchain.text_splitter import RecursiveCharacterTextSplitter +from llama_index.node_parser import LangchainNodeParser + +parser = LangchainNodeParser(RecursiveCharacterTextSplitter()) +nodes = parser.get_nodes_from_documents(documents) +``` + +### SentenceSplitter + +The `SentenceSplitter` attempts to split text while respecting the boundaries of sentences. + +```python +from llama_index.node_parser import SentenceSplitter + +splitter = SentenceSplitter( + chunk_size=1024, + chunk_overlap=20, +) +nodes = splitter.get_nodes_from_documents(documents) +``` + +### SentenceWindowNodeParser + +The `SentenceWindowNodeParser` is similar to other node parsers, except that it splits all documents into individual sentences. The resulting nodes also contain the surrounding "window" of sentences around each node in the metadata. Note that this metadata will not be visible to the LLM or embedding model. + +This is most useful for generating embeddings that have a very specific scope. Then, combined with a `MetadataReplacementNodePostProcessor`, you can replace the sentence with it's surrounding context before sending the node to the LLM. + +An example of setting up the parser with default settings is below. In practice, you would usually only want to adjust the window size of sentences. + +```python +import nltk +from llama_index.node_parser import SentenceWindowNodeParser + +node_parser = SentenceWindowNodeParser.from_defaults( + # how many sentences on either side to capture + window_size=3, + # the metadata key that holds the window of surrounding sentences + window_metadata_key="window", + # the metadata key that holds the original sentence + original_text_metadata_key="original_sentence", +) +``` + +A full example can be found [here in combination with the `MetadataReplacementNodePostProcessor`](/examples/node_postprocessor/MetadataReplacementDemo.ipynb). + +### TokenTextSplitter + +The `TokenTextSplitter` attempts to split text while respecting the boundaries of sentences. + +```python +from llama_index.node_parser import TokenTextSplitter + +splitter = TokenTextSplitter( + chunk_size=1024, + chunk_overlap=20, + separator=" ", +) +nodes = splitter.get_nodes_from_documents(documents) +``` + +## Relation-Based Node Parsers + +### HierarchicalNodeParser + +This node parser will chunk nodes into hierarchical nodes. This means a single input will be chunked into several hierarchies of chunk sizes, with each node containing a reference to it's parent node. + +When combined with the `AutoMergingRetriever`, this enables us to automatically replace retrieved nodes with their parents when a majority of children are retrieved. This process provides the LLM with more complete context for response synthesis. + +```python +from llama_index.node_parser import HierarchicalNodeParser + +node_parser = HierarchicalNodeParser.from_defaults( + chunk_sizes=[2048, 512, 128] +) +``` + +A full example can be found [here in combination with the `AutoMergingRetriever`](/examples/retrievers/auto_merging_retriever.ipynb). diff --git a/docs/module_guides/loading/node_parsers/root.md b/docs/module_guides/loading/node_parsers/root.md index 946db9a2b498c99d0909eb9eb3bded6e546f2edb..d6775dcf645b047284c7ddd0d0f258520bdd4946 100644 --- a/docs/module_guides/loading/node_parsers/root.md +++ b/docs/module_guides/loading/node_parsers/root.md @@ -1,143 +1,63 @@ -# Node Parser +# Node Parser Usage Pattern -## Concept - -Node parsers are a simple abstraction that take a list of documents, and chunk them into `Node` objects, such that each node is a specific size. When a document is broken into nodes, all of it's attributes are inherited to the children nodes (i.e. `metadata`, text and metadata templates, etc.). You can read more about [`Node` and `Document` properties](/module_guides/loading/documents_and_nodes/root.md). - -A node parser can configure the chunk size (in tokens) as well as any overlap between chunked nodes. The chunking is done by using a `TokenTextSplitter`, which default to a chunk size of 1024 and a default chunk overlap of 20 tokens. - -## Usage Pattern - -```python -from llama_index.node_parser import SimpleNodeParser - -node_parser = SimpleNodeParser.from_defaults(chunk_size=1024, chunk_overlap=20) -``` - -You can find more usage details and available customization options below. +Node parsers are a simple abstraction that take a list of documents, and chunk them into `Node` objects, such that each node is a specific chunk of the parent document. When a document is broken into nodes, all of it's attributes are inherited to the children nodes (i.e. `metadata`, text and metadata templates, etc.). You can read more about `Node` and `Document` properties [here](/core_modules/data_modules/documents_and_nodes/root.md). ## Getting Started +### Standalone Usage + Node parsers can be used on their own: ```python from llama_index import Document -from llama_index.node_parser import SimpleNodeParser +from llama_index.node_parser import SentenceSplitter -node_parser = SimpleNodeParser.from_defaults(chunk_size=1024, chunk_overlap=20) +node_parser = SentenceSplitter(chunk_size=1024, chunk_overlap=20) nodes = node_parser.get_nodes_from_documents( [Document(text="long text")], show_progress=False ) ``` -Or set inside a `ServiceContext` to be used automatically when an index is constructed using `.from_documents()`: +### Transformation Usage + +Node parsers can be included in any set of transformations with an ingestion pipeline. ```python -from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext -from llama_index.node_parser import SimpleNodeParser +from llama_index import SimpleDirectoryReader +from llama_index.ingestion import IngestionPipeline +from llama_index.node_parser import TokenTextSplitter documents = SimpleDirectoryReader("./data").load_data() -node_parser = SimpleNodeParser.from_defaults(chunk_size=1024, chunk_overlap=20) -service_context = ServiceContext.from_defaults(node_parser=node_parser) +pipeline = IngestionPipeline(transformations=[TokenTextSplitter(), ...]) -index = VectorStoreIndex.from_documents( - documents, service_context=service_context -) -``` - -## Customization - -There are several options available to customize: - -- `text_splitter` (defaults to `TokenTextSplitter`) - the text splitter used to split text into chunks. -- `include_metadata` (defaults to `True`) - whether or not `Node`s should inherit the document metadata. -- `include_prev_next_rel` (defaults to `True`) - whether or not to include previous/next relationships between chunked `Node`s -- `metadata_extractor` (defaults to `None`) - extra processing to extract helpful metadata. See [more about our metadata extractor](/module_guides/loading/documents_and_nodes/usage_metadata_extractor.md). - -If you don't want to change the `text_splitter`, you can use `SimpleNodeParser.from_defaults()` to easily change the chunk size and chunk overlap. The defaults are 1024 and 20 respectively. - -```python -from llama_index.node_parser import SimpleNodeParser - -node_parser = SimpleNodeParser.from_defaults(chunk_size=1024, chunk_overlap=20) +nodes = pipeline.run(documents=documents) ``` -### Text Splitter Customization - -If you do customize the `text_splitter` from the default `SentenceSplitter`, you can use any splitter from langchain, or optionally our `TokenTextSplitter` or `CodeSplitter`. Each text splitter has options for the default separator, as well as options for additional config. These are useful for languages that are sufficiently different from English. +### Service Context Usage -`SentenceSplitter` default configuration: +Or set inside a `ServiceContext` to be used automatically when an index is constructed using `.from_documents()`: ```python -import tiktoken +from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext from llama_index.text_splitter import SentenceSplitter -text_splitter = SentenceSplitter( - separator=" ", - chunk_size=1024, - chunk_overlap=20, - paragraph_separator="\n\n\n", - secondary_chunking_regex="[^,.;。]+[,.;。]?", - tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode, -) - -node_parser = SimpleNodeParser.from_defaults(text_splitter=text_splitter) -``` - -`TokenTextSplitter` default configuration: - -```python -import tiktoken -from llama_index.text_splitter import TokenTextSplitter - -text_splitter = TokenTextSplitter( - separator=" ", - chunk_size=1024, - chunk_overlap=20, - backup_separators=["\n"], - tokenizer=tiktoken.encoding_for_model("gpt-3.5-turbo").encode, -) - -node_parser = SimpleNodeParser.from_defaults(text_splitter=text_splitter) -``` - -`CodeSplitter` configuration: +documents = SimpleDirectoryReader("./data").load_data() -```python -from llama_index.text_splitter import CodeSplitter +text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20) +service_context = ServiceContext.from_defaults(text_splitter=text_splitter) -text_splitter = CodeSplitter( - language="python", - chunk_lines=40, - chunk_lines_overlap=15, - max_chars=1500, +index = VectorStoreIndex.from_documents( + documents, service_context=service_context ) - -node_parser = SimpleNodeParser.from_defaults(text_splitter=text_splitter) ``` -## SentenceWindowNodeParser - -The `SentenceWindowNodeParser` is similar to the `SimpleNodeParser`, except that it splits all documents into individual sentences. The resulting nodes also contain the surrounding "window" of sentences around each node in the metadata. Note that this metadata will not be visible to the LLM or embedding model. +## Modules -This is most useful for generating embeddings that have a very specific scope. Then, combined with a `MetadataReplacementNodePostProcessor`, you can replace the sentence with it's surrounding context before sending the node to the LLM. - -An example of setting up the parser with default settings is below. In practice, you would usually only want to adjust the window size of sentences. - -```python -import nltk -from llama_index.node_parser import SentenceWindowNodeParser - -node_parser = SentenceWindowNodeParser.from_defaults( - # how many sentences on either side to capture - window_size=3, - # the metadata key that holds the window of surrounding sentences - window_metadata_key="window", - # the metadata key that holds the original sentence - original_text_metadata_key="original_sentence", -) +```{toctree} +--- +maxdepth: 2 +--- +modules.md ``` - -A full example can be found [here in combination with the `MetadataReplacementNodePostProcessor`](/examples/node_postprocessor/MetadataReplacementDemo.ipynb). diff --git a/docs/module_guides/models/llms.md b/docs/module_guides/models/llms.md index 7114f092aae38bc5b2c3a35fd4c7f6c28cb0350b..6782483fe0455668b62410c66d68d2a2657c60f3 100644 --- a/docs/module_guides/models/llms.md +++ b/docs/module_guides/models/llms.md @@ -33,6 +33,32 @@ llms/usage_standalone.md llms/usage_custom.md ``` +## A Note on Tokenization + +By default, LlamaIndex uses a global tokenizer for all token counting. This defaults to `cl100k` from tiktoken, which is the tokenizer to match the default LLM `gpt-3.5-turbo`. + +If you change the LLM, you may need to update this tokenizer to ensure accurate token counts, chunking, and prompting. + +The single requirement for a tokenizer is that it is a callable function, that takes a string, and returns a list. + +You can set a global tokenizer like so: + +```python +from llama_index import set_global_tokenizer + +# tiktoken +import tiktoken + +set_global_tokenizer(tiktoken.encoding_for_model("gpt-3.5-turbo").encode) + +# huggingface +from transformers import AutoTokenizer + +set_global_tokenizer( + AutoTokenizer.from_pretrained("HuggingFaceH4/zephyr-7b-beta").encode +) +``` + ## LLM Compatibility Tracking While LLMs are powerful, not every LLM is easy to set up. Furthermore, even with proper setup, some LLMs have trouble performning tasks that require strict instruction following. diff --git a/docs/module_guides/models/llms/usage_custom.md b/docs/module_guides/models/llms/usage_custom.md index 5bc619cb2aab53d2bdcbbc2813a0311462aa17f3..75acbd1f9c4a69c1b496b5f9fe1b6fbb8f2f39a0 100644 --- a/docs/module_guides/models/llms/usage_custom.md +++ b/docs/module_guides/models/llms/usage_custom.md @@ -112,7 +112,7 @@ service_context = ServiceContext.from_defaults( ## Example: Using a HuggingFace LLM -LlamaIndex supports using LLMs from HuggingFace directly. Note that for a completely private experience, also setup a {ref}`local embedding model <custom_embeddings>`. +LlamaIndex supports using LLMs from HuggingFace directly. Note that for a completely private experience, also setup a [local embeddings model](../embeddings.md). Many open-source models from HuggingFace require either some preamble before each prompt, which is a `system_prompt`. Additionally, queries themselves may need an additional wrapper around the `query_str` itself. All this information is usually available from the HuggingFace model card for the model you are using. @@ -175,13 +175,13 @@ Several example notebooks are also listed below: To use a custom LLM model, you only need to implement the `LLM` class (or `CustomLLM` for a simpler interface) You will be responsible for passing the text to the model and returning the newly generated tokens. -Note that for a completely private experience, also setup a {ref}`local embedding model <custom_embeddings>`. +This implementation could be some local model, or even a wrapper around your own API. -Here is a small example using locally running facebook/OPT model and Huggingface's pipeline abstraction: +Note that for a completely private experience, also setup a [local embeddings model](../embeddings.md). + +Here is a small boilerplate example: ```python -import torch -from transformers import pipeline from typing import Optional, List, Mapping, Any from llama_index import ServiceContext, SimpleDirectoryReader, SummaryIndex @@ -195,57 +195,40 @@ from llama_index.llms import ( from llama_index.llms.base import llm_completion_callback -# set context window size -context_window = 2048 -# set number of output tokens -num_output = 256 - -# store the pipeline/model outside of the LLM class to avoid memory issues -model_name = "facebook/opt-iml-max-30b" -pipeline = pipeline( - "text-generation", - model=model_name, - device="cuda:0", - model_kwargs={"torch_dtype": torch.bfloat16}, -) - - class OurLLM(CustomLLM): + context_window: int = 3900 + num_output: int = 256 + model_name: str = "custom" + dummy_response: str = "My response" + @property def metadata(self) -> LLMMetadata: """Get LLM metadata.""" return LLMMetadata( - context_window=context_window, - num_output=num_output, - model_name=model_name, + context_window=self.context_window, + num_output=self.num_output, + model_name=self.model_name, ) @llm_completion_callback() def complete(self, prompt: str, **kwargs: Any) -> CompletionResponse: - prompt_length = len(prompt) - response = pipeline(prompt, max_new_tokens=num_output)[0][ - "generated_text" - ] - - # only return newly generated tokens - text = response[prompt_length:] - return CompletionResponse(text=text) + return CompletionResponse(text=self.dummy_response) @llm_completion_callback() def stream_complete( self, prompt: str, **kwargs: Any ) -> CompletionResponseGen: - raise NotImplementedError() + response = "" + for token in self.dummy_response: + response += token + yield CompletionResponse(text=response, delta=token) # define our LLM llm = OurLLM() service_context = ServiceContext.from_defaults( - llm=llm, - embed_model="local:BAAI/bge-base-en-v1.5", - context_window=context_window, - num_output=num_output, + llm=llm, embed_model="local:BAAI/bge-base-en-v1.5" ) # Load the your data diff --git a/docs/module_guides/querying/node_postprocessors/node_postprocessors.md b/docs/module_guides/querying/node_postprocessors/node_postprocessors.md index 98551f134c06c9b639fd415e4a6935c4ddc78a6a..7ac473a8329e7d260041f073a1e5d2742d248c55 100644 --- a/docs/module_guides/querying/node_postprocessors/node_postprocessors.md +++ b/docs/module_guides/querying/node_postprocessors/node_postprocessors.md @@ -5,7 +5,7 @@ Used to remove nodes that are below a similarity score threshold. ```python -from llama_index.indices.postprocessor import SimilarityPostprocessor +from llama_index.postprocessor import SimilarityPostprocessor postprocessor = SimilarityPostprocessor(similarity_cutoff=0.7) @@ -17,7 +17,7 @@ postprocessor.postprocess_nodes(nodes) Used to ensure certain keywords are either excluded or included. ```python -from llama_index.indices.postprocessor import KeywordNodePostprocessor +from llama_index.postprocessor import KeywordNodePostprocessor postprocessor = KeywordNodePostprocessor( required_keywords=["word1", "word2"], exclude_keywords=["word3", "word4"] @@ -31,7 +31,7 @@ postprocessor.postprocess_nodes(nodes) Used to replace the node content with a field from the node metadata. If the field is not present in the metadata, then the node text remains unchanged. Most useful when used in combination with the `SentenceWindowNodeParser`. ```python -from llama_index.indices.postprocessor import MetadataReplacementPostProcessor +from llama_index.postprocessor import MetadataReplacementPostProcessor postprocessor = MetadataReplacementPostProcessor( target_metadata_key="window", @@ -47,7 +47,7 @@ Models struggle to access significant details found in the center of extended co This module will re-order the retrieved nodes, which can be helpful in cases where a large top-k is needed. ```python -from llama_index.indices.postprocessor import LongContextReorder +from llama_index.postprocessor import LongContextReorder postprocessor = LongContextReorder() @@ -63,7 +63,7 @@ The percentile cutoff is a measure for using the top percentage of relevant sent The threshold cutoff can be specified instead, which uses a raw similarity cutoff for picking which sentences to keep. ```python -from llama_index.indices.postprocessor import SentenceEmbeddingOptimizer +from llama_index.postprocessor import SentenceEmbeddingOptimizer postprocessor = SentenceEmbeddingOptimizer( embed_model=service_context.embed_model, @@ -99,7 +99,7 @@ Full notebook guide is available [here](/examples/node_postprocessor/CohereReran Uses the cross-encoders from the `sentence-transformer` package to re-order nodes, and returns the top N nodes. ```python -from llama_index.indices.postprocessor import SentenceTransformerRerank +from llama_index.postprocessor import SentenceTransformerRerank # We choose a model with relatively high speed and decent accuracy. postprocessor = SentenceTransformerRerank( @@ -118,7 +118,7 @@ Please also refer to the [`sentence-transformer` docs](https://www.sbert.net/doc Uses a LLM to re-order nodes by asking the LLM to return the relevant documents and a score of how relevant they are. Returns the top N ranked nodes. ```python -from llama_index.indices.postprocessor import LLMRerank +from llama_index.postprocessor import LLMRerank postprocessor = LLMRerank(top_n=2, service_context=service_context) @@ -132,7 +132,7 @@ Full notebook guide is available [her for Gatsby](/examples/node_postprocessor/L This postproccesor returns the top K nodes sorted by date. This assumes there is a `date` field to parse in the metadata of each node. ```python -from llama_index.indices.postprocessor import FixedRecencyPostprocessor +from llama_index.postprocessor import FixedRecencyPostprocessor postprocessor = FixedRecencyPostprocessor( tok_k=1, date_key="date" # the key in the metadata to find the date @@ -150,7 +150,7 @@ A full notebook guide is available [here](/examples/node_postprocessor/RecencyPo This postproccesor returns the top K nodes after sorting by date and removing older nodes that are too similar after measuring embedding similarity. ```python -from llama_index.indices.postprocessor import EmbeddingRecencyPostprocessor +from llama_index.postprocessor import EmbeddingRecencyPostprocessor postprocessor = EmbeddingRecencyPostprocessor( service_context=service_context, date_key="date", similarity_cutoff=0.7 @@ -166,7 +166,7 @@ A full notebook guide is available [here](/examples/node_postprocessor/RecencyPo This postproccesor returns the top K nodes applying a time-weighted rerank to each node. Each time a node is retrieved, the time it was retrieved is recorded. This biases search to favor information that has not be returned in a query yet. ```python -from llama_index.indices.postprocessor import TimeWeightedPostprocessor +from llama_index.postprocessor import TimeWeightedPostprocessor postprocessor = TimeWeightedPostprocessor(time_decay=0.99, top_k=1) @@ -182,7 +182,7 @@ The PII (Personal Identifiable Information) postprocssor removes information tha ### LLM Version ```python -from llama_index.indices.postprocessor import PIINodePostprocessor +from llama_index.postprocessor import PIINodePostprocessor postprocessor = PIINodePostprocessor( service_context=service_context # this should be setup with an LLM you trust @@ -196,7 +196,7 @@ postprocessor.postprocess_nodes(nodes) This version uses the default local model from Hugging Face that is loaded when you run `pipeline("ner")`. ```python -from llama_index.indices.postprocessor import NERPIINodePostprocessor +from llama_index.postprocessor import NERPIINodePostprocessor postprocessor = NERPIINodePostprocessor() @@ -212,7 +212,7 @@ Uses pre-defined settings to read the `Node` relationships and fetch either all This is useful when you know the relationships point to important data (either before, after, or both) that should be sent to the LLM if that node is retrieved. ```python -from llama_index.indices.postprocessor import PrevNextNodePostprocessor +from llama_index.postprocessor import PrevNextNodePostprocessor postprocessor = PrevNextNodePostprocessor( docstore=index.docstore, @@ -230,7 +230,7 @@ postprocessor.postprocess_nodes(nodes) The same as PrevNextNodePostprocessor, but lets the LLM decide the mode (next, previous, or both). ```python -from llama_index.indices.postprocessor import AutoPrevNextNodePostprocessor +from llama_index.postprocessor import AutoPrevNextNodePostprocessor postprocessor = AutoPrevNextNodePostprocessor( docstore=index.docstore, diff --git a/docs/module_guides/querying/node_postprocessors/root.md b/docs/module_guides/querying/node_postprocessors/root.md index b6e0e11bd91400cbfd781104314e1e8270a45cf1..350cd4b2745152a005930f6b036ea56437505942 100644 --- a/docs/module_guides/querying/node_postprocessors/root.md +++ b/docs/module_guides/querying/node_postprocessors/root.md @@ -17,7 +17,7 @@ Confused about where node postprocessor fits in the pipeline? Read about [high-l An example of using a node postprocessors is below: ```python -from llama_index.indices.postprocessor import ( +from llama_index.postprocessor import ( SimilarityPostprocessor, CohereRerank, ) @@ -47,7 +47,7 @@ Most commonly, node-postprocessors will be used in a query engine, where they ar ```python from llama_index import VectorStoreIndex, SimpleDirectoryReader -from llama_index.indices.postprocessor import TimeWeightedPostprocessor +from llama_index.postprocessor import TimeWeightedPostprocessor documents = SimpleDirectoryReader("./data").load_data() @@ -70,7 +70,7 @@ response = query_engine.query("query string") Or used as a standalone object for filtering retrieved nodes: ```python -from llama_index.indices.postprocessor import SimilarityPostprocessor +from llama_index.postprocessor import SimilarityPostprocessor nodes = index.as_retriever().retrieve("test query str") @@ -84,7 +84,7 @@ filtered_nodes = processor.postprocess_nodes(nodes) As you may have noticed, the postprocessors take `NodeWithScore` objects as inputs, which is just a wrapper class with a `Node` and a `score` value. ```python -from llama_index.indices.postprocessor import SimilarityPostprocessor +from llama_index.postprocessor import SimilarityPostprocessor from llama_index.schema import Node, NodeWithScore nodes = [ @@ -116,7 +116,7 @@ A dummy node-postprocessor can be implemented in just a few lines of code: ```python from llama_index import QueryBundle -from llama_index.indices.postprocessor.base import BaseNodePostprocessor +from llama_index.postprocessor.base import BaseNodePostprocessor from llama_index.schema import NodeWithScore diff --git a/docs/module_guides/storing/customization.md b/docs/module_guides/storing/customization.md index ac6e1ad677ffc547cc43cf3c5d73889fd5adcbb1..4d884bd27364688f34525a7db2f3700ee6da022f 100644 --- a/docs/module_guides/storing/customization.md +++ b/docs/module_guides/storing/customization.md @@ -29,10 +29,10 @@ we use a lower-level API that gives more granular control: from llama_index.storage.docstore import SimpleDocumentStore from llama_index.storage.index_store import SimpleIndexStore from llama_index.vector_stores import SimpleVectorStore -from llama_index.node_parser import SimpleNodeParser +from llama_index.node_parser import SentenceSplitter # create parser and parse document into nodes -parser = SimpleNodeParser.from_defaults() +parser = SentenceSplitter() nodes = parser.get_nodes_from_documents(documents) # create storage context using default stores diff --git a/docs/module_guides/storing/docstores.md b/docs/module_guides/storing/docstores.md index 2ca2075df4328c1c912c1fb26791a1239dcab977..80c15fb902b4f636b63346b2deab6cdcbfb1052b 100644 --- a/docs/module_guides/storing/docstores.md +++ b/docs/module_guides/storing/docstores.md @@ -17,10 +17,10 @@ We support MongoDB as an alternative document store backend that persists data a ```python from llama_index.storage.docstore import MongoDocumentStore -from llama_index.node_parser import SimpleNodeParser +from llama_index.node_parser import SentenceSplitter # create parser and parse document into nodes -parser = SimpleNodeParser.from_defaults() +parser = SentenceSplitter() nodes = parser.get_nodes_from_documents(documents) # create (or load) docstore and add nodes @@ -51,10 +51,10 @@ We support Redis as an alternative document store backend that persists data as ```python from llama_index.storage.docstore import RedisDocumentStore -from llama_index.node_parser import SimpleNodeParser +from llama_index.node_parser import SentenceSplitter # create parser and parse document into nodes -parser = SimpleNodeParser.from_defaults() +parser = SentenceSplitter() nodes = parser.get_nodes_from_documents(documents) # create (or load) docstore and add nodes @@ -84,10 +84,10 @@ We support Firestore as an alternative document store backend that persists data ```python from llama_index.storage.docstore import FirestoreDocumentStore -from llama_index.node_parser import SimpleNodeParser +from llama_index.node_parser import SentenceSplitter # create parser and parse document into nodes -parser = SimpleNodeParser.from_defaults() +parser = SentenceSplitter() nodes = parser.get_nodes_from_documents(documents) # create (or load) docstore and add nodes diff --git a/docs/module_guides/supporting_modules/service_context.md b/docs/module_guides/supporting_modules/service_context.md index 13c788abb62c98cb7e38aef5ffef31c77b545c29..1f36eca0773e7d92b7787a502f75e1d1427c21bd 100644 --- a/docs/module_guides/supporting_modules/service_context.md +++ b/docs/module_guides/supporting_modules/service_context.md @@ -74,14 +74,11 @@ from llama_index import ( PromptHelper, ) from llama_index.llms import OpenAI -from llama_index.text_splitter import TokenTextSplitter -from llama_index.node_parser import SimpleNodeParser +from llama_index.text_splitter import SentenceSplitter llm = OpenAI(model="text-davinci-003", temperature=0, max_tokens=256) embed_model = OpenAIEmbedding() -node_parser = SimpleNodeParser.from_defaults( - text_splitter=TokenTextSplitter(chunk_size=1024, chunk_overlap=20) -) +text_splitter = SentenceSplitter(chunk_size=1024, chunk_overlap=20) prompt_helper = PromptHelper( context_window=4096, num_output=256, @@ -92,7 +89,7 @@ prompt_helper = PromptHelper( service_context = ServiceContext.from_defaults( llm=llm, embed_model=embed_model, - node_parser=node_parser, + text_splitter=text_splitter, prompt_helper=prompt_helper, ) ``` diff --git a/docs/understanding/loading/loading.md b/docs/understanding/loading/loading.md index 5086c03345d12e16cca33a7f264bd22dc436437c..59e79c4e5751b08ffc90757f8b03a26ce803e32f 100644 --- a/docs/understanding/loading/loading.md +++ b/docs/understanding/loading/loading.md @@ -52,12 +52,12 @@ In this example, you load your documents, then create a SimpleNodeParser configu ```python from llama_index import SimpleDirectoryReader, VectorStoreIndex, ServiceContext -from llama_index.node_parser import SimpleNodeParser +from llama_index.text_splitter import SentenceSplitter documents = SimpleDirectoryReader("./data").load_data() -node_parser = SimpleNodeParser.from_defaults(chunk_size=512, chunk_overlap=10) -service_context = ServiceContext.from_defaults(node_parser=node_parser) +text_splitter = SentenceSplitter(chunk_size=512, chunk_overlap=10) +service_context = ServiceContext.from_defaults(text_splitter=text_splitter) index = VectorStoreIndex.from_documents( documents, service_context=service_context @@ -83,6 +83,28 @@ node2 = TextNode(text="<text_chunk>", id_="<node_id>") index = VectorStoreIndex([node1, node2]) ``` +## Creating Nodes from Documents directly + +Using an `IngestionPipeline`, you can have more control over how nodes are created. + +```python +from llama_index import Document +from llama_index.text_splitter import SentenceSplitter +from llama_index.ingestion import IngestionPipeline + +# create the pipeline with transformations +pipeline = IngestionPipeline( + transformations=[ + SentenceSplitter(chunk_size=25, chunk_overlap=0), + ] +) + +# run the pipeline +nodes = pipeline.run(documents=[Document.example()]) +``` + +You can learn more about the [`IngestionPipeline` here.](/module_guides/loading/ingestion_pipeline/root.md) + ## Customizing Documents When creating documents, you can also attach useful metadata that can be used at the querying stage. Any metadata added to a Document will be copied to the Nodes that get created from that document. diff --git a/docs/understanding/querying/querying.md b/docs/understanding/querying/querying.md index 4790578915e3ecbc8bbca9ef5441e4af94934edd..13ea49ff3beb67a0da4d9671766aef9c5f750767 100644 --- a/docs/understanding/querying/querying.md +++ b/docs/understanding/querying/querying.md @@ -41,7 +41,7 @@ from llama_index import ( ) from llama_index.retrievers import VectorIndexRetriever from llama_index.query_engine import RetrieverQueryEngine -from llama_index.indices.postprocessor import SimilarityPostprocessor +from llama_index.postprocessor import SimilarityPostprocessor # build index index = VectorStoreIndex.from_documents(documents) diff --git a/examples/paul_graham_essay/SentenceSplittingDemo.ipynb b/examples/paul_graham_essay/SentenceSplittingDemo.ipynb index fd2f67c10c3b150e38ecec36cfc00b8dded18f90..4823b70ece9c20382d36231d6a85c3f304990a05 100644 --- a/examples/paul_graham_essay/SentenceSplittingDemo.ipynb +++ b/examples/paul_graham_essay/SentenceSplittingDemo.ipynb @@ -17,9 +17,8 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.text_splitter import TokenTextSplitter\n", - "from llama_index import SimpleDirectoryReader, Document\n", - "from llama_index.utils import globals_helper\n", + "from llama_index.node_parser import TokenTextSplitter, LangchainNodeParser\n", + "from llama_index import SimpleDirectoryReader, Document, get_tokenizer\n", "from langchain.text_splitter import (\n", " NLTKTextSplitter,\n", " SpacyTextSplitter,\n", @@ -27,10 +26,10 @@ ")\n", "\n", "document = SimpleDirectoryReader(\"data\").load_data()[0]\n", - "text_splitter_default = TokenTextSplitter() # use default settings\n", - "text_chunks = text_splitter_default.split_text(document.text)\n", + "text_parser = TokenTextSplitter() # use default settings\n", + "text_chunks = text_parser.split_text(document.text)\n", "doc_chunks = [Document(text=t) for t in text_chunks]\n", - "tokenizer = globals_helper.tokenizer\n", + "tokenizer = get_tokenizer()\n", "with open(\"splitting_1.txt\", \"w\") as f:\n", " for idx, doc in enumerate(doc_chunks):\n", " f.write(\n", @@ -38,10 +37,10 @@ " + doc.text\n", " )\n", "\n", - "from llama_index.text_splitter import SentenceSplitter\n", + "from llama_index.node_parser import SentenceSplitter\n", "\n", - "sentence_splitter = SentenceSplitter()\n", - "text_chunks = sentence_splitter.split_text(document.text)\n", + "sentence_parser = SentenceSplitter()\n", + "text_chunks = sentence_parser.split_text(document.text)\n", "doc_chunks = [Document(text=t) for t in text_chunks]\n", "with open(\"splitting_2.txt\", \"w\") as f:\n", " for idx, doc in enumerate(doc_chunks):\n", @@ -50,40 +49,16 @@ " + doc.text\n", " )\n", "\n", - "nltk_splitter = NLTKTextSplitter()\n", - "text_chunks = nltk_splitter.split_text(document.text)\n", + "nltk_parser = LangchainNodeParser(NLTKTextSplitter())\n", + "text_chunks = nltk_parser.split_text(document.text)\n", "doc_chunks = [Document(text=t) for t in text_chunks]\n", - "tokenizer = globals_helper.tokenizer\n", + "tokenizer = get_tokenizer()\n", "with open(\"splitting_3.txt\", \"w\") as f:\n", " for idx, doc in enumerate(doc_chunks):\n", " f.write(\n", " \"\\n-------\\n\\n{}. Size: {} tokens\\n\".format(idx, len(tokenizer(doc.text)))\n", " + doc.text\n", - " )\n", - "\n", - "# spacy_splitter = SpacyTextSplitter()\n", - "# text_chunks = spacy_splitter.split_text(document.text)\n", - "# tokenizer = globals_helper.tokenizer\n", - "# with open('splitting_4.txt', 'w') as f:\n", - "# for idx, doc in enumerate(doc_chunks):\n", - "# f.write(\"\\n-------\\n\\n{}. Size: {} tokens\\n\".format(idx, len(tokenizer(doc.text))) + doc.text)\n", - "\n", - "# from langchain.text_splitter import TokenTextSplitter\n", - "# token_text_splitter = TokenTextSplitter()\n", - "# text_chunks = token_text_splitter.split_text(document.text)\n", - "# doc_chunks = [Document(text=t) for t in text_chunks]\n", - "# tokenizer = globals_helper.tokenizer\n", - "# with open('splitting_5.txt', 'w') as f:\n", - "# for idx, doc in enumerate(doc_chunks):\n", - "# f.write(\"\\n-------\\n\\n{}. Size: {} tokens\\n\".format(idx, len(tokenizer(doc.text))) + doc.text)\n", - "\n", - "# recursive_splitter = RecursiveCharacterTextSplitter()\n", - "# text_chunks = recursive_splitter.split_text(document.text)\n", - "# doc_chunks = [Document(text=t) for t in text_chunks]\n", - "# tokenizer = globals_helper.tokenizer\n", - "# with open('splitting_6.txt', 'w') as f:\n", - "# for idx, doc in enumerate(doc_chunks):\n", - "# f.write(\"\\n-------\\n\\n{}. Size: {} tokens\\n\".format(idx, len(tokenizer(doc.text))) + doc.text)" + " )" ] }, { @@ -105,7 +80,6 @@ "from llama_index.text_splitter import SentenceSplitter\n", "from llama_index.schema import Document\n", "from llama_index.indices.service_context import ServiceContext\n", - "from llama_index.node_parser.simple import SimpleNodeParser\n", "from llama_index.indices.vector_store import VectorStoreIndex\n", "import wikipedia" ] @@ -117,10 +91,10 @@ "metadata": {}, "outputs": [], "source": [ - "sentence_splitter = SentenceSplitter()\n", + "sentence_parser = SentenceSplitter()\n", "wikipedia.set_lang(\"zh\")\n", "page = wikipedia.page(\"美国\", auto_suggest=True).content\n", - "sentence_splitter.split_text(page)" + "sentence_parser.split_text(page)" ] }, { @@ -130,8 +104,8 @@ "metadata": {}, "outputs": [], "source": [ - "node_parser = SimpleNodeParser.from_defaults(text_splitter=sentence_splitter)\n", - "service_context = ServiceContext.from_defaults(node_parser=node_parser)\n", + "text_splitter = SentenceSplitter()\n", + "service_context = ServiceContext.from_defaults(text_splitter=text_splitter)\n", "documents = []\n", "documents.append(Document(text=page))\n", "index = VectorStoreIndex.from_documents(documents, service_context=service_context)" diff --git a/examples/test_wiki/TestWikiReader.ipynb b/examples/test_wiki/TestWikiReader.ipynb index 117da3b195e33c3c5bca2e58c970769fae66ad6b..02030465b28db8ca27f39f9829ea75b18029df95 100644 --- a/examples/test_wiki/TestWikiReader.ipynb +++ b/examples/test_wiki/TestWikiReader.ipynb @@ -177,7 +177,7 @@ "source": [ "# set Logging to DEBUG for more detailed outputs\n", "# with keyword lookup\n", - "from llama_index.indices.postprocessor import KeywordNodePostprocessor\n", + "from llama_index.postprocessor import KeywordNodePostprocessor\n", "\n", "\n", "query_engine = index.as_query_engine(\n", diff --git a/experimental/cli/configuration.py b/experimental/cli/configuration.py index 4d2d8516cbaa9fac25681f6d3086623eb6ef2c41..a78cde46bf9b788bf18359df17a570addf6c5060 100644 --- a/experimental/cli/configuration.py +++ b/experimental/cli/configuration.py @@ -5,11 +5,11 @@ from typing import Any, Type from llama_index import ( LLMPredictor, ServiceContext, - SimpleKeywordTableIndex, VectorStoreIndex, ) from llama_index.embeddings.base import BaseEmbedding from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.indices import SimpleKeywordTableIndex from llama_index.indices.base import BaseIndex from llama_index.indices.loading import load_index_from_storage from llama_index.llm_predictor import StructuredLLMPredictor diff --git a/experimental/colbert_index/base.py b/experimental/colbert_index/base.py index c0d3bfa143c0889cd785b4b1b4f132b3b6fb559d..2662ecaa77147bbcc41be17b3fe5558390f37b91 100644 --- a/experimental/colbert_index/base.py +++ b/experimental/colbert_index/base.py @@ -1,10 +1,10 @@ from typing import Any, Dict, List, Optional, Sequence +from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexDict from llama_index.indices.base import BaseIndex -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.service_context import ServiceContext from llama_index.schema import BaseNode, NodeWithScore +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo from llama_index.storage.storage_context import StorageContext diff --git a/experimental/colbert_index/retriever.py b/experimental/colbert_index/retriever.py index 3b3e1a084e655344f6ad918ea319fa9e0c001025..6473f23a2da646a867a19d1db654f1724a645502 100644 --- a/experimental/colbert_index/retriever.py +++ b/experimental/colbert_index/retriever.py @@ -1,9 +1,8 @@ from typing import Any, Dict, List, Optional from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import NodeWithScore +from llama_index.core import BaseRetriever +from llama_index.schema import NodeWithScore, QueryBundle from llama_index.vector_stores.types import MetadataFilters from .base import ColbertIndex diff --git a/experimental/splitter_playground/app.py b/experimental/splitter_playground/app.py index 6accef9795ecf1101d25097a35b9fbaf514d46b7..6dc1429597160b5a3024c4fcd39f7e71ca37f25a 100644 --- a/experimental/splitter_playground/app.py +++ b/experimental/splitter_playground/app.py @@ -15,9 +15,9 @@ from langchain.text_splitter import TokenTextSplitter as LCTokenTextSplitter from streamlit.runtime.uploaded_file_manager import UploadedFile from llama_index import SimpleDirectoryReader +from llama_index.node_parser.interface import TextSplitter from llama_index.schema import Document from llama_index.text_splitter import CodeSplitter, SentenceSplitter, TokenTextSplitter -from llama_index.text_splitter.types import TextSplitter DEFAULT_TEXT = "The quick brown fox jumps over the lazy dog." @@ -28,7 +28,7 @@ n_cols = st.sidebar.number_input("Columns", value=2, min_value=1, max_value=3) assert isinstance(n_cols, int) -@st.cache_resource(ttl="1h") +@st.cache_resource(ttl=3600) def load_document(uploaded_files: List[UploadedFile]) -> List[Document]: # Read documents temp_dir = tempfile.TemporaryDirectory() diff --git a/llama_index/VERSION b/llama_index/VERSION index c550871d7120e95c8d918a3445c4cd67df9155d0..ac39a106c48515b621e90c028ed94c6f71bc03fa 100644 --- a/llama_index/VERSION +++ b/llama_index/VERSION @@ -1 +1 @@ -0.8.69.post2 +0.9.0 diff --git a/llama_index/__init__.py b/llama_index/__init__.py index 41703f31bff1cd9333f22384d33a4ff123a2afef..77340926418baf0fcf9e339849642b575d2448e6 100644 --- a/llama_index/__init__.py +++ b/llama_index/__init__.py @@ -7,76 +7,51 @@ with open(Path(__file__).absolute().parents[0] / "VERSION") as _f: import logging from logging import NullHandler -from typing import Optional +from typing import Callable, Optional # import global eval handler from llama_index.callbacks.global_handlers import set_global_handler from llama_index.data_structs.struct_type import IndexStructType # embeddings -from llama_index.embeddings.langchain import LangchainEmbedding -from llama_index.embeddings.openai import OpenAIEmbedding +from llama_index.embeddings import OpenAIEmbedding -# structured -from llama_index.indices.common.struct_store.base import SQLDocumentContextBuilder - -# for composability -from llama_index.indices.composability.graph import ComposableGraph -from llama_index.indices.document_summary import ( +# indices +# loading +from llama_index.indices import ( + ComposableGraph, DocumentSummaryIndex, GPTDocumentSummaryIndex, -) -from llama_index.indices.empty import EmptyIndex, GPTEmptyIndex - -# indices -from llama_index.indices.keyword_table import ( GPTKeywordTableIndex, + GPTKnowledgeGraphIndex, + GPTListIndex, GPTRAKEKeywordTableIndex, GPTSimpleKeywordTableIndex, + GPTTreeIndex, + GPTVectorStoreIndex, KeywordTableIndex, + KnowledgeGraphIndex, + ListIndex, RAKEKeywordTableIndex, SimpleKeywordTableIndex, -) -from llama_index.indices.knowledge_graph import ( - GPTKnowledgeGraphIndex, - KnowledgeGraphIndex, -) -from llama_index.indices.list import GPTListIndex, ListIndex, SummaryIndex - -# loading -from llama_index.indices.loading import ( + SummaryIndex, + TreeIndex, + VectorStoreIndex, load_graph_from_storage, load_index_from_storage, load_indices_from_storage, ) +# structured +from llama_index.indices.common.struct_store.base import SQLDocumentContextBuilder + # prompt helper from llama_index.indices.prompt_helper import PromptHelper - -# QueryBundle -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ( - ServiceContext, - set_global_service_context, -) -from llama_index.indices.struct_store.pandas import GPTPandasIndex, PandasIndex -from llama_index.indices.struct_store.sql import ( - GPTSQLStructStoreIndex, - SQLStructStoreIndex, -) -from llama_index.indices.tree import GPTTreeIndex, TreeIndex -from llama_index.indices.vector_store import GPTVectorStoreIndex, VectorStoreIndex -from llama_index.langchain_helpers.memory_wrapper import GPTIndexMemory - -# langchain helper from llama_index.llm_predictor import LLMPredictor # token predictor from llama_index.llm_predictor.mock import MockLLMPredictor -# vellum -from llama_index.llm_predictor.vellum import VellumPredictor, VellumPromptRegistry - # prompts from llama_index.prompts import ( BasePromptTemplate, @@ -86,53 +61,18 @@ from llama_index.prompts import ( PromptTemplate, SelectorPromptTemplate, ) -from llama_index.prompts.prompts import ( - KeywordExtractPrompt, - QueryKeywordExtractPrompt, - QuestionAnswerPrompt, - RefinePrompt, - SummaryPrompt, - TreeInsertPrompt, - TreeSelectMultiplePrompt, - TreeSelectPrompt, -) -from llama_index.readers import ( - BeautifulSoupWebReader, - ChromaReader, - DeepLakeReader, - DiscordReader, - FaissReader, - GithubRepositoryReader, - GoogleDocsReader, - JSONReader, - MboxReader, - MilvusReader, - NotionPageReader, - ObsidianReader, - PineconeReader, - PsychicReader, - QdrantReader, - RssReader, - SimpleDirectoryReader, - SimpleMongoReader, - SimpleWebPageReader, - SlackReader, - StringIterableReader, - TrafilaturaWebReader, - TwitterTweetReader, - WeaviateReader, - WikipediaReader, -) -from llama_index.readers.download import download_loader +from llama_index.readers import SimpleDirectoryReader, download_loader # response from llama_index.response.schema import Response # Response Synthesizer from llama_index.response_synthesizers.factory import get_response_synthesizer - -# readers -from llama_index.schema import Document +from llama_index.schema import Document, QueryBundle +from llama_index.service_context import ( + ServiceContext, + set_global_service_context, +) # storage from llama_index.storage.storage_context import StorageContext @@ -141,6 +81,9 @@ from llama_index.token_counter.mock_embed_model import MockEmbedding # sql wrapper from llama_index.utilities.sql_wrapper import SQLDatabase +# global tokenizer +from llama_index.utils import get_tokenizer, set_global_tokenizer + # best practices for library logging: # https://docs.python.org/3/howto/logging.html#configuring-logging-for-a-library logging.getLogger(__name__).addHandler(NullHandler()) @@ -156,9 +99,6 @@ __all__ = [ "KeywordTableIndex", "RAKEKeywordTableIndex", "TreeIndex", - "SQLStructStoreIndex", - "PandasIndex", - "EmptyIndex", "DocumentSummaryIndex", "KnowledgeGraphIndex", # indices - legacy names @@ -168,18 +108,14 @@ __all__ = [ "GPTRAKEKeywordTableIndex", "GPTListIndex", "ListIndex", - "GPTEmptyIndex", "GPTTreeIndex", "GPTVectorStoreIndex", - "GPTPandasIndex", - "GPTSQLStructStoreIndex", "GPTDocumentSummaryIndex", "Prompt", "PromptTemplate", "BasePromptTemplate", "ChatPromptTemplate", "SelectorPromptTemplate", - "LangchainEmbedding", "OpenAIEmbedding", "SummaryPrompt", "TreeInsertPrompt", @@ -220,7 +156,6 @@ __all__ = [ "VellumPromptRegistry", "MockEmbedding", "SQLDatabase", - "GPTIndexMemory", "SQLDocumentContextBuilder", "SQLContextBuilder", "PromptHelper", @@ -235,6 +170,8 @@ __all__ = [ "get_response_synthesizer", "set_global_service_context", "set_global_handler", + "set_global_tokenizer", + "get_tokenizer", ] # eval global toggle @@ -247,3 +184,6 @@ SQLContextBuilder = SQLDocumentContextBuilder # global service context for ServiceContext.from_defaults() global_service_context: Optional[ServiceContext] = None + +# global tokenizer +global_tokenizer: Optional[Callable[[str], list]] = None diff --git a/llama_index/agent/context_retriever_agent.py b/llama_index/agent/context_retriever_agent.py index a895236ba21d28b7a6bc30dce3e524c97f8baefb..f2a463b14b1a9053f025e60e292d369610af1100 100644 --- a/llama_index/agent/context_retriever_agent.py +++ b/llama_index/agent/context_retriever_agent.py @@ -11,7 +11,7 @@ from llama_index.callbacks import CallbackManager from llama_index.chat_engine.types import ( AgentChatResponse, ) -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.llms.base import LLM, ChatMessage from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import is_function_calling_model diff --git a/llama_index/agent/types.py b/llama_index/agent/types.py index 8cd296b896c986fea1a74d6e165ae0dbec8ff9ad..6595cf2bfa46b56ca20f7cef55b48e6ddb41853a 100644 --- a/llama_index/agent/types.py +++ b/llama_index/agent/types.py @@ -3,11 +3,11 @@ from typing import List, Optional from llama_index.callbacks import trace_method from llama_index.chat_engine.types import BaseChatEngine, StreamingAgentChatResponse -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle +from llama_index.core import BaseQueryEngine from llama_index.llms.base import ChatMessage from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.response.schema import RESPONSE_TYPE, Response +from llama_index.schema import QueryBundle class BaseAgent(BaseChatEngine, BaseQueryEngine): diff --git a/llama_index/bridge/pydantic.py b/llama_index/bridge/pydantic.py index 5ba374deab7521fb2bcfa431a51a558383460169..9f9be59f3ab929bbcad834cefd74d3052767c3b6 100644 --- a/llama_index/bridge/pydantic.py +++ b/llama_index/bridge/pydantic.py @@ -1,5 +1,7 @@ try: + import pydantic.v1 as pydantic from pydantic.v1 import ( + BaseConfig, BaseModel, Field, PrivateAttr, @@ -12,8 +14,11 @@ try: ) from pydantic.v1.error_wrappers import ValidationError from pydantic.v1.fields import FieldInfo + from pydantic.v1.generics import GenericModel except ImportError: + import pydantic # type: ignore from pydantic import ( + BaseConfig, BaseModel, Field, PrivateAttr, @@ -26,8 +31,10 @@ except ImportError: ) from pydantic.error_wrappers import ValidationError from pydantic.fields import FieldInfo + from pydantic.generics import GenericModel __all__ = [ + "pydantic", "BaseModel", "Field", "PrivateAttr", @@ -39,4 +46,6 @@ __all__ = [ "StrictStr", "FieldInfo", "ValidationError", + "GenericModel", + "BaseConfig", ] diff --git a/llama_index/callbacks/token_counting.py b/llama_index/callbacks/token_counting.py index 5081fc4eb9e729db87c1d3e1cea8fc9b25503162..3bd483ed203f19944f7ba19d749226dbb705f9ef 100644 --- a/llama_index/callbacks/token_counting.py +++ b/llama_index/callbacks/token_counting.py @@ -3,7 +3,8 @@ from typing import Any, Callable, Dict, List, Optional, cast from llama_index.callbacks.base_handler import BaseCallbackHandler from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.utils import globals_helper +from llama_index.utilities.token_counting import TokenCounter +from llama_index.utils import get_tokenizer @dataclass @@ -20,7 +21,7 @@ class TokenCountingEvent: def get_llm_token_counts( - tokenizer: Callable[[str], List], payload: Dict[str, Any], event_id: str = "" + token_counter: TokenCounter, payload: Dict[str, Any], event_id: str = "" ) -> TokenCountingEvent: from llama_index.llms import ChatMessage @@ -31,22 +32,50 @@ def get_llm_token_counts( return TokenCountingEvent( event_id=event_id, prompt=prompt, - prompt_token_count=len(tokenizer(prompt)), + prompt_token_count=token_counter.get_string_tokens(prompt), completion=completion, - completion_token_count=len(tokenizer(completion)), + completion_token_count=token_counter.get_string_tokens(completion), ) elif EventPayload.MESSAGES in payload: messages = cast(List[ChatMessage], payload.get(EventPayload.MESSAGES, [])) messages_str = "\n".join([str(x) for x in messages]) - response = str(payload.get(EventPayload.RESPONSE)) + + response = payload.get(EventPayload.RESPONSE) + response_str = str(response) + + # try getting attached token counts first + try: + usage_dict = response.additional_kwargs # type: ignore + + messages_tokens = usage_dict["prompt_tokens"] + response_tokens = usage_dict["completion_tokens"] + + if messages_tokens == 0 or response_tokens == 0: + raise ValueError("Invalid token counts!") + + return TokenCountingEvent( + event_id=event_id, + prompt=messages_str, + prompt_token_count=messages_tokens, + completion=response_str, + completion_token_count=response_tokens, + ) + + except (ValueError, KeyError): + # Invalid token counts, or no token counts attached + pass + + # Should count tokens ourselves + messages_tokens = token_counter.estimate_tokens_in_messages(messages) + response_tokens = token_counter.get_string_tokens(response_str) return TokenCountingEvent( event_id=event_id, prompt=messages_str, - prompt_token_count=len(tokenizer(messages_str)), - completion=response, - completion_token_count=len(tokenizer(response)), + prompt_token_count=messages_tokens, + completion=response_str, + completion_token_count=response_tokens, ) else: raise ValueError( @@ -74,7 +103,9 @@ class TokenCountingHandler(BaseCallbackHandler): ) -> None: self.llm_token_counts: List[TokenCountingEvent] = [] self.embedding_token_counts: List[TokenCountingEvent] = [] - self.tokenizer = tokenizer or globals_helper.tokenizer + self.tokenizer = tokenizer or get_tokenizer() + + self._token_counter = TokenCounter(tokenizer=self.tokenizer) self._verbose = verbose super().__init__( @@ -117,7 +148,7 @@ class TokenCountingHandler(BaseCallbackHandler): ): self.llm_token_counts.append( get_llm_token_counts( - tokenizer=self.tokenizer, + token_counter=self._token_counter, payload=payload, event_id=event_id, ) @@ -142,7 +173,7 @@ class TokenCountingHandler(BaseCallbackHandler): TokenCountingEvent( event_id=event_id, prompt=chunk, - prompt_token_count=len(self.tokenizer(chunk)), + prompt_token_count=self._token_counter.get_string_tokens(chunk), completion="", completion_token_count=0, ) diff --git a/llama_index/callbacks/wandb_callback.py b/llama_index/callbacks/wandb_callback.py index ee626e9fcc0ffd06b65d9dd272d74c9b29b87d64..c7bc9195951903a938e2ac7b1205a6708b985c5c 100644 --- a/llama_index/callbacks/wandb_callback.py +++ b/llama_index/callbacks/wandb_callback.py @@ -23,13 +23,14 @@ from llama_index.callbacks.schema import ( EventPayload, ) from llama_index.callbacks.token_counting import get_llm_token_counts -from llama_index.utils import globals_helper +from llama_index.utilities.token_counting import TokenCounter +from llama_index.utils import get_tokenizer if TYPE_CHECKING: from wandb import Settings as WBSettings from wandb.sdk.data_types import trace_tree - from llama_index import ( + from llama_index.indices import ( ComposableGraph, GPTEmptyIndex, GPTKeywordTableIndex, @@ -128,7 +129,7 @@ class WandbCallbackHandler(BaseCallbackHandler): "Please install it with `pip install wandb`." ) - from llama_index import ( + from llama_index.indices import ( ComposableGraph, GPTEmptyIndex, GPTKeywordTableIndex, @@ -160,7 +161,9 @@ class WandbCallbackHandler(BaseCallbackHandler): self._cur_trace_id: Optional[str] = None self._trace_map: Dict[str, List[str]] = defaultdict(list) - self.tokenizer = tokenizer or globals_helper.tokenizer + self.tokenizer = tokenizer or get_tokenizer() + self._token_counter = TokenCounter(tokenizer=self.tokenizer) + event_starts_to_ignore = ( event_starts_to_ignore if event_starts_to_ignore else [] ) @@ -463,7 +466,7 @@ class WandbCallbackHandler(BaseCallbackHandler): [str(x) for x in inputs[EventPayload.MESSAGES]] ) - token_counts = get_llm_token_counts(self.tokenizer, outputs) + token_counts = get_llm_token_counts(self._token_counter, outputs) metadata = { "formatted_prompt_tokens_count": token_counts.prompt_token_count, "prediction_tokens_count": token_counts.completion_token_count, diff --git a/llama_index/chat_engine/condense_question.py b/llama_index/chat_engine/condense_question.py index a1e3702896ee8f506b2aa987b8601d57d2979c22..e8c9500701acfadf217034193471f4442e5f2483 100644 --- a/llama_index/chat_engine/condense_question.py +++ b/llama_index/chat_engine/condense_question.py @@ -9,14 +9,14 @@ from llama_index.chat_engine.types import ( StreamingAgentChatResponse, ) from llama_index.chat_engine.utils import response_gen_from_query_engine -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseQueryEngine from llama_index.llm_predictor.base import LLMPredictor from llama_index.llms.base import ChatMessage, MessageRole from llama_index.llms.generic_utils import messages_to_history_str from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.response.schema import RESPONSE_TYPE, StreamingResponse +from llama_index.service_context import ServiceContext from llama_index.tools import ToolOutput logger = logging.getLogger(__name__) diff --git a/llama_index/chat_engine/context.py b/llama_index/chat_engine/context.py index 554222008f319f981601c0cf884491924594c4c0..e9ac30c2922def0e430636534aa490f8d5dd01fe 100644 --- a/llama_index/chat_engine/context.py +++ b/llama_index/chat_engine/context.py @@ -9,14 +9,13 @@ from llama_index.chat_engine.types import ( StreamingAgentChatResponse, ToolOutput, ) -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseRetriever from llama_index.llm_predictor.base import LLMPredictor from llama_index.llms.base import LLM, ChatMessage, MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer -from llama_index.schema import MetadataMode, NodeWithScore +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext DEFAULT_CONTEXT_TEMPLATE = ( "Context information is below." diff --git a/llama_index/chat_engine/simple.py b/llama_index/chat_engine/simple.py index 7ec546e49f859db7a133fb3c508e26f65f93baa7..3109de021902c8509fddae33c8e7ee5c774bee75 100644 --- a/llama_index/chat_engine/simple.py +++ b/llama_index/chat_engine/simple.py @@ -8,10 +8,10 @@ from llama_index.chat_engine.types import ( BaseChatEngine, StreamingAgentChatResponse, ) -from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor.base import LLMPredictor from llama_index.llms.base import LLM, ChatMessage from llama_index.memory import BaseMemory, ChatMemoryBuffer +from llama_index.service_context import ServiceContext class SimpleChatEngine(BaseChatEngine): diff --git a/llama_index/composability/joint_qa_summary.py b/llama_index/composability/joint_qa_summary.py index f13db1e23f0f566c3eb50972cc7aaf30d5042102..fe48073400cda7f007750abab33c07d07b97b24f 100644 --- a/llama_index/composability/joint_qa_summary.py +++ b/llama_index/composability/joint_qa_summary.py @@ -4,10 +4,11 @@ from typing import Optional, Sequence from llama_index.indices.list.base import SummaryIndex -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store import VectorStoreIndex +from llama_index.ingestion import run_transformations from llama_index.query_engine.router_query_engine import RouterQueryEngine from llama_index.schema import Document +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext from llama_index.tools.query_engine import QueryEngineTool @@ -55,7 +56,9 @@ class QASummaryQueryEngineBuilder: ) -> RouterQueryEngine: """Build query engine.""" # parse nodes - nodes = self._service_context.node_parser.get_nodes_from_documents(documents) + nodes = run_transformations( + documents, self._service_context.transformations # type: ignore + ) # ingest nodes self._storage_context.docstore.add_documents(nodes, allow_update=True) diff --git a/llama_index/constants.py b/llama_index/constants.py index cb3f8c1ca580b7ac826d301f1b7a40e727265e64..ab0554e2ddb2498eb9c1cc77b09d4f3c57bc81bb 100644 --- a/llama_index/constants.py +++ b/llama_index/constants.py @@ -1,5 +1,6 @@ """Set of constants.""" +DEFAULT_TEMPERATURE = 0.1 DEFAULT_CONTEXT_WINDOW = 3900 # tokens DEFAULT_NUM_OUTPUTS = 256 # tokens DEFAULT_NUM_INPUT_FILES = 10 # files diff --git a/llama_index/core/__init__.py b/llama_index/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..475a7162d095f17e39389c0776b5055940479234 --- /dev/null +++ b/llama_index/core/__init__.py @@ -0,0 +1,4 @@ +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever + +__all__ = ["BaseRetriever", "BaseQueryEngine"] diff --git a/llama_index/core/base_query_engine.py b/llama_index/core/base_query_engine.py new file mode 100644 index 0000000000000000000000000000000000000000..c7546b79f5ea9cfe851efd684baff967d615387f --- /dev/null +++ b/llama_index/core/base_query_engine.py @@ -0,0 +1,69 @@ +"""Base query engine.""" + +import logging +from abc import abstractmethod +from typing import Any, Dict, List, Optional, Sequence + +from llama_index.callbacks.base import CallbackManager +from llama_index.prompts.mixin import PromptDictType, PromptMixin +from llama_index.response.schema import RESPONSE_TYPE +from llama_index.schema import NodeWithScore, QueryBundle, QueryType + +logger = logging.getLogger(__name__) + + +class BaseQueryEngine(PromptMixin): + def __init__(self, callback_manager: Optional[CallbackManager]) -> None: + self.callback_manager = callback_manager or CallbackManager([]) + + def _get_prompts(self) -> Dict[str, Any]: + """Get prompts.""" + return {} + + def _update_prompts(self, prompts: PromptDictType) -> None: + """Update prompts.""" + + def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE: + with self.callback_manager.as_trace("query"): + if isinstance(str_or_query_bundle, str): + str_or_query_bundle = QueryBundle(str_or_query_bundle) + return self._query(str_or_query_bundle) + + async def aquery(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE: + with self.callback_manager.as_trace("query"): + if isinstance(str_or_query_bundle, str): + str_or_query_bundle = QueryBundle(str_or_query_bundle) + return await self._aquery(str_or_query_bundle) + + def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + raise NotImplementedError( + "This query engine does not support retrieve, use query directly" + ) + + def synthesize( + self, + query_bundle: QueryBundle, + nodes: List[NodeWithScore], + additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, + ) -> RESPONSE_TYPE: + raise NotImplementedError( + "This query engine does not support synthesize, use query directly" + ) + + async def asynthesize( + self, + query_bundle: QueryBundle, + nodes: List[NodeWithScore], + additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, + ) -> RESPONSE_TYPE: + raise NotImplementedError( + "This query engine does not support asynthesize, use aquery directly" + ) + + @abstractmethod + def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: + pass + + @abstractmethod + async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: + pass diff --git a/llama_index/core/base_retriever.py b/llama_index/core/base_retriever.py new file mode 100644 index 0000000000000000000000000000000000000000..baf22e316004884510621362230ab8a688300d76 --- /dev/null +++ b/llama_index/core/base_retriever.py @@ -0,0 +1,70 @@ +"""Base retriever.""" +from abc import abstractmethod +from typing import List, Optional + +from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType +from llama_index.schema import NodeWithScore, QueryBundle, QueryType +from llama_index.service_context import ServiceContext + + +class BaseRetriever(PromptMixin): + """Base retriever.""" + + def _get_prompts(self) -> PromptDictType: + """Get prompts.""" + return {} + + def _get_prompt_modules(self) -> PromptMixinType: + """Get prompt modules.""" + return {} + + def _update_prompts(self, prompts: PromptDictType) -> None: + """Update prompts.""" + + def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: + """Retrieve nodes given query. + + Args: + str_or_query_bundle (QueryType): Either a query string or + a QueryBundle object. + + """ + if isinstance(str_or_query_bundle, str): + str_or_query_bundle = QueryBundle(str_or_query_bundle) + return self._retrieve(str_or_query_bundle) + + async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: + if isinstance(str_or_query_bundle, str): + str_or_query_bundle = QueryBundle(str_or_query_bundle) + return await self._aretrieve(str_or_query_bundle) + + @abstractmethod + def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + """Retrieve nodes given query. + + Implemented by the user. + + """ + + # TODO: make this abstract + # @abstractmethod + async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: + """Asynchronously retrieve nodes given query. + + Implemented by the user. + + """ + return self._retrieve(query_bundle) + + def get_service_context(self) -> Optional[ServiceContext]: + """Attempts to resolve a service context. + Short-circuits at self.service_context, self._service_context, + or self._index.service_context. + """ + if hasattr(self, "service_context"): + return self.service_context + if hasattr(self, "_service_context"): + return self._service_context + elif hasattr(self, "_index") and hasattr(self._index, "service_context"): + return self._index.service_context + return None diff --git a/llama_index/embeddings/__init__.py b/llama_index/embeddings/__init__.py index e014906445e2d186eb87ada125c33798a17adaeb..cc597ea6fa3f0bb1db5b307f08117454b5910f5d 100644 --- a/llama_index/embeddings/__init__.py +++ b/llama_index/embeddings/__init__.py @@ -5,7 +5,7 @@ from llama_index.embeddings.adapter import ( LinearAdapterEmbeddingModel, ) from llama_index.embeddings.azure_openai import AzureOpenAIEmbedding -from llama_index.embeddings.base import SimilarityMode +from llama_index.embeddings.base import BaseEmbedding, SimilarityMode from llama_index.embeddings.bedrock import BedrockEmbedding from llama_index.embeddings.clarifai import ClarifaiEmbedding from llama_index.embeddings.clip import ClipEmbedding @@ -39,6 +39,7 @@ __all__ = [ "ClarifaiEmbedding", "ClipEmbedding", "CohereEmbedding", + "BaseEmbedding", "DEFAULT_HUGGINGFACE_EMBEDDING_MODEL", "ElasticsearchEmbedding", "GoogleUnivSentEncoderEmbedding", diff --git a/llama_index/embeddings/base.py b/llama_index/embeddings/base.py index 68ed03af6aa5b8be85b14a8ef5d2e2c6f04de4bf..e1e9a92ab8954fd34ba3b855269502f9bc998842 100644 --- a/llama_index/embeddings/base.py +++ b/llama_index/embeddings/base.py @@ -3,14 +3,14 @@ import asyncio from abc import abstractmethod from enum import Enum -from typing import Callable, Coroutine, List, Optional, Tuple +from typing import Any, Callable, Coroutine, List, Optional, Tuple import numpy as np from llama_index.bridge.pydantic import Field, validator from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.schema import BaseComponent +from llama_index.schema import BaseNode, MetadataMode, TransformComponent from llama_index.utils import get_tqdm_iterable # TODO: change to numpy array @@ -49,7 +49,7 @@ def similarity( return product / norm -class BaseEmbedding(BaseComponent): +class BaseEmbedding(TransformComponent): """Base class for embeddings.""" model_name: str = Field( @@ -58,6 +58,8 @@ class BaseEmbedding(BaseComponent): embed_batch_size: int = Field( default=DEFAULT_EMBED_BATCH_SIZE, description="The batch size for embedding calls.", + gt=0, + lte=2048, ) callback_manager: CallbackManager = Field( default_factory=lambda: CallbackManager([]), exclude=True @@ -229,7 +231,10 @@ class BaseEmbedding(BaseComponent): return text_embedding def get_text_embedding_batch( - self, texts: List[str], show_progress: bool = False + self, + texts: List[str], + show_progress: bool = False, + **kwargs: Any, ) -> List[Embedding]: """Get a list of text embeddings, with batching.""" cur_batch: List[str] = [] @@ -324,3 +329,25 @@ class BaseEmbedding(BaseComponent): ) -> float: """Get embedding similarity.""" return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode) + + def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]: + embeddings = self.get_text_embedding_batch( + [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes], + **kwargs, + ) + + for node, embedding in zip(nodes, embeddings): + node.embedding = embedding + + return nodes + + async def acall(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]: + embeddings = await self.aget_text_embedding_batch( + [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes], + **kwargs, + ) + + for node, embedding in zip(nodes, embeddings): + node.embedding = embedding + + return nodes diff --git a/llama_index/embeddings/huggingface.py b/llama_index/embeddings/huggingface.py index b59a9674c8a8ea4fe0d768fab60c75ad3f77a973..4e86d76d66de9a0f452ef77f62fa4f07918b7c41 100644 --- a/llama_index/embeddings/huggingface.py +++ b/llama_index/embeddings/huggingface.py @@ -20,10 +20,14 @@ from llama_index.utils import get_cache_dir, infer_torch_device if TYPE_CHECKING: import torch +DEFAULT_HUGGINGFACE_LENGTH = 512 + class HuggingFaceEmbedding(BaseEmbedding): tokenizer_name: str = Field(description="Tokenizer name from HuggingFace.") - max_length: int = Field(description="Maximum length of input.") + max_length: int = Field( + default=DEFAULT_HUGGINGFACE_LENGTH, description="Maximum length of input.", gt=0 + ) pooling: Pooling = Field(default=Pooling.CLS, description="Pooling strategy.") normalize: str = Field(default=True, description="Normalize embeddings or not.") query_instruction: Optional[str] = Field( diff --git a/llama_index/embeddings/langchain.py b/llama_index/embeddings/langchain.py index 361911b00242b2e407eb5cf384204eaf91f5fc94..7fda89b84d8a813df5fcd1255d1ad108c4de3e81 100644 --- a/llama_index/embeddings/langchain.py +++ b/llama_index/embeddings/langchain.py @@ -1,12 +1,14 @@ """Langchain Embedding Wrapper Module.""" -from typing import List, Optional +from typing import TYPE_CHECKING, List, Optional -from llama_index.bridge.langchain import Embeddings as LCEmbeddings from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +if TYPE_CHECKING: + from llama_index.bridge.langchain import Embeddings as LCEmbeddings + class LangchainEmbedding(BaseEmbedding): """External embeddings (taken from Langchain). @@ -16,12 +18,12 @@ class LangchainEmbedding(BaseEmbedding): embeddings class. """ - _langchain_embedding: LCEmbeddings = PrivateAttr() + _langchain_embedding: "LCEmbeddings" = PrivateAttr() _async_not_implemented_warned: bool = PrivateAttr(default=False) def __init__( self, - langchain_embeddings: LCEmbeddings, + langchain_embeddings: "LCEmbeddings", model_name: Optional[str] = None, embed_batch_size: int = DEFAULT_EMBED_BATCH_SIZE, callback_manager: Optional[CallbackManager] = None, diff --git a/llama_index/embeddings/loading.py b/llama_index/embeddings/loading.py index 7bc5d761941137d6258a49e999c10ff12ac11fdb..7a1f2fa509e3f42527de1ddbbdf69ede9d35baab 100644 --- a/llama_index/embeddings/loading.py +++ b/llama_index/embeddings/loading.py @@ -22,6 +22,8 @@ RECOGNIZED_EMBEDDINGS: Dict[str, Type[BaseEmbedding]] = { def load_embed_model(data: dict) -> BaseEmbedding: """Load Embedding by name.""" + if isinstance(data, BaseEmbedding): + return data name = data.get("class_name", None) if name is None: raise ValueError("Embedding loading requires a class_name") diff --git a/llama_index/embeddings/openai.py b/llama_index/embeddings/openai.py index 7f1654b464a40a6d71d6660427f549c86e3501de..8a8b5f094027b922b25f7491fdbb974a07273e8e 100644 --- a/llama_index/embeddings/openai.py +++ b/llama_index/embeddings/openai.py @@ -265,6 +265,11 @@ class OpenAIEmbedding(BaseEmbedding): self._query_engine = get_engine(mode, model, _QUERY_MODE_MODEL_DICT) self._text_engine = get_engine(mode, model, _TEXT_MODE_MODEL_DICT) + if "model_name" in kwargs: + model_name = kwargs.pop("model_name") + else: + model_name = model + super().__init__( embed_batch_size=embed_batch_size, callback_manager=callback_manager, diff --git a/llama_index/embeddings/utils.py b/llama_index/embeddings/utils.py index b85f8796a6c3ff33a7fdb2ee136bcd906a736f8d..165079ae6cbb03a0513621c2bd275f969ab8e607 100644 --- a/llama_index/embeddings/utils.py +++ b/llama_index/embeddings/utils.py @@ -1,8 +1,9 @@ """Embedding utils for LlamaIndex.""" import os -from typing import List, Optional, Union +from typing import TYPE_CHECKING, List, Optional, Union -from llama_index.bridge.langchain import Embeddings as LCEmbeddings +if TYPE_CHECKING: + from llama_index.bridge.langchain import Embeddings as LCEmbeddings from llama_index.embeddings.base import BaseEmbedding from llama_index.embeddings.clip import ClipEmbedding from llama_index.embeddings.huggingface import HuggingFaceEmbedding @@ -16,7 +17,7 @@ from llama_index.llms.openai_utils import validate_openai_api_key from llama_index.token_counter.mock_embed_model import MockEmbedding from llama_index.utils import get_cache_dir -EmbedType = Union[BaseEmbedding, LCEmbeddings, str] +EmbedType = Union[BaseEmbedding, "LCEmbeddings", str] def save_embedding(embedding: List[float], file_path: str) -> None: @@ -36,6 +37,11 @@ def load_embedding(file_path: str) -> List[float]: def resolve_embed_model(embed_model: Optional[EmbedType] = None) -> BaseEmbedding: """Resolve embed model.""" + try: + from llama_index.bridge.langchain import Embeddings as LCEmbeddings + except ImportError: + LCEmbeddings = None # type: ignore + if embed_model == "default": try: embed_model = OpenAIEmbedding() @@ -79,7 +85,7 @@ def resolve_embed_model(embed_model: Optional[EmbedType] = None) -> BaseEmbeddin model_name=model_name, cache_folder=cache_folder ) - if isinstance(embed_model, LCEmbeddings): + if LCEmbeddings is not None and isinstance(embed_model, LCEmbeddings): embed_model = LangchainEmbedding(embed_model) if embed_model is None: diff --git a/llama_index/evaluation/batch_runner.py b/llama_index/evaluation/batch_runner.py index 4ce2f22e2da44eb09b426e3028baa02707d27eb3..c68cb5e3445544c419b3b34085b4209736f4221d 100644 --- a/llama_index/evaluation/batch_runner.py +++ b/llama_index/evaluation/batch_runner.py @@ -1,9 +1,9 @@ import asyncio from typing import Any, Dict, List, Optional, Sequence, Tuple, cast +from llama_index.core import BaseQueryEngine from llama_index.evaluation.base import BaseEvaluator, EvaluationResult from llama_index.evaluation.eval_utils import asyncio_module -from llama_index.indices.query.base import BaseQueryEngine from llama_index.response.schema import RESPONSE_TYPE, Response diff --git a/llama_index/evaluation/benchmarks/beir.py b/llama_index/evaluation/benchmarks/beir.py index 582203d15f0862d08e3b5a65d2c317ec25af5efa..ddba34fa2ef6570bfc411aadabf231b2fcac8ca6 100644 --- a/llama_index/evaluation/benchmarks/beir.py +++ b/llama_index/evaluation/benchmarks/beir.py @@ -4,7 +4,7 @@ from typing import Callable, Dict, List import tqdm -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.schema import Document from llama_index.utils import get_cache_dir diff --git a/llama_index/evaluation/benchmarks/hotpotqa.py b/llama_index/evaluation/benchmarks/hotpotqa.py index 4651fb70f78a13e2562fd326c52a7e600b1f88ed..2d5ff6bb6e371c7eafdacf49f5c714a252113b2f 100644 --- a/llama_index/evaluation/benchmarks/hotpotqa.py +++ b/llama_index/evaluation/benchmarks/hotpotqa.py @@ -9,11 +9,9 @@ from typing import Any, Dict, List, Optional, Tuple import requests import tqdm -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle +from llama_index.core import BaseQueryEngine, BaseRetriever from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine -from llama_index.schema import NodeWithScore, TextNode +from llama_index.schema import NodeWithScore, QueryBundle, TextNode from llama_index.utils import get_cache_dir DEV_DISTRACTOR_URL = """http://curtis.ml.cmu.edu/datasets/\ diff --git a/llama_index/evaluation/correctness.py b/llama_index/evaluation/correctness.py index 8257d63208f65814c736ad11e77bd4666bca0843..57ed131b25a7317d7ad4aaea4395f2036b07587e 100644 --- a/llama_index/evaluation/correctness.py +++ b/llama_index/evaluation/correctness.py @@ -2,7 +2,6 @@ from typing import Any, Optional, Sequence, Union from llama_index.evaluation.base import BaseEvaluator, EvaluationResult -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import ( BasePromptTemplate, ChatMessage, @@ -11,6 +10,7 @@ from llama_index.prompts import ( PromptTemplate, ) from llama_index.prompts.mixin import PromptDictType +from llama_index.service_context import ServiceContext DEFAULT_SYSTEM_TEMPLATE = """ You are an expert evaluation system for a question answering chatbot. diff --git a/llama_index/evaluation/dataset_generation.py b/llama_index/evaluation/dataset_generation.py index 87bad050e499b559a234f21ce8c853af17042359..c197395ba182cb69f335b2d39da22c51c5cdfc3f 100644 --- a/llama_index/evaluation/dataset_generation.py +++ b/llama_index/evaluation/dataset_generation.py @@ -10,8 +10,9 @@ from typing import Dict, List, Tuple from pydantic import BaseModel, Field from llama_index import Document, ServiceContext, SummaryIndex -from llama_index.indices.postprocessor.node import KeywordNodePostprocessor +from llama_index.ingestion import run_transformations from llama_index.llms.openai import OpenAI +from llama_index.postprocessor.node import KeywordNodePostprocessor from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType @@ -157,7 +158,10 @@ class DatasetGenerator(PromptMixin): """Generate dataset from documents.""" if service_context is None: service_context = _get_default_service_context() - nodes = service_context.node_parser.get_nodes_from_documents(documents) + + nodes = run_transformations( + documents, service_context.transformations, show_progress=show_progress + ) # use node postprocessor to filter nodes required_keywords = required_keywords or [] diff --git a/llama_index/evaluation/eval_utils.py b/llama_index/evaluation/eval_utils.py index e989a171090020a88d9dfd75b91f542e40e04831..2b45a45478ae5621ff932458c436c15600e15402 100644 --- a/llama_index/evaluation/eval_utils.py +++ b/llama_index/evaluation/eval_utils.py @@ -11,8 +11,8 @@ from typing import Any, List import numpy as np import pandas as pd +from llama_index.core import BaseQueryEngine from llama_index.evaluation.base import EvaluationResult -from llama_index.indices.query.base import BaseQueryEngine def asyncio_module(show_progress: bool = False) -> Any: diff --git a/llama_index/evaluation/retrieval/evaluator.py b/llama_index/evaluation/retrieval/evaluator.py index 48a20a96e3e77875da758b15511a1bc8c3a5679e..af8b042fecf818e3b0f7cae6f37806728a323264 100644 --- a/llama_index/evaluation/retrieval/evaluator.py +++ b/llama_index/evaluation/retrieval/evaluator.py @@ -3,13 +3,13 @@ from typing import Any, List, Sequence from llama_index.bridge.pydantic import Field +from llama_index.core import BaseRetriever from llama_index.evaluation.retrieval.base import ( BaseRetrievalEvaluator, ) from llama_index.evaluation.retrieval.metrics_base import ( BaseRetrievalMetric, ) -from llama_index.indices.base_retriever import BaseRetriever class RetrieverEvaluator(BaseRetrievalEvaluator): diff --git a/llama_index/evaluation/semantic_similarity.py b/llama_index/evaluation/semantic_similarity.py index 1e2d7efdd4be8a69be776efc70f9de9cc06a1f6f..c77f2fa085bd6a5ac911f5fd56b2a3cc1b0ddac5 100644 --- a/llama_index/evaluation/semantic_similarity.py +++ b/llama_index/evaluation/semantic_similarity.py @@ -2,8 +2,8 @@ from typing import Any, Callable, Optional, Sequence from llama_index.embeddings.base import SimilarityMode, similarity from llama_index.evaluation.base import BaseEvaluator, EvaluationResult -from llama_index.indices.service_context import ServiceContext from llama_index.prompts.mixin import PromptDictType +from llama_index.service_context import ServiceContext class SemanticSimilarityEvaluator(BaseEvaluator): diff --git a/llama_index/node_parser/extractors/__init__.py b/llama_index/extractors/__init__.py similarity index 60% rename from llama_index/node_parser/extractors/__init__.py rename to llama_index/extractors/__init__.py index a0a1a04dc065d1bf968d9a358cab1f783e996218..781fe513a7f6ab2d96c6a5c8f84ac1f3b2253cfb 100644 --- a/llama_index/node_parser/extractors/__init__.py +++ b/llama_index/extractors/__init__.py @@ -1,11 +1,10 @@ -from llama_index.node_parser.extractors.marvin_metadata_extractor import ( +from llama_index.extractors.interface import BaseExtractor +from llama_index.extractors.marvin_metadata_extractor import ( MarvinMetadataExtractor, ) -from llama_index.node_parser.extractors.metadata_extractors import ( +from llama_index.extractors.metadata_extractors import ( EntityExtractor, KeywordExtractor, - MetadataExtractor, - MetadataFeatureExtractor, PydanticProgramExtractor, QuestionsAnsweredExtractor, SummaryExtractor, @@ -13,13 +12,12 @@ from llama_index.node_parser.extractors.metadata_extractors import ( ) __all__ = [ - "MarvinMetadataExtractor", - "EntityExtractor", - "KeywordExtractor", - "MetadataExtractor", - "MetadataFeatureExtractor", - "QuestionsAnsweredExtractor", "SummaryExtractor", + "QuestionsAnsweredExtractor", "TitleExtractor", + "KeywordExtractor", + "EntityExtractor", + "MarvinMetadataExtractor", + "BaseExtractor", "PydanticProgramExtractor", ] diff --git a/llama_index/extractors/interface.py b/llama_index/extractors/interface.py new file mode 100644 index 0000000000000000000000000000000000000000..6236946bb4b282575e9231e59d3a12da0217b89b --- /dev/null +++ b/llama_index/extractors/interface.py @@ -0,0 +1,116 @@ +"""Node parser interface.""" +from abc import abstractmethod +from copy import deepcopy +from typing import Any, Dict, List, Optional, Sequence, cast + +from typing_extensions import Self + +from llama_index.bridge.pydantic import Field +from llama_index.schema import BaseNode, MetadataMode, TextNode, TransformComponent + +DEFAULT_NODE_TEXT_TEMPLATE = """\ +[Excerpt from document]\n{metadata_str}\n\ +Excerpt:\n-----\n{content}\n-----\n""" + + +class BaseExtractor(TransformComponent): + """Metadata extractor.""" + + is_text_node_only: bool = True + + show_progress: bool = Field(default=True, description="Whether to show progress.") + + metadata_mode: MetadataMode = Field( + default=MetadataMode.ALL, description="Metadata mode to use when reading nodes." + ) + + node_text_template: str = Field( + default=DEFAULT_NODE_TEXT_TEMPLATE, + description="Template to represent how node text is mixed with metadata text.", + ) + disable_template_rewrite: bool = Field( + default=False, description="Disable the node template rewrite." + ) + + in_place: bool = Field( + default=True, description="Whether to process nodes in place." + ) + + @classmethod + def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore + if isinstance(kwargs, dict): + data.update(kwargs) + + data.pop("class_name", None) + + llm_predictor = data.get("llm_predictor", None) + if llm_predictor: + from llama_index.llm_predictor.loading import load_predictor + + llm_predictor = load_predictor(llm_predictor) + data["llm_predictor"] = llm_predictor + + return cls(**data) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "MetadataExtractor" + + @abstractmethod + def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: + """Extracts metadata for a sequence of nodes, returning a list of + metadata dictionaries corresponding to each node. + + Args: + nodes (Sequence[Document]): nodes to extract metadata from + + """ + + def process_nodes( + self, + nodes: List[BaseNode], + excluded_embed_metadata_keys: Optional[List[str]] = None, + excluded_llm_metadata_keys: Optional[List[str]] = None, + **kwargs: Any, + ) -> List[BaseNode]: + """Post process nodes parsed from documents. + + Allows extractors to be chained. + + Args: + nodes (List[BaseNode]): nodes to post-process + excluded_embed_metadata_keys (Optional[List[str]]): + keys to exclude from embed metadata + excluded_llm_metadata_keys (Optional[List[str]]): + keys to exclude from llm metadata + """ + if self.in_place: + new_nodes = nodes + else: + new_nodes = [deepcopy(node) for node in nodes] + + cur_metadata_list = self.extract(new_nodes) + for idx, node in enumerate(new_nodes): + node.metadata.update(cur_metadata_list[idx]) + + for idx, node in enumerate(new_nodes): + if excluded_embed_metadata_keys is not None: + node.excluded_embed_metadata_keys.extend(excluded_embed_metadata_keys) + if excluded_llm_metadata_keys is not None: + node.excluded_llm_metadata_keys.extend(excluded_llm_metadata_keys) + if not self.disable_template_rewrite: + if isinstance(node, TextNode): + cast(TextNode, node).text_template = self.node_text_template + + return new_nodes + + def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]: + """Post process nodes parsed from documents. + + Allows extractors to be chained. + + Args: + nodes (List[BaseNode]): nodes to post-process + """ + return self.process_nodes(nodes, **kwargs) # type: ignore diff --git a/llama_index/extractors/loading.py b/llama_index/extractors/loading.py new file mode 100644 index 0000000000000000000000000000000000000000..9c73ac5d5a31ed8f54c09c7f2c4645201de593be --- /dev/null +++ b/llama_index/extractors/loading.py @@ -0,0 +1,32 @@ +from llama_index.extractors.metadata_extractors import ( + BaseExtractor, + EntityExtractor, + KeywordExtractor, + QuestionsAnsweredExtractor, + SummaryExtractor, + TitleExtractor, +) + + +def load_extractor( + data: dict, +) -> BaseExtractor: + if isinstance(data, BaseExtractor): + return data + + extractor_name = data.get("class_name", None) + if extractor_name is None: + raise ValueError("Extractor loading requires a class_name") + + if extractor_name == SummaryExtractor.class_name(): + return SummaryExtractor.from_dict(data) + elif extractor_name == QuestionsAnsweredExtractor.class_name(): + return QuestionsAnsweredExtractor.from_dict(data) + elif extractor_name == EntityExtractor.class_name(): + return EntityExtractor.from_dict(data) + elif extractor_name == TitleExtractor.class_name(): + return TitleExtractor.from_dict(data) + elif extractor_name == KeywordExtractor.class_name(): + return KeywordExtractor.from_dict(data) + else: + raise ValueError(f"Unknown extractor name: {extractor_name}") diff --git a/llama_index/node_parser/extractors/marvin_metadata_extractor.py b/llama_index/extractors/marvin_metadata_extractor.py similarity index 71% rename from llama_index/node_parser/extractors/marvin_metadata_extractor.py rename to llama_index/extractors/marvin_metadata_extractor.py index 4b3dc9253879fe7955f0d0dba4e990ee64aa313c..f7157568f591eecca0eb19c93a908163380d7144 100644 --- a/llama_index/node_parser/extractors/marvin_metadata_extractor.py +++ b/llama_index/extractors/marvin_metadata_extractor.py @@ -1,16 +1,25 @@ -from typing import TYPE_CHECKING, Any, Dict, List, Optional, Sequence, Type, cast +from typing import ( + TYPE_CHECKING, + Any, + Dict, + Iterable, + List, + Optional, + Sequence, + Type, + cast, +) if TYPE_CHECKING: from marvin import AIModel from llama_index.bridge.pydantic import BaseModel, Field -from llama_index.node_parser.extractors.metadata_extractors import ( - MetadataFeatureExtractor, -) +from llama_index.extractors.interface import BaseExtractor from llama_index.schema import BaseNode, TextNode +from llama_index.utils import get_tqdm_iterable -class MarvinMetadataExtractor(MetadataFeatureExtractor): +class MarvinMetadataExtractor(BaseExtractor): # Forward reference to handle circular imports marvin_model: Type["AIModel"] = Field( description="The Marvin model to use for extracting custom metadata" @@ -26,22 +35,20 @@ class MarvinMetadataExtractor(MetadataFeatureExtractor): marvin_model: Marvin model to use for extracting metadata llm_model_string: (optional) LLM model string to use for extracting metadata Usage: - #create metadata extractor - metadata_extractor = MetadataExtractor( - extractors=[ - TitleExtractor(nodes=1, llm=llm), - MarvinMetadataExtractor(marvin_model=YourMarvinMetadataModel), - ], - ) + #create extractor list + extractors = [ + TitleExtractor(nodes=1, llm=llm), + MarvinMetadataExtractor(marvin_model=YourMarvinMetadataModel), + ] #create node parser to parse nodes from document - node_parser = SimpleNodeParser( - text_splitter=text_splitter, - metadata_extractor=metadata_extractor, + node_parser = SentenceSplitter( + text_splitter=text_splitter ) #use node_parser to get nodes from documents - nodes = node_parser.get_nodes_from_documents([Document(text=text)]) + from llama_index.ingestion import run_transformations + nodes = run_transformations(documents, [node_parser] + extractors) print(nodes) """ @@ -74,7 +81,11 @@ class MarvinMetadataExtractor(MetadataFeatureExtractor): ai_model = cast(AIModel, self.marvin_model) metadata_list: List[Dict] = [] - for node in nodes: + + nodes_queue: Iterable[BaseNode] = get_tqdm_iterable( + nodes, self.show_progress, "Extracting marvin metadata" + ) + for node in nodes_queue: if self.is_text_node_only and not isinstance(node, TextNode): metadata_list.append({}) continue diff --git a/llama_index/node_parser/extractors/metadata_extractors.py b/llama_index/extractors/metadata_extractors.py similarity index 77% rename from llama_index/node_parser/extractors/metadata_extractors.py rename to llama_index/extractors/metadata_extractors.py index babb122aa6213199724b80836924443ecf787d85..08065a16fad7b896a8bbb3733a07de5330c69b3a 100644 --- a/llama_index/node_parser/extractors/metadata_extractors.py +++ b/llama_index/extractors/metadata_extractors.py @@ -1,5 +1,5 @@ """ -Metadata extractors for nodes. Applied as a post processor to node parsing. +Metadata extractors for nodes. Currently, only `TextNode` is supported. Supported metadata: @@ -19,117 +19,18 @@ The prompts used to generate the metadata are specifically aimed to help disambiguate the document or subsection from other similar documents or subsections. (similar with contrastive learning) """ -from abc import abstractmethod -from copy import deepcopy from functools import reduce -from typing import Any, Callable, Dict, List, Optional, Sequence, cast +from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast from llama_index.bridge.pydantic import Field, PrivateAttr -from llama_index.llm_predictor.base import BaseLLMPredictor, LLMPredictor +from llama_index.extractors.interface import BaseExtractor +from llama_index.llm_predictor.base import LLMPredictor from llama_index.llms.base import LLM -from llama_index.node_parser.interface import BaseExtractor from llama_index.prompts import PromptTemplate -from llama_index.schema import BaseNode, MetadataMode, TextNode +from llama_index.schema import BaseNode, TextNode from llama_index.types import BasePydanticProgram from llama_index.utils import get_tqdm_iterable - -class MetadataFeatureExtractor(BaseExtractor): - is_text_node_only: bool = True - show_progress: bool = True - metadata_mode: MetadataMode = MetadataMode.ALL - - @abstractmethod - def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: - """Extracts metadata for a sequence of nodes, returning a list of - metadata dictionaries corresponding to each node. - - Args: - nodes (Sequence[Document]): nodes to extract metadata from - - """ - - -DEFAULT_NODE_TEXT_TEMPLATE = """\ -[Excerpt from document]\n{metadata_str}\n\ -Excerpt:\n-----\n{content}\n-----\n""" - - -class MetadataExtractor(BaseExtractor): - """Metadata extractor.""" - - extractors: Sequence[MetadataFeatureExtractor] = Field( - default_factory=list, - description="Metadta feature extractors to apply to each node.", - ) - node_text_template: str = Field( - default=DEFAULT_NODE_TEXT_TEMPLATE, - description="Template to represent how node text is mixed with metadata text.", - ) - disable_template_rewrite: bool = Field( - default=False, description="Disable the node template rewrite." - ) - - in_place: bool = Field( - default=True, description="Whether to process nodes in place." - ) - - @classmethod - def class_name(cls) -> str: - return "MetadataExtractor" - - def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: - """Extract metadata from a document. - - Args: - nodes (Sequence[BaseNode]): nodes to extract metadata from - - """ - metadata_list: List[Dict] = [{} for _ in nodes] - for extractor in self.extractors: - cur_metadata_list = extractor.extract(nodes) - for i, metadata in enumerate(metadata_list): - metadata.update(cur_metadata_list[i]) - - return metadata_list - - def process_nodes( - self, - nodes: List[BaseNode], - excluded_embed_metadata_keys: Optional[List[str]] = None, - excluded_llm_metadata_keys: Optional[List[str]] = None, - ) -> List[BaseNode]: - """Post process nodes parsed from documents. - - Allows extractors to be chained. - - Args: - nodes (List[BaseNode]): nodes to post-process - excluded_embed_metadata_keys (Optional[List[str]]): - keys to exclude from embed metadata - excluded_llm_metadata_keys (Optional[List[str]]): - keys to exclude from llm metadata - """ - if self.in_place: - new_nodes = nodes - else: - new_nodes = [deepcopy(node) for node in nodes] - for extractor in self.extractors: - cur_metadata_list = extractor.extract(new_nodes) - for idx, node in enumerate(new_nodes): - node.metadata.update(cur_metadata_list[idx]) - - for idx, node in enumerate(new_nodes): - if excluded_embed_metadata_keys is not None: - node.excluded_embed_metadata_keys.extend(excluded_embed_metadata_keys) - if excluded_llm_metadata_keys is not None: - node.excluded_llm_metadata_keys.extend(excluded_llm_metadata_keys) - if not self.disable_template_rewrite: - if isinstance(node, TextNode): - cast(TextNode, node).text_template = self.node_text_template - return new_nodes - - DEFAULT_TITLE_NODE_TEMPLATE = """\ Context: {context_str}. Give a title that summarizes all of \ the unique entities, titles or themes found in the context. Title: """ @@ -140,12 +41,12 @@ DEFAULT_TITLE_COMBINE_TEMPLATE = """\ what is the comprehensive title for this document? Title: """ -class TitleExtractor(MetadataFeatureExtractor): +class TitleExtractor(BaseExtractor): """Title extractor. Useful for long documents. Extracts `document_title` metadata field. Args: - llm_predictor (Optional[BaseLLMPredictor]): LLM predictor + llm_predictor (Optional[LLMPredictor]): LLM predictor nodes (int): number of nodes from front to use for title extraction node_template (str): template for node-level title clues extraction combine_template (str): template for combining node-level clues into @@ -153,11 +54,13 @@ class TitleExtractor(MetadataFeatureExtractor): """ is_text_node_only: bool = False # can work for mixture of text and non-text nodes - llm_predictor: BaseLLMPredictor = Field( + llm_predictor: LLMPredictor = Field( description="The LLMPredictor to use for generation." ) nodes: int = Field( - default=5, description="The number of nodes to extract titles from." + default=5, + description="The number of nodes to extract titles from.", + gt=0, ) node_template: str = Field( default=DEFAULT_TITLE_NODE_TEMPLATE, @@ -172,7 +75,7 @@ class TitleExtractor(MetadataFeatureExtractor): self, llm: Optional[LLM] = None, # TODO: llm_predictor arg is deprecated - llm_predictor: Optional[BaseLLMPredictor] = None, + llm_predictor: Optional[LLMPredictor] = None, nodes: int = 5, node_template: str = DEFAULT_TITLE_NODE_TEMPLATE, combine_template: str = DEFAULT_TITLE_COMBINE_TEMPLATE, @@ -201,6 +104,7 @@ class TitleExtractor(MetadataFeatureExtractor): def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: nodes_to_extract_title: List[BaseNode] = [] + for node in nodes: if len(nodes_to_extract_title) >= self.nodes: break @@ -212,12 +116,15 @@ class TitleExtractor(MetadataFeatureExtractor): # Could not extract title return [] + nodes_queue: Iterable[BaseNode] = get_tqdm_iterable( + nodes_to_extract_title, self.show_progress, "Extracting titles" + ) title_candidates = [ self.llm_predictor.predict( PromptTemplate(template=self.node_template), context_str=cast(TextNode, node).text, ) - for node in nodes_to_extract_title + for node in nodes_queue ] if len(nodes_to_extract_title) > 1: titles = reduce( @@ -236,25 +143,27 @@ class TitleExtractor(MetadataFeatureExtractor): return [{"document_title": title.strip(' \t\n\r"')} for _ in nodes] -class KeywordExtractor(MetadataFeatureExtractor): +class KeywordExtractor(BaseExtractor): """Keyword extractor. Node-level extractor. Extracts `excerpt_keywords` metadata field. Args: - llm_predictor (Optional[BaseLLMPredictor]): LLM predictor + llm_predictor (Optional[LLMPredictor]): LLM predictor keywords (int): number of keywords to extract """ - llm_predictor: BaseLLMPredictor = Field( + llm_predictor: LLMPredictor = Field( description="The LLMPredictor to use for generation." ) - keywords: int = Field(default=5, description="The number of keywords to extract.") + keywords: int = Field( + default=5, description="The number of keywords to extract.", gt=0 + ) def __init__( self, llm: Optional[LLM] = None, # TODO: llm_predictor arg is deprecated - llm_predictor: Optional[BaseLLMPredictor] = None, + llm_predictor: Optional[LLMPredictor] = None, keywords: int = 5, **kwargs: Any, ) -> None: @@ -275,7 +184,11 @@ class KeywordExtractor(MetadataFeatureExtractor): def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: metadata_list: List[Dict] = [] - for node in nodes: + nodes_queue: Iterable[BaseNode] = get_tqdm_iterable( + nodes, self.show_progress, "Extracting keywords" + ) + + for node in nodes_queue: if self.is_text_node_only and not isinstance(node, TextNode): metadata_list.append({}) continue @@ -309,23 +222,25 @@ that this context can answer. """ -class QuestionsAnsweredExtractor(MetadataFeatureExtractor): +class QuestionsAnsweredExtractor(BaseExtractor): """ Questions answered extractor. Node-level extractor. Extracts `questions_this_excerpt_can_answer` metadata field. Args: - llm_predictor (Optional[BaseLLMPredictor]): LLM predictor + llm_predictor (Optional[LLMPredictor]): LLM predictor questions (int): number of questions to extract prompt_template (str): template for question extraction, embedding_only (bool): whether to use embedding only """ - llm_predictor: BaseLLMPredictor = Field( + llm_predictor: LLMPredictor = Field( description="The LLMPredictor to use for generation." ) questions: int = Field( - default=5, description="The number of questions to generate." + default=5, + description="The number of questions to generate.", + gt=0, ) prompt_template: str = Field( default=DEFAULT_QUESTION_GEN_TMPL, @@ -339,7 +254,7 @@ class QuestionsAnsweredExtractor(MetadataFeatureExtractor): self, llm: Optional[LLM] = None, # TODO: llm_predictor arg is deprecated - llm_predictor: Optional[BaseLLMPredictor] = None, + llm_predictor: Optional[LLMPredictor] = None, questions: int = 5, prompt_template: str = DEFAULT_QUESTION_GEN_TMPL, embedding_only: bool = True, @@ -368,7 +283,7 @@ class QuestionsAnsweredExtractor(MetadataFeatureExtractor): def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: metadata_list: List[Dict] = [] - nodes_queue = get_tqdm_iterable( + nodes_queue: Iterable[BaseNode] = get_tqdm_iterable( nodes, self.show_progress, "Extracting questions" ) for node in nodes_queue: @@ -399,19 +314,19 @@ Summarize the key topics and entities of the section. \ Summary: """ -class SummaryExtractor(MetadataFeatureExtractor): +class SummaryExtractor(BaseExtractor): """ Summary extractor. Node-level extractor with adjacent sharing. Extracts `section_summary`, `prev_section_summary`, `next_section_summary` metadata fields. Args: - llm_predictor (Optional[BaseLLMPredictor]): LLM predictor + llm_predictor (Optional[LLMPredictor]): LLM predictor summaries (List[str]): list of summaries to extract: 'self', 'prev', 'next' prompt_template (str): template for summary extraction """ - llm_predictor: BaseLLMPredictor = Field( + llm_predictor: LLMPredictor = Field( description="The LLMPredictor to use for generation." ) summaries: List[str] = Field( @@ -430,7 +345,7 @@ class SummaryExtractor(MetadataFeatureExtractor): self, llm: Optional[LLM] = None, # TODO: llm_predictor arg is deprecated - llm_predictor: Optional[BaseLLMPredictor] = None, + llm_predictor: Optional[LLMPredictor] = None, summaries: List[str] = ["self"], prompt_template: str = DEFAULT_SUMMARY_EXTRACT_TEMPLATE, **kwargs: Any, @@ -461,7 +376,7 @@ class SummaryExtractor(MetadataFeatureExtractor): def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: if not all(isinstance(node, TextNode) for node in nodes): raise ValueError("Only `TextNode` is allowed for `Summary` extractor") - nodes_queue = get_tqdm_iterable( + nodes_queue: Iterable[BaseNode] = get_tqdm_iterable( nodes, self.show_progress, "Extracting summaries" ) node_summaries = [] @@ -509,7 +424,7 @@ DEFAULT_ENTITY_MAP = { DEFAULT_ENTITY_MODEL = "tomaarsen/span-marker-mbert-base-multinerd" -class EntityExtractor(MetadataFeatureExtractor): +class EntityExtractor(BaseExtractor): """ Entity extractor. Extracts `entities` into a metadata field using a default model `tomaarsen/span-marker-mbert-base-multinerd` and the SpanMarker library. @@ -522,9 +437,14 @@ class EntityExtractor(MetadataFeatureExtractor): description="The model name of the SpanMarker model to use.", ) prediction_threshold: float = Field( - default=0.5, description="The confidence threshold for accepting predictions." + default=0.5, + description="The confidence threshold for accepting predictions.", + gte=0.0, + lte=1.0, + ) + span_joiner: str = Field( + default=" ", description="The separator between entity names." ) - span_joiner: str = Field(description="The separator between entity names.") label_entities: bool = Field( default=False, description="Include entity class labels or not." ) @@ -613,7 +533,12 @@ class EntityExtractor(MetadataFeatureExtractor): def extract(self, nodes: Sequence[BaseNode]) -> List[Dict]: # Extract node-level entity metadata metadata_list: List[Dict] = [{} for _ in nodes] - for i, metadata in enumerate(metadata_list): + metadata_queue: Iterable[int] = get_tqdm_iterable( + range(len(nodes)), self.show_progress, "Extracting entities" + ) + + for i in metadata_queue: + metadata = metadata_list[i] node_text = nodes[i].get_content(metadata_mode=self.metadata_mode) words = self._tokenizer(node_text) spans = self._model.predict(words) @@ -644,7 +569,7 @@ Given the contextual information, extract out a {class_name} object.\ """ -class PydanticProgramExtractor(MetadataFeatureExtractor): +class PydanticProgramExtractor(BaseExtractor): """Pydantic program extractor. Uses an LLM to extract out a Pydantic object. Return attributes of that object diff --git a/llama_index/finetuning/cross_encoders/cross_encoder.py b/llama_index/finetuning/cross_encoders/cross_encoder.py index 2bf783c45935c5277617898d81dc64fbfe98da9e..64d37baecdc33019d89d946ae9516bc96f474d89 100644 --- a/llama_index/finetuning/cross_encoders/cross_encoder.py +++ b/llama_index/finetuning/cross_encoders/cross_encoder.py @@ -5,7 +5,7 @@ from llama_index.finetuning.cross_encoders.dataset_gen import ( CrossEncoderFinetuningDatasetSample, ) from llama_index.finetuning.types import BaseCrossEncoderFinetuningEngine -from llama_index.indices.postprocessor import SentenceTransformerRerank +from llama_index.postprocessor import SentenceTransformerRerank class CrossEncoderFinetuneEngine(BaseCrossEncoderFinetuningEngine): diff --git a/llama_index/finetuning/cross_encoders/dataset_gen.py b/llama_index/finetuning/cross_encoders/dataset_gen.py index bb18f117247ab1a74115f13e518fa418f6b6ebbc..3abb04f38306afef1b9cef7c6a08c9c685fbdeee 100644 --- a/llama_index/finetuning/cross_encoders/dataset_gen.py +++ b/llama_index/finetuning/cross_encoders/dataset_gen.py @@ -9,9 +9,8 @@ from tqdm.auto import tqdm from llama_index import VectorStoreIndex from llama_index.llms import ChatMessage, OpenAI from llama_index.llms.base import LLM -from llama_index.node_parser import SimpleNodeParser +from llama_index.node_parser import TokenTextSplitter from llama_index.schema import Document, MetadataMode -from llama_index.text_splitter import TokenTextSplitter @dataclass @@ -42,7 +41,7 @@ def generate_synthetic_queries_over_documents( qa_generate_user_msg: str = DEFAULT_QUERY_GEN_USER_PROMPT, ) -> List[str]: questions = [] - text_splitter = TokenTextSplitter( + node_parser = TokenTextSplitter( separator=" ", chunk_size=max_chunk_length, chunk_overlap=0, @@ -51,7 +50,6 @@ def generate_synthetic_queries_over_documents( ) llm = llm or OpenAI(model="gpt-3.5-turbo-16k", temperature=0.3) - node_parser = SimpleNodeParser(text_splitter=text_splitter) nodes = node_parser.get_nodes_from_documents(documents, show_progress=False) node_dict = { @@ -120,7 +118,7 @@ def generate_ce_fine_tuning_dataset( ) -> List[CrossEncoderFinetuningDatasetSample]: ce_dataset_list = [] - text_splitter = TokenTextSplitter( + node_parser = TokenTextSplitter( separator=" ", chunk_size=max_chunk_length, chunk_overlap=0, @@ -134,7 +132,6 @@ def generate_ce_fine_tuning_dataset( model="gpt-3.5-turbo-16k", temperature=0.1, logit_bias={9642: 1, 2822: 1} ) - node_parser = SimpleNodeParser(text_splitter=text_splitter) nodes = node_parser.get_nodes_from_documents(documents, show_progress=False) index = VectorStoreIndex(nodes) diff --git a/llama_index/finetuning/types.py b/llama_index/finetuning/types.py index 589e90dfcaae269249535b44202df2c5fd3209b5..5fa58d65f78fe35ce087631799782827118d83f8 100644 --- a/llama_index/finetuning/types.py +++ b/llama_index/finetuning/types.py @@ -4,8 +4,8 @@ from abc import ABC, abstractmethod from typing import Any from llama_index.embeddings.base import BaseEmbedding -from llama_index.indices.postprocessor import SentenceTransformerRerank from llama_index.llms.base import LLM +from llama_index.postprocessor import SentenceTransformerRerank class BaseLLMFinetuneEngine(ABC): diff --git a/llama_index/indices/__init__.py b/llama_index/indices/__init__.py index d59a1eb95f369a642cfc8a138c72eef58cd494f8..db65de8858a1c8a32dffa162e3ecd31c495ca9af 100644 --- a/llama_index/indices/__init__.py +++ b/llama_index/indices/__init__.py @@ -1,7 +1,13 @@ """LlamaIndex data structures.""" # indices +from llama_index.indices.composability.graph import ComposableGraph +from llama_index.indices.document_summary import ( + DocumentSummaryIndex, + GPTDocumentSummaryIndex, +) from llama_index.indices.document_summary.base import DocumentSummaryIndex +from llama_index.indices.empty.base import EmptyIndex, GPTEmptyIndex from llama_index.indices.keyword_table.base import ( GPTKeywordTableIndex, KeywordTableIndex, @@ -14,11 +20,31 @@ from llama_index.indices.keyword_table.simple_base import ( GPTSimpleKeywordTableIndex, SimpleKeywordTableIndex, ) +from llama_index.indices.knowledge_graph import ( + GPTKnowledgeGraphIndex, + KnowledgeGraphIndex, +) +from llama_index.indices.list import GPTListIndex, ListIndex, SummaryIndex from llama_index.indices.list.base import GPTListIndex, ListIndex, SummaryIndex +from llama_index.indices.loading import ( + load_graph_from_storage, + load_index_from_storage, + load_indices_from_storage, +) from llama_index.indices.managed.vectara import VectaraIndex +from llama_index.indices.multi_modal import MultiModalVectorStoreIndex +from llama_index.indices.struct_store.pandas import GPTPandasIndex, PandasIndex +from llama_index.indices.struct_store.sql import ( + GPTSQLStructStoreIndex, + SQLStructStoreIndex, +) from llama_index.indices.tree.base import GPTTreeIndex, TreeIndex +from llama_index.indices.vector_store import GPTVectorStoreIndex, VectorStoreIndex __all__ = [ + "load_graph_from_storage", + "load_index_from_storage", + "load_indices_from_storage", "KeywordTableIndex", "SimpleKeywordTableIndex", "RAKEKeywordTableIndex", @@ -26,11 +52,24 @@ __all__ = [ "TreeIndex", "VectaraIndex", "DocumentSummaryIndex", + "KnowledgeGraphIndex", + "PandasIndex", + "VectorStoreIndex", + "SQLStructStoreIndex", + "MultiModalVectorStoreIndex", + "EmptyIndex", + "ComposableGraph", # legacy + "GPTKnowledgeGraphIndex", "GPTKeywordTableIndex", "GPTSimpleKeywordTableIndex", "GPTRAKEKeywordTableIndex", + "GPTDocumentSummaryIndex", "GPTListIndex", "GPTTreeIndex", + "GPTPandasIndex", "ListIndex", + "GPTVectorStoreIndex", + "GPTSQLStructStoreIndex", + "GPTEmptyIndex", ] diff --git a/llama_index/indices/base.py b/llama_index/indices/base.py index 599fd1913cb269086168cd4aa33ba8ac361f848a..bdf448a5aefaeb994d4e667b6292360201e5b2eb 100644 --- a/llama_index/indices/base.py +++ b/llama_index/indices/base.py @@ -4,13 +4,13 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, cast from llama_index.chat_engine.types import BaseChatEngine, ChatMode +from llama_index.core import BaseQueryEngine, BaseRetriever from llama_index.data_structs.data_structs import IndexStruct -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.service_context import ServiceContext +from llama_index.ingestion import run_transformations from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import is_function_calling_model from llama_index.schema import BaseNode, Document +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import BaseDocumentStore, RefDocInfo from llama_index.storage.storage_context import StorageContext @@ -95,8 +95,12 @@ class BaseIndex(Generic[IS], ABC): with service_context.callback_manager.as_trace("index_construction"): for doc in documents: docstore.set_document_hash(doc.get_doc_id(), doc.hash) - nodes = service_context.node_parser.get_nodes_from_documents( - documents, show_progress=show_progress + + nodes = run_transformations( + documents, # type: ignore + service_context.transformations, + show_progress=show_progress, + **kwargs, ) return cls( @@ -184,9 +188,12 @@ class BaseIndex(Generic[IS], ABC): def insert(self, document: Document, **insert_kwargs: Any) -> None: """Insert a document.""" with self._service_context.callback_manager.as_trace("insert"): - nodes = self.service_context.node_parser.get_nodes_from_documents( - [document] + nodes = run_transformations( + [document], + self._service_context.transformations, + show_progress=self._show_progress, ) + self.insert_nodes(nodes, **insert_kwargs) self.docstore.set_document_hash(document.get_doc_id(), document.hash) diff --git a/llama_index/indices/base_retriever.py b/llama_index/indices/base_retriever.py index 770d76790de5479489f011d3091dd0494c8adfc7..22087ac2625a0eb7e8511da52eff235b6a103912 100644 --- a/llama_index/indices/base_retriever.py +++ b/llama_index/indices/base_retriever.py @@ -1,70 +1,6 @@ -from abc import abstractmethod -from typing import List, Optional +# for backwards compatibility +from llama_index.core import BaseRetriever -from llama_index.indices.query.schema import QueryBundle, QueryType -from llama_index.indices.service_context import ServiceContext -from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType -from llama_index.schema import NodeWithScore - - -class BaseRetriever(PromptMixin): - """Base retriever.""" - - def _get_prompts(self) -> PromptDictType: - """Get prompts.""" - return {} - - def _get_prompt_modules(self) -> PromptMixinType: - """Get prompt modules.""" - return {} - - def _update_prompts(self, prompts: PromptDictType) -> None: - """Update prompts.""" - - def retrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: - """Retrieve nodes given query. - - Args: - str_or_query_bundle (QueryType): Either a query string or - a QueryBundle object. - - """ - if isinstance(str_or_query_bundle, str): - str_or_query_bundle = QueryBundle(str_or_query_bundle) - return self._retrieve(str_or_query_bundle) - - async def aretrieve(self, str_or_query_bundle: QueryType) -> List[NodeWithScore]: - if isinstance(str_or_query_bundle, str): - str_or_query_bundle = QueryBundle(str_or_query_bundle) - return await self._aretrieve(str_or_query_bundle) - - @abstractmethod - def _retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - """Retrieve nodes given query. - - Implemented by the user. - - """ - - # TODO: make this abstract - # @abstractmethod - async def _aretrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - """Asynchronously retrieve nodes given query. - - Implemented by the user. - - """ - return self._retrieve(query_bundle) - - def get_service_context(self) -> Optional[ServiceContext]: - """Attempts to resolve a service context. - Short-circuits at self.service_context, self._service_context, - or self._index.service_context. - """ - if hasattr(self, "service_context"): - return self.service_context - if hasattr(self, "_service_context"): - return self._service_context - elif hasattr(self, "_index") and hasattr(self._index, "service_context"): - return self._index.service_context - return None +__all__ = [ + "BaseRetriever", +] diff --git a/llama_index/indices/common/struct_store/base.py b/llama_index/indices/common/struct_store/base.py index 44df7ee2af2c41bbd3ca2cc5b9fed8c539a62b65..4437ec1009691286b9c7a0920930b5beba3fdcb8 100644 --- a/llama_index/indices/common/struct_store/base.py +++ b/llama_index/indices/common/struct_store/base.py @@ -6,8 +6,8 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, cast from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.data_structs.table import StructDatapoint -from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.node_parser.interface import TextSplitter from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import ( DEFAULT_REFINE_TABLE_CONTEXT_PROMPT_SEL, @@ -19,7 +19,7 @@ from llama_index.prompts.default_prompts import ( from llama_index.prompts.prompt_type import PromptType from llama_index.response_synthesizers import get_response_synthesizer from llama_index.schema import BaseNode, MetadataMode -from llama_index.text_splitter import TextSplitter +from llama_index.service_context import ServiceContext from llama_index.utilities.sql_wrapper import SQLDatabase from llama_index.utils import truncate_text diff --git a/llama_index/indices/common_tree/base.py b/llama_index/indices/common_tree/base.py index e59a27673b2cb64d6d1b3b936d18814bd01369a0..f43986b1c0e21ec0f14c6d01ae8ead9ea658f83f 100644 --- a/llama_index/indices/common_tree/base.py +++ b/llama_index/indices/common_tree/base.py @@ -8,10 +8,10 @@ from typing import Dict, List, Optional, Sequence, Tuple from llama_index.async_utils import run_async_tasks from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.data_structs.data_structs import IndexGraph -from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import get_sorted_node_list, truncate_text from llama_index.prompts import BasePromptTemplate from llama_index.schema import BaseNode, MetadataMode, TextNode +from llama_index.service_context import ServiceContext from llama_index.storage.docstore import BaseDocumentStore from llama_index.storage.docstore.registry import get_default_docstore from llama_index.utils import get_tqdm_iterable diff --git a/llama_index/indices/composability/graph.py b/llama_index/indices/composability/graph.py index 695ab0cdc53853f8be8376931448702e0562431a..d7e5e14c3db829c17e20f2c100f1fea955e26f9e 100644 --- a/llama_index/indices/composability/graph.py +++ b/llama_index/indices/composability/graph.py @@ -2,11 +2,11 @@ from typing import Any, Dict, List, Optional, Sequence, Type, cast +from llama_index.core import BaseQueryEngine from llama_index.data_structs.data_structs import IndexStruct from llama_index.indices.base import BaseIndex -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.service_context import ServiceContext from llama_index.schema import IndexNode, NodeRelationship, ObjectType, RelatedNodeInfo +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext diff --git a/llama_index/indices/document_summary/base.py b/llama_index/indices/document_summary/base.py index 9eecd67c05cb1e69ea5b0a3d782883e5cb84c250..1bf97e20ac5ef0db910c0de478191a288b46bbb1 100644 --- a/llama_index/indices/document_summary/base.py +++ b/llama_index/indices/document_summary/base.py @@ -10,10 +10,9 @@ from collections import defaultdict from enum import Enum from typing import Any, Dict, Optional, Sequence, Union, cast +from llama_index.core import BaseRetriever from llama_index.data_structs.document_summary import IndexDocumentSummary from llama_index.indices.base import BaseIndex -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import embed_nodes from llama_index.response.schema import Response from llama_index.response_synthesizers import ( @@ -28,6 +27,7 @@ from llama_index.schema import ( RelatedNodeInfo, TextNode, ) +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo from llama_index.storage.storage_context import StorageContext from llama_index.utils import get_tqdm_iterable diff --git a/llama_index/indices/document_summary/retrievers.py b/llama_index/indices/document_summary/retrievers.py index 79bdff26f9983dda5e153b183105b209ec59c9ce..5dc3e47d11c7e71800a0f1f7b786f00c06446644 100644 --- a/llama_index/indices/document_summary/retrievers.py +++ b/llama_index/indices/document_summary/retrievers.py @@ -7,17 +7,16 @@ This module contains retrievers for document summary indices. import logging from typing import Any, Callable, List, Optional -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.document_summary.base import DocumentSummaryIndex -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import ( default_format_node_batch_fn, default_parse_choice_select_answer_fn, ) from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext from llama_index.vector_stores.types import VectorStoreQuery logger = logging.getLogger(__name__) diff --git a/llama_index/indices/empty/base.py b/llama_index/indices/empty/base.py index b565ab79071aa7c46756ff40ed22511073d50adc..6f74184f486a7fb23d96d6046ba91f30a226b551 100644 --- a/llama_index/indices/empty/base.py +++ b/llama_index/indices/empty/base.py @@ -7,12 +7,11 @@ pure LLM calls. from typing import Any, Dict, Optional, Sequence +from llama_index.core import BaseQueryEngine, BaseRetriever from llama_index.data_structs.data_structs import EmptyIndexStruct from llama_index.indices.base import BaseIndex -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.service_context import ServiceContext from llama_index.schema import BaseNode +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo diff --git a/llama_index/indices/empty/retrievers.py b/llama_index/indices/empty/retrievers.py index f4c09a76315156c333749af33b956e5448f03134..f5aeed4baa6c7e64fe62caf2046b27591c86a263 100644 --- a/llama_index/indices/empty/retrievers.py +++ b/llama_index/indices/empty/retrievers.py @@ -1,12 +1,11 @@ """Default query for EmptyIndex.""" from typing import Any, List, Optional -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.empty.base import EmptyIndex -from llama_index.indices.query.schema import QueryBundle from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle class EmptyIndexRetriever(BaseRetriever): diff --git a/llama_index/indices/keyword_table/base.py b/llama_index/indices/keyword_table/base.py index 6c7c1ac25b17f39f6951b129509e47317ea69055..fb6b17156bf1b144f804a5ce347ff7fdc99a9517 100644 --- a/llama_index/indices/keyword_table/base.py +++ b/llama_index/indices/keyword_table/base.py @@ -13,17 +13,17 @@ from enum import Enum from typing import Any, Dict, Optional, Sequence, Set, Union from llama_index.async_utils import run_async_tasks +from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import KeywordTable from llama_index.indices.base import BaseIndex -from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.keyword_table.utils import extract_keywords_given_response -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import ( DEFAULT_KEYWORD_EXTRACT_TEMPLATE, DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE, ) from llama_index.schema import BaseNode, MetadataMode +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo from llama_index.utils import get_tqdm_iterable diff --git a/llama_index/indices/keyword_table/rake_base.py b/llama_index/indices/keyword_table/rake_base.py index 814ffb882dd23d582f016eb32448958b70f47439..b4188e731282216b08bf1a9d28cf43c24e7768a7 100644 --- a/llama_index/indices/keyword_table/rake_base.py +++ b/llama_index/indices/keyword_table/rake_base.py @@ -6,7 +6,7 @@ Similar to KeywordTableIndex, but uses RAKE instead of GPT. from typing import Any, Set, Union -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.keyword_table.base import ( BaseKeywordTableIndex, KeywordTableRetrieverMode, diff --git a/llama_index/indices/keyword_table/retrievers.py b/llama_index/indices/keyword_table/retrievers.py index 584298f5975257a47afa511b9674f564113a303b..83f43e732daccee4c577ed3e3aa8e91e6438e4f9 100644 --- a/llama_index/indices/keyword_table/retrievers.py +++ b/llama_index/indices/keyword_table/retrievers.py @@ -4,20 +4,19 @@ from abc import abstractmethod from collections import defaultdict from typing import Any, Dict, List, Optional -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.keyword_table.base import BaseKeywordTableIndex from llama_index.indices.keyword_table.utils import ( extract_keywords_given_response, rake_extract_keywords, simple_extract_keywords, ) -from llama_index.indices.query.schema import QueryBundle from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import ( DEFAULT_KEYWORD_EXTRACT_TEMPLATE, DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE, ) -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle from llama_index.utils import truncate_text DQKET = DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE diff --git a/llama_index/indices/keyword_table/simple_base.py b/llama_index/indices/keyword_table/simple_base.py index b6e5a7299d14c16c85b3088b39d2d4082158032f..f54a57866c8bff67925e0f7a477b5387103f9e82 100644 --- a/llama_index/indices/keyword_table/simple_base.py +++ b/llama_index/indices/keyword_table/simple_base.py @@ -7,7 +7,7 @@ technique that doesn't involve GPT - just uses regex. from typing import Any, Set, Union -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.keyword_table.base import ( BaseKeywordTableIndex, KeywordTableRetrieverMode, diff --git a/llama_index/indices/knowledge_graph/base.py b/llama_index/indices/knowledge_graph/base.py index 5c285ed7f349e7d0f10128ad379f700bb680e443..6387c673aa4cd4b7e717250fc508e48501fef8cf 100644 --- a/llama_index/indices/knowledge_graph/base.py +++ b/llama_index/indices/knowledge_graph/base.py @@ -8,15 +8,15 @@ import logging from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple from llama_index.constants import GRAPH_STORE_KEY +from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import KG from llama_index.graph_stores.simple import SimpleGraphStore from llama_index.graph_stores.types import GraphStore from llama_index.indices.base import BaseIndex -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_KG_TRIPLET_EXTRACT_PROMPT from llama_index.schema import BaseNode, MetadataMode +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo from llama_index.storage.storage_context import StorageContext from llama_index.utils import get_tqdm_iterable diff --git a/llama_index/indices/knowledge_graph/retrievers.py b/llama_index/indices/knowledge_graph/retrievers.py index 12bcae26b5b4f86697374eb744be0644d9dcb13c..7a6490c67f9b772b721d8127fd9e0893a7b414f7 100644 --- a/llama_index/indices/knowledge_graph/retrievers.py +++ b/llama_index/indices/knowledge_graph/retrievers.py @@ -4,15 +4,20 @@ from collections import defaultdict from enum import Enum from typing import Any, Callable, Dict, List, Optional, Set, Tuple -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.keyword_table.utils import extract_keywords_given_response from llama_index.indices.knowledge_graph.base import KnowledgeGraphIndex from llama_index.indices.query.embedding_utils import get_top_k_embeddings -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import BasePromptTemplate, PromptTemplate, PromptType from llama_index.prompts.default_prompts import DEFAULT_QUERY_KEYWORD_EXTRACT_TEMPLATE -from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, TextNode +from llama_index.schema import ( + BaseNode, + MetadataMode, + NodeWithScore, + QueryBundle, + TextNode, +) +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext from llama_index.utils import print_text, truncate_text diff --git a/llama_index/indices/list/base.py b/llama_index/indices/list/base.py index 84dc3fd3e62e108d8bb758e65355c41c95df80b0..c4e3321c2e848a12892e2ae5f394957002a8e97d 100644 --- a/llama_index/indices/list/base.py +++ b/llama_index/indices/list/base.py @@ -8,11 +8,11 @@ in sequence in order to answer a given query. from enum import Enum from typing import Any, Dict, Optional, Sequence, Union +from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexList from llama_index.indices.base import BaseIndex -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.service_context import ServiceContext from llama_index.schema import BaseNode +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo from llama_index.utils import get_tqdm_iterable diff --git a/llama_index/indices/list/retrievers.py b/llama_index/indices/list/retrievers.py index 32913205f2a392ad3c9d8f7e5d48538177029a6a..1ac8c614b038466132a515b90030402c2526eb43 100644 --- a/llama_index/indices/list/retrievers.py +++ b/llama_index/indices/list/retrievers.py @@ -2,11 +2,9 @@ import logging from typing import Any, Callable, List, Optional, Tuple -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.list.base import SummaryIndex from llama_index.indices.query.embedding_utils import get_top_k_embeddings -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import ( default_format_node_batch_fn, default_parse_choice_select_answer_fn, @@ -15,7 +13,8 @@ from llama_index.prompts import PromptTemplate from llama_index.prompts.default_prompts import ( DEFAULT_CHOICE_SELECT_PROMPT, ) -from llama_index.schema import BaseNode, MetadataMode, NodeWithScore +from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext logger = logging.getLogger(__name__) diff --git a/llama_index/indices/managed/base.py b/llama_index/indices/managed/base.py index 23ece665102ed24d9dbf738d916f904576683afb..d192f6d302d5e45127a0356d135c5782f794b539 100644 --- a/llama_index/indices/managed/base.py +++ b/llama_index/indices/managed/base.py @@ -6,11 +6,11 @@ An index that that is built on top of a managed service. from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Sequence, Type +from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexDict from llama_index.indices.base import BaseIndex, IndexType -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.service_context import ServiceContext from llama_index.schema import BaseNode, Document +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo from llama_index.storage.storage_context import StorageContext diff --git a/llama_index/indices/managed/vectara/base.py b/llama_index/indices/managed/vectara/base.py index 24ce65a107c7560118ba23da038f44bef0ad1bd4..2f4ca2aeae04b4579f94240b283b440a552ed163 100644 --- a/llama_index/indices/managed/vectara/base.py +++ b/llama_index/indices/managed/vectara/base.py @@ -12,11 +12,11 @@ from typing import Any, Optional, Sequence, Type import requests +from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexDict, IndexStructType -from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.managed.base import BaseManagedIndex, IndexType -from llama_index.indices.service_context import ServiceContext from llama_index.schema import BaseNode, Document, MetadataMode, TextNode +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext _logger = logging.getLogger(__name__) diff --git a/llama_index/indices/managed/vectara/retriever.py b/llama_index/indices/managed/vectara/retriever.py index d67b3ff2f993bb605a57a60eccc9b8f7d33f1230..3bc97583d3273e59b32d44f804ab0127f79f5b7a 100644 --- a/llama_index/indices/managed/vectara/retriever.py +++ b/llama_index/indices/managed/vectara/retriever.py @@ -7,11 +7,10 @@ import logging from typing import Any, Dict, List from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.managed.types import ManagedIndexQueryMode from llama_index.indices.managed.vectara.base import VectaraIndex -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import NodeWithScore, TextNode +from llama_index.schema import NodeWithScore, QueryBundle, TextNode _logger = logging.getLogger(__name__) diff --git a/llama_index/indices/multi_modal/base.py b/llama_index/indices/multi_modal/base.py index bba86fb7245ba4777b11e410025cbc769d05ef82..900e83656c237ac25dbc0d85559134196874d7da 100644 --- a/llama_index/indices/multi_modal/base.py +++ b/llama_index/indices/multi_modal/base.py @@ -6,12 +6,10 @@ An index that that is built on top of multiple vector stores for different modal import logging from typing import Any, List, Optional, Sequence, cast +from llama_index.core import BaseQueryEngine, BaseRetriever from llama_index.data_structs.data_structs import IndexDict, MultiModelIndexDict from llama_index.embeddings.multi_modal_base import MultiModalEmbedding from llama_index.embeddings.utils import EmbedType, resolve_embed_model -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import ( async_embed_image_nodes, async_embed_nodes, @@ -20,6 +18,7 @@ from llama_index.indices.utils import ( ) from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.schema import BaseNode, ImageNode +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext from llama_index.vector_stores.simple import DEFAULT_VECTOR_STORE, SimpleVectorStore from llama_index.vector_stores.types import VectorStore diff --git a/llama_index/indices/multi_modal/retriever.py b/llama_index/indices/multi_modal/retriever.py index 87c900ec2f549b446b6fd067b6133c0f9cf950bc..1cc2239e8a74dc892e4394f428e4148db3b2e9f2 100644 --- a/llama_index/indices/multi_modal/retriever.py +++ b/llama_index/indices/multi_modal/retriever.py @@ -6,9 +6,8 @@ from typing import Any, Dict, List, Optional from llama_index.constants import DEFAULT_SIMILARITY_TOP_K from llama_index.embeddings.multi_modal_base import MultiModalEmbedding from llama_index.indices.multi_modal.base import MultiModalVectorStoreIndex -from llama_index.indices.query.schema import QueryBundle from llama_index.indices.vector_store.retrievers.retriever import VectorIndexRetriever -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle from llama_index.vector_stores.types import ( MetadataFilters, VectorStoreQuery, diff --git a/llama_index/indices/postprocessor.py b/llama_index/indices/postprocessor.py new file mode 100644 index 0000000000000000000000000000000000000000..8837b2cc54aca6a168f93e81addf460cf4669fd9 --- /dev/null +++ b/llama_index/indices/postprocessor.py @@ -0,0 +1,38 @@ +# for backward compatibility +from llama_index.postprocessor import ( + AutoPrevNextNodePostprocessor, + CohereRerank, + EmbeddingRecencyPostprocessor, + FixedRecencyPostprocessor, + KeywordNodePostprocessor, + LLMRerank, + LongContextReorder, + LongLLMLinguaPostprocessor, + MetadataReplacementPostProcessor, + NERPIINodePostprocessor, + PIINodePostprocessor, + PrevNextNodePostprocessor, + SentenceEmbeddingOptimizer, + SentenceTransformerRerank, + SimilarityPostprocessor, + TimeWeightedPostprocessor, +) + +__all__ = [ + "SimilarityPostprocessor", + "KeywordNodePostprocessor", + "PrevNextNodePostprocessor", + "AutoPrevNextNodePostprocessor", + "FixedRecencyPostprocessor", + "EmbeddingRecencyPostprocessor", + "TimeWeightedPostprocessor", + "PIINodePostprocessor", + "NERPIINodePostprocessor", + "CohereRerank", + "LLMRerank", + "SentenceEmbeddingOptimizer", + "SentenceTransformerRerank", + "MetadataReplacementPostProcessor", + "LongContextReorder", + "LongLLMLinguaPostprocessor", +] diff --git a/llama_index/indices/prompt_helper.py b/llama_index/indices/prompt_helper.py index 50c693263a7ed2ab007e6423ba84299f5f84b60a..ea6caf31bce544b5d7bb744547acc747fc4a8605 100644 --- a/llama_index/indices/prompt_helper.py +++ b/llama_index/indices/prompt_helper.py @@ -9,18 +9,24 @@ needed), or truncating them so that they fit in a single LLM call. """ import logging +from copy import deepcopy +from string import Formatter from typing import Callable, List, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.llm_predictor.base import LLMMetadata -from llama_index.llms.openai_utils import is_chat_model -from llama_index.prompts import BasePromptTemplate +from llama_index.llms.base import LLM, ChatMessage +from llama_index.node_parser.text.token import TokenTextSplitter +from llama_index.node_parser.text.utils import truncate_text +from llama_index.prompts import ( + BasePromptTemplate, + ChatPromptTemplate, + SelectorPromptTemplate, +) from llama_index.prompts.prompt_utils import get_empty_prompt_txt from llama_index.schema import BaseComponent -from llama_index.text_splitter import TokenTextSplitter -from llama_index.text_splitter.utils import truncate_text -from llama_index.utils import globals_helper +from llama_index.utilities.token_counting import TokenCounter DEFAULT_PADDING = 5 DEFAULT_CHUNK_OVERLAP_RATIO = 0.1 @@ -68,7 +74,7 @@ class PromptHelper(BaseComponent): default=" ", description="The separator when chunking tokens." ) - _tokenizer: Callable[[str], List] = PrivateAttr() + _token_counter: TokenCounter = PrivateAttr() def __init__( self, @@ -84,7 +90,7 @@ class PromptHelper(BaseComponent): raise ValueError("chunk_overlap_ratio must be a float between 0. and 1.") # TODO: make configurable - self._tokenizer = tokenizer or globals_helper.tokenizer + self._token_counter = TokenCounter(tokenizer=tokenizer) super().__init__( context_window=context_window, @@ -114,11 +120,6 @@ class PromptHelper(BaseComponent): else: num_output = llm_metadata.num_output - # TODO: account for token counting in chat models - model_name = llm_metadata.model_name - if is_chat_model(model_name): - context_window -= 150 - return cls( context_window=context_window, num_output=num_output, @@ -132,7 +133,7 @@ class PromptHelper(BaseComponent): def class_name(cls) -> str: return "PromptHelper" - def _get_available_context_size(self, prompt: BasePromptTemplate) -> int: + def _get_available_context_size(self, num_prompt_tokens: int) -> int: """Get available context size. This is calculated as: @@ -143,11 +144,7 @@ class PromptHelper(BaseComponent): Notes: - Available context size is further clamped to be non-negative. """ - empty_prompt_txt = get_empty_prompt_txt(prompt) - num_empty_prompt_tokens = len(self._tokenizer(empty_prompt_txt)) - context_size_tokens = ( - self.context_window - num_empty_prompt_tokens - self.num_output - ) + context_size_tokens = self.context_window - num_prompt_tokens - self.num_output if context_size_tokens < 0: raise ValueError( f"Calculated available context size {context_size_tokens} was" @@ -156,7 +153,11 @@ class PromptHelper(BaseComponent): return context_size_tokens def _get_available_chunk_size( - self, prompt: BasePromptTemplate, num_chunks: int = 1, padding: int = 5 + self, + prompt: BasePromptTemplate, + num_chunks: int = 1, + padding: int = 5, + llm: Optional[LLM] = None, ) -> int: """Get available chunk size. @@ -168,7 +169,47 @@ class PromptHelper(BaseComponent): - By default, we use padding of 5 (to save space for formatting needs). - Available chunk size is further clamped to chunk_size_limit if specified. """ - available_context_size = self._get_available_context_size(prompt) + if isinstance(prompt, SelectorPromptTemplate): + prompt = prompt.select(llm=llm) + + if isinstance(prompt, ChatPromptTemplate): + messages: List[ChatMessage] = prompt.message_templates + + # account for partial formatting + partial_messages = [] + for message in messages: + partial_message = deepcopy(message) + + # get string variables (if any) + template_vars = [ + var + for _, var, _, _ in Formatter().parse(str(message)) + if var is not None + ] + + # figure out which variables are partially formatted + used_vars = {} + for var_name, val in prompt.kwargs.items(): + if var_name in template_vars: + used_vars[var_name] = val + + # format partial message + if len(used_vars) > 0 and partial_message.content is not None: + partial_message.content = partial_message.content.format( + **used_vars + ) + + # add to list of partial messages + partial_messages.append(partial_message) + + num_prompt_tokens = self._token_counter.estimate_tokens_in_messages( + partial_messages + ) + else: + prompt_str = get_empty_prompt_txt(prompt) + num_prompt_tokens = self._token_counter.get_string_tokens(prompt_str) + + available_context_size = self._get_available_context_size(num_prompt_tokens) result = available_context_size // num_chunks - padding if self.chunk_size_limit is not None: result = min(result, self.chunk_size_limit) @@ -179,11 +220,14 @@ class PromptHelper(BaseComponent): prompt: BasePromptTemplate, num_chunks: int = 1, padding: int = DEFAULT_PADDING, + llm: Optional[LLM] = None, ) -> TokenTextSplitter: """Get text splitter configured to maximally pack available context window, taking into account of given prompt, and desired number of chunks. """ - chunk_size = self._get_available_chunk_size(prompt, num_chunks, padding=padding) + chunk_size = self._get_available_chunk_size( + prompt, num_chunks, padding=padding, llm=llm + ) if chunk_size <= 0: raise ValueError(f"Chunk size {chunk_size} is not positive.") chunk_overlap = int(self.chunk_overlap_ratio * chunk_size) @@ -191,7 +235,7 @@ class PromptHelper(BaseComponent): separator=self.separator, chunk_size=chunk_size, chunk_overlap=chunk_overlap, - tokenizer=self._tokenizer, + tokenizer=self._token_counter.tokenizer, ) def truncate( @@ -199,12 +243,14 @@ class PromptHelper(BaseComponent): prompt: BasePromptTemplate, text_chunks: Sequence[str], padding: int = DEFAULT_PADDING, + llm: Optional[LLM] = None, ) -> List[str]: """Truncate text chunks to fit available context window.""" text_splitter = self.get_text_splitter_given_prompt( prompt, num_chunks=len(text_chunks), padding=padding, + llm=llm, ) return [truncate_text(chunk, text_splitter) for chunk in text_chunks] @@ -213,6 +259,7 @@ class PromptHelper(BaseComponent): prompt: BasePromptTemplate, text_chunks: Sequence[str], padding: int = DEFAULT_PADDING, + llm: Optional[LLM] = None, ) -> List[str]: """Repack text chunks to fit available context window. @@ -220,6 +267,8 @@ class PromptHelper(BaseComponent): that more fully "pack" the prompt template given the context_window. """ - text_splitter = self.get_text_splitter_given_prompt(prompt, padding=padding) + text_splitter = self.get_text_splitter_given_prompt( + prompt, padding=padding, llm=llm + ) combined_str = "\n\n".join([c.strip() for c in text_chunks if c.strip()]) return text_splitter.split_text(combined_str) diff --git a/llama_index/indices/query/base.py b/llama_index/indices/query/base.py index 8e7596db9eed8b27b1e33913b09610b15d14c756..87d179f262bea1a7d156bbff7cd6627dcc22007d 100644 --- a/llama_index/indices/query/base.py +++ b/llama_index/indices/query/base.py @@ -1,70 +1,6 @@ -"""Base query engine.""" +# for backwards compatibility +from llama_index.core import BaseQueryEngine -import logging -from abc import abstractmethod -from typing import Any, Dict, List, Optional, Sequence - -from llama_index.callbacks.base import CallbackManager -from llama_index.indices.query.schema import QueryBundle, QueryType -from llama_index.prompts.mixin import PromptDictType, PromptMixin -from llama_index.response.schema import RESPONSE_TYPE -from llama_index.schema import NodeWithScore - -logger = logging.getLogger(__name__) - - -class BaseQueryEngine(PromptMixin): - def __init__(self, callback_manager: Optional[CallbackManager]) -> None: - self.callback_manager = callback_manager or CallbackManager([]) - - def _get_prompts(self) -> Dict[str, Any]: - """Get prompts.""" - return {} - - def _update_prompts(self, prompts: PromptDictType) -> None: - """Update prompts.""" - - def query(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE: - with self.callback_manager.as_trace("query"): - if isinstance(str_or_query_bundle, str): - str_or_query_bundle = QueryBundle(str_or_query_bundle) - return self._query(str_or_query_bundle) - - async def aquery(self, str_or_query_bundle: QueryType) -> RESPONSE_TYPE: - with self.callback_manager.as_trace("query"): - if isinstance(str_or_query_bundle, str): - str_or_query_bundle = QueryBundle(str_or_query_bundle) - return await self._aquery(str_or_query_bundle) - - def retrieve(self, query_bundle: QueryBundle) -> List[NodeWithScore]: - raise NotImplementedError( - "This query engine does not support retrieve, use query directly" - ) - - def synthesize( - self, - query_bundle: QueryBundle, - nodes: List[NodeWithScore], - additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, - ) -> RESPONSE_TYPE: - raise NotImplementedError( - "This query engine does not support synthesize, use query directly" - ) - - async def asynthesize( - self, - query_bundle: QueryBundle, - nodes: List[NodeWithScore], - additional_source_nodes: Optional[Sequence[NodeWithScore]] = None, - ) -> RESPONSE_TYPE: - raise NotImplementedError( - "This query engine does not support asynthesize, use aquery directly" - ) - - @abstractmethod - def _query(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: - pass - - @abstractmethod - async def _aquery(self, query_bundle: QueryBundle) -> RESPONSE_TYPE: - pass +__all__ = [ + "BaseQueryEngine", +] diff --git a/llama_index/indices/query/query_transform/base.py b/llama_index/indices/query/query_transform/base.py index f488cc664bfb5033855937c427a95d0271aab783..6313cde6ee2d4f997f0e43aa027a39957f2b2733 100644 --- a/llama_index/indices/query/query_transform/base.py +++ b/llama_index/indices/query/query_transform/base.py @@ -12,13 +12,13 @@ from llama_index.indices.query.query_transform.prompts import ( ImageOutputQueryTransformPrompt, StepDecomposeQueryTransformPrompt, ) -from llama_index.indices.query.schema import QueryBundle, QueryType from llama_index.llm_predictor import LLMPredictor from llama_index.llm_predictor.base import BaseLLMPredictor from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_HYDE_PROMPT from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType from llama_index.response.schema import Response +from llama_index.schema import QueryBundle, QueryType from llama_index.utils import print_text diff --git a/llama_index/indices/query/query_transform/feedback_transform.py b/llama_index/indices/query/query_transform/feedback_transform.py index 3a5e150895be06189239875da12b35f07200cb27..0e8342b054550f17162612aa8a1194f15eebe9ce 100644 --- a/llama_index/indices/query/query_transform/feedback_transform.py +++ b/llama_index/indices/query/query_transform/feedback_transform.py @@ -3,11 +3,11 @@ from typing import Dict, Optional from llama_index.evaluation.base import Evaluation from llama_index.indices.query.query_transform.base import BaseQueryTransform -from llama_index.indices.query.schema import QueryBundle from llama_index.llm_predictor import LLMPredictor from llama_index.llm_predictor.base import BaseLLMPredictor from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType +from llama_index.schema import QueryBundle logger = logging.getLogger(__name__) diff --git a/llama_index/indices/query/schema.py b/llama_index/indices/query/schema.py index 377f5a9dba56ebf78d820f28d13b20077d9f0830..af5ac64b410c55d7a005ae4ccb2808cb06e18fd8 100644 --- a/llama_index/indices/query/schema.py +++ b/llama_index/indices/query/schema.py @@ -1,44 +1,4 @@ -"""Query Schema. +# for backwards compatibility +from llama_index.schema import QueryBundle, QueryType -This schema is used under the hood for all queries, but is primarily -exposed for recursive queries over composable indices. - -""" - -from dataclasses import dataclass -from typing import List, Optional, Union - -from dataclasses_json import DataClassJsonMixin - - -@dataclass -class QueryBundle(DataClassJsonMixin): - """ - Query bundle. - - This dataclass contains the original query string and associated transformations. - - Args: - query_str (str): the original user-specified query string. - This is currently used by all non embedding-based queries. - embedding_strs (list[str]): list of strings used for embedding the query. - This is currently used by all embedding-based queries. - embedding (list[float]): the stored embedding for the query. - """ - - query_str: str - custom_embedding_strs: Optional[List[str]] = None - embedding: Optional[List[float]] = None - - @property - def embedding_strs(self) -> List[str]: - """Use custom embedding strs if specified, otherwise use query str.""" - if self.custom_embedding_strs is None: - if len(self.query_str) == 0: - return [] - return [self.query_str] - else: - return self.custom_embedding_strs - - -QueryType = Union[str, QueryBundle] +__all__ = ["QueryBundle", "QueryType"] diff --git a/llama_index/indices/service_context.py b/llama_index/indices/service_context.py index 5e202d17bf6588f0a03d86ead68a74de5a563f2e..8979ec9e9e7756d4daed383e5857639c69a43ee7 100644 --- a/llama_index/indices/service_context.py +++ b/llama_index/indices/service_context.py @@ -1,367 +1,6 @@ -import logging -from dataclasses import dataclass -from typing import Optional +# for backwards compatibility +from llama_index.service_context import ServiceContext -import llama_index -from llama_index.bridge.pydantic import BaseModel -from llama_index.callbacks.base import CallbackManager -from llama_index.embeddings.base import BaseEmbedding -from llama_index.embeddings.utils import EmbedType, resolve_embed_model -from llama_index.indices.prompt_helper import PromptHelper -from llama_index.llm_predictor import LLMPredictor -from llama_index.llm_predictor.base import BaseLLMPredictor, LLMMetadata -from llama_index.llms.base import LLM -from llama_index.llms.utils import LLMType, resolve_llm -from llama_index.logger import LlamaLogger -from llama_index.node_parser.interface import NodeParser -from llama_index.node_parser.sentence_window import SentenceWindowNodeParser -from llama_index.node_parser.simple import SimpleNodeParser -from llama_index.prompts.base import BasePromptTemplate -from llama_index.text_splitter.types import TextSplitter -from llama_index.types import PydanticProgramMode - -logger = logging.getLogger(__name__) - - -def _get_default_node_parser( - chunk_size: Optional[int] = None, - chunk_overlap: Optional[int] = None, - callback_manager: Optional[CallbackManager] = None, -) -> NodeParser: - """Get default node parser.""" - return SimpleNodeParser.from_defaults( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - callback_manager=callback_manager, - ) - - -def _get_default_prompt_helper( - llm_metadata: LLMMetadata, - context_window: Optional[int] = None, - num_output: Optional[int] = None, -) -> PromptHelper: - """Get default prompt helper.""" - if context_window is not None: - llm_metadata.context_window = context_window - if num_output is not None: - llm_metadata.num_output = num_output - return PromptHelper.from_llm_metadata(llm_metadata=llm_metadata) - - -class ServiceContextData(BaseModel): - llm: dict - llm_predictor: dict - prompt_helper: dict - embed_model: dict - node_parser: dict - text_splitter: Optional[dict] - metadata_extractor: Optional[dict] - extractors: Optional[list] - - -@dataclass -class ServiceContext: - """Service Context container. - - The service context container is a utility container for LlamaIndex - index and query classes. It contains the following: - - llm_predictor: BaseLLMPredictor - - prompt_helper: PromptHelper - - embed_model: BaseEmbedding - - node_parser: NodeParser - - llama_logger: LlamaLogger (deprecated) - - callback_manager: CallbackManager - - """ - - llm_predictor: BaseLLMPredictor - prompt_helper: PromptHelper - embed_model: BaseEmbedding - node_parser: NodeParser - llama_logger: LlamaLogger - callback_manager: CallbackManager - - @classmethod - def from_defaults( - cls, - llm_predictor: Optional[BaseLLMPredictor] = None, - llm: Optional[LLMType] = "default", - prompt_helper: Optional[PromptHelper] = None, - embed_model: Optional[EmbedType] = "default", - node_parser: Optional[NodeParser] = None, - llama_logger: Optional[LlamaLogger] = None, - callback_manager: Optional[CallbackManager] = None, - system_prompt: Optional[str] = None, - query_wrapper_prompt: Optional[BasePromptTemplate] = None, - # pydantic program mode (used if output_cls is specified) - pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, - # node parser kwargs - chunk_size: Optional[int] = None, - chunk_overlap: Optional[int] = None, - # prompt helper kwargs - context_window: Optional[int] = None, - num_output: Optional[int] = None, - # deprecated kwargs - chunk_size_limit: Optional[int] = None, - ) -> "ServiceContext": - """Create a ServiceContext from defaults. - If an argument is specified, then use the argument value provided for that - parameter. If an argument is not specified, then use the default value. - - You can change the base defaults by setting llama_index.global_service_context - to a ServiceContext object with your desired settings. - - Args: - llm_predictor (Optional[BaseLLMPredictor]): LLMPredictor - prompt_helper (Optional[PromptHelper]): PromptHelper - embed_model (Optional[BaseEmbedding]): BaseEmbedding - or "local" (use local model) - node_parser (Optional[NodeParser]): NodeParser - llama_logger (Optional[LlamaLogger]): LlamaLogger (deprecated) - chunk_size (Optional[int]): chunk_size - callback_manager (Optional[CallbackManager]): CallbackManager - system_prompt (Optional[str]): System-wide prompt to be prepended - to all input prompts, used to guide system "decision making" - query_wrapper_prompt (Optional[BasePromptTemplate]): A format to wrap - passed-in input queries. - - Deprecated Args: - chunk_size_limit (Optional[int]): renamed to chunk_size - - """ - if chunk_size_limit is not None and chunk_size is None: - logger.warning( - "chunk_size_limit is deprecated, please specify chunk_size instead" - ) - chunk_size = chunk_size_limit - - if llama_index.global_service_context is not None: - return cls.from_service_context( - llama_index.global_service_context, - llm_predictor=llm_predictor, - prompt_helper=prompt_helper, - embed_model=embed_model, - node_parser=node_parser, - llama_logger=llama_logger, - callback_manager=callback_manager, - chunk_size=chunk_size, - chunk_size_limit=chunk_size_limit, - ) - - callback_manager = callback_manager or CallbackManager([]) - if llm != "default": - if llm_predictor is not None: - raise ValueError("Cannot specify both llm and llm_predictor") - llm = resolve_llm(llm) - llm_predictor = llm_predictor or LLMPredictor( - llm=llm, pydantic_program_mode=pydantic_program_mode - ) - if isinstance(llm_predictor, LLMPredictor): - llm_predictor.llm.callback_manager = callback_manager - if system_prompt: - llm_predictor.system_prompt = system_prompt - if query_wrapper_prompt: - llm_predictor.query_wrapper_prompt = query_wrapper_prompt - - # NOTE: the embed_model isn't used in all indices - embed_model = resolve_embed_model(embed_model) - embed_model.callback_manager = callback_manager - - prompt_helper = prompt_helper or _get_default_prompt_helper( - llm_metadata=llm_predictor.metadata, - context_window=context_window, - num_output=num_output, - ) - - node_parser = node_parser or _get_default_node_parser( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - callback_manager=callback_manager, - ) - - llama_logger = llama_logger or LlamaLogger() - - return cls( - llm_predictor=llm_predictor, - embed_model=embed_model, - prompt_helper=prompt_helper, - node_parser=node_parser, - llama_logger=llama_logger, # deprecated - callback_manager=callback_manager, - ) - - @classmethod - def from_service_context( - cls, - service_context: "ServiceContext", - llm_predictor: Optional[BaseLLMPredictor] = None, - llm: Optional[LLMType] = "default", - prompt_helper: Optional[PromptHelper] = None, - embed_model: Optional[EmbedType] = "default", - node_parser: Optional[NodeParser] = None, - llama_logger: Optional[LlamaLogger] = None, - callback_manager: Optional[CallbackManager] = None, - system_prompt: Optional[str] = None, - query_wrapper_prompt: Optional[BasePromptTemplate] = None, - # node parser kwargs - chunk_size: Optional[int] = None, - chunk_overlap: Optional[int] = None, - # prompt helper kwargs - context_window: Optional[int] = None, - num_output: Optional[int] = None, - # deprecated kwargs - chunk_size_limit: Optional[int] = None, - ) -> "ServiceContext": - """Instantiate a new service context using a previous as the defaults.""" - if chunk_size_limit is not None and chunk_size is None: - logger.warning( - "chunk_size_limit is deprecated, please specify chunk_size", - DeprecationWarning, - ) - chunk_size = chunk_size_limit - - callback_manager = callback_manager or service_context.callback_manager - if llm != "default": - if llm_predictor is not None: - raise ValueError("Cannot specify both llm and llm_predictor") - llm = resolve_llm(llm) - llm_predictor = LLMPredictor(llm=llm) - - llm_predictor = llm_predictor or service_context.llm_predictor - if isinstance(llm_predictor, LLMPredictor): - llm_predictor.llm.callback_manager = callback_manager - if system_prompt: - llm_predictor.system_prompt = system_prompt - if query_wrapper_prompt: - llm_predictor.query_wrapper_prompt = query_wrapper_prompt - - # NOTE: the embed_model isn't used in all indices - # default to using the embed model passed from the service context - if embed_model == "default": - embed_model = service_context.embed_model - embed_model = resolve_embed_model(embed_model) - embed_model.callback_manager = callback_manager - - prompt_helper = prompt_helper or service_context.prompt_helper - if context_window is not None or num_output is not None: - prompt_helper = _get_default_prompt_helper( - llm_metadata=llm_predictor.metadata, - context_window=context_window, - num_output=num_output, - ) - - node_parser = node_parser or service_context.node_parser - if chunk_size is not None or chunk_overlap is not None: - node_parser = _get_default_node_parser( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - callback_manager=callback_manager, - ) - - llama_logger = llama_logger or service_context.llama_logger - - return cls( - llm_predictor=llm_predictor, - embed_model=embed_model, - prompt_helper=prompt_helper, - node_parser=node_parser, - llama_logger=llama_logger, # deprecated - callback_manager=callback_manager, - ) - - @property - def llm(self) -> LLM: - if not isinstance(self.llm_predictor, LLMPredictor): - raise ValueError("llm_predictor must be an instance of LLMPredictor") - return self.llm_predictor.llm - - def to_dict(self) -> dict: - """Convert service context to dict.""" - llm_dict = self.llm_predictor.llm.to_dict() - llm_predictor_dict = self.llm_predictor.to_dict() - - embed_model_dict = self.embed_model.to_dict() - - prompt_helper_dict = self.prompt_helper.to_dict() - - node_parser_dict = self.node_parser.to_dict() - - metadata_extractor_dict = None - extractor_dicts = None - text_splitter_dict = None - if isinstance(self.node_parser, SimpleNodeParser) and isinstance( - self.node_parser.text_splitter, TextSplitter - ): - text_splitter_dict = self.node_parser.text_splitter.to_dict() - - if isinstance(self.node_parser, (SimpleNodeParser, SentenceWindowNodeParser)): - if self.node_parser.metadata_extractor: - metadata_extractor_dict = self.node_parser.metadata_extractor.to_dict() - extractor_dicts = [] - for extractor in self.node_parser.metadata_extractor.extractors: - extractor_dicts.append(extractor.to_dict()) - - return ServiceContextData( - llm=llm_dict, - llm_predictor=llm_predictor_dict, - prompt_helper=prompt_helper_dict, - embed_model=embed_model_dict, - node_parser=node_parser_dict, - text_splitter=text_splitter_dict, - metadata_extractor=metadata_extractor_dict, - extractors=extractor_dicts, - ).dict() - - @classmethod - def from_dict(cls, data: dict) -> "ServiceContext": - from llama_index.embeddings.loading import load_embed_model - from llama_index.llm_predictor.loading import load_predictor - from llama_index.llms.loading import load_llm - from llama_index.node_parser.extractors.loading import load_extractor - from llama_index.node_parser.loading import load_parser - from llama_index.text_splitter.loading import load_text_splitter - - service_context_data = ServiceContextData.parse_obj(data) - - llm = load_llm(service_context_data.llm) - llm_predictor = load_predictor(service_context_data.llm_predictor, llm=llm) - - embed_model = load_embed_model(service_context_data.embed_model) - - prompt_helper = PromptHelper.from_dict(service_context_data.prompt_helper) - - extractors = None - if service_context_data.extractors: - extractors = [] - for extractor_dict in service_context_data.extractors: - extractors.append(load_extractor(extractor_dict, llm=llm)) - - metadata_extractor = None - if service_context_data.metadata_extractor: - metadata_extractor = load_extractor( - service_context_data.metadata_extractor, - extractors=extractors, - ) - - text_splitter = None - if service_context_data.text_splitter: - text_splitter = load_text_splitter(service_context_data.text_splitter) - - node_parser = load_parser( - service_context_data.node_parser, - text_splitter=text_splitter, - metadata_extractor=metadata_extractor, - ) - - return cls.from_defaults( - llm_predictor=llm_predictor, - prompt_helper=prompt_helper, - embed_model=embed_model, - node_parser=node_parser, - ) - - -def set_global_service_context(service_context: Optional[ServiceContext]) -> None: - """Helper function to set the global service context.""" - llama_index.global_service_context = service_context +__all__ = [ + "ServiceContext", +] diff --git a/llama_index/indices/struct_store/base.py b/llama_index/indices/struct_store/base.py index 1657c4197ec97d498fd8a1eede79292757cf9dbd..e191701795072fd4798b069e4796c9da40545f01 100644 --- a/llama_index/indices/struct_store/base.py +++ b/llama_index/indices/struct_store/base.py @@ -5,10 +5,10 @@ from typing import Any, Callable, Dict, Generic, Optional, Sequence, TypeVar from llama_index.data_structs.table import BaseStructTable from llama_index.indices.base import BaseIndex -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_SCHEMA_EXTRACT_PROMPT from llama_index.schema import BaseNode +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo BST = TypeVar("BST", bound=BaseStructTable) diff --git a/llama_index/indices/struct_store/container_builder.py b/llama_index/indices/struct_store/container_builder.py index 011d92cbbd289573ceabd3247491630a9c172b79..35725ccb580a922b51593691d3adc001fba5d3a3 100644 --- a/llama_index/indices/struct_store/container_builder.py +++ b/llama_index/indices/struct_store/container_builder.py @@ -6,9 +6,8 @@ from typing import Any, Dict, List, Optional, Type from llama_index.indices.base import BaseIndex from llama_index.indices.common.struct_store.base import SQLDocumentContextBuilder from llama_index.indices.common.struct_store.schema import SQLContextContainer -from llama_index.indices.query.schema import QueryType from llama_index.readers.base import Document -from llama_index.schema import BaseNode +from llama_index.schema import BaseNode, QueryType from llama_index.utilities.sql_wrapper import SQLDatabase DEFAULT_CONTEXT_QUERY_TMPL = ( diff --git a/llama_index/indices/struct_store/json_query.py b/llama_index/indices/struct_store/json_query.py index a749db43f152d047429825280eabcaca71fe0357..943d6b68ff847751ef2055ec37001df1a816a34c 100644 --- a/llama_index/indices/struct_store/json_query.py +++ b/llama_index/indices/struct_store/json_query.py @@ -2,14 +2,14 @@ import json import logging from typing import Any, Callable, Dict, List, Optional, Union -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseQueryEngine from llama_index.prompts import BasePromptTemplate, PromptTemplate from llama_index.prompts.default_prompts import DEFAULT_JSON_PATH_PROMPT from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.prompts.prompt_type import PromptType from llama_index.response.schema import Response +from llama_index.schema import QueryBundle +from llama_index.service_context import ServiceContext from llama_index.utils import print_text logger = logging.getLogger(__name__) diff --git a/llama_index/indices/struct_store/pandas.py b/llama_index/indices/struct_store/pandas.py index 16b41b641992ca3a1306f510cb5db7db18e0e3a1..129b6e927a5081320b8f251a875068e45cb0cfc1 100644 --- a/llama_index/indices/struct_store/pandas.py +++ b/llama_index/indices/struct_store/pandas.py @@ -5,9 +5,8 @@ from typing import Any, Optional, Sequence import pandas as pd +from llama_index.core import BaseQueryEngine, BaseRetriever from llama_index.data_structs.table import PandasStructTable -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.struct_store.base import BaseStructStoreIndex from llama_index.schema import BaseNode diff --git a/llama_index/indices/struct_store/sql.py b/llama_index/indices/struct_store/sql.py index b2ddb5684b5de988e9c60d875d6b5731a161091d..32ca4425a0ddb19f2f6e30c58b52aa7ac883102b 100644 --- a/llama_index/indices/struct_store/sql.py +++ b/llama_index/indices/struct_store/sql.py @@ -5,17 +5,16 @@ from typing import Any, Optional, Sequence, Union from sqlalchemy import Table +from llama_index.core import BaseQueryEngine, BaseRetriever from llama_index.data_structs.table import SQLStructTable -from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.common.struct_store.schema import SQLContextContainer from llama_index.indices.common.struct_store.sql import SQLStructDatapointExtractor -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.service_context import ServiceContext from llama_index.indices.struct_store.base import BaseStructStoreIndex from llama_index.indices.struct_store.container_builder import ( SQLContextContainerBuilder, ) from llama_index.schema import BaseNode +from llama_index.service_context import ServiceContext from llama_index.utilities.sql_wrapper import SQLDatabase diff --git a/llama_index/indices/struct_store/sql_query.py b/llama_index/indices/struct_store/sql_query.py index de66d5d550e18b0ed5d14afb6bf9d6c49314b59f..a543d18d4946c1f33ddb1786c08b6e3cc637ab42 100644 --- a/llama_index/indices/struct_store/sql_query.py +++ b/llama_index/indices/struct_store/sql_query.py @@ -5,9 +5,7 @@ from typing import Any, Dict, List, Optional, Union, cast from sqlalchemy import Table -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseQueryEngine from llama_index.indices.struct_store.container_builder import ( SQLContextContainerBuilder, ) @@ -26,6 +24,8 @@ from llama_index.response.schema import Response from llama_index.response_synthesizers import ( get_response_synthesizer, ) +from llama_index.schema import QueryBundle +from llama_index.service_context import ServiceContext from llama_index.utilities.sql_wrapper import SQLDatabase logger = logging.getLogger(__name__) diff --git a/llama_index/indices/struct_store/sql_retriever.py b/llama_index/indices/struct_store/sql_retriever.py index 63b0066e2f950442f299eafe8dd27a2b3f616af8..2b3b17e17f54974ee32e9ad1c9851570efeede4e 100644 --- a/llama_index/indices/struct_store/sql_retriever.py +++ b/llama_index/indices/struct_store/sql_retriever.py @@ -7,10 +7,8 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from sqlalchemy import Table +from llama_index.core import BaseRetriever from llama_index.embeddings.base import BaseEmbedding -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle, QueryType -from llama_index.indices.service_context import ServiceContext from llama_index.objects.base import ObjectRetriever from llama_index.objects.table_node_mapping import SQLTableSchema from llama_index.prompts import BasePromptTemplate @@ -18,7 +16,8 @@ from llama_index.prompts.default_prompts import ( DEFAULT_TEXT_TO_SQL_PROMPT, ) from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType -from llama_index.schema import NodeWithScore, TextNode +from llama_index.schema import NodeWithScore, QueryBundle, QueryType, TextNode +from llama_index.service_context import ServiceContext from llama_index.utilities.sql_wrapper import SQLDatabase logger = logging.getLogger(__name__) diff --git a/llama_index/indices/tree/all_leaf_retriever.py b/llama_index/indices/tree/all_leaf_retriever.py index 91d208c2d2755e9597a5ba465f629eea56238d27..76d693ff1de95fd8b41c2406186ae3debcf746cc 100644 --- a/llama_index/indices/tree/all_leaf_retriever.py +++ b/llama_index/indices/tree/all_leaf_retriever.py @@ -3,12 +3,11 @@ import logging from typing import Any, List, cast +from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexGraph -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.base import TreeIndex from llama_index.indices.utils import get_sorted_node_list -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle logger = logging.getLogger(__name__) diff --git a/llama_index/indices/tree/base.py b/llama_index/indices/tree/base.py index daaf94b94dbd86bfff0e12bccf0ce64c8ce57d30..c1365f09db63bc13850cf66f4d9945b025ce414a 100644 --- a/llama_index/indices/tree/base.py +++ b/llama_index/indices/tree/base.py @@ -3,12 +3,12 @@ from enum import Enum from typing import Any, Dict, Optional, Sequence, Union +from llama_index.core import BaseRetriever + # from llama_index.data_structs.data_structs import IndexGraph from llama_index.data_structs.data_structs import IndexGraph from llama_index.indices.base import BaseIndex -from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.common_tree.base import GPTTreeIndexBuilder -from llama_index.indices.service_context import ServiceContext from llama_index.indices.tree.inserter import TreeIndexInserter from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import ( @@ -16,6 +16,7 @@ from llama_index.prompts.default_prompts import ( DEFAULT_SUMMARY_PROMPT, ) from llama_index.schema import BaseNode +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo diff --git a/llama_index/indices/tree/inserter.py b/llama_index/indices/tree/inserter.py index 5ff7829d71b393322c478821b9a20e4af5048a3b..1e8eb526e246cfdc398f14b7d4874d04f761104a 100644 --- a/llama_index/indices/tree/inserter.py +++ b/llama_index/indices/tree/inserter.py @@ -3,7 +3,6 @@ from typing import Optional, Sequence from llama_index.data_structs.data_structs import IndexGraph -from llama_index.indices.service_context import ServiceContext from llama_index.indices.tree.utils import get_numbered_text_from_nodes from llama_index.indices.utils import ( extract_numbers_given_response, @@ -15,6 +14,7 @@ from llama_index.prompts.default_prompts import ( DEFAULT_SUMMARY_PROMPT, ) from llama_index.schema import BaseNode, MetadataMode, TextNode +from llama_index.service_context import ServiceContext from llama_index.storage.docstore import BaseDocumentStore from llama_index.storage.docstore.registry import get_default_docstore diff --git a/llama_index/indices/tree/select_leaf_embedding_retriever.py b/llama_index/indices/tree/select_leaf_embedding_retriever.py index 94031f7f943d8745d9c59d5b0ad5ba16c2200f09..4f5dc95532f011b3a2ae15fd272cafe4e8d49734 100644 --- a/llama_index/indices/tree/select_leaf_embedding_retriever.py +++ b/llama_index/indices/tree/select_leaf_embedding_retriever.py @@ -3,10 +3,9 @@ import logging from typing import Dict, List, Tuple, cast -from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.select_leaf_retriever import TreeSelectLeafRetriever from llama_index.indices.utils import get_sorted_node_list -from llama_index.schema import BaseNode, MetadataMode +from llama_index.schema import BaseNode, MetadataMode, QueryBundle logger = logging.getLogger(__name__) diff --git a/llama_index/indices/tree/select_leaf_retriever.py b/llama_index/indices/tree/select_leaf_retriever.py index 2d3f34b543bebaa265863b7eb99c4703c20b02fb..0ef5858b7cbd3daff552f39522a9f36e8a157ce2 100644 --- a/llama_index/indices/tree/select_leaf_retriever.py +++ b/llama_index/indices/tree/select_leaf_retriever.py @@ -3,8 +3,7 @@ import logging from typing import Any, Dict, List, Optional, cast -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle +from llama_index.core import BaseRetriever from llama_index.indices.tree.base import TreeIndex from llama_index.indices.tree.utils import get_numbered_text_from_nodes from llama_index.indices.utils import ( @@ -20,7 +19,7 @@ from llama_index.prompts.default_prompts import ( ) from llama_index.response.schema import Response from llama_index.response_synthesizers import get_response_synthesizer -from llama_index.schema import BaseNode, MetadataMode, NodeWithScore +from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle from llama_index.utils import print_text, truncate_text logger = logging.getLogger(__name__) diff --git a/llama_index/indices/tree/tree_root_retriever.py b/llama_index/indices/tree/tree_root_retriever.py index 0456c5b558483edb84abca4386ac0542910a7b25..a79e33420017c318ae8d4b73b0cbc8333d45e827 100644 --- a/llama_index/indices/tree/tree_root_retriever.py +++ b/llama_index/indices/tree/tree_root_retriever.py @@ -2,11 +2,10 @@ import logging from typing import Any, List -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle +from llama_index.core import BaseRetriever from llama_index.indices.tree.base import TreeIndex from llama_index.indices.utils import get_sorted_node_list -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle logger = logging.getLogger(__name__) diff --git a/llama_index/indices/tree/utils.py b/llama_index/indices/tree/utils.py index c61bc62a8c3b32486c1a6232b7a83b162de14bb1..8c71ecf9d02b54985f5a2e6eba5484a84254866f 100644 --- a/llama_index/indices/tree/utils.py +++ b/llama_index/indices/tree/utils.py @@ -1,8 +1,8 @@ from typing import List, Optional +from llama_index.node_parser.text import TokenTextSplitter +from llama_index.node_parser.text.utils import truncate_text from llama_index.schema import BaseNode -from llama_index.text_splitter import TokenTextSplitter -from llama_index.text_splitter.utils import truncate_text def get_numbered_text_from_nodes( diff --git a/llama_index/indices/vector_store/base.py b/llama_index/indices/vector_store/base.py index 76ed0ee8e41429f2a742570cf52d0c4bcca877db..c8b143e89536d461e481016cb6797407966a1fe8 100644 --- a/llama_index/indices/vector_store/base.py +++ b/llama_index/indices/vector_store/base.py @@ -7,12 +7,12 @@ import logging from typing import Any, Dict, List, Optional, Sequence from llama_index.async_utils import run_async_tasks +from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexDict from llama_index.indices.base import BaseIndex -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import async_embed_nodes, embed_nodes from llama_index.schema import BaseNode, ImageNode, IndexNode +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.types import RefDocInfo from llama_index.storage.storage_context import StorageContext from llama_index.vector_stores.types import VectorStore diff --git a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py index 0d4929230189db4dd75c5c23302698ed6cb67661..c2e661b6d725a04ead62454405f474ca01fb9261 100644 --- a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py +++ b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py @@ -2,9 +2,7 @@ import logging from typing import Any, List, Optional, cast from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseRetriever from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.indices.vector_store.retrievers import VectorIndexRetriever from llama_index.indices.vector_store.retrievers.auto_retriever.output_parser import ( @@ -15,7 +13,8 @@ from llama_index.indices.vector_store.retrievers.auto_retriever.prompts import ( VectorStoreQueryPrompt, ) from llama_index.output_parsers.base import OutputParserException, StructuredOutput -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext from llama_index.vector_stores.types import ( MetadataFilters, VectorStoreInfo, diff --git a/llama_index/indices/vector_store/retrievers/retriever.py b/llama_index/indices/vector_store/retrievers/retriever.py index d5c71ba59364a5f14070c2105d76f16a0f9811f0..64c886913f014b73f06a5a455558d8622d92860b 100644 --- a/llama_index/indices/vector_store/retrievers/retriever.py +++ b/llama_index/indices/vector_store/retrievers/retriever.py @@ -4,12 +4,11 @@ from typing import Any, Dict, List, Optional from llama_index.constants import DEFAULT_SIMILARITY_TOP_K +from llama_index.core import BaseRetriever from llama_index.data_structs.data_structs import IndexDict -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle from llama_index.indices.utils import log_vector_store_query_result from llama_index.indices.vector_store.base import VectorStoreIndex -from llama_index.schema import NodeWithScore, ObjectType +from llama_index.schema import NodeWithScore, ObjectType, QueryBundle from llama_index.vector_stores.types import ( MetadataFilters, VectorStoreQuery, diff --git a/llama_index/ingestion/__init__.py b/llama_index/ingestion/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..81219aa40728de417fa56e9d760ab1e62adafaf6 --- /dev/null +++ b/llama_index/ingestion/__init__.py @@ -0,0 +1,13 @@ +from llama_index.ingestion.cache import IngestionCache +from llama_index.ingestion.pipeline import ( + IngestionPipeline, + arun_transformations, + run_transformations, +) + +__all__ = [ + "IngestionCache", + "IngestionPipeline", + "run_transformations", + "arun_transformations", +] diff --git a/llama_index/ingestion/cache.py b/llama_index/ingestion/cache.py new file mode 100644 index 0000000000000000000000000000000000000000..7fcfbddcd31e47f142368feb9043d1ae53ea7112 --- /dev/null +++ b/llama_index/ingestion/cache.py @@ -0,0 +1,92 @@ +from typing import List, Optional + +import fsspec + +from llama_index.bridge.pydantic import BaseModel, Field +from llama_index.schema import BaseNode +from llama_index.storage.docstore.utils import doc_to_json, json_to_doc +from llama_index.storage.kvstore import ( + FirestoreKVStore as FirestoreCache, +) +from llama_index.storage.kvstore import ( + MongoDBKVStore as MongoDBCache, +) +from llama_index.storage.kvstore import ( + RedisKVStore as RedisCache, +) +from llama_index.storage.kvstore import ( + SimpleKVStore as SimpleCache, +) +from llama_index.storage.kvstore.types import ( + BaseKVStore as BaseCache, +) + +DEFAULT_CACHE_NAME = "llama_cache" + + +class IngestionCache(BaseModel): + class Config: + arbitrary_types_allowed = True + + nodes_key = "nodes" + + collection: str = Field( + default=DEFAULT_CACHE_NAME, description="Collection name of the cache." + ) + cache: BaseCache = Field(default_factory=SimpleCache, description="Cache to use.") + + # TODO: add async get/put methods? + def put( + self, key: str, nodes: List[BaseNode], collection: Optional[str] = None + ) -> None: + """Put a value into the cache.""" + collection = collection or self.collection + + val = {self.nodes_key: [doc_to_json(node) for node in nodes]} + self.cache.put(key, val, collection=collection) + + def get( + self, key: str, collection: Optional[str] = None + ) -> Optional[List[BaseNode]]: + """Get a value from the cache.""" + collection = collection or self.collection + node_dicts = self.cache.get(key, collection=collection) + + if node_dicts is None: + return None + + return [json_to_doc(node_dict) for node_dict in node_dicts[self.nodes_key]] + + def clear(self, collection: Optional[str] = None) -> None: + """Clear the cache.""" + collection = collection or self.collection + data = self.cache.get_all(collection=collection) + for key in data: + self.cache.delete(key, collection=collection) + + def persist( + self, persist_path: str, fs: Optional[fsspec.AbstractFileSystem] = None + ) -> None: + """Persist the cache to a directory, if possible.""" + if isinstance(self.cache, SimpleCache): + self.cache.persist(persist_path, fs=fs) + else: + print("Warning: skipping persist, only needed for SimpleCache.") + + @classmethod + def from_persist_path( + cls, persist_path: str, collection: str = DEFAULT_CACHE_NAME + ) -> "IngestionCache": + """Create a IngestionCache from a persist directory.""" + return cls( + collection=collection, + cache=SimpleCache.from_persist_path(persist_path), + ) + + +__all__ = [ + "SimpleCache", + "RedisCache", + "MongoDBCache", + "FirestoreCache", +] diff --git a/llama_index/ingestion/pipeline.py b/llama_index/ingestion/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..5b59fde4d5a404f38f7e41c3ef36029c4c15531c --- /dev/null +++ b/llama_index/ingestion/pipeline.py @@ -0,0 +1,250 @@ +import re +from hashlib import sha256 +from typing import Any, List, Optional, Sequence + +from llama_index.bridge.pydantic import BaseModel, Field +from llama_index.embeddings.utils import resolve_embed_model +from llama_index.ingestion.cache import IngestionCache +from llama_index.node_parser import SentenceSplitter +from llama_index.readers.base import ReaderConfig +from llama_index.schema import BaseNode, Document, MetadataMode, TransformComponent +from llama_index.service_context import ServiceContext +from llama_index.vector_stores.types import BasePydanticVectorStore + + +def remove_unstable_values(s: str) -> str: + """Remove unstable key/value pairs. + + Examples include: + - <__main__.Test object at 0x7fb9f3793f50> + - <function test_fn at 0x7fb9f37a8900> + """ + pattern = r"<[\w\s_\. ]+ at 0x[a-z0-9]+>" + return re.sub(pattern, "", s) + + +def get_transformation_hash( + nodes: List[BaseNode], transformation: TransformComponent +) -> str: + """Get the hash of a transformation.""" + nodes_str = "".join( + [str(node.get_content(metadata_mode=MetadataMode.ALL)) for node in nodes] + ) + + transformation_dict = transformation.to_dict() + transform_string = remove_unstable_values(str(transformation_dict)) + + return sha256((nodes_str + transform_string).encode("utf-8")).hexdigest() + + +def run_transformations( + nodes: List[BaseNode], + transformations: Sequence[TransformComponent], + in_place: bool = True, + cache: Optional[IngestionCache] = None, + cache_collection: Optional[str] = None, + **kwargs: Any, +) -> List[BaseNode]: + """Run a series of transformations on a set of nodes. + + Args: + nodes: The nodes to transform. + transformations: The transformations to apply to the nodes. + + Returns: + The transformed nodes. + """ + if not in_place: + nodes = list(nodes) + + for transform in transformations: + if cache is not None: + hash = get_transformation_hash(nodes, transform) + cached_nodes = cache.get(hash, collection=cache_collection) + if cached_nodes is not None: + nodes = cached_nodes + else: + nodes = transform(nodes, **kwargs) + cache.put(hash, nodes, collection=cache_collection) + else: + nodes = transform(nodes, **kwargs) + + return nodes + + +async def arun_transformations( + nodes: List[BaseNode], + transformations: Sequence[TransformComponent], + in_place: bool = True, + cache: Optional[IngestionCache] = None, + cache_collection: Optional[str] = None, + **kwargs: Any, +) -> List[BaseNode]: + """Run a series of transformations on a set of nodes. + + Args: + nodes: The nodes to transform. + transformations: The transformations to apply to the nodes. + + Returns: + The transformed nodes. + """ + if not in_place: + nodes = list(nodes) + + for transform in transformations: + if cache is not None: + hash = get_transformation_hash(nodes, transform) + + cached_nodes = cache.get(hash, collection=cache_collection) + if cached_nodes is not None: + nodes = cached_nodes + else: + nodes = await transform.acall(nodes, **kwargs) + cache.put(hash, nodes, collection=cache_collection) + else: + nodes = await transform.acall(nodes, **kwargs) + + return nodes + + +class IngestionPipeline(BaseModel): + """An ingestion pipeline that can be applied to data.""" + + transformations: List[TransformComponent] = Field( + description="Transformations to apply to the data" + ) + + documents: Optional[Sequence[Document]] = Field(description="Documents to ingest") + reader: Optional[ReaderConfig] = Field(description="Reader to use to read the data") + vector_store: Optional[BasePydanticVectorStore] = Field( + description="Vector store to use to store the data" + ) + cache: IngestionCache = Field( + default_factory=IngestionCache, + description="Cache to use to store the data", + ) + disable_cache: bool = Field(default=False, description="Disable the cache") + + def __init__( + self, + transformations: Optional[List[TransformComponent]] = None, + reader: Optional[ReaderConfig] = None, + documents: Optional[Sequence[Document]] = None, + vector_store: Optional[BasePydanticVectorStore] = None, + cache: Optional[IngestionCache] = None, + ) -> None: + if transformations is None: + transformations = self._get_default_transformations() + + super().__init__( + transformations=transformations, + reader=reader, + documents=documents, + vector_store=vector_store, + cache=cache or IngestionCache(), + ) + + @classmethod + def from_service_context( + cls, + service_context: ServiceContext, + reader: Optional[ReaderConfig] = None, + documents: Optional[Sequence[Document]] = None, + vector_store: Optional[BasePydanticVectorStore] = None, + cache: Optional[IngestionCache] = None, + ) -> "IngestionPipeline": + transformations = [ + *service_context.transformations, + service_context.embed_model, + ] + + return cls( + transformations=transformations, + reader=reader, + documents=documents, + vector_store=vector_store, + cache=cache, + ) + + def _get_default_transformations(self) -> List[TransformComponent]: + return [ + SentenceSplitter(), + resolve_embed_model("default"), + ] + + def run( + self, + show_progress: bool = False, + documents: Optional[List[Document]] = None, + nodes: Optional[List[BaseNode]] = None, + cache_collection: Optional[str] = None, + in_place: bool = True, + **kwargs: Any, + ) -> Sequence[BaseNode]: + input_nodes: List[BaseNode] = [] + if documents is not None: + input_nodes += documents + + if nodes is not None: + input_nodes += nodes + + if self.documents is not None: + input_nodes += self.documents + + if self.reader is not None: + input_nodes += self.reader.read() + + nodes = run_transformations( + input_nodes, + self.transformations, + show_progress=show_progress, + cache=self.cache if not self.disable_cache else None, + cache_collection=cache_collection, + in_place=in_place, + **kwargs, + ) + + if self.vector_store is not None: + self.vector_store.add([n for n in nodes if n.embedding is not None]) + + return nodes + + async def arun( + self, + show_progress: bool = False, + documents: Optional[List[Document]] = None, + nodes: Optional[List[BaseNode]] = None, + cache_collection: Optional[str] = None, + in_place: bool = True, + **kwargs: Any, + ) -> Sequence[BaseNode]: + input_nodes: List[BaseNode] = [] + if documents is not None: + input_nodes += documents + + if nodes is not None: + input_nodes += nodes + + if self.documents is not None: + input_nodes += self.documents + + if self.reader is not None: + input_nodes += self.reader.read() + + nodes = await arun_transformations( + input_nodes, + self.transformations, + show_progress=show_progress, + cache=self.cache if not self.disable_cache else None, + cache_collection=cache_collection, + in_place=in_place, + **kwargs, + ) + + if self.vector_store is not None: + await self.vector_store.async_add( + [n for n in nodes if n.embedding is not None] + ) + + return nodes diff --git a/llama_index/langchain_helpers/__init__.py b/llama_index/langchain_helpers/__init__.py index ced105858f1e3558ef6ed5556d0e548659b84440..8b8e0806869e408242cb61e3edac2e3fb4029970 100644 --- a/llama_index/langchain_helpers/__init__.py +++ b/llama_index/langchain_helpers/__init__.py @@ -1 +1,9 @@ """Init file for langchain helpers.""" + +try: + import langchain # noqa +except ImportError: + raise ImportError( + "langchain not installed. " + "Please install langchain with `pip install llama_index[langchain]`." + ) diff --git a/llama_index/langchain_helpers/agents/tools.py b/llama_index/langchain_helpers/agents/tools.py index 7287877e39948fb34f497563f4e9bcdde25eba6c..01801486dfe48cafced23ad1a035884535718d40 100644 --- a/llama_index/langchain_helpers/agents/tools.py +++ b/llama_index/langchain_helpers/agents/tools.py @@ -4,7 +4,7 @@ from typing import Any, Dict, List from llama_index.bridge.langchain import BaseTool from llama_index.bridge.pydantic import BaseModel, Field -from llama_index.indices.query.base import BaseQueryEngine +from llama_index.core import BaseQueryEngine from llama_index.response.schema import RESPONSE_TYPE from llama_index.schema import TextNode diff --git a/llama_index/llm_predictor/base.py b/llama_index/llm_predictor/base.py index 9e2e74985d34d9b54f48ccbe03a4ab8b59290b15..79444c07d79d9935430af1bb09eb039362bf4801 100644 --- a/llama_index/llm_predictor/base.py +++ b/llama_index/llm_predictor/base.py @@ -3,7 +3,9 @@ import logging from abc import ABC, abstractmethod from collections import ChainMap -from typing import Any, List, Optional +from typing import Any, Dict, List, Optional + +from typing_extensions import Self from llama_index.bridge.pydantic import BaseModel, PrivateAttr from llama_index.callbacks.base import CallbackManager @@ -26,6 +28,16 @@ logger = logging.getLogger(__name__) class BaseLLMPredictor(BaseComponent, ABC): """Base LLM Predictor.""" + def dict(self, **kwargs: Any) -> Dict[str, Any]: + data = super().dict(**kwargs) + data["llm"] = self.llm.to_dict() + return data + + def to_dict(self, **kwargs: Any) -> Dict[str, Any]: + data = super().to_dict(**kwargs) + data["llm"] = self.llm.to_dict() + return data + @property @abstractmethod def llm(self) -> LLM: @@ -100,6 +112,22 @@ class LLMPredictor(BaseLLMPredictor): pydantic_program_mode=pydantic_program_mode, ) + @classmethod + def from_dict(cls, data: Dict[str, Any], **kwargs: Any) -> Self: # type: ignore + if isinstance(kwargs, dict): + data.update(kwargs) + + data.pop("class_name", None) + + llm = data.get("llm", "default") + if llm != "default": + from llama_index.llms.loading import load_llm + + llm = load_llm(llm) + + data["llm"] = llm + return cls(**data) + @classmethod def class_name(cls) -> str: return "LLMPredictor" diff --git a/llama_index/llm_predictor/loading.py b/llama_index/llm_predictor/loading.py index ff08dc13ad1e8e12e8d96601acf2a2e2a8808102..aabbf3317fcffa85e4a8090390803c514aa3d4e8 100644 --- a/llama_index/llm_predictor/loading.py +++ b/llama_index/llm_predictor/loading.py @@ -1,22 +1,21 @@ -from typing import Optional - from llama_index.llm_predictor.base import BaseLLMPredictor, LLMPredictor from llama_index.llm_predictor.mock import MockLLMPredictor from llama_index.llm_predictor.structured import StructuredLLMPredictor from llama_index.llm_predictor.vellum.predictor import VellumPredictor -from llama_index.llms.base import LLM -def load_predictor(data: dict, llm: Optional[LLM] = None) -> BaseLLMPredictor: +def load_predictor(data: dict) -> BaseLLMPredictor: """Load predictor by class name.""" + if isinstance(data, BaseLLMPredictor): + return data predictor_name = data.get("class_name", None) if predictor_name is None: raise ValueError("Predictor loading requires a class_name") if predictor_name == LLMPredictor.class_name(): - return LLMPredictor.from_dict(data, llm=llm) + return LLMPredictor.from_dict(data) elif predictor_name == StructuredLLMPredictor.class_name(): - return StructuredLLMPredictor.from_dict(data, llm=llm) + return StructuredLLMPredictor.from_dict(data) elif predictor_name == MockLLMPredictor.class_name(): return MockLLMPredictor.from_dict(data) elif predictor_name == VellumPredictor.class_name(): diff --git a/llama_index/llm_predictor/mock.py b/llama_index/llm_predictor/mock.py index d72cd6711d4dabc7c658fa6ce0652ccd9e5dccf6..7ddaf99f4d0f52351ac0cd04440c7613c6b767ed 100644 --- a/llama_index/llm_predictor/mock.py +++ b/llama_index/llm_predictor/mock.py @@ -13,7 +13,7 @@ from llama_index.token_counter.utils import ( mock_extract_kg_triplets_response, ) from llama_index.types import TokenAsyncGen, TokenGen -from llama_index.utils import globals_helper +from llama_index.utils import get_tokenizer # TODO: consolidate with unit tests in tests/mock_utils/mock_predict.py @@ -21,7 +21,7 @@ from llama_index.utils import globals_helper def _mock_summary_predict(max_tokens: int, prompt_args: Dict) -> str: """Mock summary predict.""" # tokens in response shouldn't be larger than tokens in `context_str` - num_text_tokens = len(globals_helper.tokenizer(prompt_args["context_str"])) + num_text_tokens = len(get_tokenizer()(prompt_args["context_str"])) token_limit = min(num_text_tokens, max_tokens) return " ".join(["summary"] * token_limit) @@ -45,7 +45,7 @@ def _mock_query_select_multiple(num_chunks: int) -> str: def _mock_answer(max_tokens: int, prompt_args: Dict) -> str: """Mock answer.""" # tokens in response shouldn't be larger than tokens in `text` - num_ctx_tokens = len(globals_helper.tokenizer(prompt_args["context_str"])) + num_ctx_tokens = len(get_tokenizer()(prompt_args["context_str"])) token_limit = min(num_ctx_tokens, max_tokens) return " ".join(["answer"] * token_limit) @@ -59,8 +59,8 @@ def _mock_refine(max_tokens: int, prompt: BasePromptTemplate, prompt_args: Dict) existing_answer = prompt.kwargs["existing_answer"] else: existing_answer = prompt_args["existing_answer"] - num_ctx_tokens = len(globals_helper.tokenizer(prompt_args["context_msg"])) - num_exist_tokens = len(globals_helper.tokenizer(existing_answer)) + num_ctx_tokens = len(get_tokenizer()(prompt_args["context_msg"])) + num_exist_tokens = len(get_tokenizer()(existing_answer)) token_limit = min(num_ctx_tokens + num_exist_tokens, max_tokens) return " ".join(["answer"] * token_limit) diff --git a/llama_index/llms/__init__.py b/llama_index/llms/__init__.py index 4eef3c908940aa99f9882ca586ec0306f413caa5..177d54da3af4ad5cce8078ed691a321ee1312b29 100644 --- a/llama_index/llms/__init__.py +++ b/llama_index/llms/__init__.py @@ -3,6 +3,7 @@ from llama_index.llms.anthropic import Anthropic from llama_index.llms.anyscale import Anyscale from llama_index.llms.azure_openai import AzureOpenAI from llama_index.llms.base import ( + LLM, ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -47,6 +48,7 @@ __all__ = [ "ChatMessage", "ChatResponse", "ChatResponseAsyncGen", + "LLM", "ChatResponseGen", "Clarifai", "Cohere", diff --git a/llama_index/llms/anthropic.py b/llama_index/llms/anthropic.py index 82d79feb81b3bf30c191d800212b9d00a3160fba..0cf10ad7c83150ad7accc73bacd122e8d1bcba9a 100644 --- a/llama_index/llms/anthropic.py +++ b/llama_index/llms/anthropic.py @@ -2,6 +2,7 @@ from typing import Any, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager +from llama_index.constants import DEFAULT_TEMPERATURE from llama_index.llms.anthropic_utils import ( anthropic_modelname_to_contextsize, messages_to_anthropic_prompt, @@ -27,18 +28,32 @@ from llama_index.llms.generic_utils import ( stream_chat_to_completion_decorator, ) +DEFAULT_ANTHROPIC_MODEL = "claude-2" +DEFAULT_ANTHROPIC_MAX_TOKENS = 512 + class Anthropic(LLM): - model: str = Field(description="The anthropic model to use.") - temperature: float = Field(description="The temperature to use for sampling.") - max_tokens: int = Field(description="The maximum number of tokens to generate.") + model: str = Field( + default=DEFAULT_ANTHROPIC_MODEL, description="The anthropic model to use." + ) + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description="The temperature to use for sampling.", + gte=0.0, + lte=1.0, + ) + max_tokens: int = Field( + default=DEFAULT_ANTHROPIC_MAX_TOKENS, + description="The maximum number of tokens to generate.", + gt=0, + ) base_url: Optional[str] = Field(default=None, description="The base URL to use.") timeout: Optional[float] = Field( - default=None, description="The timeout to use in seconds." + default=None, description="The timeout to use in seconds.", gte=0 ) max_retries: int = Field( - default=10, description="The maximum number of API retries." + default=10, description="The maximum number of API retries.", gte=0 ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the anthropic API." @@ -49,9 +64,9 @@ class Anthropic(LLM): def __init__( self, - model: str = "claude-2", - temperature: float = 0.1, - max_tokens: int = 512, + model: str = DEFAULT_ANTHROPIC_MODEL, + temperature: float = DEFAULT_TEMPERATURE, + max_tokens: int = DEFAULT_ANTHROPIC_MAX_TOKENS, base_url: Optional[str] = None, timeout: Optional[float] = None, max_retries: int = 10, diff --git a/llama_index/llms/anthropic_utils.py b/llama_index/llms/anthropic_utils.py index ec396314271b7672c264cea2916c18f182b8bee9..e7c73d5f8c4909f174a4dfe35300379dff503b16 100644 --- a/llama_index/llms/anthropic_utils.py +++ b/llama_index/llms/anthropic_utils.py @@ -1,4 +1,4 @@ -from typing import Sequence +from typing import Dict, Sequence from llama_index.llms.base import ChatMessage, MessageRole @@ -6,7 +6,7 @@ HUMAN_PREFIX = "\n\nHuman:" ASSISTANT_PREFIX = "\n\nAssistant:" -CLAUDE_MODELS = { +CLAUDE_MODELS: Dict[str, int] = { "claude-instant-1": 100000, "claude-instant-1.2": 100000, "claude-2": 100000, diff --git a/llama_index/llms/anyscale.py b/llama_index/llms/anyscale.py index 3ef1d6c9b694b762e4d3f54398783955144557bb..714aa86858bff4a4b62ef1d4ce0816a9b8c6be9f 100644 --- a/llama_index/llms/anyscale.py +++ b/llama_index/llms/anyscale.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional from llama_index.callbacks import CallbackManager +from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE from llama_index.llms.anyscale_utils import ( anyscale_modelname_to_contextsize, ) @@ -18,8 +19,8 @@ class Anyscale(OpenAI): def __init__( self, model: str = DEFAULT_MODEL, - temperature: float = 0.1, - max_tokens: int = 256, + temperature: float = DEFAULT_TEMPERATURE, + max_tokens: int = DEFAULT_NUM_OUTPUTS, additional_kwargs: Optional[Dict[str, Any]] = None, max_retries: int = 10, api_base: Optional[str] = DEFAULT_API_BASE, diff --git a/llama_index/llms/everlyai.py b/llama_index/llms/everlyai.py index 55b53feaa9752742e3774860d124bc9de70f2779..1ff6404b5935574e09738051dc0d6244e9686566 100644 --- a/llama_index/llms/everlyai.py +++ b/llama_index/llms/everlyai.py @@ -1,6 +1,7 @@ from typing import Any, Dict, Optional from llama_index.callbacks import CallbackManager +from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE from llama_index.llms.base import LLMMetadata from llama_index.llms.everlyai_utils import everlyai_modelname_to_contextsize from llama_index.llms.generic_utils import get_from_param_or_env @@ -14,8 +15,8 @@ class EverlyAI(OpenAI): def __init__( self, model: str = DEFAULT_MODEL, - temperature: float = 0.1, - max_tokens: int = 256, + temperature: float = DEFAULT_TEMPERATURE, + max_tokens: int = DEFAULT_NUM_OUTPUTS, additional_kwargs: Optional[Dict[str, Any]] = None, max_retries: int = 10, api_key: Optional[str] = None, diff --git a/llama_index/llms/gradient.py b/llama_index/llms/gradient.py index 5512e01346078c632e37723abfb2db88a7c506a5..cadcdcf6ef3120e5e3acb339f49f296189da5e3f 100644 --- a/llama_index/llms/gradient.py +++ b/llama_index/llms/gradient.py @@ -4,6 +4,7 @@ from typing_extensions import override from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager +from llama_index.constants import DEFAULT_NUM_OUTPUTS from llama_index.llms import ( CompletionResponse, CompletionResponseGen, @@ -19,6 +20,7 @@ class _BaseGradientLLM(CustomLLM): # Config max_tokens: Optional[int] = Field( + default=DEFAULT_NUM_OUTPUTS, description="The number of tokens to generate.", gt=0, lt=512, diff --git a/llama_index/llms/huggingface.py b/llama_index/llms/huggingface.py index 3dbda246ce1f9453aa510476ececb6be6e0f9213..b900550eba585091320963808dac1f34d08ca939 100644 --- a/llama_index/llms/huggingface.py +++ b/llama_index/llms/huggingface.py @@ -4,7 +4,10 @@ from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS +from llama_index.constants import ( + DEFAULT_CONTEXT_WINDOW, + DEFAULT_NUM_OUTPUTS, +) from llama_index.llms import ChatResponseAsyncGen, CompletionResponseAsyncGen from llama_index.llms.base import ( LLM, @@ -28,6 +31,7 @@ from llama_index.llms.generic_utils import ( ) from llama_index.prompts.base import PromptTemplate +DEFAULT_HUGGINGFACE_MODEL = "StabilityAI/stablelm-tuned-alpha-3b" if TYPE_CHECKING: try: from huggingface_hub import AsyncInferenceClient, InferenceClient @@ -46,22 +50,31 @@ class HuggingFaceLLM(CustomLLM): """HuggingFace LLM.""" model_name: str = Field( + default=DEFAULT_HUGGINGFACE_MODEL, description=( "The model name to use from HuggingFace. " "Unused if `model` is passed in directly." - ) + ), ) context_window: int = Field( - description="The maximum number of tokens available for input." + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of tokens available for input.", + gt=0, + ) + max_new_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The maximum number of tokens to generate.", + gt=0, ) - max_new_tokens: int = Field(description="The maximum number of tokens to generate.") system_prompt: str = Field( + default="", description=( "The system prompt, containing any extra instructions or context. " "The model card on HuggingFace should specify if this is needed." ), ) query_wrapper_prompt: str = Field( + default="{query_str}", description=( "The query wrapper prompt, containing the query placeholder. " "The model card on HuggingFace should specify if this is needed. " @@ -69,13 +82,14 @@ class HuggingFaceLLM(CustomLLM): ), ) tokenizer_name: str = Field( + default=DEFAULT_HUGGINGFACE_MODEL, description=( "The name of the tokenizer to use from HuggingFace. " "Unused if `tokenizer` is passed in directly." - ) + ), ) - device_map: Optional[str] = Field( - description="The device_map to use. Defaults to 'auto'." + device_map: str = Field( + default="auto", description="The device_map to use. Defaults to 'auto'." ) stopping_ids: List[int] = Field( default_factory=list, @@ -110,12 +124,12 @@ class HuggingFaceLLM(CustomLLM): def __init__( self, - context_window: int = 4096, - max_new_tokens: int = 256, + context_window: int = DEFAULT_CONTEXT_WINDOW, + max_new_tokens: int = DEFAULT_NUM_OUTPUTS, system_prompt: str = "", query_wrapper_prompt: Union[str, PromptTemplate] = "{query_str}", - tokenizer_name: str = "StabilityAI/stablelm-tuned-alpha-3b", - model_name: str = "StabilityAI/stablelm-tuned-alpha-3b", + tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL, + model_name: str = DEFAULT_HUGGINGFACE_MODEL, model: Optional[Any] = None, tokenizer: Optional[Any] = None, device_map: Optional[str] = "auto", diff --git a/llama_index/llms/konko.py b/llama_index/llms/konko.py index 871ca6d7f84fc2cb31e692f61aa6e085e669dcdb..10af19d63fc730bd4c315d4060e11a5dc50ad1e3 100644 --- a/llama_index/llms/konko.py +++ b/llama_index/llms/konko.py @@ -2,6 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager +from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE from llama_index.llms.base import ( LLM, ChatMessage, @@ -35,17 +36,30 @@ from llama_index.llms.konko_utils import ( to_openai_message_dicts, ) +DEFAULT_KONKO_MODEL = "meta-llama/Llama-2-13b-chat-hf" + class Konko(LLM): - model: str = Field(description="The konko model to use.") - temperature: float = Field(description="The temperature to use during generation.") + model: str = Field( + default=DEFAULT_KONKO_MODEL, description="The konko model to use." + ) + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description="The temperature to use during generation.", + gte=0.0, + lte=1.0, + ) max_tokens: Optional[int] = Field( - description="The maximum number of tokens to generate." + default=DEFAULT_NUM_OUTPUTS, + description="The maximum number of tokens to generate.", + gt=0, ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the konko API." ) - max_retries: int = Field(description="The maximum number of API retries.") + max_retries: int = Field( + default=10, description="The maximum number of API retries.", gte=0 + ) konko_api_key: str = Field(default=None, description="The konko API key.") openai_api_key: str = Field(default=None, description="The Openai API key.") @@ -55,9 +69,9 @@ class Konko(LLM): def __init__( self, - model: str = "meta-llama/Llama-2-13b-chat-hf", - temperature: float = 0.1, - max_tokens: Optional[int] = 256, + model: str = DEFAULT_KONKO_MODEL, + temperature: float = DEFAULT_TEMPERATURE, + max_tokens: Optional[int] = DEFAULT_NUM_OUTPUTS, additional_kwargs: Optional[Dict[str, Any]] = None, max_retries: int = 10, konko_api_key: Optional[str] = None, diff --git a/llama_index/llms/langchain.py b/llama_index/llms/langchain.py index 641c59b222497561d9f8f95894840df4f8617888..145f60a5f93a4b1ffb264850ba82430d2c3f07ca 100644 --- a/llama_index/llms/langchain.py +++ b/llama_index/llms/langchain.py @@ -1,11 +1,11 @@ from threading import Thread -from typing import Any, Generator, Optional, Sequence +from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence -from langchain.base_language import BaseLanguageModel +if TYPE_CHECKING: + from langchain.base_language import BaseLanguageModel from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.langchain_helpers.streaming import StreamingGeneratorCallbackHandler from llama_index.llms.base import ( LLM, ChatMessage, @@ -19,20 +19,17 @@ from llama_index.llms.base import ( llm_chat_callback, llm_completion_callback, ) -from llama_index.llms.langchain_utils import ( - from_lc_messages, - get_llm_metadata, - to_lc_messages, -) class LangChainLLM(LLM): """Adapter for a LangChain LLM.""" - _llm: BaseLanguageModel = PrivateAttr() + _llm: Any = PrivateAttr() def __init__( - self, llm: BaseLanguageModel, callback_manager: Optional[CallbackManager] = None + self, + llm: "BaseLanguageModel", + callback_manager: Optional[CallbackManager] = None, ) -> None: self._llm = llm super().__init__(callback_manager=callback_manager) @@ -42,15 +39,22 @@ class LangChainLLM(LLM): return "LangChainLLM" @property - def llm(self) -> BaseLanguageModel: + def llm(self) -> "BaseLanguageModel": return self._llm @property def metadata(self) -> LLMMetadata: + from llama_index.llms.langchain_utils import get_llm_metadata + return get_llm_metadata(self._llm) @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: + from llama_index.llms.langchain_utils import ( + from_lc_messages, + to_lc_messages, + ) + lc_messages = to_lc_messages(messages) lc_message = self._llm.predict_messages(messages=lc_messages, **kwargs) message = from_lc_messages([lc_message])[0] @@ -65,6 +69,10 @@ class LangChainLLM(LLM): def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: + from llama_index.langchain_helpers.streaming import ( + StreamingGeneratorCallbackHandler, + ) + handler = StreamingGeneratorCallbackHandler() if not hasattr(self._llm, "streaming"): @@ -93,6 +101,10 @@ class LangChainLLM(LLM): @llm_completion_callback() def stream_complete(self, prompt: str, **kwargs: Any) -> CompletionResponseGen: + from llama_index.langchain_helpers.streaming import ( + StreamingGeneratorCallbackHandler, + ) + handler = StreamingGeneratorCallbackHandler() if not hasattr(self._llm, "streaming"): diff --git a/llama_index/llms/litellm.py b/llama_index/llms/litellm.py index 77c2f317b1435c237f880a28bc4bcfabccca65bb..637855dbb763cd578e270c5ee65cab6de05e3c7f 100644 --- a/llama_index/llms/litellm.py +++ b/llama_index/llms/litellm.py @@ -2,6 +2,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager +from llama_index.constants import DEFAULT_TEMPERATURE from llama_index.llms.base import ( LLM, ChatMessage, @@ -35,26 +36,40 @@ from llama_index.llms.litellm_utils import ( validate_litellm_api_key, ) +DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo" + class LiteLLM(LLM): model: str = Field( - description="The LiteLLM model to use." - ) # For complete list of providers https://docs.litellm.ai/docs/providers - temperature: float = Field(description="The temperature to use during generation.") + default=DEFAULT_LITELLM_MODEL, + description=( + "The LiteLLM model to use. " + "For complete list of providers https://docs.litellm.ai/docs/providers" + ), + ) + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description="The temperature to use during generation.", + gte=0.0, + lte=1.0, + ) max_tokens: Optional[int] = Field( - description="The maximum number of tokens to generate." + description="The maximum number of tokens to generate.", + gt=0, ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the LLM API.", # for all inputs https://docs.litellm.ai/docs/completion/input ) - max_retries: int = Field(description="The maximum number of API retries.") + max_retries: int = Field( + default=10, description="The maximum number of API retries." + ) def __init__( self, - model: str = "gpt-3.5-turbo", - temperature: float = 0.1, + model: str = DEFAULT_LITELLM_MODEL, + temperature: float = DEFAULT_TEMPERATURE, max_tokens: Optional[int] = None, additional_kwargs: Optional[Dict[str, Any]] = None, max_retries: int = 10, diff --git a/llama_index/llms/llama_cpp.py b/llama_index/llms/llama_cpp.py index 735e14407dfd24f095eabe98e335fa66af51a39d..1ff5edec52ce5e8f323459df990677608063e0b2 100644 --- a/llama_index/llms/llama_cpp.py +++ b/llama_index/llms/llama_cpp.py @@ -6,7 +6,11 @@ from tqdm import tqdm from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS +from llama_index.constants import ( + DEFAULT_CONTEXT_WINDOW, + DEFAULT_NUM_OUTPUTS, + DEFAULT_TEMPERATURE, +) from llama_index.llms.base import ( ChatMessage, ChatResponse, @@ -45,11 +49,21 @@ class LlamaCPP(CustomLLM): model_path: Optional[str] = Field( description="The path to the llama-cpp model to use." ) - temperature: float = Field(description="The temperature to use for sampling.") - max_new_tokens: int = Field(description="The maximum number of tokens to generate.") + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description="The temperature to use for sampling.", + gte=0.0, + lte=1.0, + ) + max_new_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The maximum number of tokens to generate.", + gt=0, + ) context_window: int = Field( default=DEFAULT_CONTEXT_WINDOW, description="The maximum number of context tokens for the model.", + gt=0, ) messages_to_prompt: Callable = Field( description="The function to convert messages to a prompt.", exclude=True @@ -74,7 +88,7 @@ class LlamaCPP(CustomLLM): self, model_url: Optional[str] = None, model_path: Optional[str] = None, - temperature: float = 0.1, + temperature: float = DEFAULT_TEMPERATURE, max_new_tokens: int = DEFAULT_NUM_OUTPUTS, context_window: int = DEFAULT_CONTEXT_WINDOW, messages_to_prompt: Optional[Callable] = None, diff --git a/llama_index/llms/loading.py b/llama_index/llms/loading.py index d33f9227a8ded0e220d9235365dcce0ffcfe1f83..2cbb9e74ad06eb89a43f4e93c357456556c07b00 100644 --- a/llama_index/llms/loading.py +++ b/llama_index/llms/loading.py @@ -35,6 +35,8 @@ RECOGNIZED_LLMS: Dict[str, Type[LLM]] = { def load_llm(data: dict) -> LLM: """Load LLM by name.""" + if isinstance(data, LLM): + return data llm_name = data.get("class_name", None) if llm_name is None: raise ValueError("LLM loading requires a class_name") diff --git a/llama_index/llms/localai.py b/llama_index/llms/localai.py index 1ee2e1cc0113eb1ac2d4d61a15db12cad0e2a72b..e84ad83e58fb279a0c074ba2229dea772f65f183 100644 --- a/llama_index/llms/localai.py +++ b/llama_index/llms/localai.py @@ -23,6 +23,7 @@ class LocalAI(OpenAI): context_window: int = Field( default=DEFAULT_CONTEXT_WINDOW, description="The maximum number of context tokens for the model.", + gt=0, ) globally_use_chat_completions: Optional[bool] = Field( default=None, diff --git a/llama_index/llms/monsterapi.py b/llama_index/llms/monsterapi.py index 7fc52625dc3705d5b046d09763568d7ff2c0cde7..c7f12759ee909ca556fd8537d4a63e84a60dbd33 100644 --- a/llama_index/llms/monsterapi.py +++ b/llama_index/llms/monsterapi.py @@ -17,14 +17,27 @@ from llama_index.llms.generic_utils import ( messages_to_prompt as generic_messages_to_prompt, ) +DEFAULT_MONSTER_TEMP = 0.75 + class MonsterLLM(CustomLLM): model: str = Field(description="The MonsterAPI model to use.") monster_api_key: Optional[str] = Field(description="The MonsterAPI key to use.") - max_new_tokens: int = Field(description="The number of tokens to generate.") - temperature: float = Field(description="The temperature to use for sampling.") + max_new_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The number of tokens to generate.", + gt=0, + ) + temperature: float = Field( + default=DEFAULT_MONSTER_TEMP, + description="The temperature to use for sampling.", + gte=0.0, + lte=1.0, + ) context_window: int = Field( - description="The number of context tokens available to the LLM." + default=DEFAULT_CONTEXT_WINDOW, + description="The number of context tokens available to the LLM.", + gt=0, ) messages_to_prompt: Callable = Field( @@ -41,7 +54,7 @@ class MonsterLLM(CustomLLM): model: str, monster_api_key: Optional[str] = None, max_new_tokens: int = DEFAULT_NUM_OUTPUTS, - temperature: float = 0.75, + temperature: float = DEFAULT_MONSTER_TEMP, context_window: int = DEFAULT_CONTEXT_WINDOW, callback_manager: Optional[CallbackManager] = None, messages_to_prompt: Optional[Callable] = None, diff --git a/llama_index/llms/ollama.py b/llama_index/llms/ollama.py index fc12798e25d387be9030e6154b90bb81669c3047..bcc37c32c9c8d4d5caa90b5cfba198f666bda337 100644 --- a/llama_index/llms/ollama.py +++ b/llama_index/llms/ollama.py @@ -3,7 +3,11 @@ from typing import Any, Callable, Dict, Iterator, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS +from llama_index.constants import ( + DEFAULT_CONTEXT_WINDOW, + DEFAULT_NUM_OUTPUTS, + DEFAULT_TEMPERATURE, +) from llama_index.llms.base import ( ChatMessage, ChatResponse, @@ -27,11 +31,20 @@ from llama_index.llms.generic_utils import ( class Ollama(CustomLLM): base_url: str = Field(description="Base url the model is hosted under.") model: str = Field(description="The Ollama model to use.") - temperature: float = Field(description="The temperature to use for sampling.") + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description="The temperature to use for sampling.", + gte=0.0, + lte=1.0, + ) context_window: int = Field( - description="The maximum number of context tokens for the model." + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of context tokens for the model.", + gt=0, + ) + prompt_key: str = Field( + default="prompt", description="The key to use for the prompt in API calls." ) - prompt_key: str = Field(description="The key to use for the prompt in API calls.") additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the Ollama API." ) diff --git a/llama_index/llms/openai.py b/llama_index/llms/openai.py index d58a2f4a25708d3d3cecabe2174791c55ecc4d59..b0f7e116a204e4e84a07849ea9b6d6fa43e79aec 100644 --- a/llama_index/llms/openai.py +++ b/llama_index/llms/openai.py @@ -22,6 +22,9 @@ from openai.types.chat.chat_completion_chunk import ( from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager +from llama_index.constants import ( + DEFAULT_TEMPERATURE, +) from llama_index.llms.base import ( LLM, ChatMessage, @@ -55,6 +58,8 @@ from llama_index.llms.openai_utils import ( to_openai_message_dicts, ) +DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo" + @runtime_checkable class Tokenizer(Protocol): @@ -65,10 +70,18 @@ class Tokenizer(Protocol): class OpenAI(LLM): - model: str = Field(description="The OpenAI model to use.") - temperature: float = Field(description="The temperature to use during generation.") + model: str = Field( + default=DEFAULT_OPENAI_MODEL, description="The OpenAI model to use." + ) + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description="The temperature to use during generation.", + gte=0.0, + lte=1.0, + ) max_tokens: Optional[int] = Field( - default=None, description="The maximum number of tokens to generate." + description="The maximum number of tokens to generate.", + gt=0, ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the OpenAI API." @@ -93,8 +106,8 @@ class OpenAI(LLM): def __init__( self, - model: str = "gpt-3.5-turbo", - temperature: float = 0.1, + model: str = DEFAULT_OPENAI_MODEL, + temperature: float = DEFAULT_TEMPERATURE, max_tokens: Optional[int] = None, additional_kwargs: Optional[Dict[str, Any]] = None, max_retries: int = 3, diff --git a/llama_index/llms/openai_utils.py b/llama_index/llms/openai_utils.py index fd594571843f13c138e8820adb55a8b63cf4a5ad..3ba53e465085cdeae70c4bcd9a3dc2eec4570a19 100644 --- a/llama_index/llms/openai_utils.py +++ b/llama_index/llms/openai_utils.py @@ -28,7 +28,7 @@ DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1" DEFAULT_OPENAI_API_VERSION = "" -GPT4_MODELS = { +GPT4_MODELS: Dict[str, int] = { # stable model names: # resolves to gpt-4-0314 before 2023-06-27, # resolves to gpt-4-0613 after @@ -47,12 +47,12 @@ GPT4_MODELS = { "gpt-4-32k-0314": 32768, } -AZURE_TURBO_MODELS = { +AZURE_TURBO_MODELS: Dict[str, int] = { "gpt-35-turbo-16k": 16384, "gpt-35-turbo": 4096, } -TURBO_MODELS = { +TURBO_MODELS: Dict[str, int] = { # stable model names: # resolves to gpt-3.5-turbo-0301 before 2023-06-27, # resolves to gpt-3.5-turbo-0613 until 2023-12-11, @@ -71,14 +71,14 @@ TURBO_MODELS = { "gpt-3.5-turbo-0301": 4096, } -GPT3_5_MODELS = { +GPT3_5_MODELS: Dict[str, int] = { "text-davinci-003": 4097, "text-davinci-002": 4097, # instruct models "gpt-3.5-turbo-instruct": 4096, } -GPT3_MODELS = { +GPT3_MODELS: Dict[str, int] = { "text-ada-001": 2049, "text-babbage-001": 2040, "text-curie-001": 2049, diff --git a/llama_index/llms/palm.py b/llama_index/llms/palm.py index 91d2c453f9b45ba9b8ea4f37725ffd51d7392f79..d907acb0e19e7dda04f762bfc7b5af4baec32519 100644 --- a/llama_index/llms/palm.py +++ b/llama_index/llms/palm.py @@ -4,6 +4,7 @@ from typing import Any, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager +from llama_index.constants import DEFAULT_NUM_OUTPUTS from llama_index.llms.base import ( CompletionResponse, CompletionResponseGen, @@ -12,12 +13,20 @@ from llama_index.llms.base import ( ) from llama_index.llms.custom import CustomLLM +DEFAULT_PALM_MODEL = "models/text-bison-001" + class PaLM(CustomLLM): """PaLM LLM.""" - model_name: str = Field(description="The PaLM model to use.") - num_output: int = Field(description="The number of tokens to generate.") + model_name: str = Field( + default=DEFAULT_PALM_MODEL, description="The PaLM model to use." + ) + num_output: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The number of tokens to generate.", + gt=0, + ) generate_kwargs: dict = Field( default_factory=dict, description="Kwargs for generation." ) @@ -27,7 +36,7 @@ class PaLM(CustomLLM): def __init__( self, api_key: Optional[str] = None, - model_name: Optional[str] = "models/text-bison-001", + model_name: Optional[str] = DEFAULT_PALM_MODEL, num_output: Optional[int] = None, callback_manager: Optional[CallbackManager] = None, **generate_kwargs: Any, diff --git a/llama_index/llms/portkey.py b/llama_index/llms/portkey.py index c9e444af99946feaa30d4dfb7d92228b7cb8b318..5acf56b481aba3c0cd12117fc05e0016c11d84e6 100644 --- a/llama_index/llms/portkey.py +++ b/llama_index/llms/portkey.py @@ -36,6 +36,8 @@ if TYPE_CHECKING: PortkeyResponse, ) +DEFAULT_PORTKEY_MODEL = "gpt-3.5-turbo" + class Portkey(CustomLLM): """_summary_. @@ -48,7 +50,7 @@ class Portkey(CustomLLM): description="The mode for using the Portkey integration" ) - model: Optional[str] = Field(default="gpt-3.5-turbo") + model: Optional[str] = Field(default=DEFAULT_PORTKEY_MODEL) llm: "LLMOptions" = Field(description="LLM parameter", default_factory=dict) llms: List["LLMOptions"] = Field(description="LLM parameters", default_factory=list) diff --git a/llama_index/llms/predibase.py b/llama_index/llms/predibase.py index a8cb28fa8dfdf49850f2eb54126c788a3c6d2c68..c7993044cb69ebf36292b4f55d55829309633978 100644 --- a/llama_index/llms/predibase.py +++ b/llama_index/llms/predibase.py @@ -3,7 +3,11 @@ from typing import Any, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.constants import DEFAULT_CONTEXT_WINDOW +from llama_index.constants import ( + DEFAULT_CONTEXT_WINDOW, + DEFAULT_NUM_OUTPUTS, + DEFAULT_TEMPERATURE, +) from llama_index.llms.base import ( CompletionResponse, CompletionResponseGen, @@ -18,10 +22,21 @@ class PredibaseLLM(CustomLLM): model_name: str = Field(description="The Predibase model to use.") predibase_api_key: str = Field(description="The Predibase API key to use.") - max_new_tokens: int = Field(description="The number of tokens to generate.") - temperature: float = Field(description="The temperature to use for sampling.") + max_new_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="The number of tokens to generate.", + gt=0, + ) + temperature: float = Field( + default=DEFAULT_TEMPERATURE, + description="The temperature to use for sampling.", + gte=0.0, + lte=1.0, + ) context_window: int = Field( - description="The number of context tokens available to the LLM." + default=DEFAULT_CONTEXT_WINDOW, + description="The number of context tokens available to the LLM.", + gt=0, ) _client: Any = PrivateAttr() @@ -30,8 +45,8 @@ class PredibaseLLM(CustomLLM): self, model_name: str, predibase_api_key: Optional[str] = None, - max_new_tokens: int = 256, - temperature: float = 0.1, + max_new_tokens: int = DEFAULT_NUM_OUTPUTS, + temperature: float = DEFAULT_TEMPERATURE, context_window: int = DEFAULT_CONTEXT_WINDOW, callback_manager: Optional[CallbackManager] = None, ) -> None: diff --git a/llama_index/llms/replicate.py b/llama_index/llms/replicate.py index fe81e2ad85bf0f7a15a7c3d297e24ba7d3387b58..8283ffe54064fb52918c36cb4f834678473b1aef 100644 --- a/llama_index/llms/replicate.py +++ b/llama_index/llms/replicate.py @@ -22,15 +22,24 @@ from llama_index.llms.generic_utils import ( messages_to_prompt as generic_messages_to_prompt, ) +DEFAULT_REPLICATE_TEMP = 0.75 + class Replicate(CustomLLM): model: str = Field(description="The Replicate model to use.") + temperature: float = Field( + default=DEFAULT_REPLICATE_TEMP, + description="The temperature to use for sampling.", + gte=0.01, + lte=1.0, + ) image: str = Field( description="The image file for multimodal model to use. (optional)" ) - temperature: float = Field(description="The temperature to use for sampling.") context_window: int = Field( - description="The maximum number of context tokens for the model." + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of context tokens for the model.", + gt=0, ) prompt_key: str = Field(description="The key to use for the prompt in API calls.") additional_kwargs: Dict[str, Any] = Field( @@ -46,8 +55,8 @@ class Replicate(CustomLLM): def __init__( self, model: str, + temperature: float = DEFAULT_REPLICATE_TEMP, image: Optional[str] = "", - temperature: float = 0.75, additional_kwargs: Optional[Dict[str, Any]] = None, context_window: int = DEFAULT_CONTEXT_WINDOW, prompt_key: str = "prompt", diff --git a/llama_index/llms/rungpt.py b/llama_index/llms/rungpt.py index 953693f832c16ee7137906eb8c5e33e46d7cc0ba..65ff1e91e8053cec3860a1238908ada8c5073b6e 100644 --- a/llama_index/llms/rungpt.py +++ b/llama_index/llms/rungpt.py @@ -3,7 +3,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple, Union from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager -from llama_index.constants import DEFAULT_CONTEXT_WINDOW +from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.llms.base import ( LLM, ChatMessage, @@ -19,16 +19,32 @@ from llama_index.llms.base import ( llm_completion_callback, ) +DEFAULT_RUNGPT_MODEL = "rungpt" +DEFAULT_RUNGPT_TEMP = 0.75 + class RunGptLLM(LLM): """The opengpt of Jina AI models.""" - model: Optional[str] = Field(description="The rungpt model to use.") + model: Optional[str] = Field( + default=DEFAULT_RUNGPT_MODEL, description="The rungpt model to use." + ) endpoint: str = Field(description="The endpoint of serving address.") - temperature: float = Field(description="The temperature to use for sampling.") - max_tokens: Optional[int] = Field(description="Max tokens model generates.") + temperature: float = Field( + default=DEFAULT_RUNGPT_TEMP, + description="The temperature to use for sampling.", + gte=0.0, + lte=1.0, + ) + max_tokens: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="Max tokens model generates.", + gt=0, + ) context_window: int = Field( - description="The maximum number of context tokens for the model." + default=DEFAULT_CONTEXT_WINDOW, + description="The maximum number of context tokens for the model.", + gt=0, ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the Replicate API." @@ -39,10 +55,10 @@ class RunGptLLM(LLM): def __init__( self, - model: Optional[str] = "rungpt", + model: Optional[str] = DEFAULT_RUNGPT_MODEL, endpoint: str = "0.0.0.0:51002", - temperature: float = 0.75, - max_tokens: Optional[int] = 256, + temperature: float = DEFAULT_RUNGPT_TEMP, + max_tokens: Optional[int] = DEFAULT_NUM_OUTPUTS, context_window: int = DEFAULT_CONTEXT_WINDOW, additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, diff --git a/llama_index/llms/utils.py b/llama_index/llms/utils.py index c4101d101c376b0e1d950349346c94500d984429..53b1bcb3b18e959f82e65aefb1ade932b6a65a7b 100644 --- a/llama_index/llms/utils.py +++ b/llama_index/llms/utils.py @@ -1,20 +1,27 @@ -from typing import Optional, Union +from typing import TYPE_CHECKING, Optional, Union -from langchain.base_language import BaseLanguageModel +if TYPE_CHECKING: + from langchain.base_language import BaseLanguageModel from llama_index.llms.base import LLM -from llama_index.llms.langchain import LangChainLLM from llama_index.llms.llama_cpp import LlamaCPP from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt from llama_index.llms.mock import MockLLM from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import validate_openai_api_key -LLMType = Union[str, LLM, BaseLanguageModel] +LLMType = Union[str, LLM, "BaseLanguageModel"] def resolve_llm(llm: Optional[LLMType] = None) -> LLM: """Resolve LLM from string or LLM instance.""" + try: + from langchain.base_language import BaseLanguageModel + + from llama_index.llms.langchain import LangChainLLM + except ImportError: + BaseLanguageModel = None # type: ignore + if llm == "default": # return default OpenAI model. If it fails, return LlamaCPP try: @@ -45,7 +52,7 @@ def resolve_llm(llm: Optional[LLMType] = None) -> LLM: completion_to_prompt=completion_to_prompt, model_kwargs={"n_gpu_layers": 1}, ) - elif isinstance(llm, BaseLanguageModel): + elif BaseLanguageModel is not None and isinstance(llm, BaseLanguageModel): # NOTE: if it's a langchain model, wrap it in a LangChainLLM llm = LangChainLLM(llm=llm) elif llm is None: diff --git a/llama_index/llms/xinference.py b/llama_index/llms/xinference.py index 29752a26492ac205f10779fa13265d0d8faf0616..d0f15dd0bb00f92b391b6666a067b899d8c7e701 100644 --- a/llama_index/llms/xinference.py +++ b/llama_index/llms/xinference.py @@ -22,15 +22,20 @@ from llama_index.llms.xinference_utils import ( # an approximation of the ratio between llama and GPT2 tokens TOKEN_RATIO = 2.5 +DEFAULT_XINFERENCE_TEMP = 1.0 class Xinference(CustomLLM): model_uid: str = Field(description="The Xinference model to use.") endpoint: str = Field(description="The Xinference endpoint URL to use.") - temperature: float = Field(description="The temperature to use for sampling.") - max_tokens: int = Field(description="The maximum new tokens to generate as answer.") + temperature: float = Field( + description="The temperature to use for sampling.", gte=0.0, lte=1.0 + ) + max_tokens: int = Field( + description="The maximum new tokens to generate as answer.", gt=0 + ) context_window: int = Field( - description="The maximum number of context tokens for the model." + description="The maximum number of context tokens for the model.", gt=0 ) model_description: Dict[str, Any] = Field( description="The model description from Xinference." @@ -42,7 +47,7 @@ class Xinference(CustomLLM): self, model_uid: str, endpoint: str, - temperature: float = 1.0, + temperature: float = DEFAULT_XINFERENCE_TEMP, max_tokens: Optional[int] = None, callback_manager: Optional[CallbackManager] = None, ) -> None: diff --git a/llama_index/node_parser/__init__.py b/llama_index/node_parser/__init__.py index 24ef8936830fa46d264f50830a9b6d9b661a7f6f..5c32929197eca1253200cabf746271f4907dbd67 100644 --- a/llama_index/node_parser/__init__.py +++ b/llama_index/node_parser/__init__.py @@ -1,24 +1,48 @@ """Node parsers.""" -from llama_index.node_parser.hierarchical import ( +from llama_index.node_parser.file.html import HTMLNodeParser +from llama_index.node_parser.file.json import JSONNodeParser +from llama_index.node_parser.file.markdown import MarkdownNodeParser +from llama_index.node_parser.file.simple_file import SimpleFileNodeParser +from llama_index.node_parser.interface import ( + MetadataAwareTextSplitter, + NodeParser, + TextSplitter, +) +from llama_index.node_parser.relational.hierarchical import ( HierarchicalNodeParser, get_leaf_nodes, get_root_nodes, ) -from llama_index.node_parser.interface import NodeParser -from llama_index.node_parser.sentence_window import SentenceWindowNodeParser -from llama_index.node_parser.simple import SimpleNodeParser -from llama_index.node_parser.unstructured_element import ( +from llama_index.node_parser.relational.unstructured_element import ( UnstructuredElementNodeParser, ) +from llama_index.node_parser.text.code import CodeSplitter +from llama_index.node_parser.text.langchain import LangchainNodeParser +from llama_index.node_parser.text.sentence import SentenceSplitter +from llama_index.node_parser.text.sentence_window import SentenceWindowNodeParser +from llama_index.node_parser.text.token import TokenTextSplitter + +# deprecated, for backwards compatibility +SimpleNodeParser = SentenceSplitter __all__ = [ - "SimpleNodeParser", + "TokenTextSplitter", + "SentenceSplitter", + "CodeSplitter", + "SimpleFileNodeParser", + "HTMLNodeParser", + "MarkdownNodeParser", + "JSONNodeParser", "SentenceWindowNodeParser", "NodeParser", "HierarchicalNodeParser", + "TextSplitter", + "MetadataAwareTextSplitter", + "LangchainNodeParser", "UnstructuredElementNodeParser", - "get_base_nodes_and_mappings", "get_leaf_nodes", "get_root_nodes", + # deprecated, for backwards compatibility + "SimpleNodeParser", ] diff --git a/llama_index/node_parser/extractors/loading.py b/llama_index/node_parser/extractors/loading.py deleted file mode 100644 index 34b2247349a7523094eed9c5598b53591b6a33e5..0000000000000000000000000000000000000000 --- a/llama_index/node_parser/extractors/loading.py +++ /dev/null @@ -1,36 +0,0 @@ -from typing import List, Optional - -from llama_index.llms.base import LLM -from llama_index.node_parser.extractors.metadata_extractors import ( - EntityExtractor, - KeywordExtractor, - MetadataExtractor, - QuestionsAnsweredExtractor, - SummaryExtractor, - TitleExtractor, -) - - -def load_extractor( - data: dict, - extractors: Optional[List[MetadataExtractor]] = None, - llm: Optional[LLM] = None, -) -> MetadataExtractor: - extractor_name = data.get("class_name", None) - if extractor_name is None: - raise ValueError("Extractor loading requires a class_name") - - if extractor_name == MetadataExtractor.class_name(): - return MetadataExtractor.from_dict(data, extractors=extractors) - elif extractor_name == SummaryExtractor.class_name(): - return SummaryExtractor.from_dict(data, llm=llm) - elif extractor_name == QuestionsAnsweredExtractor.class_name(): - return QuestionsAnsweredExtractor.from_dict(data, llm=llm) - elif extractor_name == EntityExtractor.class_name(): - return EntityExtractor.from_dict(data) - elif extractor_name == TitleExtractor.class_name(): - return TitleExtractor.from_dict(data, llm=llm) - elif extractor_name == KeywordExtractor.class_name(): - return KeywordExtractor.from_dict(data, llm=llm) - else: - raise ValueError(f"Unknown extractor name: {extractor_name}") diff --git a/llama_index/node_parser/file/__init__.py b/llama_index/node_parser/file/__init__.py index e69de29bb2d1d6434b8b29ae775ad8c2e48c5391..7b576e38989203e504b9ac2a68646485b81254da 100644 --- a/llama_index/node_parser/file/__init__.py +++ b/llama_index/node_parser/file/__init__.py @@ -0,0 +1,11 @@ +from llama_index.node_parser.file.html import HTMLNodeParser +from llama_index.node_parser.file.json import JSONNodeParser +from llama_index.node_parser.file.markdown import MarkdownNodeParser +from llama_index.node_parser.file.simple_file import SimpleFileNodeParser + +__all__ = [ + "SimpleFileNodeParser", + "HTMLNodeParser", + "MarkdownNodeParser", + "JSONNodeParser", +] diff --git a/llama_index/node_parser/file/html.py b/llama_index/node_parser/file/html.py index c45498a1a99cbc6b1db5fb50d750b8c9d09a7f1a..3575328d1cacfb2ca99092ccaa5985ffa80b568f 100644 --- a/llama_index/node_parser/file/html.py +++ b/llama_index/node_parser/file/html.py @@ -1,9 +1,8 @@ """HTML node parser.""" -from typing import List, Optional, Sequence +from typing import Any, List, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks.base import CallbackManager -from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.node_parser.interface import NodeParser from llama_index.node_parser.node_utils import build_nodes_from_splits from llama_index.schema import BaseNode, MetadataMode, TextNode @@ -25,15 +24,6 @@ class HTMLNodeParser(NodeParser): from bs4 import Tag - include_metadata: bool = Field( - default=True, description="Whether or not to consider metadata when splitting." - ) - include_prev_next_rel: bool = Field( - default=True, description="Include prev/next node relationships." - ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True - ) tags: List[str] = Field( default=DEFAULT_TAGS, description="HTML tags to extract text from." ) @@ -60,30 +50,18 @@ class HTMLNodeParser(NodeParser): """Get class name.""" return "HTMLNodeParser" - def get_nodes_from_documents( + def _parse_nodes( self, - documents: Sequence[TextNode], + nodes: Sequence[BaseNode], show_progress: bool = False, + **kwargs: Any, ) -> List[BaseNode]: - """Parse document into nodes. - - Args: - documents (Sequence[TextNode]): TextNodes or Documents to parse - - """ - with self.callback_manager.event( - CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents} - ) as event: - all_nodes: List[BaseNode] = [] - documents_with_progress = get_tqdm_iterable( - documents, show_progress, "Parsing documents into nodes" - ) - - for document in documents_with_progress: - nodes = self.get_nodes_from_node(document) - all_nodes.extend(nodes) + all_nodes: List[BaseNode] = [] + nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes") - event.on_end(payload={EventPayload.NODES: all_nodes}) + for node in nodes_with_progress: + nodes = self.get_nodes_from_node(node) + all_nodes.extend(nodes) return all_nodes @@ -145,9 +123,7 @@ class HTMLNodeParser(NodeParser): metadata: dict, ) -> TextNode: """Build node from single text split.""" - node = build_nodes_from_splits( - [text_split], node, self.include_metadata, self.include_prev_next_rel - )[0] + node = build_nodes_from_splits([text_split], node)[0] if self.include_metadata: node.metadata = {**node.metadata, **metadata} diff --git a/llama_index/node_parser/file/json.py b/llama_index/node_parser/file/json.py index aa8e79a4b6a1bfd9c3f4835ae5562ca1d125818b..5d6e19de145b5c0890b54c3b1c83701a1c27aeae 100644 --- a/llama_index/node_parser/file/json.py +++ b/llama_index/node_parser/file/json.py @@ -1,10 +1,8 @@ """JSON node parser.""" import json -from typing import Dict, Generator, List, Optional, Sequence +from typing import Any, Dict, Generator, List, Optional, Sequence -from llama_index.bridge.pydantic import Field from llama_index.callbacks.base import CallbackManager -from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.node_parser.interface import NodeParser from llama_index.node_parser.node_utils import build_nodes_from_splits from llama_index.schema import BaseNode, MetadataMode, TextNode @@ -22,16 +20,6 @@ class JSONNodeParser(NodeParser): """ - include_metadata: bool = Field( - default=True, description="Whether or not to consider metadata when splitting." - ) - include_prev_next_rel: bool = Field( - default=True, description="Include prev/next node relationships." - ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True - ) - @classmethod def from_defaults( cls, @@ -52,30 +40,15 @@ class JSONNodeParser(NodeParser): """Get class name.""" return "JSONNodeParser" - def get_nodes_from_documents( - self, - documents: Sequence[TextNode], - show_progress: bool = False, + def _parse_nodes( + self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any ) -> List[BaseNode]: - """Parse document into nodes. - - Args: - documents (Sequence[TextNode]): TextNodes or Documents to parse - - """ - with self.callback_manager.event( - CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents} - ) as event: - all_nodes: List[BaseNode] = [] - documents_with_progress = get_tqdm_iterable( - documents, show_progress, "Parsing documents into nodes" - ) - - for document in documents_with_progress: - nodes = self.get_nodes_from_node(document) - all_nodes.extend(nodes) + all_nodes: List[BaseNode] = [] + nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes") - event.on_end(payload={EventPayload.NODES: all_nodes}) + for node in nodes_with_progress: + nodes = self.get_nodes_from_node(node) + all_nodes.extend(nodes) return all_nodes @@ -91,13 +64,11 @@ class JSONNodeParser(NodeParser): json_nodes = [] if isinstance(data, dict): lines = [*self._depth_first_yield(data, 0, [])] - json_nodes.append(self._build_node_from_split("\n".join(lines), node, {})) + json_nodes.extend(build_nodes_from_splits(["\n".join(lines)], node)) elif isinstance(data, list): for json_object in data: lines = [*self._depth_first_yield(json_object, 0, [])] - json_nodes.append( - self._build_node_from_split("\n".join(lines), node, {}) - ) + json_nodes.extend(build_nodes_from_splits(["\n".join(lines)], node)) else: raise ValueError("JSON is invalid") @@ -125,19 +96,3 @@ class JSONNodeParser(NodeParser): new_path = path[-levels_back:] new_path.append(str(json_data)) yield " ".join(new_path) - - def _build_node_from_split( - self, - text_split: str, - node: BaseNode, - metadata: dict, - ) -> TextNode: - """Build node from single text split.""" - node = build_nodes_from_splits( - [text_split], node, self.include_metadata, self.include_prev_next_rel - )[0] - - if self.include_metadata: - node.metadata = {**node.metadata, **metadata} - - return node diff --git a/llama_index/node_parser/file/markdown.py b/llama_index/node_parser/file/markdown.py index cdd944ff526b0a4d51929e2565d5551299dffac9..7836915dcad86aa3594b11338e161fdc91d52bcb 100644 --- a/llama_index/node_parser/file/markdown.py +++ b/llama_index/node_parser/file/markdown.py @@ -1,10 +1,8 @@ """Markdown node parser.""" import re -from typing import Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence -from llama_index.bridge.pydantic import Field from llama_index.callbacks.base import CallbackManager -from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.node_parser.interface import NodeParser from llama_index.node_parser.node_utils import build_nodes_from_splits from llama_index.schema import BaseNode, MetadataMode, TextNode @@ -22,16 +20,6 @@ class MarkdownNodeParser(NodeParser): """ - include_metadata: bool = Field( - default=True, description="Whether or not to consider metadata when splitting." - ) - include_prev_next_rel: bool = Field( - default=True, description="Include prev/next node relationships." - ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True - ) - @classmethod def from_defaults( cls, @@ -52,30 +40,18 @@ class MarkdownNodeParser(NodeParser): """Get class name.""" return "MarkdownNodeParser" - def get_nodes_from_documents( + def _parse_nodes( self, - documents: Sequence[TextNode], + nodes: Sequence[BaseNode], show_progress: bool = False, + **kwargs: Any, ) -> List[BaseNode]: - """Parse document into nodes. - - Args: - documents (Sequence[TextNode]): TextNodes or Documents to parse - - """ - with self.callback_manager.event( - CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents} - ) as event: - all_nodes: List[BaseNode] = [] - documents_with_progress = get_tqdm_iterable( - documents, show_progress, "Parsing documents into nodes" - ) - - for document in documents_with_progress: - nodes = self.get_nodes_from_node(document) - all_nodes.extend(nodes) + all_nodes: List[BaseNode] = [] + nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes") - event.on_end(payload={EventPayload.NODES: all_nodes}) + for node in nodes_with_progress: + nodes = self.get_nodes_from_node(node) + all_nodes.extend(nodes) return all_nodes @@ -137,9 +113,7 @@ class MarkdownNodeParser(NodeParser): metadata: dict, ) -> TextNode: """Build node from single text split.""" - node = build_nodes_from_splits( - [text_split], node, self.include_metadata, self.include_prev_next_rel - )[0] + node = build_nodes_from_splits([text_split], node)[0] if self.include_metadata: node.metadata = {**node.metadata, **metadata} diff --git a/llama_index/node_parser/file/simple_file.py b/llama_index/node_parser/file/simple_file.py new file mode 100644 index 0000000000000000000000000000000000000000..9bd21854c94ef28dddbea6bc8d9f399bfe59ce4a --- /dev/null +++ b/llama_index/node_parser/file/simple_file.py @@ -0,0 +1,82 @@ +"""Simple file node parser.""" +from typing import Any, Dict, List, Optional, Sequence, Type + +from llama_index.callbacks.base import CallbackManager +from llama_index.node_parser.file.html import HTMLNodeParser +from llama_index.node_parser.file.json import JSONNodeParser +from llama_index.node_parser.file.markdown import MarkdownNodeParser +from llama_index.node_parser.interface import NodeParser +from llama_index.schema import BaseNode +from llama_index.utils import get_tqdm_iterable + +FILE_NODE_PARSERS: Dict[str, Type[NodeParser]] = { + ".md": MarkdownNodeParser, + ".html": HTMLNodeParser, + ".json": JSONNodeParser, +} + + +class SimpleFileNodeParser(NodeParser): + """Simple file node parser. + + Splits a document loaded from a file into Nodes using logic based on the file type + automatically detects the NodeParser to use based on file type + + Args: + include_metadata (bool): whether to include metadata in nodes + include_prev_next_rel (bool): whether to include prev/next relationships + + """ + + @classmethod + def from_defaults( + cls, + include_metadata: bool = True, + include_prev_next_rel: bool = True, + callback_manager: Optional[CallbackManager] = None, + ) -> "SimpleFileNodeParser": + callback_manager = callback_manager or CallbackManager([]) + + return cls( + include_metadata=include_metadata, + include_prev_next_rel=include_prev_next_rel, + callback_manager=callback_manager, + ) + + @classmethod + def class_name(cls) -> str: + """Get class name.""" + return "SimpleFileNodeParser" + + def _parse_nodes( + self, + nodes: Sequence[BaseNode], + show_progress: bool = False, + **kwargs: Any, + ) -> List[BaseNode]: + """Parse document into nodes. + + Args: + nodes (Sequence[BaseNode]): nodes to parse + """ + all_nodes: List[BaseNode] = [] + documents_with_progress = get_tqdm_iterable( + nodes, show_progress, "Parsing documents into nodes" + ) + + for document in documents_with_progress: + ext = document.metadata["extension"] + if ext in FILE_NODE_PARSERS: + parser = FILE_NODE_PARSERS[ext]( + include_metadata=self.include_metadata, + include_prev_next_rel=self.include_prev_next_rel, + callback_manager=self.callback_manager, + ) + + nodes = parser.get_nodes_from_documents([document], show_progress) + all_nodes.extend(nodes) + else: + # What to do when file type isn't supported yet? + all_nodes.extend(document) + + return all_nodes diff --git a/llama_index/node_parser/interface.py b/llama_index/node_parser/interface.py index 63198cb708a578e09cf1b1d45103a4918c1c64b9..5e6bd53449a2863d904f7954238a17380b1aaa92 100644 --- a/llama_index/node_parser/interface.py +++ b/llama_index/node_parser/interface.py @@ -1,43 +1,154 @@ """Node parser interface.""" from abc import ABC, abstractmethod -from typing import Dict, List, Sequence +from typing import Any, List, Sequence -from llama_index.schema import BaseComponent, BaseNode, Document +from llama_index.bridge.pydantic import Field +from llama_index.callbacks import CallbackManager, CBEventType, EventPayload +from llama_index.node_parser.node_utils import build_nodes_from_splits +from llama_index.schema import ( + BaseNode, + Document, + MetadataMode, + NodeRelationship, + TransformComponent, +) +from llama_index.utils import get_tqdm_iterable -class NodeParser(BaseComponent, ABC): +class NodeParser(TransformComponent, ABC): """Base interface for node parser.""" + include_metadata: bool = Field( + default=True, description="Whether or not to consider metadata when splitting." + ) + include_prev_next_rel: bool = Field( + default=True, description="Include prev/next node relationships." + ) + callback_manager: CallbackManager = Field( + default_factory=CallbackManager, exclude=True + ) + class Config: arbitrary_types_allowed = True @abstractmethod + def _parse_nodes( + self, + nodes: Sequence[BaseNode], + show_progress: bool = False, + **kwargs: Any, + ) -> List[BaseNode]: + ... + def get_nodes_from_documents( self, documents: Sequence[Document], show_progress: bool = False, + **kwargs: Any, ) -> List[BaseNode]: """Parse documents into nodes. Args: documents (Sequence[Document]): documents to parse + show_progress (bool): whether to show progress bar """ + doc_id_to_document = {doc.doc_id: doc for doc in documents} + with self.callback_manager.event( + CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents} + ) as event: + nodes = self._parse_nodes(documents, show_progress=show_progress, **kwargs) -class BaseExtractor(BaseComponent, ABC): - """Base interface for feature extractor.""" + if self.include_metadata: + for node in nodes: + if node.ref_doc_id is not None: + node.metadata.update( + doc_id_to_document[node.ref_doc_id].metadata + ) - class Config: - arbitrary_types_allowed = True + if self.include_prev_next_rel: + for i, node in enumerate(nodes): + if i > 0: + node.relationships[NodeRelationship.PREVIOUS] = nodes[ + i - 1 + ].as_related_node_info() + if i < len(nodes) - 1: + node.relationships[NodeRelationship.NEXT] = nodes[ + i + 1 + ].as_related_node_info() + + event.on_end({EventPayload.NODES: nodes}) + return nodes + + def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]: + return self.get_nodes_from_documents(nodes, **kwargs) + + +class TextSplitter(NodeParser): @abstractmethod - def extract( - self, - nodes: List[BaseNode], - ) -> List[Dict]: - """Post process nodes parsed from documents. + def split_text(self, text: str) -> List[str]: + ... - Args: - nodes (List[BaseNode]): nodes to extract from - """ + def split_texts(self, texts: List[str]) -> List[str]: + nested_texts = [self.split_text(text) for text in texts] + return [item for sublist in nested_texts for item in sublist] + + def _parse_nodes( + self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any + ) -> List[BaseNode]: + all_nodes: List[BaseNode] = [] + nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes") + for node in nodes_with_progress: + splits = self.split_text(node.get_content()) + + all_nodes.extend(build_nodes_from_splits(splits, node)) + + return all_nodes + + +class MetadataAwareTextSplitter(TextSplitter): + @abstractmethod + def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: + ... + + def split_texts_metadata_aware( + self, texts: List[str], metadata_strs: List[str] + ) -> List[str]: + if len(texts) != len(metadata_strs): + raise ValueError("Texts and metadata_strs must have the same length") + nested_texts = [ + self.split_text_metadata_aware(text, metadata) + for text, metadata in zip(texts, metadata_strs) + ] + return [item for sublist in nested_texts for item in sublist] + + def _get_metadata_str(self, node: BaseNode) -> str: + """Helper function to get the proper metadata str for splitting.""" + embed_metadata_str = node.get_metadata_str(mode=MetadataMode.EMBED) + llm_metadata_str = node.get_metadata_str(mode=MetadataMode.LLM) + + # use the longest metadata str for splitting + if len(embed_metadata_str) > len(llm_metadata_str): + metadata_str = embed_metadata_str + else: + metadata_str = llm_metadata_str + + return metadata_str + + def _parse_nodes( + self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any + ) -> List[BaseNode]: + all_nodes: List[BaseNode] = [] + nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes") + + for node in nodes_with_progress: + metadata_str = self._get_metadata_str(node) + splits = self.split_text_metadata_aware( + node.get_content(metadata_mode=MetadataMode.NONE), + metadata_str=metadata_str, + ) + all_nodes.extend(build_nodes_from_splits(splits, node)) + + return all_nodes diff --git a/llama_index/node_parser/loading.py b/llama_index/node_parser/loading.py index 0c68348a74b197cb86f25cd1f45df499027e2bc3..95d299a7d5b145dd456064f3dea956fdd8b12124 100644 --- a/llama_index/node_parser/loading.py +++ b/llama_index/node_parser/loading.py @@ -1,30 +1,39 @@ -from typing import Optional +from typing import Dict, Type -from llama_index.node_parser.extractors.metadata_extractors import MetadataExtractor +from llama_index.node_parser.file.html import HTMLNodeParser +from llama_index.node_parser.file.json import JSONNodeParser +from llama_index.node_parser.file.markdown import MarkdownNodeParser +from llama_index.node_parser.file.simple_file import SimpleFileNodeParser from llama_index.node_parser.interface import NodeParser -from llama_index.node_parser.sentence_window import SentenceWindowNodeParser -from llama_index.node_parser.simple import SimpleNodeParser -from llama_index.text_splitter.sentence_splitter import SentenceSplitter -from llama_index.text_splitter.types import SplitterType +from llama_index.node_parser.relational.hierarchical import HierarchicalNodeParser +from llama_index.node_parser.text.code import CodeSplitter +from llama_index.node_parser.text.sentence import SentenceSplitter +from llama_index.node_parser.text.sentence_window import SentenceWindowNodeParser +from llama_index.node_parser.text.token import TokenTextSplitter + +all_node_parsers: Dict[str, Type[NodeParser]] = { + HTMLNodeParser.class_name(): HTMLNodeParser, + JSONNodeParser.class_name(): JSONNodeParser, + MarkdownNodeParser.class_name(): MarkdownNodeParser, + SimpleFileNodeParser.class_name(): SimpleFileNodeParser, + HierarchicalNodeParser.class_name(): HierarchicalNodeParser, + CodeSplitter.class_name(): CodeSplitter, + SentenceSplitter.class_name(): SentenceSplitter, + TokenTextSplitter.class_name(): TokenTextSplitter, + SentenceWindowNodeParser.class_name(): SentenceWindowNodeParser, +} def load_parser( data: dict, - text_splitter: Optional[SplitterType] = None, - metadata_extractor: Optional[MetadataExtractor] = None, ) -> NodeParser: + if isinstance(data, NodeParser): + return data parser_name = data.get("class_name", None) if parser_name is None: raise ValueError("Parser loading requires a class_name") - if parser_name == SimpleNodeParser.class_name(): - return SimpleNodeParser.from_dict( - data, text_splitter=text_splitter, metadata_extractor=metadata_extractor - ) - elif parser_name == SentenceWindowNodeParser.class_name(): - assert isinstance(text_splitter, (type(None), SentenceSplitter)) - return SentenceWindowNodeParser.from_dict( - data, sentence_splitter=text_splitter, metadata_extractor=metadata_extractor - ) + if parser_name not in all_node_parsers: + raise ValueError(f"Invalid parser name: {parser_name}") else: - raise ValueError(f"Unknown parser name: {parser_name}") + return all_node_parsers[parser_name].from_dict(data) diff --git a/llama_index/node_parser/node_utils.py b/llama_index/node_parser/node_utils.py index e8bd2308c3eb5e76f53f3c7642ca91e4d67f6264..391ba4344154260bfaa95feb6424aefc492525d5 100644 --- a/llama_index/node_parser/node_utils.py +++ b/llama_index/node_parser/node_utils.py @@ -9,11 +9,9 @@ from llama_index.schema import ( Document, ImageDocument, ImageNode, - MetadataMode, NodeRelationship, TextNode, ) -from llama_index.text_splitter.types import MetadataAwareTextSplitter, SplitterType from llama_index.utils import truncate_text logger = logging.getLogger(__name__) @@ -22,8 +20,6 @@ logger = logging.getLogger(__name__) def build_nodes_from_splits( text_splits: List[str], document: BaseNode, - include_metadata: bool = True, - include_prev_next_rel: bool = False, ref_doc: Optional[BaseNode] = None, ) -> List[TextNode]: """Build nodes from splits.""" @@ -33,15 +29,10 @@ def build_nodes_from_splits( for i, text_chunk in enumerate(text_splits): logger.debug(f"> Adding chunk: {truncate_text(text_chunk, 50)}") - node_metadata = {} - if include_metadata: - node_metadata = document.metadata - if isinstance(document, ImageDocument): image_node = ImageNode( text=text_chunk, embedding=document.embedding, - metadata=node_metadata, image=document.image, image_path=document.image_path, image_url=document.image_url, @@ -57,7 +48,6 @@ def build_nodes_from_splits( node = TextNode( text=text_chunk, embedding=document.embedding, - metadata=node_metadata, excluded_embed_metadata_keys=document.excluded_embed_metadata_keys, excluded_llm_metadata_keys=document.excluded_llm_metadata_keys, metadata_seperator=document.metadata_seperator, @@ -70,7 +60,6 @@ def build_nodes_from_splits( node = TextNode( text=text_chunk, embedding=document.embedding, - metadata=node_metadata, excluded_embed_metadata_keys=document.excluded_embed_metadata_keys, excluded_llm_metadata_keys=document.excluded_llm_metadata_keys, metadata_seperator=document.metadata_seperator, @@ -82,108 +71,4 @@ def build_nodes_from_splits( else: raise ValueError(f"Unknown document type: {type(document)}") - # account for pure image documents - if len(text_splits) == 0 and isinstance(document, ImageDocument): - node_metadata = {} - if include_metadata: - node_metadata = document.metadata - - image_node = ImageNode( - text="", - embedding=document.embedding, - metadata=node_metadata, - image=document.image, - image_path=document.image_path, - image_url=document.image_url, - excluded_embed_metadata_keys=document.excluded_embed_metadata_keys, - excluded_llm_metadata_keys=document.excluded_llm_metadata_keys, - metadata_seperator=document.metadata_seperator, - metadata_template=document.metadata_template, - text_template=document.text_template, - relationships={NodeRelationship.SOURCE: ref_doc.as_related_node_info()}, - ) - nodes.append(image_node) # type: ignore - - # if include_prev_next_rel, then add prev/next relationships - if include_prev_next_rel: - for i, node in enumerate(nodes): - if i > 0: - node.relationships[NodeRelationship.PREVIOUS] = nodes[ - i - 1 - ].as_related_node_info() - if i < len(nodes) - 1: - node.relationships[NodeRelationship.NEXT] = nodes[ - i + 1 - ].as_related_node_info() - return nodes - - -def get_nodes_from_document( - document: BaseNode, - text_splitter: SplitterType, - include_metadata: bool = True, - include_prev_next_rel: bool = False, -) -> List[TextNode]: - """Get nodes from document. - - NOTE: this function has been deprecated, please use - get_nodes_from_node which supports both documents/nodes. - - """ - return get_nodes_from_node( - document, - text_splitter, - include_metadata=include_metadata, - include_prev_next_rel=include_prev_next_rel, - ref_doc=document, - ) - - -def get_nodes_from_node( - node: BaseNode, - text_splitter: SplitterType, - include_metadata: bool = True, - include_prev_next_rel: bool = False, - ref_doc: Optional[BaseNode] = None, -) -> List[TextNode]: - """Get nodes from document.""" - if include_metadata: - if isinstance(text_splitter, MetadataAwareTextSplitter): - embed_metadata_str = node.get_metadata_str(mode=MetadataMode.EMBED) - llm_metadata_str = node.get_metadata_str(mode=MetadataMode.LLM) - - # use the longest metadata str for splitting - if len(embed_metadata_str) > len(llm_metadata_str): - metadata_str = embed_metadata_str - else: - metadata_str = llm_metadata_str - - text_splits = text_splitter.split_text_metadata_aware( - text=node.get_content(metadata_mode=MetadataMode.NONE), - metadata_str=metadata_str, - ) - else: - logger.warning( - f"include_metadata is set to True but {text_splitter} " - "is not metadata-aware." - "Node content length may exceed expected chunk size." - "Try lowering the chunk size or using a metadata-aware text splitter " - "if this is a problem." - ) - - text_splits = text_splitter.split_text( - node.get_content(metadata_mode=MetadataMode.NONE), - ) - else: - text_splits = text_splitter.split_text( - node.get_content(metadata_mode=MetadataMode.NONE), - ) - - return build_nodes_from_splits( - text_splits, - node, - include_metadata=include_metadata, - include_prev_next_rel=include_prev_next_rel, - ref_doc=ref_doc, - ) diff --git a/llama_index/node_parser/relational/__init__.py b/llama_index/node_parser/relational/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f481d23bbdbef7c66f776a67488235e66591f43f --- /dev/null +++ b/llama_index/node_parser/relational/__init__.py @@ -0,0 +1,9 @@ +from llama_index.node_parser.relational.hierarchical import HierarchicalNodeParser +from llama_index.node_parser.relational.unstructured_element import ( + UnstructuredElementNodeParser, +) + +__all__ = [ + "HierarchicalNodeParser", + "UnstructuredElementNodeParser", +] diff --git a/llama_index/node_parser/hierarchical.py b/llama_index/node_parser/relational/hierarchical.py similarity index 66% rename from llama_index/node_parser/hierarchical.py rename to llama_index/node_parser/relational/hierarchical.py index ed99855661ad093557e852d8a8bc23675469c68b..a3eef65c36c480e156ffd5dd96988c604ec3b77b 100644 --- a/llama_index/node_parser/hierarchical.py +++ b/llama_index/node_parser/relational/hierarchical.py @@ -1,15 +1,13 @@ """Hierarchical node parser.""" -from typing import Dict, List, Optional, Sequence +from typing import Any, Dict, List, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.node_parser.extractors.metadata_extractors import MetadataExtractor from llama_index.node_parser.interface import NodeParser -from llama_index.node_parser.node_utils import get_nodes_from_document +from llama_index.node_parser.text.sentence import SentenceSplitter from llama_index.schema import BaseNode, Document, NodeRelationship -from llama_index.text_splitter import TextSplitter, get_default_text_splitter from llama_index.utils import get_tqdm_iterable @@ -45,7 +43,7 @@ def get_root_nodes(nodes: List[BaseNode]) -> List[BaseNode]: class HierarchicalNodeParser(NodeParser): """Hierarchical node parser. - Splits a document into a recursive hierarchy Nodes using a TextSplitter. + Splits a document into a recursive hierarchy Nodes using a NodeParser. NOTE: this will return a hierarchy of nodes in a flat list, where there will be overlap between parent nodes (e.g. with a bigger chunk size), and child nodes @@ -57,12 +55,6 @@ class HierarchicalNodeParser(NodeParser): chunk size 512 - list of third-level nodes, where each node is a child of a second-level node, chunk size 128 - - Args: - text_splitter (Optional[TextSplitter]): text splitter - include_metadata (bool): whether to include metadata in nodes - include_prev_next_rel (bool): whether to include prev/next relationships - """ chunk_sizes: Optional[List[int]] = Field( @@ -71,73 +63,55 @@ class HierarchicalNodeParser(NodeParser): "The chunk sizes to use when splitting documents, in order of level." ), ) - text_splitter_ids: List[str] = Field( + node_parser_ids: List[str] = Field( default_factory=list, description=( - "List of ids for the text splitters to use when splitting documents, " + "List of ids for the node parsers to use when splitting documents, " + "in order of level (first id used for first level, etc.)." ), ) - text_splitter_map: Dict[str, TextSplitter] = Field( - description="Map of text splitter id to text splitter.", - ) - include_metadata: bool = Field( - default=True, description="Whether or not to consider metadata when splitting." - ) - include_prev_next_rel: bool = Field( - default=True, description="Include prev/next node relationships." - ) - metadata_extractor: Optional[MetadataExtractor] = Field( - default=None, description="Metadata extraction pipeline to apply to nodes." - ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True + node_parser_map: Dict[str, NodeParser] = Field( + description="Map of node parser id to node parser.", ) @classmethod def from_defaults( cls, chunk_sizes: Optional[List[int]] = None, - text_splitter_ids: Optional[List[str]] = None, - text_splitter_map: Optional[Dict[str, TextSplitter]] = None, + node_parser_ids: Optional[List[str]] = None, + node_parser_map: Optional[Dict[str, NodeParser]] = None, include_metadata: bool = True, include_prev_next_rel: bool = True, callback_manager: Optional[CallbackManager] = None, - metadata_extractor: Optional[MetadataExtractor] = None, ) -> "HierarchicalNodeParser": callback_manager = callback_manager or CallbackManager([]) - if text_splitter_ids is None: + if node_parser_ids is None: if chunk_sizes is None: chunk_sizes = [2048, 512, 128] - text_splitter_ids = [ - f"chunk_size_{chunk_size}" for chunk_size in chunk_sizes - ] - text_splitter_map = {} - for chunk_size, text_splitter_id in zip(chunk_sizes, text_splitter_ids): - text_splitter_map[text_splitter_id] = get_default_text_splitter( + node_parser_ids = [f"chunk_size_{chunk_size}" for chunk_size in chunk_sizes] + node_parser_map = {} + for chunk_size, node_parser_id in zip(chunk_sizes, node_parser_ids): + node_parser_map[node_parser_id] = SentenceSplitter( chunk_size=chunk_size, callback_manager=callback_manager, ) else: if chunk_sizes is not None: + raise ValueError("Cannot specify both node_parser_ids and chunk_sizes.") + if node_parser_map is None: raise ValueError( - "Cannot specify both text_splitter_ids and chunk_sizes." - ) - if text_splitter_map is None: - raise ValueError( - "Must specify text_splitter_map if using text_splitter_ids." + "Must specify node_parser_map if using node_parser_ids." ) return cls( chunk_sizes=chunk_sizes, - text_splitter_ids=text_splitter_ids, - text_splitter_map=text_splitter_map, + node_parser_ids=node_parser_ids, + node_parser_map=node_parser_map, include_metadata=include_metadata, include_prev_next_rel=include_prev_next_rel, callback_manager=callback_manager, - metadata_extractor=metadata_extractor, ) @classmethod @@ -151,10 +125,10 @@ class HierarchicalNodeParser(NodeParser): show_progress: bool = False, ) -> List[BaseNode]: """Recursively get nodes from nodes.""" - if level >= len(self.text_splitter_ids): + if level >= len(self.node_parser_ids): raise ValueError( f"Level {level} is greater than number of text " - f"splitters ({len(self.text_splitter_ids)})." + f"splitters ({len(self.node_parser_ids)})." ) # first split current nodes into sub-nodes @@ -163,12 +137,9 @@ class HierarchicalNodeParser(NodeParser): ) sub_nodes = [] for node in nodes_with_progress: - cur_sub_nodes = get_nodes_from_document( - node, - self.text_splitter_map[self.text_splitter_ids[level]], - self.include_metadata, - include_prev_next_rel=self.include_prev_next_rel, - ) + cur_sub_nodes = self.node_parser_map[ + self.node_parser_ids[level] + ].get_nodes_from_documents([node]) # add parent relationship from sub node to parent node # add child relationship from parent node to sub node # NOTE: Only add relationships if level > 0, since we don't want to add @@ -183,7 +154,7 @@ class HierarchicalNodeParser(NodeParser): sub_nodes.extend(cur_sub_nodes) # now for each sub-node, recursively split into sub-sub-nodes, and add - if level < len(self.text_splitter_ids) - 1: + if level < len(self.node_parser_ids) - 1: sub_sub_nodes = self._recursively_get_nodes_from_nodes( sub_nodes, level + 1, @@ -198,6 +169,7 @@ class HierarchicalNodeParser(NodeParser): self, documents: Sequence[Document], show_progress: bool = False, + **kwargs: Any, ) -> List[BaseNode]: """Parse document into nodes. @@ -219,9 +191,12 @@ class HierarchicalNodeParser(NodeParser): nodes_from_doc = self._recursively_get_nodes_from_nodes([doc], 0) all_nodes.extend(nodes_from_doc) - if self.metadata_extractor is not None: - all_nodes = self.metadata_extractor.process_nodes(all_nodes) - event.on_end(payload={EventPayload.NODES: all_nodes}) return all_nodes + + # Unused abstract method + def _parse_nodes( + self, nodes: Sequence[BaseNode], show_progress: bool = False, **kwargs: Any + ) -> List[BaseNode]: + return list(nodes) diff --git a/llama_index/node_parser/unstructured_element.py b/llama_index/node_parser/relational/unstructured_element.py similarity index 90% rename from llama_index/node_parser/unstructured_element.py rename to llama_index/node_parser/relational/unstructured_element.py index f2b940e9776234cf70e08b1549bf6e79868f1b24..aa373f0c842d096ee5c4920420a80bbbd71cd4f4 100644 --- a/llama_index/node_parser/unstructured_element.py +++ b/llama_index/node_parser/relational/unstructured_element.py @@ -8,9 +8,7 @@ from tqdm import tqdm from llama_index.bridge.pydantic import Field from llama_index.callbacks.base import CallbackManager -from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.llms.openai import LLM, OpenAI -from llama_index.node_parser import SimpleNodeParser from llama_index.node_parser.interface import NodeParser from llama_index.response.schema import PydanticResponse from llama_index.schema import BaseNode, Document, IndexNode, TextNode @@ -105,7 +103,7 @@ def extract_table_summaries( ) -> None: """Go through elements, extract out summaries that are tables.""" from llama_index.indices.list.base import SummaryIndex - from llama_index.indices.service_context import ServiceContext + from llama_index.service_context import ServiceContext llm = llm or OpenAI() llm = cast(LLM, llm) @@ -149,7 +147,9 @@ def _get_nodes_from_buffer( def get_nodes_from_elements(elements: List[Element]) -> List[BaseNode]: """Get nodes and mappings.""" - node_parser = SimpleNodeParser.from_defaults() + from llama_index.node_parser import SentenceSplitter + + node_parser = SentenceSplitter() nodes = [] cur_text_el_buffer: List[str] = [] @@ -267,30 +267,18 @@ class UnstructuredElementNodeParser(NodeParser): # will return a list of Nodes and Index Nodes return get_nodes_from_elements(elements) - def get_nodes_from_documents( + def _parse_nodes( self, - documents: Sequence[TextNode], + nodes: Sequence[BaseNode], show_progress: bool = False, + **kwargs: Any, ) -> List[BaseNode]: - """Parse document into nodes. - - Args: - documents (Sequence[TextNode]): TextNodes or Documents to parse - - """ - with self.callback_manager.event( - CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents} - ) as event: - all_nodes: List[BaseNode] = [] - documents_with_progress = get_tqdm_iterable( - documents, show_progress, "Parsing documents into nodes" - ) - - for document in documents_with_progress: - nodes = self.get_nodes_from_node(document) - all_nodes.extend(nodes) + all_nodes: List[BaseNode] = [] + nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes") - event.on_end(payload={EventPayload.NODES: all_nodes}) + for node in nodes_with_progress: + nodes = self.get_nodes_from_node(node) + all_nodes.extend(nodes) return all_nodes diff --git a/llama_index/node_parser/simple.py b/llama_index/node_parser/simple.py deleted file mode 100644 index a310946f46c57f8043e4c27601de6289b539e1fc..0000000000000000000000000000000000000000 --- a/llama_index/node_parser/simple.py +++ /dev/null @@ -1,107 +0,0 @@ -"""Simple node parser.""" -from typing import List, Optional, Sequence - -from llama_index.bridge.pydantic import Field -from llama_index.callbacks.base import CallbackManager -from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.node_parser.extractors.metadata_extractors import MetadataExtractor -from llama_index.node_parser.interface import NodeParser -from llama_index.node_parser.node_utils import get_nodes_from_document -from llama_index.schema import BaseNode, Document -from llama_index.text_splitter import SplitterType, get_default_text_splitter -from llama_index.utils import get_tqdm_iterable - - -class SimpleNodeParser(NodeParser): - """Simple node parser. - - Splits a document into Nodes using a TextSplitter. - - Args: - text_splitter (Optional[TextSplitter]): text splitter - include_metadata (bool): whether to include metadata in nodes - include_prev_next_rel (bool): whether to include prev/next relationships - - """ - - text_splitter: SplitterType = Field( - description="The text splitter to use when splitting documents." - ) - include_metadata: bool = Field( - default=True, description="Whether or not to consider metadata when splitting." - ) - include_prev_next_rel: bool = Field( - default=True, description="Include prev/next node relationships." - ) - metadata_extractor: Optional[MetadataExtractor] = Field( - default=None, description="Metadata extraction pipeline to apply to nodes." - ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True - ) - - @classmethod - def from_defaults( - cls, - chunk_size: Optional[int] = None, - chunk_overlap: Optional[int] = None, - text_splitter: Optional[SplitterType] = None, - include_metadata: bool = True, - include_prev_next_rel: bool = True, - callback_manager: Optional[CallbackManager] = None, - metadata_extractor: Optional[MetadataExtractor] = None, - ) -> "SimpleNodeParser": - callback_manager = callback_manager or CallbackManager([]) - - text_splitter = text_splitter or get_default_text_splitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - callback_manager=callback_manager, - ) - return cls( - text_splitter=text_splitter, - include_metadata=include_metadata, - include_prev_next_rel=include_prev_next_rel, - callback_manager=callback_manager, - metadata_extractor=metadata_extractor, - ) - - @classmethod - def class_name(cls) -> str: - return "SimpleNodeParser" - - def get_nodes_from_documents( - self, - documents: Sequence[Document], - show_progress: bool = False, - ) -> List[BaseNode]: - """Parse document into nodes. - - Args: - documents (Sequence[Document]): documents to parse - include_metadata (bool): whether to include metadata in nodes - - """ - with self.callback_manager.event( - CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents} - ) as event: - all_nodes: List[BaseNode] = [] - documents_with_progress = get_tqdm_iterable( - documents, show_progress, "Parsing documents into nodes" - ) - - for document in documents_with_progress: - nodes = get_nodes_from_document( - document, - self.text_splitter, - self.include_metadata, - include_prev_next_rel=self.include_prev_next_rel, - ) - all_nodes.extend(nodes) - - if self.metadata_extractor is not None: - all_nodes = self.metadata_extractor.process_nodes(all_nodes) - - event.on_end(payload={EventPayload.NODES: all_nodes}) - - return all_nodes diff --git a/llama_index/node_parser/simple_file.py b/llama_index/node_parser/simple_file.py deleted file mode 100644 index 1d07db855a7ebabda6bc25940528d65339e9197f..0000000000000000000000000000000000000000 --- a/llama_index/node_parser/simple_file.py +++ /dev/null @@ -1,98 +0,0 @@ -"""Simple file node parser.""" -from typing import Dict, List, Optional, Sequence, Type - -from llama_index.bridge.pydantic import Field -from llama_index.callbacks.base import CallbackManager -from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.node_parser.file.html import HTMLNodeParser -from llama_index.node_parser.file.json import JSONNodeParser -from llama_index.node_parser.file.markdown import MarkdownNodeParser -from llama_index.node_parser.interface import NodeParser -from llama_index.schema import BaseNode, Document -from llama_index.utils import get_tqdm_iterable - -FILE_NODE_PARSERS: Dict[str, Type[NodeParser]] = { - ".md": MarkdownNodeParser, - ".html": HTMLNodeParser, - ".json": JSONNodeParser, -} - - -class SimpleFileNodeParser(NodeParser): - """Simple file node parser. - - Splits a document loaded from a file into Nodes using logic based on the file type - automatically detects the NodeParser to use based on file type - - Args: - include_metadata (bool): whether to include metadata in nodes - include_prev_next_rel (bool): whether to include prev/next relationships - - """ - - include_metadata: bool = Field( - default=True, description="Whether or not to consider metadata when splitting." - ) - include_prev_next_rel: bool = Field( - default=True, description="Include prev/next node relationships." - ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True - ) - - @classmethod - def from_defaults( - cls, - include_metadata: bool = True, - include_prev_next_rel: bool = True, - callback_manager: Optional[CallbackManager] = None, - ) -> "SimpleFileNodeParser": - callback_manager = callback_manager or CallbackManager([]) - - return cls( - include_metadata=include_metadata, - include_prev_next_rel=include_prev_next_rel, - callback_manager=callback_manager, - ) - - @classmethod - def class_name(cls) -> str: - """Get class name.""" - return "SimpleFileNodeParser" - - def get_nodes_from_documents( - self, - documents: Sequence[Document], - show_progress: bool = False, - ) -> List[BaseNode]: - """Parse document into nodes. - - Args: - documents (Sequence[Document]): documents to parse - """ - with self.callback_manager.event( - CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents} - ) as event: - all_nodes: List[BaseNode] = [] - documents_with_progress = get_tqdm_iterable( - documents, show_progress, "Parsing documents into nodes" - ) - - for document in documents_with_progress: - ext = document.metadata["extension"] - if ext in FILE_NODE_PARSERS: - parser = FILE_NODE_PARSERS[ext]( - include_metadata=self.include_metadata, - include_prev_next_rel=self.include_prev_next_rel, - callback_manager=self.callback_manager, - ) - - nodes = parser.get_nodes_from_documents([document], show_progress) - all_nodes.extend(nodes) - else: - # What to do when file type isn't supported yet? - all_nodes.extend(document) - - event.on_end(payload={EventPayload.NODES: all_nodes}) - - return all_nodes diff --git a/llama_index/node_parser/text/__init__.py b/llama_index/node_parser/text/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..13d4287121276caece70b5fe53d59b24ceefeb78 --- /dev/null +++ b/llama_index/node_parser/text/__init__.py @@ -0,0 +1,13 @@ +from llama_index.node_parser.text.code import CodeSplitter +from llama_index.node_parser.text.langchain import LangchainNodeParser +from llama_index.node_parser.text.sentence import SentenceSplitter +from llama_index.node_parser.text.sentence_window import SentenceWindowNodeParser +from llama_index.node_parser.text.token import TokenTextSplitter + +__all__ = [ + "CodeSplitter", + "LangchainNodeParser", + "SentenceSplitter", + "SentenceWindowNodeParser", + "TokenTextSplitter", +] diff --git a/llama_index/text_splitter/code_splitter.py b/llama_index/node_parser/text/code.py similarity index 84% rename from llama_index/text_splitter/code_splitter.py rename to llama_index/node_parser/text/code.py index cc610cc9eb4530aafef7199b7cc391e196929da1..4d89b68ab0530743754e069b4314d21fc30cc8d5 100644 --- a/llama_index/text_splitter/code_splitter.py +++ b/llama_index/node_parser/text/code.py @@ -1,10 +1,9 @@ """Code splitter.""" -from typing import Any, List, Optional +from typing import Any, List from llama_index.bridge.pydantic import Field -from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.text_splitter.types import TextSplitter +from llama_index.node_parser.interface import TextSplitter DEFAULT_CHUNK_LINES = 40 DEFAULT_LINES_OVERLAP = 15 @@ -24,33 +23,32 @@ class CodeSplitter(TextSplitter): chunk_lines: int = Field( default=DEFAULT_CHUNK_LINES, description="The number of lines to include in each chunk.", + gt=0, ) chunk_lines_overlap: int = Field( default=DEFAULT_LINES_OVERLAP, description="How many lines of code each chunk overlaps with.", + gt=0, ) max_chars: int = Field( - default=DEFAULT_MAX_CHARS, description="Maximum number of characters per chunk." - ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True + default=DEFAULT_MAX_CHARS, + description="Maximum number of characters per chunk.", + gt=0, ) - def __init__( - self, + @classmethod + def from_defaults( + cls, language: str, - chunk_lines: int = 40, - chunk_lines_overlap: int = 15, - max_chars: int = 1500, - callback_manager: Optional[CallbackManager] = None, - ): - callback_manager = callback_manager or CallbackManager([]) - super().__init__( + chunk_lines: int = DEFAULT_CHUNK_LINES, + chunk_lines_overlap: int = DEFAULT_LINES_OVERLAP, + max_chars: int = DEFAULT_MAX_CHARS, + ) -> "CodeSplitter": + return cls( language=language, chunk_lines=chunk_lines, chunk_lines_overlap=chunk_lines_overlap, max_chars=max_chars, - callback_manager=callback_manager, ) @classmethod diff --git a/llama_index/node_parser/text/langchain.py b/llama_index/node_parser/text/langchain.py new file mode 100644 index 0000000000000000000000000000000000000000..5f938fb006c44fdf37552939e56907f000c7fd31 --- /dev/null +++ b/llama_index/node_parser/text/langchain.py @@ -0,0 +1,45 @@ +from typing import TYPE_CHECKING, List, Optional + +from llama_index.bridge.pydantic import PrivateAttr +from llama_index.callbacks import CallbackManager +from llama_index.node_parser.interface import TextSplitter + +if TYPE_CHECKING: + from langchain.text_splitter import TextSplitter as LC_TextSplitter + + +class LangchainNodeParser(TextSplitter): + """ + Basic wrapper around langchain's text splitter. + + TODO: Figure out how to make this metadata aware. + """ + + _lc_splitter: "LC_TextSplitter" = PrivateAttr() + + def __init__( + self, + lc_splitter: "LC_TextSplitter", + callback_manager: Optional[CallbackManager] = None, + include_metadata: bool = True, + include_prev_next_rel: bool = True, + ): + """Initialize with parameters.""" + try: + from langchain.text_splitter import TextSplitter as LC_TextSplitter # noqa + except ImportError: + raise ImportError( + "Could not run `from langchain.text_splitter import TextSplitter`, " + "please run `pip install langchain`" + ) + + super().__init__( + callback_manager=callback_manager or CallbackManager(), + include_metadata=include_metadata, + include_prev_next_rel=include_prev_next_rel, + ) + self._lc_splitter = lc_splitter + + def split_text(self, text: str) -> List[str]: + """Split text into sentences.""" + return self._lc_splitter.split_text(text) diff --git a/llama_index/text_splitter/sentence_splitter.py b/llama_index/node_parser/text/sentence.py similarity index 81% rename from llama_index/text_splitter/sentence_splitter.py rename to llama_index/node_parser/text/sentence.py index d3f7f07b257aa9ec600efee10323fec2cd744199..4ebd6cb5eb5a8978af68ba3c8fa264354936d6d0 100644 --- a/llama_index/text_splitter/sentence_splitter.py +++ b/llama_index/node_parser/text/sentence.py @@ -6,14 +6,14 @@ from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.constants import DEFAULT_CHUNK_SIZE -from llama_index.text_splitter.types import MetadataAwareTextSplitter -from llama_index.text_splitter.utils import ( +from llama_index.node_parser.interface import MetadataAwareTextSplitter +from llama_index.node_parser.text.utils import ( split_by_char, split_by_regex, split_by_sentence_tokenizer, split_by_sep, ) -from llama_index.utils import globals_helper +from llama_index.utils import get_tokenizer SENTENCE_CHUNK_OVERLAP = 200 CHUNKING_REGEX = "[^,.;。?ï¼]+[,.;。?ï¼]?" @@ -27,7 +27,7 @@ class _Split: class SentenceSplitter(MetadataAwareTextSplitter): - """_Split text with a preference for complete sentences. + """Parse text with a preference for complete sentences. In general, this class tries to keep sentences and paragraphs together. Therefore compared to the original TokenTextSplitter, there are less likely to be @@ -35,11 +35,14 @@ class SentenceSplitter(MetadataAwareTextSplitter): """ chunk_size: int = Field( - default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk." + default=DEFAULT_CHUNK_SIZE, + description="The token chunk size for each chunk.", + gt=0, ) chunk_overlap: int = Field( default=SENTENCE_CHUNK_OVERLAP, description="The token overlap of each chunk when splitting.", + gte=0, ) separator: str = Field( default=" ", description="Default separator for splitting into words" @@ -50,22 +53,9 @@ class SentenceSplitter(MetadataAwareTextSplitter): secondary_chunking_regex: str = Field( default=CHUNKING_REGEX, description="Backup regex for splitting into sentences." ) - chunking_tokenizer_fn: Callable[[str], List[str]] = Field( - exclude=True, - description=( - "Function to split text into sentences. " - "Defaults to `nltk.sent_tokenize`." - ), - ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True - ) - tokenizer: Callable = Field( - default_factory=globals_helper.tokenizer, # type: ignore - description="Tokenizer for splitting words into tokens.", - exclude=True, - ) + _chunking_tokenizer_fn: Callable[[str], List[str]] = PrivateAttr() + _tokenizer: Callable = PrivateAttr() _split_fns: List[Callable] = PrivateAttr() _sub_sentence_split_fns: List[Callable] = PrivateAttr() @@ -79,6 +69,8 @@ class SentenceSplitter(MetadataAwareTextSplitter): chunking_tokenizer_fn: Optional[Callable[[str], List[str]]] = None, secondary_chunking_regex: str = CHUNKING_REGEX, callback_manager: Optional[CallbackManager] = None, + include_metadata: bool = True, + include_prev_next_rel: bool = True, ): """Initialize with parameters.""" if chunk_overlap > chunk_size: @@ -88,12 +80,14 @@ class SentenceSplitter(MetadataAwareTextSplitter): ) callback_manager = callback_manager or CallbackManager([]) - chunking_tokenizer_fn = chunking_tokenizer_fn or split_by_sentence_tokenizer() - tokenizer = tokenizer or globals_helper.tokenizer + self._chunking_tokenizer_fn = ( + chunking_tokenizer_fn or split_by_sentence_tokenizer() + ) + self._tokenizer = tokenizer or get_tokenizer() self._split_fns = [ split_by_sep(paragraph_separator), - chunking_tokenizer_fn, + self._chunking_tokenizer_fn, ] self._sub_sentence_split_fns = [ @@ -105,12 +99,41 @@ class SentenceSplitter(MetadataAwareTextSplitter): super().__init__( chunk_size=chunk_size, chunk_overlap=chunk_overlap, - chunking_tokenizer_fn=chunking_tokenizer_fn, secondary_chunking_regex=secondary_chunking_regex, separator=separator, paragraph_separator=paragraph_separator, callback_manager=callback_manager, + include_metadata=include_metadata, + include_prev_next_rel=include_prev_next_rel, + ) + + @classmethod + def from_defaults( + cls, + separator: str = " ", + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = SENTENCE_CHUNK_OVERLAP, + tokenizer: Optional[Callable] = None, + paragraph_separator: str = DEFAULT_PARAGRAPH_SEP, + chunking_tokenizer_fn: Optional[Callable[[str], List[str]]] = None, + secondary_chunking_regex: str = CHUNKING_REGEX, + callback_manager: Optional[CallbackManager] = None, + include_metadata: bool = True, + include_prev_next_rel: bool = True, + ) -> "SentenceSplitter": + """Initialize with parameters.""" + callback_manager = callback_manager or CallbackManager([]) + return cls( + separator=separator, + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, tokenizer=tokenizer, + paragraph_separator=paragraph_separator, + chunking_tokenizer_fn=chunking_tokenizer_fn, + secondary_chunking_regex=secondary_chunking_regex, + callback_manager=callback_manager, + include_metadata=include_metadata, + include_prev_next_rel=include_prev_next_rel, ) @classmethod @@ -118,7 +141,7 @@ class SentenceSplitter(MetadataAwareTextSplitter): return "SentenceSplitter" def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: - metadata_len = len(self.tokenizer(metadata_str)) + metadata_len = len(self._tokenizer(metadata_str)) effective_chunk_size = self.chunk_size - metadata_len if effective_chunk_size <= 0: raise ValueError( @@ -221,7 +244,7 @@ class SentenceSplitter(MetadataAwareTextSplitter): while len(splits) > 0: cur_split = splits[0] - cur_split_len = len(self.tokenizer(cur_split.text)) + cur_split_len = len(self._tokenizer(cur_split.text)) if cur_split_len > chunk_size: raise ValueError("Single token exceeded chunk size") if cur_chunk_len + cur_split_len > chunk_size and not new_chunk: @@ -263,7 +286,7 @@ class SentenceSplitter(MetadataAwareTextSplitter): return new_chunks def _token_size(self, text: str) -> int: - return len(self.tokenizer(text)) + return len(self._tokenizer(text)) def _get_splits_by_fns(self, text: str) -> Tuple[List[str], bool]: for split_fn in self._split_fns: diff --git a/llama_index/node_parser/sentence_window.py b/llama_index/node_parser/text/sentence_window.py similarity index 55% rename from llama_index/node_parser/sentence_window.py rename to llama_index/node_parser/text/sentence_window.py index 54cabb93fccd9f1c0719e1a568e9f883f3104788..dd7a5749372db079061ebcb586edf21692443162 100644 --- a/llama_index/node_parser/sentence_window.py +++ b/llama_index/node_parser/text/sentence_window.py @@ -1,14 +1,12 @@ """Simple node parser.""" -from typing import Callable, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks.base import CallbackManager -from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.node_parser.extractors.metadata_extractors import MetadataExtractor from llama_index.node_parser.interface import NodeParser from llama_index.node_parser.node_utils import build_nodes_from_splits -from llama_index.schema import BaseNode, Document -from llama_index.text_splitter.utils import split_by_sentence_tokenizer +from llama_index.node_parser.text.utils import split_by_sentence_tokenizer +from llama_index.schema import BaseNode, Document, MetadataMode from llama_index.utils import get_tqdm_iterable DEFAULT_WINDOW_SIZE = 3 @@ -36,6 +34,7 @@ class SentenceWindowNodeParser(NodeParser): window_size: int = Field( default=DEFAULT_WINDOW_SIZE, description="The number of sentences on each side of a sentence to capture.", + gt=0, ) window_metadata_key: str = Field( default=DEFAULT_WINDOW_METADATA_KEY, @@ -45,53 +44,11 @@ class SentenceWindowNodeParser(NodeParser): default=DEFAULT_OG_TEXT_METADATA_KEY, description="The metadata key to store the original sentence in.", ) - include_metadata: bool = Field( - default=True, description="Whether or not to consider metadata when splitting." - ) - include_prev_next_rel: bool = Field( - default=True, description="Include prev/next node relationships." - ) - metadata_extractor: Optional[MetadataExtractor] = Field( - default=None, description="Metadata extraction pipeline to apply to nodes." - ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True - ) - - def __init__( - self, - sentence_splitter: Optional[Callable[[str], List[str]]] = None, - window_size: int = DEFAULT_WINDOW_SIZE, - window_metadata_key: str = DEFAULT_WINDOW_METADATA_KEY, - original_text_metadata_key: str = DEFAULT_OG_TEXT_METADATA_KEY, - include_metadata: bool = True, - include_prev_next_rel: bool = True, - callback_manager: Optional[CallbackManager] = None, - metadata_extractor: Optional[MetadataExtractor] = None, - ) -> None: - """Init params.""" - callback_manager = callback_manager or CallbackManager([]) - sentence_splitter = sentence_splitter or split_by_sentence_tokenizer() - super().__init__( - sentence_splitter=sentence_splitter, - window_size=window_size, - window_metadata_key=window_metadata_key, - original_text_metadata_key=original_text_metadata_key, - include_metadata=include_metadata, - include_prev_next_rel=include_prev_next_rel, - callback_manager=callback_manager, - metadata_extractor=metadata_extractor, - ) @classmethod def class_name(cls) -> str: return "SentenceWindowNodeParser" - @property - def text_splitter(self) -> Callable[[str], List[str]]: - """Get text splitter.""" - return self.sentence_splitter - @classmethod def from_defaults( cls, @@ -102,7 +59,6 @@ class SentenceWindowNodeParser(NodeParser): include_metadata: bool = True, include_prev_next_rel: bool = True, callback_manager: Optional[CallbackManager] = None, - metadata_extractor: Optional[MetadataExtractor] = None, ) -> "SentenceWindowNodeParser": callback_manager = callback_manager or CallbackManager([]) @@ -116,38 +72,22 @@ class SentenceWindowNodeParser(NodeParser): include_metadata=include_metadata, include_prev_next_rel=include_prev_next_rel, callback_manager=callback_manager, - metadata_extractor=metadata_extractor, ) - def get_nodes_from_documents( + def _parse_nodes( self, - documents: Sequence[Document], + nodes: Sequence[BaseNode], show_progress: bool = False, + **kwargs: Any, ) -> List[BaseNode]: - """Parse document into nodes. - - Args: - documents (Sequence[Document]): documents to parse - include_metadata (bool): whether to include metadata in nodes - - """ - with self.callback_manager.event( - CBEventType.NODE_PARSING, payload={EventPayload.DOCUMENTS: documents} - ) as event: - all_nodes: List[BaseNode] = [] - documents_with_progress = get_tqdm_iterable( - documents, show_progress, "Parsing documents into nodes" - ) - - for document in documents_with_progress: - self.sentence_splitter(document.text) - nodes = self.build_window_nodes_from_documents([document]) - all_nodes.extend(nodes) - - if self.metadata_extractor is not None: - all_nodes = self.metadata_extractor.process_nodes(all_nodes) + """Parse document into nodes.""" + all_nodes: List[BaseNode] = [] + nodes_with_progress = get_tqdm_iterable(nodes, show_progress, "Parsing nodes") - event.on_end(payload={EventPayload.NODES: all_nodes}) + for node in nodes_with_progress: + self.sentence_splitter(node.get_content(metadata_mode=MetadataMode.NONE)) + nodes = self.build_window_nodes_from_documents([node]) + all_nodes.extend(nodes) return all_nodes @@ -160,7 +100,8 @@ class SentenceWindowNodeParser(NodeParser): text = doc.text text_splits = self.sentence_splitter(text) nodes = build_nodes_from_splits( - text_splits, doc, include_prev_next_rel=True + text_splits, + doc, ) # add window to each node diff --git a/llama_index/text_splitter/token_splitter.py b/llama_index/node_parser/text/token.py similarity index 78% rename from llama_index/text_splitter/token_splitter.py rename to llama_index/node_parser/text/token.py index 07c59bdd4c537c86c9a8a92c6dde83df82d96b45..95491bc5d52920eb29d6ddf7f7451b009420cfe9 100644 --- a/llama_index/text_splitter/token_splitter.py +++ b/llama_index/node_parser/text/token.py @@ -6,9 +6,9 @@ from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE -from llama_index.text_splitter.types import MetadataAwareTextSplitter -from llama_index.text_splitter.utils import split_by_char, split_by_sep -from llama_index.utils import globals_helper +from llama_index.node_parser.interface import MetadataAwareTextSplitter +from llama_index.node_parser.text.utils import split_by_char, split_by_sep +from llama_index.utils import get_tokenizer _logger = logging.getLogger(__name__) @@ -20,11 +20,14 @@ class TokenTextSplitter(MetadataAwareTextSplitter): """Implementation of splitting text that looks at word tokens.""" chunk_size: int = Field( - default=DEFAULT_CHUNK_SIZE, description="The token chunk size for each chunk." + default=DEFAULT_CHUNK_SIZE, + description="The token chunk size for each chunk.", + gt=0, ) chunk_overlap: int = Field( default=DEFAULT_CHUNK_OVERLAP, description="The token overlap of each chunk when splitting.", + gte=0, ) separator: str = Field( default=" ", description="Default separator for splitting into words" @@ -32,15 +35,8 @@ class TokenTextSplitter(MetadataAwareTextSplitter): backup_separators: List = Field( default_factory=list, description="Additional separators for splitting." ) - callback_manager: CallbackManager = Field( - default_factory=CallbackManager, exclude=True - ) - tokenizer: Callable = Field( - default_factory=globals_helper.tokenizer, # type: ignore - description="Tokenizer for splitting words into tokens.", - exclude=True, - ) + _tokenizer: Callable = PrivateAttr() _split_fns: List[Callable] = PrivateAttr() def __init__( @@ -51,6 +47,8 @@ class TokenTextSplitter(MetadataAwareTextSplitter): callback_manager: Optional[CallbackManager] = None, separator: str = " ", backup_separators: Optional[List[str]] = ["\n"], + include_metadata: bool = True, + include_prev_next_rel: bool = True, ): """Initialize with parameters.""" if chunk_overlap > chunk_size: @@ -59,7 +57,8 @@ class TokenTextSplitter(MetadataAwareTextSplitter): f"({chunk_size}), should be smaller." ) callback_manager = callback_manager or CallbackManager([]) - tokenizer = tokenizer or globals_helper.tokenizer + + self._tokenizer = tokenizer or get_tokenizer() all_seps = [separator] + (backup_separators or []) self._split_fns = [split_by_sep(sep) for sep in all_seps] + [split_by_char()] @@ -70,7 +69,31 @@ class TokenTextSplitter(MetadataAwareTextSplitter): separator=separator, backup_separators=backup_separators, callback_manager=callback_manager, - tokenizer=tokenizer, + include_metadata=include_metadata, + include_prev_next_rel=include_prev_next_rel, + ) + + @classmethod + def from_defaults( + cls, + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = DEFAULT_CHUNK_OVERLAP, + separator: str = " ", + backup_separators: Optional[List[str]] = ["\n"], + callback_manager: Optional[CallbackManager] = None, + include_metadata: bool = True, + include_prev_next_rel: bool = True, + ) -> "TokenTextSplitter": + """Initialize with default parameters.""" + callback_manager = callback_manager or CallbackManager([]) + return cls( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + separator=separator, + backup_separators=backup_separators, + callback_manager=callback_manager, + include_metadata=include_metadata, + include_prev_next_rel=include_prev_next_rel, ) @classmethod @@ -79,7 +102,7 @@ class TokenTextSplitter(MetadataAwareTextSplitter): def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: """Split text into chunks, reserving space required for metadata str.""" - metadata_len = len(self.tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN + metadata_len = len(self._tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN effective_chunk_size = self.chunk_size - metadata_len if effective_chunk_size <= 0: raise ValueError( @@ -129,7 +152,7 @@ class TokenTextSplitter(MetadataAwareTextSplitter): NOTE: the splits contain the separators. """ - if len(self.tokenizer(text)) <= chunk_size: + if len(self._tokenizer(text)) <= chunk_size: return [text] for split_fn in self._split_fns: @@ -139,7 +162,7 @@ class TokenTextSplitter(MetadataAwareTextSplitter): new_splits = [] for split in splits: - split_len = len(self.tokenizer(split)) + split_len = len(self._tokenizer(split)) if split_len <= chunk_size: new_splits.append(split) else: @@ -161,7 +184,7 @@ class TokenTextSplitter(MetadataAwareTextSplitter): cur_chunk: List[str] = [] cur_len = 0 for split in splits: - split_len = len(self.tokenizer(split)) + split_len = len(self._tokenizer(split)) if split_len > chunk_size: _logger.warning( f"Got a split of size {split_len}, ", @@ -183,7 +206,7 @@ class TokenTextSplitter(MetadataAwareTextSplitter): while cur_len > self.chunk_overlap or cur_len + split_len > chunk_size: # pop off the first element first_chunk = cur_chunk.pop(0) - cur_len -= len(self.tokenizer(first_chunk)) + cur_len -= len(self._tokenizer(first_chunk)) cur_chunk.append(split) cur_len += split_len diff --git a/llama_index/text_splitter/utils.py b/llama_index/node_parser/text/utils.py similarity index 96% rename from llama_index/text_splitter/utils.py rename to llama_index/node_parser/text/utils.py index e959d723f1428a07626c2529eb4582d5271e6ec2..1f581c43c369757af2ffcf5eceecebc28a1408cb 100644 --- a/llama_index/text_splitter/utils.py +++ b/llama_index/node_parser/text/utils.py @@ -1,7 +1,9 @@ import logging from typing import Callable, List -from llama_index.text_splitter.types import TextSplitter +from llama_index.node_parser.interface import TextSplitter + +logger = logging.getLogger(__name__) logger = logging.getLogger(__name__) diff --git a/llama_index/objects/base.py b/llama_index/objects/base.py index 7730eb1186bf8eb3b8476c0ad49e5f0a553295c6..ea47bad2ad5dde0468a8f7fc4ed1a0febd4d4d42 100644 --- a/llama_index/objects/base.py +++ b/llama_index/objects/base.py @@ -2,14 +2,14 @@ from typing import Any, Generic, List, Optional, Sequence, Type, TypeVar +from llama_index.core import BaseRetriever from llama_index.indices.base import BaseIndex -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryType from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.objects.base_node_mapping import ( BaseObjectNodeMapping, SimpleObjectNodeMapping, ) +from llama_index.schema import QueryType OT = TypeVar("OT") diff --git a/llama_index/output_parsers/guardrails.py b/llama_index/output_parsers/guardrails.py index 6555f4b0a0c335bb8ef7e50aaa26c16efc7280d9..406ab904ee1e4019816085c027b453399eccefb3 100644 --- a/llama_index/output_parsers/guardrails.py +++ b/llama_index/output_parsers/guardrails.py @@ -12,13 +12,14 @@ except ImportError: PromptCallable = None from copy import deepcopy -from typing import Any, Callable, Optional +from typing import TYPE_CHECKING, Any, Callable, Optional -from llama_index.bridge.langchain import BaseLLM +if TYPE_CHECKING: + from llama_index.bridge.langchain import BaseLLM from llama_index.types import BaseOutputParser -def get_callable(llm: Optional[BaseLLM]) -> Optional[Callable]: +def get_callable(llm: Optional["BaseLLM"]) -> Optional[Callable]: """Get callable.""" if llm is None: return None @@ -32,7 +33,7 @@ class GuardrailsOutputParser(BaseOutputParser): def __init__( self, guard: Guard, - llm: Optional[BaseLLM] = None, + llm: Optional["BaseLLM"] = None, format_key: Optional[str] = None, ): """Initialize a Guardrails output parser.""" @@ -43,7 +44,7 @@ class GuardrailsOutputParser(BaseOutputParser): @classmethod @deprecated(version="0.8.46") def from_rail( - cls, rail: str, llm: Optional[BaseLLM] = None + cls, rail: str, llm: Optional["BaseLLM"] = None ) -> "GuardrailsOutputParser": """From rail.""" if Guard is None: @@ -56,7 +57,7 @@ class GuardrailsOutputParser(BaseOutputParser): @classmethod @deprecated(version="0.8.46") def from_rail_string( - cls, rail_string: str, llm: Optional[BaseLLM] = None + cls, rail_string: str, llm: Optional["BaseLLM"] = None ) -> "GuardrailsOutputParser": """From rail string.""" if Guard is None: @@ -69,7 +70,7 @@ class GuardrailsOutputParser(BaseOutputParser): def parse( self, output: str, - llm: Optional[BaseLLM] = None, + llm: Optional["BaseLLM"] = None, num_reasks: Optional[int] = 1, *args: Any, **kwargs: Any diff --git a/llama_index/output_parsers/langchain.py b/llama_index/output_parsers/langchain.py index a60629a71de50d808e179ab3d8e0ab4da409f60e..a8a86c67d49d9776cdc4caf0af3077385ca848fd 100644 --- a/llama_index/output_parsers/langchain.py +++ b/llama_index/output_parsers/langchain.py @@ -1,9 +1,10 @@ """Base output parser class.""" from string import Formatter -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional -from llama_index.bridge.langchain import BaseOutputParser as LCOutputParser +if TYPE_CHECKING: + from llama_index.bridge.langchain import BaseOutputParser as LCOutputParser from llama_index.types import BaseOutputParser @@ -11,7 +12,7 @@ class LangchainOutputParser(BaseOutputParser): """Langchain output parser.""" def __init__( - self, output_parser: LCOutputParser, format_key: Optional[str] = None + self, output_parser: "LCOutputParser", format_key: Optional[str] = None ) -> None: """Init params.""" self._output_parser = output_parser diff --git a/llama_index/indices/postprocessor/__init__.py b/llama_index/postprocessor/__init__.py similarity index 58% rename from llama_index/indices/postprocessor/__init__.py rename to llama_index/postprocessor/__init__.py index d32240833d52aa0fdee154f92ab889fcf1f1fa63..1e46caf9cb2b8daccb56df4675c586865dd70673 100644 --- a/llama_index/indices/postprocessor/__init__.py +++ b/llama_index/postprocessor/__init__.py @@ -1,30 +1,30 @@ """Node PostProcessor module.""" -from llama_index.indices.postprocessor.cohere_rerank import CohereRerank -from llama_index.indices.postprocessor.llm_rerank import LLMRerank -from llama_index.indices.postprocessor.longllmlingua import LongLLMLinguaPostprocessor -from llama_index.indices.postprocessor.metadata_replacement import ( +from llama_index.postprocessor.cohere_rerank import CohereRerank +from llama_index.postprocessor.llm_rerank import LLMRerank +from llama_index.postprocessor.longllmlingua import LongLLMLinguaPostprocessor +from llama_index.postprocessor.metadata_replacement import ( MetadataReplacementPostProcessor, ) -from llama_index.indices.postprocessor.node import ( +from llama_index.postprocessor.node import ( AutoPrevNextNodePostprocessor, KeywordNodePostprocessor, LongContextReorder, PrevNextNodePostprocessor, SimilarityPostprocessor, ) -from llama_index.indices.postprocessor.node_recency import ( +from llama_index.postprocessor.node_recency import ( EmbeddingRecencyPostprocessor, FixedRecencyPostprocessor, TimeWeightedPostprocessor, ) -from llama_index.indices.postprocessor.optimizer import SentenceEmbeddingOptimizer -from llama_index.indices.postprocessor.pii import ( +from llama_index.postprocessor.optimizer import SentenceEmbeddingOptimizer +from llama_index.postprocessor.pii import ( NERPIINodePostprocessor, PIINodePostprocessor, ) -from llama_index.indices.postprocessor.sbert_rerank import SentenceTransformerRerank +from llama_index.postprocessor.sbert_rerank import SentenceTransformerRerank __all__ = [ "SimilarityPostprocessor", diff --git a/llama_index/indices/postprocessor/cohere_rerank.py b/llama_index/postprocessor/cohere_rerank.py similarity index 93% rename from llama_index/indices/postprocessor/cohere_rerank.py rename to llama_index/postprocessor/cohere_rerank.py index 68f3fcb93f308c4c04db9e48dd5f83cec267b2e1..950b1715834caa4f8c86ea361e90fe7f787baea9 100644 --- a/llama_index/indices/postprocessor/cohere_rerank.py +++ b/llama_index/postprocessor/cohere_rerank.py @@ -3,9 +3,8 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CBEventType, EventPayload -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import NodeWithScore +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import NodeWithScore, QueryBundle class CohereRerank(BaseNodePostprocessor): diff --git a/llama_index/indices/postprocessor/llm_rerank.py b/llama_index/postprocessor/llm_rerank.py similarity index 94% rename from llama_index/indices/postprocessor/llm_rerank.py rename to llama_index/postprocessor/llm_rerank.py index 6baa56df0a162d032dbf0f5764eb16a1a821e421..000d10aad71d4797c2b18ccb5f095accb22a7329 100644 --- a/llama_index/indices/postprocessor/llm_rerank.py +++ b/llama_index/postprocessor/llm_rerank.py @@ -2,17 +2,16 @@ from typing import Callable, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import ( default_format_node_batch_fn, default_parse_choice_select_answer_fn, ) +from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_CHOICE_SELECT_PROMPT from llama_index.prompts.mixin import PromptDictType -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext class LLMRerank(BaseNodePostprocessor): diff --git a/llama_index/indices/postprocessor/longllmlingua.py b/llama_index/postprocessor/longllmlingua.py similarity index 95% rename from llama_index/indices/postprocessor/longllmlingua.py rename to llama_index/postprocessor/longllmlingua.py index e15f0794d51f6233bd0e73919650f620e9d0653d..e6379760f8d595c5129b0cdd8718170a55d67db0 100644 --- a/llama_index/indices/postprocessor/longllmlingua.py +++ b/llama_index/postprocessor/longllmlingua.py @@ -3,9 +3,8 @@ import logging from typing import Any, Dict, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import MetadataMode, NodeWithScore, TextNode +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode logger = logging.getLogger(__name__) diff --git a/llama_index/indices/postprocessor/metadata_replacement.py b/llama_index/postprocessor/metadata_replacement.py similarity index 82% rename from llama_index/indices/postprocessor/metadata_replacement.py rename to llama_index/postprocessor/metadata_replacement.py index 0840aaea456a987ace21d51b78e252dcf995bfbd..82513a9396587f1c30bbf4519ee94aa484c26487 100644 --- a/llama_index/indices/postprocessor/metadata_replacement.py +++ b/llama_index/postprocessor/metadata_replacement.py @@ -1,9 +1,8 @@ from typing import List, Optional from llama_index.bridge.pydantic import Field -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import MetadataMode, NodeWithScore +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle class MetadataReplacementPostProcessor(BaseNodePostprocessor): diff --git a/llama_index/indices/postprocessor/node.py b/llama_index/postprocessor/node.py similarity index 98% rename from llama_index/indices/postprocessor/node.py rename to llama_index/postprocessor/node.py index 3bbb1b39154516ba042b19ac2f1d8b4e4d935df7..a0b8d9b66b0d0c173d9fb7956e3df559655e6f91 100644 --- a/llama_index/indices/postprocessor/node.py +++ b/llama_index/postprocessor/node.py @@ -4,12 +4,11 @@ import logging from typing import Dict, List, Optional, cast from llama_index.bridge.pydantic import Field, validator -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts.base import PromptTemplate from llama_index.response_synthesizers import ResponseMode, get_response_synthesizer -from llama_index.schema import NodeRelationship, NodeWithScore +from llama_index.schema import NodeRelationship, NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext from llama_index.storage.docstore import BaseDocumentStore logger = logging.getLogger(__name__) diff --git a/llama_index/indices/postprocessor/node_recency.py b/llama_index/postprocessor/node_recency.py similarity index 96% rename from llama_index/indices/postprocessor/node_recency.py rename to llama_index/postprocessor/node_recency.py index 7d8c0dba32b1678e4a7f80e1e1523f8e07a0fefe..55c3bab45658e3bc6f480e5eaff2b4ed691ea36a 100644 --- a/llama_index/indices/postprocessor/node_recency.py +++ b/llama_index/postprocessor/node_recency.py @@ -6,10 +6,9 @@ import numpy as np import pandas as pd from llama_index.bridge.pydantic import Field -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext -from llama_index.schema import MetadataMode, NodeWithScore +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext # NOTE: currently not being used # DEFAULT_INFER_RECENCY_TMPL = ( diff --git a/llama_index/indices/postprocessor/optimizer.py b/llama_index/postprocessor/optimizer.py similarity index 96% rename from llama_index/indices/postprocessor/optimizer.py rename to llama_index/postprocessor/optimizer.py index b706c4b3c1d4523befe733b538631b10f8884f7f..b5b80fe4961751ceb424d83f888036a94e6335fa 100644 --- a/llama_index/indices/postprocessor/optimizer.py +++ b/llama_index/postprocessor/optimizer.py @@ -5,10 +5,9 @@ from typing import Callable, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.embeddings.base import BaseEmbedding from llama_index.embeddings.openai import OpenAIEmbedding -from llama_index.indices.postprocessor.types import BaseNodePostprocessor from llama_index.indices.query.embedding_utils import get_top_k_embeddings -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import MetadataMode, NodeWithScore +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle logger = logging.getLogger(__name__) diff --git a/llama_index/indices/postprocessor/pii.py b/llama_index/postprocessor/pii.py similarity index 95% rename from llama_index/indices/postprocessor/pii.py rename to llama_index/postprocessor/pii.py index 3ce47fcb7e1cc48bc031cbe070dfb0f3db54de64..83eb2d6ada309848c786bdcb92b5a35b2a37a4a1 100644 --- a/llama_index/indices/postprocessor/pii.py +++ b/llama_index/postprocessor/pii.py @@ -3,11 +3,10 @@ import json from copy import deepcopy from typing import Callable, Dict, List, Optional, Tuple -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts.base import PromptTemplate -from llama_index.schema import MetadataMode, NodeWithScore +from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext DEFAULT_PII_TMPL = ( "The current context information is provided. \n" diff --git a/llama_index/indices/postprocessor/sbert_rerank.py b/llama_index/postprocessor/sbert_rerank.py similarity index 92% rename from llama_index/indices/postprocessor/sbert_rerank.py rename to llama_index/postprocessor/sbert_rerank.py index 73e00bb4cf496195e2795635cb58054878437c5d..ea65c0e0846818104ff58ab4a5975fdc5ec4c7f5 100644 --- a/llama_index/indices/postprocessor/sbert_rerank.py +++ b/llama_index/postprocessor/sbert_rerank.py @@ -2,9 +2,8 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CBEventType, EventPayload -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import MetadataMode, NodeWithScore +from llama_index.postprocessor.types import BaseNodePostprocessor +from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle DEFAULT_SENTENCE_TRANSFORMER_MAX_LENGTH = 512 diff --git a/llama_index/indices/postprocessor/types.py b/llama_index/postprocessor/types.py similarity index 93% rename from llama_index/indices/postprocessor/types.py rename to llama_index/postprocessor/types.py index 1179e0da811648575838f38c2e66a23495371918..abdcef3ee8d89adf841e135b2711997f94679028 100644 --- a/llama_index/indices/postprocessor/types.py +++ b/llama_index/postprocessor/types.py @@ -3,9 +3,8 @@ from typing import List, Optional from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager -from llama_index.indices.query.schema import QueryBundle from llama_index.prompts.mixin import PromptDictType, PromptMixinType -from llama_index.schema import BaseComponent, NodeWithScore +from llama_index.schema import BaseComponent, NodeWithScore, QueryBundle class BaseNodePostprocessor(BaseComponent, ABC): diff --git a/llama_index/program/predefined/evaporate/base.py b/llama_index/program/predefined/evaporate/base.py index b9796fd952a3e5e689fd9f10f9fd1878e0bf8b29..832a2327411d44b2138d273a68f99d547ec6be58 100644 --- a/llama_index/program/predefined/evaporate/base.py +++ b/llama_index/program/predefined/evaporate/base.py @@ -4,7 +4,6 @@ from typing import Any, Dict, Generic, List, Optional, Type import pandas as pd -from llama_index.indices.service_context import ServiceContext from llama_index.program.predefined.df import ( DataFrameRow, DataFrameRowsOnly, @@ -18,6 +17,7 @@ from llama_index.program.predefined.evaporate.prompts import ( SchemaIDPrompt, ) from llama_index.schema import BaseNode, TextNode +from llama_index.service_context import ServiceContext from llama_index.types import BasePydanticProgram, Model from llama_index.utils import print_text diff --git a/llama_index/program/predefined/evaporate/extractor.py b/llama_index/program/predefined/evaporate/extractor.py index bb3c3c15d282e00d1b6a7b379ff9f0d8e4a89807..bf8afb741141d1b7d6b6398874502e103973baa2 100644 --- a/llama_index/program/predefined/evaporate/extractor.py +++ b/llama_index/program/predefined/evaporate/extractor.py @@ -5,8 +5,6 @@ from collections import defaultdict from contextlib import contextmanager from typing import Any, Dict, List, Optional, Set, Tuple -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.program.predefined.evaporate.prompts import ( DEFAULT_EXPECTED_OUTPUT_PREFIX_TMPL, DEFAULT_FIELD_EXTRACT_QUERY_TMPL, @@ -15,7 +13,8 @@ from llama_index.program.predefined.evaporate.prompts import ( FnGeneratePrompt, SchemaIDPrompt, ) -from llama_index.schema import BaseNode, MetadataMode, NodeWithScore +from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext class TimeoutException(Exception): diff --git a/llama_index/prompts/base.py b/llama_index/prompts/base.py index 828310ac66d4b93ae00ffbecb6e30c81dbed419d..50690067385f1a016cfad1bbfb6df7cf554d4d92 100644 --- a/llama_index/prompts/base.py +++ b/llama_index/prompts/base.py @@ -3,17 +3,18 @@ from abc import ABC, abstractmethod from copy import deepcopy -from typing import Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple from pydantic import Field -from llama_index.bridge.langchain import BasePromptTemplate as LangchainTemplate -from llama_index.bridge.langchain import ConditionalPromptSelector as LangchainSelector +if TYPE_CHECKING: + from llama_index.bridge.langchain import BasePromptTemplate as LangchainTemplate + from llama_index.bridge.langchain import ( + ConditionalPromptSelector as LangchainSelector, + ) from llama_index.bridge.pydantic import BaseModel from llama_index.llms.base import LLM, ChatMessage from llama_index.llms.generic_utils import messages_to_prompt, prompt_to_messages -from llama_index.llms.langchain import LangChainLLM -from llama_index.llms.langchain_utils import from_lc_messages from llama_index.prompts.prompt_type import PromptType from llama_index.prompts.utils import get_template_vars from llama_index.types import BaseOutputParser @@ -273,7 +274,7 @@ class SelectorPromptTemplate(BasePromptTemplate): output_parser=output_parser, ) - def _select(self, llm: Optional[LLM] = None) -> BasePromptTemplate: + def select(self, llm: Optional[LLM] = None) -> BasePromptTemplate: # ensure output parser is up to date self.default_template.output_parser = self.output_parser @@ -304,29 +305,29 @@ class SelectorPromptTemplate(BasePromptTemplate): def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: """Format the prompt into a string.""" - prompt = self._select(llm=llm) + prompt = self.select(llm=llm) return prompt.format(**kwargs) def format_messages( self, llm: Optional[LLM] = None, **kwargs: Any ) -> List[ChatMessage]: """Format the prompt into a list of chat messages.""" - prompt = self._select(llm=llm) + prompt = self.select(llm=llm) return prompt.format_messages(**kwargs) def get_template(self, llm: Optional[LLM] = None) -> str: - prompt = self._select(llm=llm) + prompt = self.select(llm=llm) return prompt.get_template(llm=llm) class LangchainPromptTemplate(BasePromptTemplate): - selector: LangchainSelector + selector: Any requires_langchain_llm: bool = False def __init__( self, - template: Optional[LangchainTemplate] = None, - selector: Optional[LangchainSelector] = None, + template: Optional["LangchainTemplate"] = None, + selector: Optional["LangchainSelector"] = None, output_parser: Optional[BaseOutputParser] = None, prompt_type: str = PromptType.CUSTOM, metadata: Optional[Dict[str, Any]] = None, @@ -334,6 +335,14 @@ class LangchainPromptTemplate(BasePromptTemplate): function_mappings: Optional[Dict[str, Callable]] = None, requires_langchain_llm: bool = False, ) -> None: + try: + from llama_index.bridge.langchain import ( + ConditionalPromptSelector as LangchainSelector, + ) + except ImportError: + raise ImportError( + "Must install `llama_index[langchain]` to use LangchainPromptTemplate." + ) if selector is None: if template is None: raise ValueError("Must provide either template or selector.") @@ -363,6 +372,10 @@ class LangchainPromptTemplate(BasePromptTemplate): def partial_format(self, **kwargs: Any) -> "BasePromptTemplate": """Partially format the prompt.""" + from llama_index.bridge.langchain import ( + ConditionalPromptSelector as LangchainSelector, + ) + mapped_kwargs = self._map_all_vars(kwargs) default_prompt = self.selector.default_prompt.partial(**mapped_kwargs) conditionals = [ @@ -380,6 +393,8 @@ class LangchainPromptTemplate(BasePromptTemplate): def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: """Format the prompt into a string.""" + from llama_index.llms.langchain import LangChainLLM + if llm is not None: # if llamaindex LLM is provided, and we require a langchain LLM, # then error. but otherwise if `requires_langchain_llm` is False, @@ -401,6 +416,9 @@ class LangchainPromptTemplate(BasePromptTemplate): self, llm: Optional[LLM] = None, **kwargs: Any ) -> List[ChatMessage]: """Format the prompt into a list of chat messages.""" + from llama_index.llms.langchain import LangChainLLM + from llama_index.llms.langchain_utils import from_lc_messages + if llm is not None: # if llamaindex LLM is provided, and we require a langchain LLM, # then error. but otherwise if `requires_langchain_llm` is False, @@ -421,6 +439,8 @@ class LangchainPromptTemplate(BasePromptTemplate): return from_lc_messages(lc_messages) def get_template(self, llm: Optional[LLM] = None) -> str: + from llama_index.llms.langchain import LangChainLLM + if llm is not None: # if llamaindex LLM is provided, and we require a langchain LLM, # then error. but otherwise if `requires_langchain_llm` is False, diff --git a/llama_index/query_engine/__init__.py b/llama_index/query_engine/__init__.py index 4885b4b2bf17939c41a7a1879e918902a5af85f4..8aa6632eb1ca1777526b40c04bc92363cd867259 100644 --- a/llama_index/query_engine/__init__.py +++ b/llama_index/query_engine/__init__.py @@ -1,4 +1,4 @@ -from llama_index.indices.query.base import BaseQueryEngine +from llama_index.core import BaseQueryEngine # SQL from llama_index.indices.struct_store.sql_query import ( diff --git a/llama_index/query_engine/citation_query_engine.py b/llama_index/query_engine/citation_query_engine.py index b91aa09405c44f6102e3831651ba885bcd82d295..2268a866eb6639b89737ca58f594a2246d16390f 100644 --- a/llama_index/query_engine/citation_query_engine.py +++ b/llama_index/query_engine/citation_query_engine.py @@ -2,11 +2,10 @@ from typing import Any, List, Optional, Sequence from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.core import BaseQueryEngine, BaseRetriever from llama_index.indices.base import BaseGPTIndex -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle +from llama_index.node_parser import SentenceSplitter, TextSplitter +from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import PromptTemplate from llama_index.prompts.base import BasePromptTemplate from llama_index.prompts.mixin import PromptMixinType @@ -16,9 +15,7 @@ from llama_index.response_synthesizers import ( ResponseMode, get_response_synthesizer, ) -from llama_index.schema import MetadataMode, NodeWithScore, TextNode -from llama_index.text_splitter import get_default_text_splitter -from llama_index.text_splitter.types import TextSplitter +from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle, TextNode CITATION_QA_TEMPLATE = PromptTemplate( "Please provide an answer based solely on the provided sources. " @@ -86,7 +83,7 @@ class CitationQueryEngine(BaseQueryEngine): Size of citation chunks, default=512. Useful for controlling granularity of sources. citation_chunk_overlap (int): Overlap of citation nodes, default=20. - text_splitter (Optional[TextSplitterType]): + text_splitter (Optional[TextSplitter]): A text splitter for creating citation source nodes. Default is a SentenceSplitter. callback_manager (Optional[CallbackManager]): A callback manager. @@ -105,7 +102,7 @@ class CitationQueryEngine(BaseQueryEngine): callback_manager: Optional[CallbackManager] = None, metadata_mode: MetadataMode = MetadataMode.NONE, ) -> None: - self.text_splitter = text_splitter or get_default_text_splitter( + self.text_splitter = text_splitter or SentenceSplitter( chunk_size=citation_chunk_size, chunk_overlap=citation_chunk_overlap ) self._retriever = retriever diff --git a/llama_index/query_engine/cogniswitch_query_engine.py b/llama_index/query_engine/cogniswitch_query_engine.py index baaa369b2de8b4dabb62be96137e98ce0237b3f4..072c0512f1f51682b06ba1a867d71e59a04cf2f6 100644 --- a/llama_index/query_engine/cogniswitch_query_engine.py +++ b/llama_index/query_engine/cogniswitch_query_engine.py @@ -2,9 +2,9 @@ from typing import Any, Dict import requests -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle +from llama_index.core import BaseQueryEngine from llama_index.response.schema import Response +from llama_index.schema import QueryBundle class CogniswitchQueryEngine(BaseQueryEngine): diff --git a/llama_index/query_engine/custom.py b/llama_index/query_engine/custom.py index 67ff04bd0c3a7aabab63a259fc1e6c94dbf5ab20..fd6f1915ab6370dfb10e9f8c19c93341792a87cd 100644 --- a/llama_index/query_engine/custom.py +++ b/llama_index/query_engine/custom.py @@ -7,10 +7,10 @@ from pydantic import BaseModel from llama_index.bridge.pydantic import Field from llama_index.callbacks.base import CallbackManager -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle, QueryType +from llama_index.core import BaseQueryEngine from llama_index.prompts.mixin import PromptMixinType from llama_index.response.schema import RESPONSE_TYPE, Response +from llama_index.schema import QueryBundle, QueryType STR_OR_RESPONSE_TYPE = Union[Response, str] diff --git a/llama_index/query_engine/flare/answer_inserter.py b/llama_index/query_engine/flare/answer_inserter.py index 9a2c7893eaf16b84f1fbc57458448714fff407ee..1d996e61090ec5d728012715fd78799b4a8b225b 100644 --- a/llama_index/query_engine/flare/answer_inserter.py +++ b/llama_index/query_engine/flare/answer_inserter.py @@ -3,10 +3,10 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional -from llama_index.indices.service_context import ServiceContext from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType from llama_index.query_engine.flare.schema import QueryTask +from llama_index.service_context import ServiceContext class BaseLookaheadAnswerInserter(PromptMixin): diff --git a/llama_index/query_engine/flare/base.py b/llama_index/query_engine/flare/base.py index 169ee9f6cb787d788e231611338e85dd48386f0f..bd429b75a5b9965ba172c5c6a2dabbbcf88f75f2 100644 --- a/llama_index/query_engine/flare/base.py +++ b/llama_index/query_engine/flare/base.py @@ -7,9 +7,7 @@ Active Retrieval Augmented Generation. from typing import Any, Dict, Optional from llama_index.callbacks.base import CallbackManager -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseQueryEngine from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.query_engine.flare.answer_inserter import ( @@ -21,6 +19,8 @@ from llama_index.query_engine.flare.output_parser import ( QueryTaskOutputParser, ) from llama_index.response.schema import RESPONSE_TYPE, Response +from llama_index.schema import QueryBundle +from llama_index.service_context import ServiceContext from llama_index.utils import print_text # These prompts are taken from the FLARE repo: diff --git a/llama_index/query_engine/graph_query_engine.py b/llama_index/query_engine/graph_query_engine.py index f325af8b16ad3b5cae1026a328e609cb259d7273..98b594724e1c121c5c19de28f569e4b3317430b6 100644 --- a/llama_index/query_engine/graph_query_engine.py +++ b/llama_index/query_engine/graph_query_engine.py @@ -1,11 +1,10 @@ from typing import Any, Dict, List, Optional, Tuple from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.core import BaseQueryEngine from llama_index.indices.composability.graph import ComposableGraph -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle from llama_index.response.schema import RESPONSE_TYPE -from llama_index.schema import IndexNode, NodeWithScore, TextNode +from llama_index.schema import IndexNode, NodeWithScore, QueryBundle, TextNode class ComposableGraphQueryEngine(BaseQueryEngine): diff --git a/llama_index/query_engine/knowledge_graph_query_engine.py b/llama_index/query_engine/knowledge_graph_query_engine.py index f0ecba74f0e39d43ce07d347fb5d707f7e8f1cbe..afee41f0f53e3397d8e35c78ca8cee2a460c81ac 100644 --- a/llama_index/query_engine/knowledge_graph_query_engine.py +++ b/llama_index/query_engine/knowledge_graph_query_engine.py @@ -4,18 +4,17 @@ import logging from typing import Any, Dict, List, Optional, Sequence from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.core import BaseQueryEngine from llama_index.graph_stores.registry import ( GRAPH_STORE_CLASS_TO_GRAPH_STORE_TYPE, GraphStoreType, ) -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.prompts.base import BasePromptTemplate, PromptTemplate, PromptType from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer -from llama_index.schema import NodeWithScore, TextNode +from llama_index.schema import NodeWithScore, QueryBundle, TextNode +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext from llama_index.utils import print_text diff --git a/llama_index/query_engine/multi_modal.py b/llama_index/query_engine/multi_modal.py index 004695d236a673c4c3de91b59562898e3837a256..305eb25ca0fd933deae23c92a9f873d2a281303e 100644 --- a/llama_index/query_engine/multi_modal.py +++ b/llama_index/query_engine/multi_modal.py @@ -3,11 +3,11 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.indices.multi_modal import MultiModalVectorIndexRetriever -from llama_index.indices.postprocessor.types import BaseNodePostprocessor from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle from llama_index.multi_modal_llms.base import MultiModalLLM from llama_index.multi_modal_llms.openai import OpenAIMultiModal +from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT from llama_index.prompts.mixin import PromptMixinType diff --git a/llama_index/query_engine/multistep_query_engine.py b/llama_index/query_engine/multistep_query_engine.py index f106c8f3a654d4e1ad3096d5ac375eea42d10b69..fc875ce65ad6c7de1d95db699388a215b7b78df7 100644 --- a/llama_index/query_engine/multistep_query_engine.py +++ b/llama_index/query_engine/multistep_query_engine.py @@ -1,13 +1,12 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, cast from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.indices.query.base import BaseQueryEngine +from llama_index.core import BaseQueryEngine from llama_index.indices.query.query_transform.base import StepDecomposeQueryTransform -from llama_index.indices.query.schema import QueryBundle from llama_index.prompts.mixin import PromptMixinType from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer -from llama_index.schema import NodeWithScore, TextNode +from llama_index.schema import NodeWithScore, QueryBundle, TextNode def default_stop_fn(stop_dict: Dict) -> bool: diff --git a/llama_index/query_engine/pandas_query_engine.py b/llama_index/query_engine/pandas_query_engine.py index dd2f2c69e36c00f675e8c0a2e74439af6ca7b06c..961fcc9ac51790dc881d78310e7bb3671e5e2c5e 100644 --- a/llama_index/query_engine/pandas_query_engine.py +++ b/llama_index/query_engine/pandas_query_engine.py @@ -13,14 +13,14 @@ from typing import Any, Callable, Optional import numpy as np import pandas as pd -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseQueryEngine from llama_index.indices.struct_store.pandas import PandasIndex from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_PANDAS_PROMPT from llama_index.prompts.mixin import PromptMixinType from llama_index.response.schema import Response +from llama_index.schema import QueryBundle +from llama_index.service_context import ServiceContext from llama_index.utils import print_text logger = logging.getLogger(__name__) diff --git a/llama_index/query_engine/retriever_query_engine.py b/llama_index/query_engine/retriever_query_engine.py index 3ab4561f2befd8566e2cb2be36d18c848bf2d2be..937663b38add16d8757f669ab0c67e69305a3626 100644 --- a/llama_index/query_engine/retriever_query_engine.py +++ b/llama_index/query_engine/retriever_query_engine.py @@ -3,11 +3,8 @@ from typing import Any, List, Optional, Sequence from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.postprocessor.types import BaseNodePostprocessor -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import BasePromptTemplate from llama_index.prompts.mixin import PromptMixinType from llama_index.response.schema import RESPONSE_TYPE @@ -16,7 +13,8 @@ from llama_index.response_synthesizers import ( ResponseMode, get_response_synthesizer, ) -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext class RetrieverQueryEngine(BaseQueryEngine): diff --git a/llama_index/query_engine/retry_query_engine.py b/llama_index/query_engine/retry_query_engine.py index e9d6eb35fb42e12c9e231ecda3d4c91d747a1ec8..7a7b20fdb87827d12281e67087c4dcc95277bb9b 100644 --- a/llama_index/query_engine/retry_query_engine.py +++ b/llama_index/query_engine/retry_query_engine.py @@ -2,15 +2,15 @@ import logging from typing import Optional from llama_index.callbacks.base import CallbackManager +from llama_index.core import BaseQueryEngine from llama_index.evaluation.base import BaseEvaluator from llama_index.evaluation.guideline import GuidelineEvaluator -from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.query_transform.feedback_transform import ( FeedbackQueryTransformation, ) -from llama_index.indices.query.schema import QueryBundle from llama_index.prompts.mixin import PromptMixinType from llama_index.response.schema import RESPONSE_TYPE, Response +from llama_index.schema import QueryBundle logger = logging.getLogger(__name__) diff --git a/llama_index/query_engine/retry_source_query_engine.py b/llama_index/query_engine/retry_source_query_engine.py index 3e3661930e3d7745ad10b15b79999b7e78e4974c..13be39f1377702cae21d46dc6d36a274af37ba41 100644 --- a/llama_index/query_engine/retry_source_query_engine.py +++ b/llama_index/query_engine/retry_source_query_engine.py @@ -2,15 +2,14 @@ import logging from typing import Optional from llama_index.callbacks.base import CallbackManager +from llama_index.core import BaseQueryEngine from llama_index.evaluation import BaseEvaluator from llama_index.indices.list.base import SummaryIndex -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.prompts.mixin import PromptMixinType from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine from llama_index.response.schema import RESPONSE_TYPE, Response -from llama_index.schema import Document +from llama_index.schema import Document, QueryBundle +from llama_index.service_context import ServiceContext logger = logging.getLogger(__name__) diff --git a/llama_index/query_engine/router_query_engine.py b/llama_index/query_engine/router_query_engine.py index 5ecfdf1991ba7232f1a7a5e5a5362e4fae480749..04242a6d091d0f9a8a84f56106a29c246cebd0ed 100644 --- a/llama_index/query_engine/router_query_engine.py +++ b/llama_index/query_engine/router_query_engine.py @@ -5,10 +5,7 @@ from llama_index.async_utils import run_async_tasks from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseQueryEngine, BaseRetriever from llama_index.objects.base import ObjectRetriever from llama_index.prompts.default_prompt_selectors import ( DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, @@ -21,9 +18,10 @@ from llama_index.response.schema import ( StreamingResponse, ) from llama_index.response_synthesizers import TreeSummarize -from llama_index.schema import BaseNode +from llama_index.schema import BaseNode, QueryBundle from llama_index.selectors.types import BaseSelector from llama_index.selectors.utils import get_selector_from_context +from llama_index.service_context import ServiceContext from llama_index.tools.query_engine import QueryEngineTool from llama_index.tools.types import ToolMetadata diff --git a/llama_index/query_engine/sql_join_query_engine.py b/llama_index/query_engine/sql_join_query_engine.py index a0433c7d09a75fb108b78922c78c01631e5bec49..30d06ee2b55d8e60e4f8966455caec06d353cd6f 100644 --- a/llama_index/query_engine/sql_join_query_engine.py +++ b/llama_index/query_engine/sql_join_query_engine.py @@ -4,10 +4,8 @@ import logging from typing import Callable, Dict, Optional, Union from llama_index.callbacks.base import CallbackManager -from llama_index.indices.query.base import BaseQueryEngine +from llama_index.core import BaseQueryEngine from llama_index.indices.query.query_transform.base import BaseQueryTransform -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.indices.struct_store.sql_query import ( BaseSQLTableQueryEngine, NLSQLTableQueryEngine, @@ -17,9 +15,11 @@ from llama_index.llm_predictor.base import BaseLLMPredictor from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.response.schema import RESPONSE_TYPE, Response +from llama_index.schema import QueryBundle from llama_index.selectors.llm_selectors import LLMSingleSelector from llama_index.selectors.pydantic_selectors import PydanticSingleSelector from llama_index.selectors.utils import get_selector_from_context +from llama_index.service_context import ServiceContext from llama_index.tools.query_engine import QueryEngineTool from llama_index.utils import print_text diff --git a/llama_index/query_engine/sql_vector_query_engine.py b/llama_index/query_engine/sql_vector_query_engine.py index 230670366b60236ba0a7526ebaf9a3b984562933..0f75178426d1c963ac32d4ed81af9188e84e8394 100644 --- a/llama_index/query_engine/sql_vector_query_engine.py +++ b/llama_index/query_engine/sql_vector_query_engine.py @@ -4,7 +4,6 @@ import logging from typing import Any, Optional, Union from llama_index.callbacks.base import CallbackManager -from llama_index.indices.service_context import ServiceContext from llama_index.indices.struct_store.sql_query import ( BaseSQLTableQueryEngine, NLSQLTableQueryEngine, @@ -21,6 +20,7 @@ from llama_index.query_engine.sql_join_query_engine import ( ) from llama_index.selectors.llm_selectors import LLMSingleSelector from llama_index.selectors.pydantic_selectors import PydanticSingleSelector +from llama_index.service_context import ServiceContext from llama_index.tools.query_engine import QueryEngineTool logger = logging.getLogger(__name__) diff --git a/llama_index/query_engine/sub_question_query_engine.py b/llama_index/query_engine/sub_question_query_engine.py index ef8adf932147484e92761923820807b46aca1950..d564520384ce4d73baa645b1e6229c65170acd4f 100644 --- a/llama_index/query_engine/sub_question_query_engine.py +++ b/llama_index/query_engine/sub_question_query_engine.py @@ -6,16 +6,15 @@ from llama_index.async_utils import run_async_tasks from llama_index.bridge.pydantic import BaseModel, Field from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseQueryEngine from llama_index.prompts.mixin import PromptMixinType from llama_index.question_gen.llm_generators import LLMQuestionGenerator from llama_index.question_gen.openai_generator import OpenAIQuestionGenerator from llama_index.question_gen.types import BaseQuestionGenerator, SubQuestion from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer -from llama_index.schema import NodeWithScore, TextNode +from llama_index.schema import NodeWithScore, QueryBundle, TextNode +from llama_index.service_context import ServiceContext from llama_index.tools.query_engine import QueryEngineTool from llama_index.utils import get_color_mapping, print_text diff --git a/llama_index/query_engine/transform_query_engine.py b/llama_index/query_engine/transform_query_engine.py index fffdadf1453b22f2266d26e5f3b07fe81da1da5f..219d8ecf7e9b117ea23d61f4a1e2476cdd0e5d78 100644 --- a/llama_index/query_engine/transform_query_engine.py +++ b/llama_index/query_engine/transform_query_engine.py @@ -1,12 +1,11 @@ from typing import List, Optional, Sequence from llama_index.callbacks.base import CallbackManager -from llama_index.indices.query.base import BaseQueryEngine +from llama_index.core import BaseQueryEngine from llama_index.indices.query.query_transform.base import BaseQueryTransform -from llama_index.indices.query.schema import QueryBundle from llama_index.prompts.mixin import PromptMixinType from llama_index.response.schema import RESPONSE_TYPE -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle class TransformQueryEngine(BaseQueryEngine): diff --git a/llama_index/question_gen/guidance_generator.py b/llama_index/question_gen/guidance_generator.py index d78732bd4abcab81897077f5b4fdcdb3044252e8..e031ec903f149e2cb98b7df7b2cda6a1705a8f34 100644 --- a/llama_index/question_gen/guidance_generator.py +++ b/llama_index/question_gen/guidance_generator.py @@ -1,6 +1,5 @@ from typing import TYPE_CHECKING, List, Optional, Sequence, cast -from llama_index.indices.query.schema import QueryBundle from llama_index.program.guidance_program import GuidancePydanticProgram from llama_index.prompts.guidance_utils import convert_to_handlebars from llama_index.prompts.mixin import PromptDictType @@ -13,6 +12,7 @@ from llama_index.question_gen.types import ( SubQuestion, SubQuestionList, ) +from llama_index.schema import QueryBundle from llama_index.tools.types import ToolMetadata if TYPE_CHECKING: diff --git a/llama_index/question_gen/llm_generators.py b/llama_index/question_gen/llm_generators.py index 63a7501edb16cc3877942cd27f31e2aafa92267a..18b68fd9fafa3ba6c2a0351fd2b0188a017bc080 100644 --- a/llama_index/question_gen/llm_generators.py +++ b/llama_index/question_gen/llm_generators.py @@ -1,7 +1,5 @@ from typing import List, Optional, Sequence, cast -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor.base import BaseLLMPredictor from llama_index.output_parsers.base import StructuredOutput from llama_index.prompts.base import BasePromptTemplate, PromptTemplate @@ -13,6 +11,8 @@ from llama_index.question_gen.prompts import ( build_tools_text, ) from llama_index.question_gen.types import BaseQuestionGenerator, SubQuestion +from llama_index.schema import QueryBundle +from llama_index.service_context import ServiceContext from llama_index.tools.types import ToolMetadata from llama_index.types import BaseOutputParser diff --git a/llama_index/question_gen/openai_generator.py b/llama_index/question_gen/openai_generator.py index 62118a88e8e855a7b9c786b09a6f79ec2469789c..c461d32c1c0aefa56ab08064b7708cdc2881a632 100644 --- a/llama_index/question_gen/openai_generator.py +++ b/llama_index/question_gen/openai_generator.py @@ -1,6 +1,5 @@ from typing import List, Optional, Sequence, cast -from llama_index.indices.query.schema import QueryBundle from llama_index.llms.base import LLM from llama_index.llms.openai import OpenAI from llama_index.program.openai_program import OpenAIPydanticProgram @@ -11,6 +10,7 @@ from llama_index.question_gen.types import ( SubQuestion, SubQuestionList, ) +from llama_index.schema import QueryBundle from llama_index.tools.types import ToolMetadata DEFAULT_MODEL_NAME = "gpt-3.5-turbo-0613" diff --git a/llama_index/question_gen/types.py b/llama_index/question_gen/types.py index 0cf3cc7fd8335f5e07972b6556b2b18ada4eafa0..b673447c715d4c18cb4e8951917850f07dfc792f 100644 --- a/llama_index/question_gen/types.py +++ b/llama_index/question_gen/types.py @@ -2,8 +2,8 @@ from abc import abstractmethod from typing import List, Sequence from llama_index.bridge.pydantic import BaseModel -from llama_index.indices.query.schema import QueryBundle from llama_index.prompts.mixin import PromptMixin, PromptMixinType +from llama_index.schema import QueryBundle from llama_index.tools.types import ToolMetadata diff --git a/llama_index/readers/__init__.py b/llama_index/readers/__init__.py index 1487e9664bc77bcd9d1122d42c4a234bdd2b68e2..e90cd2cb5f760d1ca2517aa95584ac83d3faa5aa 100644 --- a/llama_index/readers/__init__.py +++ b/llama_index/readers/__init__.py @@ -11,6 +11,7 @@ definition of a Document - the bare minimum is a `text` property. """ from llama_index.readers.bagel import BagelReader +from llama_index.readers.base import ReaderConfig from llama_index.readers.chatgpt_plugin import ChatGPTRetrievalPluginReader from llama_index.readers.chroma import ChromaReader from llama_index.readers.dashvector import DashVectorReader @@ -91,6 +92,7 @@ __all__ = [ "ChatGPTRetrievalPluginReader", "BagelReader", "HTMLTagReader", + "ReaderConfig", "PDFReader", "DashVectorReader", "download_loader", diff --git a/llama_index/readers/base.py b/llama_index/readers/base.py index 7fe1fd795ea628466d2b691ffd80dcfdce3199f3..4c56c5eb93b35c392fd8fa4cc03e1f2825fb3dd5 100644 --- a/llama_index/readers/base.py +++ b/llama_index/readers/base.py @@ -1,8 +1,9 @@ """Base reader class.""" from abc import ABC -from typing import Any, Dict, Iterable, List +from typing import TYPE_CHECKING, Any, Dict, Iterable, List -from llama_index.bridge.langchain import Document as LCDocument +if TYPE_CHECKING: + from llama_index.bridge.langchain import Document as LCDocument from llama_index.bridge.pydantic import Field from llama_index.schema import BaseComponent, Document @@ -20,7 +21,7 @@ class BaseReader(ABC): """Load data from the input directory.""" return list(self.lazy_load_data(*args, **load_kwargs)) - def load_langchain_documents(self, **load_kwargs: Any) -> List[LCDocument]: + def load_langchain_documents(self, **load_kwargs: Any) -> List["LCDocument"]: """Load data in LangChain document format.""" docs = self.load_data(**load_kwargs) return [d.to_langchain_format() for d in docs] @@ -39,12 +40,12 @@ class BasePydanticReader(BaseReader, BaseComponent): class ReaderConfig(BaseComponent): - """Represents a loader and it's input arguments.""" + """Represents a reader and it's input arguments.""" - loader: BaseReader = Field(..., description="Loader to use.") - loader_args: List[Any] = Field(default_factor=list, description="Loader args.") - loader_kwargs: Dict[str, Any] = Field( - default_factory=dict, description="Loader kwargs." + reader: BasePydanticReader = Field(..., description="Reader to use.") + reader_args: List[Any] = Field(default_factory=list, description="Reader args.") + reader_kwargs: Dict[str, Any] = Field( + default_factory=dict, description="Reader kwargs." ) class Config: @@ -52,4 +53,18 @@ class ReaderConfig(BaseComponent): @classmethod def class_name(cls) -> str: - return "LoaderConfig" + """Get the name identifier of the class.""" + return "ReaderConfig" + + def to_dict(self, **kwargs: Any) -> Dict[str, Any]: + """Convert the class to a dictionary.""" + return { + "loader": self.reader.to_dict(**kwargs), + "reader_args": self.reader_args, + "reader_kwargs": self.reader_kwargs, + "class_name": self.class_name(), + } + + def read(self) -> List[Document]: + """Call the loader with the given arguments.""" + return self.reader.load_data(*self.reader_args, **self.reader_kwargs) diff --git a/llama_index/readers/file/base.py b/llama_index/readers/file/base.py index 27278d1d2d764754b776f2e86fc49cda62ecff8f..1df39de12537bf39ba9eee07f067ccf2afe8074b 100644 --- a/llama_index/readers/file/base.py +++ b/llama_index/readers/file/base.py @@ -1,5 +1,6 @@ """Simple reader that reads files of different formats from a directory.""" import logging +import mimetypes import os from datetime import datetime from pathlib import Path @@ -45,6 +46,9 @@ def default_file_metadata_func(file_path: str) -> Dict: """ return { "file_path": file_path, + "file_name": os.path.basename(file_path), + "file_type": mimetypes.guess_type(file_path)[0], + "file_size": os.path.getsize(file_path), "creation_date": datetime.fromtimestamp( Path(file_path).stat().st_ctime ).strftime("%Y-%m-%d"), @@ -270,6 +274,9 @@ class SimpleDirectoryReader(BaseReader): # TimeWeightedPostprocessor, but excluded for embedding and LLMprompts doc.excluded_embed_metadata_keys.extend( [ + "file_name", + "file_type", + "file_size", "creation_date", "last_modified_date", "last_accessed_date", @@ -277,6 +284,9 @@ class SimpleDirectoryReader(BaseReader): ) doc.excluded_llm_metadata_keys.extend( [ + "file_name", + "file_type", + "file_size", "creation_date", "last_modified_date", "last_accessed_date", diff --git a/llama_index/readers/loading.py b/llama_index/readers/loading.py index 30c7eb0b857915c7a93347973660129de508211f..ff1cdbcac899e538736e3f1a0f4bf30ef7da5b94 100644 --- a/llama_index/readers/loading.py +++ b/llama_index/readers/loading.py @@ -37,6 +37,8 @@ ALL_READERS: Dict[str, Type[BasePydanticReader]] = { def load_reader(data: Dict[str, Any]) -> BasePydanticReader: + if isinstance(data, BasePydanticReader): + return data class_name = data.get("class_name", None) if class_name is None: raise ValueError("Must specify `class_name` in reader data.") diff --git a/llama_index/readers/obsidian.py b/llama_index/readers/obsidian.py index bc21896904af0028ad705a3fe00d5aeb59efbfa9..b5b5b2d5673f5c17ed79f35c31c58418b945119f 100644 --- a/llama_index/readers/obsidian.py +++ b/llama_index/readers/obsidian.py @@ -9,7 +9,6 @@ import os from pathlib import Path from typing import Any, List -from llama_index.bridge.langchain import Document as LCDocument from llama_index.readers.base import BaseReader from llama_index.readers.file.markdown_reader import MarkdownReader from llama_index.schema import Document @@ -38,8 +37,3 @@ class ObsidianReader(BaseReader): content = MarkdownReader().load_data(Path(filepath)) docs.extend(content) return docs - - def load_langchain_documents(self, **load_kwargs: Any) -> List[LCDocument]: - """Load data in LangChain document format.""" - docs = self.load_data(**load_kwargs) - return [d.to_langchain_format() for d in docs] diff --git a/llama_index/response_synthesizers/accumulate.py b/llama_index/response_synthesizers/accumulate.py index 550a5e33f79a345c0d9a004a3ddec5417730f83d..b56fc21f2da4b24f16161d4fb4943f2d9c46812b 100644 --- a/llama_index/response_synthesizers/accumulate.py +++ b/llama_index/response_synthesizers/accumulate.py @@ -2,13 +2,13 @@ import asyncio from typing import Any, List, Optional, Sequence from llama_index.async_utils import run_async_tasks -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import ( DEFAULT_TEXT_QA_PROMPT_SEL, ) from llama_index.prompts.mixin import PromptDictType from llama_index.response_synthesizers.base import BaseSynthesizer +from llama_index.service_context import ServiceContext from llama_index.types import RESPONSE_TEXT_TYPE diff --git a/llama_index/response_synthesizers/base.py b/llama_index/response_synthesizers/base.py index 4b72edb67371e8bb05717193492e4413a65600a5..9f77d4a5871b8013e4f5d35004f85e8b939a3e46 100644 --- a/llama_index/response_synthesizers/base.py +++ b/llama_index/response_synthesizers/base.py @@ -13,8 +13,6 @@ from typing import Any, Dict, Generator, List, Optional, Sequence, Union from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.prompts.mixin import PromptMixin from llama_index.response.schema import ( RESPONSE_TYPE, @@ -22,7 +20,8 @@ from llama_index.response.schema import ( Response, StreamingResponse, ) -from llama_index.schema import BaseNode, MetadataMode, NodeWithScore +from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle +from llama_index.service_context import ServiceContext from llama_index.types import RESPONSE_TEXT_TYPE logger = logging.getLogger(__name__) diff --git a/llama_index/response_synthesizers/factory.py b/llama_index/response_synthesizers/factory.py index 9164f2148f083aae5ee9783b3ac2e26eb097373d..bbfbaea6560fae75faacb831e8ba9550bfac556f 100644 --- a/llama_index/response_synthesizers/factory.py +++ b/llama_index/response_synthesizers/factory.py @@ -2,7 +2,6 @@ from typing import Callable, Optional from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.base import CallbackManager -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import ( DEFAULT_REFINE_PROMPT_SEL, @@ -23,6 +22,7 @@ from llama_index.response_synthesizers.refine import Refine from llama_index.response_synthesizers.simple_summarize import SimpleSummarize from llama_index.response_synthesizers.tree_summarize import TreeSummarize from llama_index.response_synthesizers.type import ResponseMode +from llama_index.service_context import ServiceContext from llama_index.types import BasePydanticProgram diff --git a/llama_index/response_synthesizers/generation.py b/llama_index/response_synthesizers/generation.py index 5ea5f96aed6b20b6eb4d15b116a1bc7c9750297e..825c282f75c8e188d9e834bee6dcbf87ca584915 100644 --- a/llama_index/response_synthesizers/generation.py +++ b/llama_index/response_synthesizers/generation.py @@ -1,10 +1,10 @@ from typing import Any, Optional, Sequence -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT from llama_index.prompts.mixin import PromptDictType from llama_index.response_synthesizers.base import BaseSynthesizer +from llama_index.service_context import ServiceContext from llama_index.types import RESPONSE_TEXT_TYPE diff --git a/llama_index/response_synthesizers/refine.py b/llama_index/response_synthesizers/refine.py index 06ab3e051ea57231251366b78949c4c1c410b30f..b031a758edfbef5dd063df5934c2b98c7c9ae243 100644 --- a/llama_index/response_synthesizers/refine.py +++ b/llama_index/response_synthesizers/refine.py @@ -2,7 +2,6 @@ import logging from typing import Any, Callable, Generator, Optional, Sequence, Type, cast from llama_index.bridge.pydantic import BaseModel, Field, ValidationError -from llama_index.indices.service_context import ServiceContext from llama_index.indices.utils import truncate_text from llama_index.llm_predictor.base import BaseLLMPredictor from llama_index.prompts.base import BasePromptTemplate, PromptTemplate @@ -13,6 +12,7 @@ from llama_index.prompts.default_prompt_selectors import ( from llama_index.prompts.mixin import PromptDictType from llama_index.response.utils import get_response_text from llama_index.response_synthesizers.base import BaseSynthesizer +from llama_index.service_context import ServiceContext from llama_index.types import RESPONSE_TEXT_TYPE, BasePydanticProgram logger = logging.getLogger(__name__) diff --git a/llama_index/response_synthesizers/simple_summarize.py b/llama_index/response_synthesizers/simple_summarize.py index 8ab4dd1bc47798d20e3525c42093d7ff9f8a5d2b..0930729a27e89f6f753e8b63da0d623deea411f6 100644 --- a/llama_index/response_synthesizers/simple_summarize.py +++ b/llama_index/response_synthesizers/simple_summarize.py @@ -1,10 +1,10 @@ from typing import Any, Generator, Optional, Sequence, cast -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import DEFAULT_TEXT_QA_PROMPT_SEL from llama_index.prompts.mixin import PromptDictType from llama_index.response_synthesizers.base import BaseSynthesizer +from llama_index.service_context import ServiceContext from llama_index.types import RESPONSE_TEXT_TYPE diff --git a/llama_index/response_synthesizers/tree_summarize.py b/llama_index/response_synthesizers/tree_summarize.py index 77d9a15908344485a607ebb3db380c07eb0fe3a3..773726d70faf047ee38c690022a79592697040dc 100644 --- a/llama_index/response_synthesizers/tree_summarize.py +++ b/llama_index/response_synthesizers/tree_summarize.py @@ -2,13 +2,13 @@ import asyncio from typing import Any, List, Optional, Sequence from llama_index.async_utils import run_async_tasks -from llama_index.indices.service_context import ServiceContext from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import ( DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, ) from llama_index.prompts.mixin import PromptDictType from llama_index.response_synthesizers.base import BaseSynthesizer +from llama_index.service_context import ServiceContext from llama_index.types import RESPONSE_TEXT_TYPE, BaseModel diff --git a/llama_index/retrievers/__init__.py b/llama_index/retrievers/__init__.py index 47236851361678f44555d9f5ca38f80f8312d78f..be6aee61f567c9c670fec5cf97c56dd50a581d3b 100644 --- a/llama_index/retrievers/__init__.py +++ b/llama_index/retrievers/__init__.py @@ -1,4 +1,4 @@ -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.empty.retrievers import EmptyIndexRetriever from llama_index.indices.keyword_table.retrievers import KeywordTableSimpleRetriever from llama_index.indices.knowledge_graph.retrievers import ( diff --git a/llama_index/retrievers/auto_merging_retriever.py b/llama_index/retrievers/auto_merging_retriever.py index 197c52555842b7932e905dc81da698b9d72c420a..365df24d769ddc6f2bde8f111902d122a973ee27 100644 --- a/llama_index/retrievers/auto_merging_retriever.py +++ b/llama_index/retrievers/auto_merging_retriever.py @@ -4,11 +4,10 @@ import logging from collections import defaultdict from typing import Dict, List, Tuple, cast -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle +from llama_index.core import BaseRetriever from llama_index.indices.utils import truncate_text from llama_index.indices.vector_store.retrievers.retriever import VectorIndexRetriever -from llama_index.schema import BaseNode, NodeWithScore +from llama_index.schema import BaseNode, NodeWithScore, QueryBundle from llama_index.storage.storage_context import StorageContext logger = logging.getLogger(__name__) diff --git a/llama_index/retrievers/bm25_retriever.py b/llama_index/retrievers/bm25_retriever.py index 5af5689b3c1e21cd25def8ee598cf5b45dd5f1f2..bc59da465a2ffca412617e4506feddac5b9d82f0 100644 --- a/llama_index/retrievers/bm25_retriever.py +++ b/llama_index/retrievers/bm25_retriever.py @@ -2,12 +2,11 @@ import logging from typing import Callable, List, Optional, cast from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle +from llama_index.core import BaseRetriever from llama_index.indices.vector_store.base import VectorStoreIndex -from llama_index.schema import BaseNode, NodeWithScore +from llama_index.schema import BaseNode, NodeWithScore, QueryBundle from llama_index.storage.docstore.types import BaseDocumentStore -from llama_index.utils import globals_helper +from llama_index.utils import get_tokenizer logger = logging.getLogger(__name__) @@ -54,7 +53,7 @@ class BM25Retriever(BaseRetriever): nodes is not None ), "Please pass exactly one of index, nodes, or docstore." - tokenizer = tokenizer or globals_helper.tokenizer + tokenizer = tokenizer or get_tokenizer() return cls( nodes=nodes, tokenizer=tokenizer, diff --git a/llama_index/retrievers/fusion_retriever.py b/llama_index/retrievers/fusion_retriever.py index c702c17e535671e22d7bf2e34304867f23c997fe..51603ef63c67b50466de428c03cd2c26af3ea549 100644 --- a/llama_index/retrievers/fusion_retriever.py +++ b/llama_index/retrievers/fusion_retriever.py @@ -4,10 +4,9 @@ from typing import Dict, List, Optional, Tuple from llama_index.async_utils import run_async_tasks from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.indices.query.schema import QueryBundle from llama_index.llms.utils import LLMType, resolve_llm from llama_index.retrievers import BaseRetriever -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle QUERY_GEN_PROMPT = ( "You are a helpful assistant that generates multiple search queries based on a " diff --git a/llama_index/retrievers/recursive_retriever.py b/llama_index/retrievers/recursive_retriever.py index e5b1158bcb898d2497d2358c260f6fde8634abca..e4e031d8404eb52f79218a1bd60526f76b6e173a 100644 --- a/llama_index/retrievers/recursive_retriever.py +++ b/llama_index/retrievers/recursive_retriever.py @@ -2,10 +2,8 @@ from typing import Dict, List, Optional, Tuple, Union from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import BaseNode, IndexNode, NodeWithScore, TextNode +from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle, TextNode from llama_index.utils import print_text DEFAULT_QUERY_RESPONSE_TMPL = "Query: {query_str}\nResponse: {response}" diff --git a/llama_index/retrievers/router_retriever.py b/llama_index/retrievers/router_retriever.py index 8c861afc250745ae7496f5e6f03a692fb7f257bb..a57b43ee358c6eb3212722da56918b7d16bbe8d0 100644 --- a/llama_index/retrievers/router_retriever.py +++ b/llama_index/retrievers/router_retriever.py @@ -5,13 +5,12 @@ import logging from typing import List, Optional, Sequence from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseRetriever from llama_index.prompts.mixin import PromptMixinType -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle from llama_index.selectors.types import BaseSelector from llama_index.selectors.utils import get_selector_from_context +from llama_index.service_context import ServiceContext from llama_index.tools.retriever_tool import RetrieverTool logger = logging.getLogger(__name__) diff --git a/llama_index/retrievers/transform_retriever.py b/llama_index/retrievers/transform_retriever.py index a7d4a3a5716e476aa728e962d2dc6f954843941d..0dd08176ccf3ac57fcbb38bef94a3678dfba87c8 100644 --- a/llama_index/retrievers/transform_retriever.py +++ b/llama_index/retrievers/transform_retriever.py @@ -1,10 +1,9 @@ from typing import List, Optional -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.query.query_transform.base import BaseQueryTransform -from llama_index.indices.query.schema import QueryBundle from llama_index.prompts.mixin import PromptMixinType -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle class TransformRetriever(BaseRetriever): diff --git a/llama_index/retrievers/you_retriever.py b/llama_index/retrievers/you_retriever.py index 21fdfdd1d742a0ec261007e70d2db4500180da73..a964863a185ebc79221f442a47cb3559642f5497 100644 --- a/llama_index/retrievers/you_retriever.py +++ b/llama_index/retrievers/you_retriever.py @@ -6,9 +6,8 @@ from typing import List, Optional import requests -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import NodeWithScore, TextNode +from llama_index.core import BaseRetriever +from llama_index.schema import NodeWithScore, QueryBundle, TextNode logger = logging.getLogger(__name__) diff --git a/llama_index/schema.py b/llama_index/schema.py index 33a737ea3aed871f5f5d94025c157d12e638096a..da104e1cb7404aa6ee804fce252c34e9ec866fe7 100644 --- a/llama_index/schema.py +++ b/llama_index/schema.py @@ -3,11 +3,13 @@ import json import textwrap import uuid from abc import abstractmethod +from dataclasses import dataclass from enum import Enum, auto from hashlib import sha256 from io import BytesIO from typing import TYPE_CHECKING, Any, Dict, List, Optional, Union +from dataclasses_json import DataClassJsonMixin from typing_extensions import Self from llama_index.bridge.pydantic import BaseModel, Field, root_validator @@ -32,8 +34,17 @@ ImageType = Union[str, BytesIO] class BaseComponent(BaseModel): """Base component object to capture class names.""" + class Config: + @staticmethod + def schema_extra(schema: Dict[str, Any], model: "BaseComponent") -> None: + """Add class name to schema.""" + schema["properties"]["class_name"] = { + "title": "Class Name", + "type": "string", + "default": model.class_name(), + } + @classmethod - @abstractmethod def class_name(cls) -> str: """ Get the class name, used as a unique ID in serialization. @@ -41,6 +52,15 @@ class BaseComponent(BaseModel): This provides a key that makes serialization robust against actual class name changes. """ + return "base_component" + + def json(self, **kwargs: Any) -> str: + return self.to_json(**kwargs) + + def dict(self, **kwargs: Any) -> Dict[str, Any]: + data = super().dict(**kwargs) + data["class_name"] = self.class_name() + return data def __getstate__(self) -> Dict[str, Any]: state = super().__getstate__() @@ -85,6 +105,21 @@ class BaseComponent(BaseModel): return cls.from_dict(data, **kwargs) +class TransformComponent(BaseComponent): + """Base class for transform components.""" + + class Config: + arbitrary_types_allowed = True + + @abstractmethod + def __call__(self, nodes: List["BaseNode"], **kwargs: Any) -> List["BaseNode"]: + """Transform nodes.""" + + async def acall(self, nodes: List["BaseNode"], **kwargs: Any) -> List["BaseNode"]: + """Async transform nodes.""" + return self.__call__(nodes, **kwargs) + + class NodeRelationship(str, Enum): """Node relationships used in `BaseNode` class. @@ -658,3 +693,36 @@ class ImageDocument(Document, ImageNode): @classmethod def class_name(cls) -> str: return "ImageDocument" + + +@dataclass +class QueryBundle(DataClassJsonMixin): + """ + Query bundle. + + This dataclass contains the original query string and associated transformations. + + Args: + query_str (str): the original user-specified query string. + This is currently used by all non embedding-based queries. + embedding_strs (list[str]): list of strings used for embedding the query. + This is currently used by all embedding-based queries. + embedding (list[float]): the stored embedding for the query. + """ + + query_str: str + custom_embedding_strs: Optional[List[str]] = None + embedding: Optional[List[float]] = None + + @property + def embedding_strs(self) -> List[str]: + """Use custom embedding strs if specified, otherwise use query str.""" + if self.custom_embedding_strs is None: + if len(self.query_str) == 0: + return [] + return [self.query_str] + else: + return self.custom_embedding_strs + + +QueryType = Union[str, QueryBundle] diff --git a/llama_index/selectors/embedding_selectors.py b/llama_index/selectors/embedding_selectors.py index 9e18f86d96f6aa98fc4dffc9301a622e41ea1e50..a33e8f87159329082f6f7171b71228c599aa0841 100644 --- a/llama_index/selectors/embedding_selectors.py +++ b/llama_index/selectors/embedding_selectors.py @@ -3,8 +3,8 @@ from typing import Any, Dict, Optional, Sequence from llama_index.embeddings.base import BaseEmbedding from llama_index.embeddings.utils import resolve_embed_model from llama_index.indices.query.embedding_utils import get_top_k_embeddings -from llama_index.indices.query.schema import QueryBundle from llama_index.prompts.mixin import PromptDictType +from llama_index.schema import QueryBundle from llama_index.selectors.types import ( BaseSelector, SelectorResult, diff --git a/llama_index/selectors/llm_selectors.py b/llama_index/selectors/llm_selectors.py index 1a56d0ccb50fe660cd6b3822f7214d36dfe24652..e4a2425491e9d7d615176d0cfcc0f66de5143cf5 100644 --- a/llama_index/selectors/llm_selectors.py +++ b/llama_index/selectors/llm_selectors.py @@ -1,12 +1,11 @@ from typing import Any, Dict, List, Optional, Sequence, cast -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor.base import BaseLLMPredictor from llama_index.output_parsers.base import StructuredOutput from llama_index.output_parsers.selection import Answer, SelectionOutputParser from llama_index.prompts.mixin import PromptDictType from llama_index.prompts.prompt_type import PromptType +from llama_index.schema import QueryBundle from llama_index.selectors.prompts import ( DEFAULT_MULTI_SELECT_PROMPT_TMPL, DEFAULT_SINGLE_SELECT_PROMPT_TMPL, @@ -14,6 +13,7 @@ from llama_index.selectors.prompts import ( SingleSelectPrompt, ) from llama_index.selectors.types import BaseSelector, SelectorResult, SingleSelection +from llama_index.service_context import ServiceContext from llama_index.tools.types import ToolMetadata from llama_index.types import BaseOutputParser diff --git a/llama_index/selectors/pydantic_selectors.py b/llama_index/selectors/pydantic_selectors.py index b50224088f31e18e2862aa9590623ed32d53e2f1..cefbbb9a31353bd96175e29cb817234aa07cf8cc 100644 --- a/llama_index/selectors/pydantic_selectors.py +++ b/llama_index/selectors/pydantic_selectors.py @@ -1,9 +1,9 @@ from typing import Any, Dict, Optional, Sequence -from llama_index.indices.query.schema import QueryBundle from llama_index.llms.openai import OpenAI from llama_index.program.openai_program import OpenAIPydanticProgram from llama_index.prompts.mixin import PromptDictType +from llama_index.schema import QueryBundle from llama_index.selectors.llm_selectors import _build_choices_text from llama_index.selectors.prompts import ( DEFAULT_MULTI_PYD_SELECT_PROMPT_TMPL, diff --git a/llama_index/selectors/types.py b/llama_index/selectors/types.py index 8ccf7f4607081c9bf7c90f4a325f28d3b702d207..63d10a4fefd2ac39386ae89f91e32f98447a1364 100644 --- a/llama_index/selectors/types.py +++ b/llama_index/selectors/types.py @@ -2,8 +2,8 @@ from abc import abstractmethod from typing import List, Sequence, Union from llama_index.bridge.pydantic import BaseModel -from llama_index.indices.query.schema import QueryBundle, QueryType from llama_index.prompts.mixin import PromptMixin, PromptMixinType +from llama_index.schema import QueryBundle, QueryType from llama_index.tools.types import ToolMetadata MetadataType = Union[str, ToolMetadata] diff --git a/llama_index/selectors/utils.py b/llama_index/selectors/utils.py index e067e95981280385329900e554b2938ce1708106..c9fbbb5a81ad8c4c5b57d2267903a507676b4596 100644 --- a/llama_index/selectors/utils.py +++ b/llama_index/selectors/utils.py @@ -1,12 +1,12 @@ from typing import Optional -from llama_index.indices.service_context import ServiceContext from llama_index.selectors.llm_selectors import LLMMultiSelector, LLMSingleSelector from llama_index.selectors.pydantic_selectors import ( PydanticMultiSelector, PydanticSingleSelector, ) from llama_index.selectors.types import BaseSelector +from llama_index.service_context import ServiceContext def get_selector_from_context( diff --git a/llama_index/service_context.py b/llama_index/service_context.py new file mode 100644 index 0000000000000000000000000000000000000000..b12d64bfdbcf8c38132fda2775d54b360f43b391 --- /dev/null +++ b/llama_index/service_context.py @@ -0,0 +1,360 @@ +import logging +from dataclasses import dataclass +from typing import List, Optional + +import llama_index +from llama_index.bridge.pydantic import BaseModel +from llama_index.callbacks.base import CallbackManager +from llama_index.embeddings.base import BaseEmbedding +from llama_index.embeddings.utils import EmbedType, resolve_embed_model +from llama_index.indices.prompt_helper import PromptHelper +from llama_index.llm_predictor import LLMPredictor +from llama_index.llm_predictor.base import BaseLLMPredictor, LLMMetadata +from llama_index.llms.base import LLM +from llama_index.llms.utils import LLMType, resolve_llm +from llama_index.logger import LlamaLogger +from llama_index.node_parser.interface import NodeParser, TextSplitter +from llama_index.node_parser.text.sentence import ( + DEFAULT_CHUNK_SIZE, + SENTENCE_CHUNK_OVERLAP, + SentenceSplitter, +) +from llama_index.prompts.base import BasePromptTemplate +from llama_index.schema import TransformComponent +from llama_index.types import PydanticProgramMode + +logger = logging.getLogger(__name__) + + +def _get_default_node_parser( + chunk_size: int = DEFAULT_CHUNK_SIZE, + chunk_overlap: int = SENTENCE_CHUNK_OVERLAP, + callback_manager: Optional[CallbackManager] = None, +) -> NodeParser: + """Get default node parser.""" + return SentenceSplitter( + chunk_size=chunk_size, + chunk_overlap=chunk_overlap, + callback_manager=callback_manager or CallbackManager(), + ) + + +def _get_default_prompt_helper( + llm_metadata: LLMMetadata, + context_window: Optional[int] = None, + num_output: Optional[int] = None, +) -> PromptHelper: + """Get default prompt helper.""" + if context_window is not None: + llm_metadata.context_window = context_window + if num_output is not None: + llm_metadata.num_output = num_output + return PromptHelper.from_llm_metadata(llm_metadata=llm_metadata) + + +class ServiceContextData(BaseModel): + llm: dict + llm_predictor: dict + prompt_helper: dict + embed_model: dict + transformations: List[dict] + + +@dataclass +class ServiceContext: + """Service Context container. + + The service context container is a utility container for LlamaIndex + index and query classes. It contains the following: + - llm_predictor: BaseLLMPredictor + - prompt_helper: PromptHelper + - embed_model: BaseEmbedding + - node_parser: NodeParser + - llama_logger: LlamaLogger (deprecated) + - callback_manager: CallbackManager + + """ + + llm_predictor: BaseLLMPredictor + prompt_helper: PromptHelper + embed_model: BaseEmbedding + transformations: List[TransformComponent] + llama_logger: LlamaLogger + callback_manager: CallbackManager + + @classmethod + def from_defaults( + cls, + llm_predictor: Optional[BaseLLMPredictor] = None, + llm: Optional[LLMType] = "default", + prompt_helper: Optional[PromptHelper] = None, + embed_model: Optional[EmbedType] = "default", + node_parser: Optional[NodeParser] = None, + text_splitter: Optional[TextSplitter] = None, + transformations: Optional[List[TransformComponent]] = None, + llama_logger: Optional[LlamaLogger] = None, + callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + query_wrapper_prompt: Optional[BasePromptTemplate] = None, + # pydantic program mode (used if output_cls is specified) + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + # node parser kwargs + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + # prompt helper kwargs + context_window: Optional[int] = None, + num_output: Optional[int] = None, + # deprecated kwargs + chunk_size_limit: Optional[int] = None, + ) -> "ServiceContext": + """Create a ServiceContext from defaults. + If an argument is specified, then use the argument value provided for that + parameter. If an argument is not specified, then use the default value. + + You can change the base defaults by setting llama_index.global_service_context + to a ServiceContext object with your desired settings. + + Args: + llm_predictor (Optional[BaseLLMPredictor]): LLMPredictor + prompt_helper (Optional[PromptHelper]): PromptHelper + embed_model (Optional[BaseEmbedding]): BaseEmbedding + or "local" (use local model) + node_parser (Optional[NodeParser]): NodeParser + llama_logger (Optional[LlamaLogger]): LlamaLogger (deprecated) + chunk_size (Optional[int]): chunk_size + callback_manager (Optional[CallbackManager]): CallbackManager + system_prompt (Optional[str]): System-wide prompt to be prepended + to all input prompts, used to guide system "decision making" + query_wrapper_prompt (Optional[BasePromptTemplate]): A format to wrap + passed-in input queries. + + Deprecated Args: + chunk_size_limit (Optional[int]): renamed to chunk_size + + """ + if chunk_size_limit is not None and chunk_size is None: + logger.warning( + "chunk_size_limit is deprecated, please specify chunk_size instead" + ) + chunk_size = chunk_size_limit + + if llama_index.global_service_context is not None: + return cls.from_service_context( + llama_index.global_service_context, + llm_predictor=llm_predictor, + prompt_helper=prompt_helper, + embed_model=embed_model, + node_parser=node_parser, + text_splitter=text_splitter, + llama_logger=llama_logger, + callback_manager=callback_manager, + chunk_size=chunk_size, + chunk_size_limit=chunk_size_limit, + ) + + callback_manager = callback_manager or CallbackManager([]) + if llm != "default": + if llm_predictor is not None: + raise ValueError("Cannot specify both llm and llm_predictor") + llm = resolve_llm(llm) + llm_predictor = llm_predictor or LLMPredictor( + llm=llm, pydantic_program_mode=pydantic_program_mode + ) + if isinstance(llm_predictor, LLMPredictor): + llm_predictor.llm.callback_manager = callback_manager + if system_prompt: + llm_predictor.system_prompt = system_prompt + if query_wrapper_prompt: + llm_predictor.query_wrapper_prompt = query_wrapper_prompt + + # NOTE: the embed_model isn't used in all indices + # NOTE: embed model should be a transformation, but the way the service + # context works, we can't put in there yet. + embed_model = resolve_embed_model(embed_model) + embed_model.callback_manager = callback_manager + + prompt_helper = prompt_helper or _get_default_prompt_helper( + llm_metadata=llm_predictor.metadata, + context_window=context_window, + num_output=num_output, + ) + + if text_splitter is not None and node_parser is not None: + raise ValueError("Cannot specify both text_splitter and node_parser") + + node_parser = ( + text_splitter # text splitter extends node parser + or node_parser + or _get_default_node_parser( + chunk_size=chunk_size or DEFAULT_CHUNK_SIZE, + chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP, + callback_manager=callback_manager, + ) + ) + + transformations = transformations or [node_parser] + + llama_logger = llama_logger or LlamaLogger() + + return cls( + llm_predictor=llm_predictor, + embed_model=embed_model, + prompt_helper=prompt_helper, + transformations=transformations, + llama_logger=llama_logger, # deprecated + callback_manager=callback_manager, + ) + + @classmethod + def from_service_context( + cls, + service_context: "ServiceContext", + llm_predictor: Optional[BaseLLMPredictor] = None, + llm: Optional[LLMType] = "default", + prompt_helper: Optional[PromptHelper] = None, + embed_model: Optional[EmbedType] = "default", + node_parser: Optional[NodeParser] = None, + text_splitter: Optional[TextSplitter] = None, + transformations: Optional[List[TransformComponent]] = None, + llama_logger: Optional[LlamaLogger] = None, + callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + query_wrapper_prompt: Optional[BasePromptTemplate] = None, + # node parser kwargs + chunk_size: Optional[int] = None, + chunk_overlap: Optional[int] = None, + # prompt helper kwargs + context_window: Optional[int] = None, + num_output: Optional[int] = None, + # deprecated kwargs + chunk_size_limit: Optional[int] = None, + ) -> "ServiceContext": + """Instantiate a new service context using a previous as the defaults.""" + if chunk_size_limit is not None and chunk_size is None: + logger.warning( + "chunk_size_limit is deprecated, please specify chunk_size", + DeprecationWarning, + ) + chunk_size = chunk_size_limit + + callback_manager = callback_manager or service_context.callback_manager + if llm != "default": + if llm_predictor is not None: + raise ValueError("Cannot specify both llm and llm_predictor") + llm = resolve_llm(llm) + llm_predictor = LLMPredictor(llm=llm) + + llm_predictor = llm_predictor or service_context.llm_predictor + if isinstance(llm_predictor, LLMPredictor): + llm_predictor.llm.callback_manager = callback_manager + if system_prompt: + llm_predictor.system_prompt = system_prompt + if query_wrapper_prompt: + llm_predictor.query_wrapper_prompt = query_wrapper_prompt + + # NOTE: the embed_model isn't used in all indices + # default to using the embed model passed from the service context + if embed_model == "default": + embed_model = service_context.embed_model + embed_model = resolve_embed_model(embed_model) + embed_model.callback_manager = callback_manager + + prompt_helper = prompt_helper or service_context.prompt_helper + if context_window is not None or num_output is not None: + prompt_helper = _get_default_prompt_helper( + llm_metadata=llm_predictor.metadata, + context_window=context_window, + num_output=num_output, + ) + + transformations = transformations or [] + node_parser_found = False + for transform in service_context.transformations: + if isinstance(transform, NodeParser): + node_parser_found = True + break + + if text_splitter is not None and node_parser is not None: + raise ValueError("Cannot specify both text_splitter and node_parser") + + if not node_parser_found: + node_parser = ( + text_splitter # text splitter extends node parser + or node_parser + or _get_default_node_parser( + chunk_size=chunk_size or DEFAULT_CHUNK_SIZE, + chunk_overlap=chunk_overlap or SENTENCE_CHUNK_OVERLAP, + callback_manager=callback_manager, + ) + ) + transformations = [node_parser, *transformations] + + llama_logger = llama_logger or service_context.llama_logger + + return cls( + llm_predictor=llm_predictor, + embed_model=embed_model, + prompt_helper=prompt_helper, + transformations=transformations, + llama_logger=llama_logger, # deprecated + callback_manager=callback_manager, + ) + + @property + def llm(self) -> LLM: + if not isinstance(self.llm_predictor, LLMPredictor): + raise ValueError("llm_predictor must be an instance of LLMPredictor") + return self.llm_predictor.llm + + def to_dict(self) -> dict: + """Convert service context to dict.""" + llm_dict = self.llm_predictor.llm.to_dict() + llm_predictor_dict = self.llm_predictor.to_dict() + + embed_model_dict = self.embed_model.to_dict() + + prompt_helper_dict = self.prompt_helper.to_dict() + + tranform_list_dict = [x.to_dict() for x in self.transformations] + + return ServiceContextData( + llm=llm_dict, + llm_predictor=llm_predictor_dict, + prompt_helper=prompt_helper_dict, + embed_model=embed_model_dict, + transformations=tranform_list_dict, + ).dict() + + @classmethod + def from_dict(cls, data: dict) -> "ServiceContext": + from llama_index.embeddings.loading import load_embed_model + from llama_index.extractors.loading import load_extractor + from llama_index.llm_predictor.loading import load_predictor + from llama_index.node_parser.loading import load_parser + + service_context_data = ServiceContextData.parse_obj(data) + + llm_predictor = load_predictor(service_context_data.llm_predictor) + + embed_model = load_embed_model(service_context_data.embed_model) + + prompt_helper = PromptHelper.from_dict(service_context_data.prompt_helper) + + transformations: List[TransformComponent] = [] + for transform in service_context_data.transformations: + try: + transformations.append(load_parser(transform)) + except ValueError: + transformations.append(load_extractor(transform)) + + return cls.from_defaults( + llm_predictor=llm_predictor, + prompt_helper=prompt_helper, + embed_model=embed_model, + transformations=transformations, + ) + + +def set_global_service_context(service_context: Optional[ServiceContext]) -> None: + """Helper function to set the global service context.""" + llama_index.global_service_context = service_context diff --git a/llama_index/storage/docstore/keyval_docstore.py b/llama_index/storage/docstore/keyval_docstore.py index 99a5e5d2231052e5c2fd540dae691591f265cdbe..3cb52479dabb96be5b59193a8e3ac4bc5f7cfd6d 100644 --- a/llama_index/storage/docstore/keyval_docstore.py +++ b/llama_index/storage/docstore/keyval_docstore.py @@ -21,7 +21,7 @@ class KVDocumentStore(BaseDocumentStore): otherwise, each index would create a docstore under the hood. .. code-block:: python - nodes = SimpleNodeParser.get_nodes_from_documents() + nodes = SentenceSplitter().get_nodes_from_documents() docstore = SimpleDocumentStore() docstore.add_documents(nodes) storage_context = StorageContext.from_defaults(docstore=docstore) diff --git a/llama_index/storage/docstore/utils.py b/llama_index/storage/docstore/utils.py index d363beaccff43c8d4cbb79230bc7a74b691dacd3..f59a0cada3c8767bb150605b4b8e3c089186c6f1 100644 --- a/llama_index/storage/docstore/utils.py +++ b/llama_index/storage/docstore/utils.py @@ -2,6 +2,7 @@ from llama_index.constants import DATA_KEY, TYPE_KEY from llama_index.schema import ( BaseNode, Document, + ImageDocument, ImageNode, IndexNode, NodeRelationship, @@ -27,6 +28,8 @@ def json_to_doc(doc_dict: dict) -> BaseNode: else: if doc_type == Document.get_type(): doc = Document.parse_obj(data_dict) + elif doc_type == ImageDocument.get_type(): + doc = ImageDocument.parse_obj(data_dict) elif doc_type == TextNode.get_type(): doc = TextNode.parse_obj(data_dict) elif doc_type == ImageNode.get_type(): diff --git a/llama_index/text_splitter/__init__.py b/llama_index/text_splitter/__init__.py index 9e5f01f6ac07c024029a7f57020a3243cda12bcc..62e8c4a1ae4ddf200b6e6611c6948d4326c36d53 100644 --- a/llama_index/text_splitter/__init__.py +++ b/llama_index/text_splitter/__init__.py @@ -1,35 +1,12 @@ -from typing import Optional - -from llama_index.callbacks.base import CallbackManager -from llama_index.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE -from llama_index.text_splitter.code_splitter import CodeSplitter -from llama_index.text_splitter.sentence_splitter import SentenceSplitter -from llama_index.text_splitter.token_splitter import TokenTextSplitter -from llama_index.text_splitter.types import SplitterType, TextSplitter - - -def get_default_text_splitter( - chunk_size: Optional[int] = None, - chunk_overlap: Optional[int] = None, - callback_manager: Optional[CallbackManager] = None, -) -> TextSplitter: - """Get default text splitter.""" - chunk_size = chunk_size or DEFAULT_CHUNK_SIZE - chunk_overlap = ( - chunk_overlap if chunk_overlap is not None else DEFAULT_CHUNK_OVERLAP - ) - - return SentenceSplitter( - chunk_size=chunk_size, - chunk_overlap=chunk_overlap, - callback_manager=callback_manager, - ) - +# TODO: Deprecated import support for old text splitters +from llama_index.node_parser.text.code import CodeSplitter +from llama_index.node_parser.text.sentence import ( + SentenceSplitter, +) +from llama_index.node_parser.text.token import TokenTextSplitter __all__ = [ - "TextSplitter", - "TokenTextSplitter", "SentenceSplitter", + "TokenTextSplitter", "CodeSplitter", - "SplitterType", ] diff --git a/llama_index/text_splitter/loading.py b/llama_index/text_splitter/loading.py deleted file mode 100644 index a6f907971d1f86ca41291aa52f617ea17c9cc701..0000000000000000000000000000000000000000 --- a/llama_index/text_splitter/loading.py +++ /dev/null @@ -1,23 +0,0 @@ -from typing import Dict, Type - -from llama_index.text_splitter.code_splitter import CodeSplitter -from llama_index.text_splitter.sentence_splitter import SentenceSplitter -from llama_index.text_splitter.token_splitter import TokenTextSplitter -from llama_index.text_splitter.types import TextSplitter - -RECOGNIZED_TEXT_SPLITTERS: Dict[str, Type[TextSplitter]] = { - SentenceSplitter.class_name(): SentenceSplitter, - TokenTextSplitter.class_name(): TokenTextSplitter, - CodeSplitter.class_name(): CodeSplitter, -} - - -def load_text_splitter(data: dict) -> TextSplitter: - text_splitter_name = data.get("class_name", None) - if text_splitter_name is None: - raise ValueError("TextSplitter loading requires a class_name") - - if text_splitter_name not in RECOGNIZED_TEXT_SPLITTERS: - raise ValueError(f"Invalid TextSplitter name: {text_splitter_name}") - - return RECOGNIZED_TEXT_SPLITTERS[text_splitter_name].from_dict(data) diff --git a/llama_index/text_splitter/types.py b/llama_index/text_splitter/types.py deleted file mode 100644 index c8c4ea7c03414ee7da52f01f1633ed069f76de91..0000000000000000000000000000000000000000 --- a/llama_index/text_splitter/types.py +++ /dev/null @@ -1,39 +0,0 @@ -"""Text splitter implementations.""" -from abc import ABC, abstractmethod -from typing import List, Union - -from llama_index.bridge.langchain import TextSplitter as LC_TextSplitter -from llama_index.schema import BaseComponent - - -class TextSplitter(ABC, BaseComponent): - class Config: - arbitrary_types_allowed = True - - @abstractmethod - def split_text(self, text: str) -> List[str]: - ... - - def split_texts(self, texts: List[str]) -> List[str]: - nested_texts = [self.split_text(text) for text in texts] - return [item for sublist in nested_texts for item in sublist] - - -class MetadataAwareTextSplitter(TextSplitter): - @abstractmethod - def split_text_metadata_aware(self, text: str, metadata_str: str) -> List[str]: - ... - - def split_texts_metadata_aware( - self, texts: List[str], metadata_strs: List[str] - ) -> List[str]: - if len(texts) != len(metadata_strs): - raise ValueError("Texts and metadata_strs must have the same length") - nested_texts = [ - self.split_text_metadata_aware(text, metadata) - for text, metadata in zip(texts, metadata_strs) - ] - return [item for sublist in nested_texts for item in sublist] - - -SplitterType = Union[TextSplitter, LC_TextSplitter] diff --git a/llama_index/tools/function_tool.py b/llama_index/tools/function_tool.py index 17fbdd8258d272811a3a34545ba18752a22cc0a6..7abbe510429b7599a826d982f13ab70431f8b8ac 100644 --- a/llama_index/tools/function_tool.py +++ b/llama_index/tools/function_tool.py @@ -1,7 +1,8 @@ from inspect import signature -from typing import Any, Awaitable, Callable, Optional, Type +from typing import TYPE_CHECKING, Any, Awaitable, Callable, Optional, Type -from llama_index.bridge.langchain import StructuredTool, Tool +if TYPE_CHECKING: + from llama_index.bridge.langchain import StructuredTool, Tool from llama_index.bridge.pydantic import BaseModel from llama_index.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput from llama_index.tools.utils import create_schema_from_function @@ -99,8 +100,10 @@ class FunctionTool(AsyncBaseTool): def to_langchain_tool( self, **langchain_tool_kwargs: Any, - ) -> Tool: + ) -> "Tool": """To langchain tool.""" + from llama_index.bridge.langchain import Tool + langchain_tool_kwargs = self._process_langchain_tool_kwargs( langchain_tool_kwargs ) @@ -113,8 +116,10 @@ class FunctionTool(AsyncBaseTool): def to_langchain_structured_tool( self, **langchain_tool_kwargs: Any, - ) -> StructuredTool: + ) -> "StructuredTool": """To langchain structured tool.""" + from llama_index.bridge.langchain import StructuredTool + langchain_tool_kwargs = self._process_langchain_tool_kwargs( langchain_tool_kwargs ) diff --git a/llama_index/tools/query_engine.py b/llama_index/tools/query_engine.py index 154bea7473cc9d1d7141d3ebf11f718998e66e74..7ec36b52ca5e49fee00d2a5230bcf631cf11cbd9 100644 --- a/llama_index/tools/query_engine.py +++ b/llama_index/tools/query_engine.py @@ -1,7 +1,11 @@ -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.langchain_helpers.agents.tools import IndexToolConfig, LlamaIndexTool +from llama_index.core import BaseQueryEngine + +if TYPE_CHECKING: + from llama_index.langchain_helpers.agents.tools import ( + LlamaIndexTool, + ) from llama_index.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput DEFAULT_NAME = "query_engine_tool" @@ -81,7 +85,12 @@ class QueryEngineTool(AsyncBaseTool): raw_output=response, ) - def as_langchain_tool(self) -> LlamaIndexTool: + def as_langchain_tool(self) -> "LlamaIndexTool": + from llama_index.langchain_helpers.agents.tools import ( + IndexToolConfig, + LlamaIndexTool, + ) + tool_config = IndexToolConfig( query_engine=self.query_engine, name=self.metadata.name, diff --git a/llama_index/tools/retriever_tool.py b/llama_index/tools/retriever_tool.py index ff5fdd222affde42ab55b83254c7482a5c9099eb..029d320c4660ed3004155d6fc4639ace32314d19 100644 --- a/llama_index/tools/retriever_tool.py +++ b/llama_index/tools/retriever_tool.py @@ -1,9 +1,11 @@ """Retriever tool.""" -from typing import Any, Optional +from typing import TYPE_CHECKING, Any, Optional -from llama_index.indices.base_retriever import BaseRetriever -from llama_index.langchain_helpers.agents.tools import LlamaIndexTool +from llama_index.core import BaseRetriever + +if TYPE_CHECKING: + from llama_index.langchain_helpers.agents.tools import LlamaIndexTool from llama_index.schema import MetadataMode from llama_index.tools.types import AsyncBaseTool, ToolMetadata, ToolOutput @@ -101,5 +103,5 @@ class RetrieverTool(AsyncBaseTool): raw_output=docs, ) - def as_langchain_tool(self) -> LlamaIndexTool: + def as_langchain_tool(self) -> "LlamaIndexTool": raise NotImplementedError("`as_langchain_tool` not implemented here.") diff --git a/llama_index/tools/types.py b/llama_index/tools/types.py index 72666b82c30435ea83d14e98c019df7dd8c2cbd5..28486f87e4b4d2cb6a1c46816234dc4f801747e3 100644 --- a/llama_index/tools/types.py +++ b/llama_index/tools/types.py @@ -1,10 +1,11 @@ from abc import abstractmethod from dataclasses import dataclass -from typing import Any, Dict, Optional, Type +from typing import TYPE_CHECKING, Any, Dict, Optional, Type +if TYPE_CHECKING: + from llama_index.bridge.langchain import StructuredTool, Tool from deprecated import deprecated -from llama_index.bridge.langchain import StructuredTool, Tool from llama_index.bridge.pydantic import BaseModel @@ -106,8 +107,10 @@ class BaseTool: def to_langchain_tool( self, **langchain_tool_kwargs: Any, - ) -> Tool: + ) -> "Tool": """To langchain tool.""" + from llama_index.bridge.langchain import Tool + langchain_tool_kwargs = self._process_langchain_tool_kwargs( langchain_tool_kwargs ) @@ -119,8 +122,10 @@ class BaseTool: def to_langchain_structured_tool( self, **langchain_tool_kwargs: Any, - ) -> StructuredTool: + ) -> "StructuredTool": """To langchain structured tool.""" + from llama_index.bridge.langchain import StructuredTool + langchain_tool_kwargs = self._process_langchain_tool_kwargs( langchain_tool_kwargs ) diff --git a/llama_index/utilities/token_counting.py b/llama_index/utilities/token_counting.py new file mode 100644 index 0000000000000000000000000000000000000000..7884266601a6fea7cce6f19fae6cad5f20453f7e --- /dev/null +++ b/llama_index/utilities/token_counting.py @@ -0,0 +1,82 @@ +# Modified from: +# https://github.com/nyno-ai/openai-token-counter + +from typing import Any, Callable, Dict, List, Optional + +from llama_index.llms import ChatMessage, MessageRole +from llama_index.utils import get_tokenizer + + +class TokenCounter: + """Token counter class. + + Attributes: + model (Optional[str]): The model to use for token counting. + """ + + def __init__(self, tokenizer: Optional[Callable[[str], list]] = None) -> None: + self.tokenizer = tokenizer or get_tokenizer() + + def get_string_tokens(self, string: str) -> int: + """Get the token count for a string. + + Args: + string (str): The string to count. + + Returns: + int: The token count. + """ + return len(self.tokenizer(string)) + + def estimate_tokens_in_messages(self, messages: List[ChatMessage]) -> int: + """Estimate token count for a single message. + + Args: + message (OpenAIMessage): The message to estimate the token count for. + + Returns: + int: The estimated token count. + """ + tokens = 0 + + for message in messages: + if message.role: + tokens += self.get_string_tokens(message.role) + + if message.content: + tokens += self.get_string_tokens(message.content) + + additional_kwargs = {**message.additional_kwargs} + + if "function_call" in additional_kwargs: + function_call = additional_kwargs.pop("function_call") + if function_call.get("name", None) is not None: + tokens += self.get_string_tokens(function_call["name"]) + + if function_call.get("arguments", None) is not None: + tokens += self.get_string_tokens(function_call["arguments"]) + + tokens += 3 # Additional tokens for function call + + tokens += 3 # Add three per message + + if message.role == MessageRole.FUNCTION: + tokens -= 2 # Subtract 2 if role is "function" + + return tokens + + def estimate_tokens_in_functions(self, functions: List[Dict[str, Any]]) -> int: + """Estimate token count for the functions. + + We take here a list of functions created using the `to_openai_spec` function (or similar). + + Args: + function (list[Dict[str, Any]]): The functions to estimate the token count for. + + Returns: + int: The estimated token count. + """ + prompt_definition = str(functions) + tokens = self.get_string_tokens(prompt_definition) + tokens += 9 # Additional tokens for function definition + return tokens diff --git a/llama_index/utils.py b/llama_index/utils.py index d189d135680d00ca9eabe1ba8818db7f381c10f3..32a0e582c8e74fc272add381abf771cbc01a789e 100644 --- a/llama_index/utils.py +++ b/llama_index/utils.py @@ -21,10 +21,12 @@ from typing import ( Iterable, List, Optional, + Protocol, Set, Type, Union, cast, + runtime_checkable, ) @@ -41,7 +43,7 @@ class GlobalsHelper: @property def tokenizer(self) -> Callable[[str], List]: - """Get tokenizer.""" + """Get tokenizer. TODO: Deprecated.""" if self._tokenizer is None: tiktoken_import_err = ( "`tiktoken` package not found, please run `pip install tiktoken`" @@ -87,6 +89,41 @@ class GlobalsHelper: globals_helper = GlobalsHelper() +# Global Tokenizer +@runtime_checkable +class Tokenizer(Protocol): + def encode(self, text: str, *args: Any, **kwargs: Any) -> List[Any]: + ... + + +def set_global_tokenizer(tokenizer: Union[Tokenizer, Callable[[str], list]]) -> None: + import llama_index + + if isinstance(tokenizer, Tokenizer): + llama_index.global_tokenizer = tokenizer.encode + else: + llama_index.global_tokenizer = tokenizer + + +def get_tokenizer() -> Callable[[str], List]: + import llama_index + + if llama_index.global_tokenizer is None: + tiktoken_import_err = ( + "`tiktoken` package not found, please run `pip install tiktoken`" + ) + try: + import tiktoken + except ImportError: + raise ImportError(tiktoken_import_err) + enc = tiktoken.encoding_for_model("gpt-3.5-turbo") + tokenizer = partial(enc.encode, allowed_special="all") + set_global_tokenizer(tokenizer) + + assert llama_index.global_tokenizer is not None + return llama_index.global_tokenizer + + def get_new_id(d: Set) -> str: """Get a new ID.""" while True: @@ -234,7 +271,8 @@ def get_tqdm_iterable(items: Iterable, show_progress: bool, desc: str) -> Iterab def count_tokens(text: str) -> int: - tokens = globals_helper.tokenizer(text) + tokenizer = get_tokenizer() + tokens = tokenizer(text) return len(tokens) diff --git a/llama_index/vector_stores/chroma.py b/llama_index/vector_stores/chroma.py index bb26d937170cfb7ca33c6b6380fd7a81cd166648..e98988c1451a0c988cafa1da8895efce5f2ec8ca 100644 --- a/llama_index/vector_stores/chroma.py +++ b/llama_index/vector_stores/chroma.py @@ -67,6 +67,7 @@ class ChromaVectorStore(BasePydanticVectorStore): stores_text: bool = True flat_metadata: bool = True + collection_name: Optional[str] host: Optional[str] port: Optional[str] ssl: bool @@ -78,7 +79,8 @@ class ChromaVectorStore(BasePydanticVectorStore): def __init__( self, - chroma_collection: Any, + chroma_collection: Optional[Any] = None, + collection_name: Optional[str] = None, host: Optional[str] = None, port: Optional[str] = None, ssl: bool = False, @@ -89,18 +91,25 @@ class ChromaVectorStore(BasePydanticVectorStore): ) -> None: """Init params.""" try: - import chromadb # noqa + import chromadb except ImportError: raise ImportError(import_err_msg) from chromadb.api.models.Collection import Collection - self._collection = cast(Collection, chroma_collection) + if chroma_collection is None: + client = chromadb.HttpClient(host=host, port=port, ssl=ssl, headers=headers) + self._collection = client.get_or_create_collection( + name=collection_name, **collection_kwargs + ) + else: + self._collection = cast(Collection, chroma_collection) super().__init__( host=host, port=port, ssl=ssl, headers=headers, + collection_name=collection_name, persist_dir=persist_dir, collection_kwargs=collection_kwargs or {}, ) diff --git a/llama_index/vector_stores/loading.py b/llama_index/vector_stores/loading.py index 70e96443ad8dd17175bde7cca832461c784cfcc9..f3da5805e251ff49e83d4fe691c084657228832b 100644 --- a/llama_index/vector_stores/loading.py +++ b/llama_index/vector_stores/loading.py @@ -19,6 +19,8 @@ LOADABLE_VECTOR_STORES: Dict[str, Type[BasePydanticVectorStore]] = { def load_vector_store(data: dict) -> BasePydanticVectorStore: + if isinstance(data, BasePydanticVectorStore): + return data class_name = data.pop("class_name", None) if class_name is None: raise ValueError("class_name is required to load a vector store") @@ -49,4 +51,4 @@ def load_vector_store(data: dict) -> BasePydanticVectorStore: data["auth_config"] = auth_config - return LOADABLE_VECTOR_STORES[class_name].from_params(**data) # type: ignore + return LOADABLE_VECTOR_STORES[class_name](**data) # type: ignore diff --git a/llama_index/vector_stores/myscale.py b/llama_index/vector_stores/myscale.py index 0f03ef280eb762ffb7af93cbbaca4e78e3e5db29..29e49c5f445d26305f550ad83fbb85fc4ba37d20 100644 --- a/llama_index/vector_stores/myscale.py +++ b/llama_index/vector_stores/myscale.py @@ -7,7 +7,6 @@ import json import logging from typing import Any, Dict, List, Optional, cast -from llama_index.indices.service_context import ServiceContext from llama_index.readers.myscale import ( MyScaleSettings, escape_str, @@ -20,6 +19,7 @@ from llama_index.schema import ( RelatedNodeInfo, TextNode, ) +from llama_index.service_context import ServiceContext from llama_index.utils import iter_batch from llama_index.vector_stores.types import ( VectorStore, diff --git a/llama_index/vector_stores/qdrant.py b/llama_index/vector_stores/qdrant.py index d615f60d6a1dc00841574c3562da7e30ced82d52..93fde0de900ffc94659463c988f68186f678daaa 100644 --- a/llama_index/vector_stores/qdrant.py +++ b/llama_index/vector_stores/qdrant.py @@ -71,9 +71,13 @@ class QdrantVectorStore(BasePydanticVectorStore): raise ImportError(import_err_msg) if client is None: - raise ValueError("Missing Qdrant client!") + client_kwargs = client_kwargs or {} + self._client = ( + qdrant_client.QdrantClient(url=url, api_key=api_key, **client_kwargs), + ) + else: + self._client = cast(qdrant_client.QdrantClient, client) - self._client = cast(qdrant_client.QdrantClient, client) self._collection_initialized = self._collection_exists(collection_name) super().__init__( @@ -85,37 +89,6 @@ class QdrantVectorStore(BasePydanticVectorStore): client_kwargs=client_kwargs or {}, ) - @classmethod - def from_params( - cls, - collection_name: str, - url: Optional[str] = None, - api_key: Optional[str] = None, - client_kwargs: Optional[dict] = None, - batch_size: int = 100, - prefer_grpc: bool = False, - **kwargs: Any, - ) -> "QdrantVectorStore": - """Create a connection to a remote Qdrant vector store from a config.""" - try: - import qdrant_client - except ImportError: - raise ImportError(import_err_msg) - - client_kwargs = client_kwargs or {} - return cls( - collection_name=collection_name, - client=qdrant_client.QdrantClient( - url=url, api_key=api_key, prefer_grpc=prefer_grpc, **client_kwargs - ), - batch_size=batch_size, - prefer_grpc=prefer_grpc, - client_kwargs=client_kwargs, - url=url, - api_key=api_key, - **kwargs, - ) - @classmethod def class_name(cls) -> str: return "QdrantVectorStore" diff --git a/llama_index/vector_stores/typesense.py b/llama_index/vector_stores/typesense.py index c0a840aa0623ae5a46bef2e1bd860a51babd7d1a..e448aed4b96dfb737b00e9f45d5e7586b07e7754 100644 --- a/llama_index/vector_stores/typesense.py +++ b/llama_index/vector_stores/typesense.py @@ -7,8 +7,8 @@ An index that that is built on top of an existing vector store. import logging from typing import Any, Callable, List, Optional, cast -from llama_index import utils from llama_index.schema import BaseNode, MetadataMode, TextNode +from llama_index.utils import get_tokenizer from llama_index.vector_stores.types import ( MetadataFilters, VectorStore, @@ -75,7 +75,7 @@ class TypesenseVectorStore(VectorStore): f"got {type(client)}" ) self._client = cast(typesense.Client, client) - self._tokenizer = tokenizer or utils.globals_helper.tokenizer + self._tokenizer = tokenizer or get_tokenizer() self._text_key = text_key self._collection_name = collection_name self._collection = self._client.collections[self._collection_name] diff --git a/llama_index/vector_stores/weaviate.py b/llama_index/vector_stores/weaviate.py index a3ba2d3f7ef4a4e12b303dc32992b201fe0bdf15..2364ec05a40cd46a4029f6e5524c1e3fadf02c43 100644 --- a/llama_index/vector_stores/weaviate.py +++ b/llama_index/vector_stores/weaviate.py @@ -92,14 +92,21 @@ class WeaviateVectorStore(BasePydanticVectorStore): """Initialize params.""" try: import weaviate # noqa - from weaviate import Client + from weaviate import AuthApiKey, Client except ImportError: raise ImportError(import_err_msg) if weaviate_client is None: - raise ValueError("Missing Weaviate client!") + if isinstance(auth_config, dict): + auth_config = AuthApiKey(**auth_config) + + client_kwargs = client_kwargs or {} + self._client = Client( + url=url, auth_client_secret=auth_config, **client_kwargs + ) + else: + self._client = cast(Client, weaviate_client) - self._client = cast(Client, weaviate_client) # validate class prefix starts with a capital letter if class_prefix is not None: logger.warning("class_prefix is deprecated, please use index_name") @@ -120,7 +127,7 @@ class WeaviateVectorStore(BasePydanticVectorStore): url=url, index_name=index_name, text_key=text_key, - auth_config=auth_config or {}, + auth_config=auth_config.__dict__ if auth_config else {}, client_kwargs=client_kwargs or {}, ) diff --git a/poetry.lock b/poetry.lock index be875952137dabc157e54eb0881eed6da642e5dc..3546b375d5da56fa29b71d944b0ef7d987b46e9a 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.6.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. [[package]] name = "accelerate" @@ -33,7 +33,7 @@ testing = ["bitsandbytes", "datasets", "deepspeed", "evaluate", "parameterized", name = "aiohttp" version = "3.8.6" description = "Async http client/server framework (asyncio)" -optional = false +optional = true python-versions = ">=3.6" files = [ {file = "aiohttp-3.8.6-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:41d55fc043954cddbbd82503d9cc3f4814a40bcef30b3569bc7b5e34130718c1"}, @@ -141,7 +141,7 @@ speedups = ["Brotli", "aiodns", "cchardet"] name = "aiosignal" version = "1.3.1" description = "aiosignal: a list of registered asynchronous callbacks" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "aiosignal-1.3.1-py3-none-any.whl", hash = "sha256:f8376fb07dd1e86a584e4fcdec80b36b7f81aac666ebc724e2c090300dd83b17"}, @@ -339,7 +339,7 @@ typing-extensions = {version = ">=4.0.0", markers = "python_version < \"3.11\""} name = "async-timeout" version = "4.0.3" description = "Timeout context manager for asyncio programs" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "async-timeout-4.0.3.tar.gz", hash = "sha256:4640d96be84d82d02ed59ea2b7105a0f7b33abe8703703cd0ab0bf87c427522f"}, @@ -1338,7 +1338,7 @@ files = [ name = "frozenlist" version = "1.4.0" description = "A list-like structure which implements collections.abc.MutableSequence" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "frozenlist-1.4.0-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:764226ceef3125e53ea2cb275000e309c0aa5464d43bd72abd661e27fffc26ab"}, @@ -1494,11 +1494,11 @@ files = [ google-auth = ">=2.14.1,<3.0.dev0" googleapis-common-protos = ">=1.56.2,<2.0.dev0" grpcio = [ - {version = ">=1.33.2,<2.0dev", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0dev", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0dev", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] grpcio-status = [ - {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "python_version < \"3.11\" and extra == \"grpc\""}, + {version = ">=1.33.2,<2.0.dev0", optional = true, markers = "extra == \"grpc\""}, {version = ">=1.49.1,<2.0.dev0", optional = true, markers = "python_version >= \"3.11\" and extra == \"grpc\""}, ] protobuf = ">=3.19.5,<3.20.0 || >3.20.0,<3.20.1 || >3.20.1,<4.21.0 || >4.21.0,<4.21.1 || >4.21.1,<4.21.2 || >4.21.2,<4.21.3 || >4.21.3,<4.21.4 || >4.21.4,<4.21.5 || >4.21.5,<5.0.0.dev0" @@ -2152,7 +2152,7 @@ dev = ["hypothesis"] name = "jsonpatch" version = "1.33" description = "Apply JSON-Patches (RFC 6902)" -optional = false +optional = true python-versions = ">=2.7, !=3.0.*, !=3.1.*, !=3.2.*, !=3.3.*, !=3.4.*, !=3.5.*, !=3.6.*" files = [ {file = "jsonpatch-1.33-py2.py3-none-any.whl", hash = "sha256:0ae28c0cd062bbd8b8ecc26d7d164fbbea9652a1a3693f3b956c1eae5145dade"}, @@ -2527,7 +2527,7 @@ files = [ name = "langchain" version = "0.0.331" description = "Building applications with LLMs through composability" -optional = false +optional = true python-versions = ">=3.8.1,<4.0" files = [ {file = "langchain-0.0.331-py3-none-any.whl", hash = "sha256:64e6e1a57b8deafc1c4e914820b2b8e22a5eed60d49432cadc3b8cca9d613694"}, @@ -2581,7 +2581,7 @@ data = ["language-data (>=1.1,<2.0)"] name = "langsmith" version = "0.0.58" description = "Client library to connect to the LangSmith LLM Tracing and Evaluation Platform." -optional = false +optional = true python-versions = ">=3.8.1,<4.0" files = [ {file = "langsmith-0.0.58-py3-none-any.whl", hash = "sha256:75a82744da2d2fa647d8d8d66a2b3791edc914a640516fa2f46cd5502b697380"}, @@ -2840,16 +2840,6 @@ files = [ {file = "MarkupSafe-2.1.3-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:5bbe06f8eeafd38e5d0a4894ffec89378b6c6a625ff57e3028921f8ff59318ac"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win32.whl", hash = "sha256:dd15ff04ffd7e05ffcb7fe79f1b98041b8ea30ae9234aed2a9168b5797c3effb"}, {file = "MarkupSafe-2.1.3-cp311-cp311-win_amd64.whl", hash = "sha256:134da1eca9ec0ae528110ccc9e48041e0828d79f24121a1a146161103c76e686"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:f698de3fd0c4e6972b92290a45bd9b1536bffe8c6759c62471efaa8acb4c37bc"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:aa57bd9cf8ae831a362185ee444e15a93ecb2e344c8e52e4d721ea3ab6ef1823"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:ffcc3f7c66b5f5b7931a5aa68fc9cecc51e685ef90282f4a82f0f5e9b704ad11"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:47d4f1c5f80fc62fdd7777d0d40a2e9dda0a05883ab11374334f6c4de38adffd"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:1f67c7038d560d92149c060157d623c542173016c4babc0c1913cca0564b9939"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_aarch64.whl", hash = "sha256:9aad3c1755095ce347e26488214ef77e0485a3c34a50c5a5e2471dff60b9dd9c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:14ff806850827afd6b07a5f32bd917fb7f45b046ba40c57abdb636674a8b559c"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8f9293864fe09b8149f0cc42ce56e3f0e54de883a9de90cd427f191c346eb2e1"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win32.whl", hash = "sha256:715d3562f79d540f251b99ebd6d8baa547118974341db04f5ad06d5ea3eb8007"}, - {file = "MarkupSafe-2.1.3-cp312-cp312-win_amd64.whl", hash = "sha256:1b8dd8c3fd14349433c79fa8abeb573a55fc0fdd769133baac1f5e07abf54aeb"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:8e254ae696c88d98da6555f5ace2279cf7cd5b3f52be2b5cf97feafe883b58d2"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:cb0932dc158471523c9637e807d9bfb93e06a95cbf010f1a38b98623b929ef2b"}, {file = "MarkupSafe-2.1.3-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9402b03f1a1b4dc4c19845e5c749e3ab82d5078d16a2a4c2cd2df62d57bb0707"}, @@ -3008,7 +2998,7 @@ broker = ["pymsalruntime (>=0.13.2,<0.14)"] name = "multidict" version = "6.0.4" description = "multidict implementation" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "multidict-6.0.4-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:0b1a97283e0c85772d613878028fec909f003993e1007eafa715b24b377cb9b8"}, @@ -3608,7 +3598,6 @@ description = "Optimum Library is an extension of the Hugging Face Transformers optional = true python-versions = ">=3.7.0" files = [ - {file = "optimum-1.14.0-py3-none-any.whl", hash = "sha256:6eea0a8f626912393b80a2cad2b03f9775d798eda33aa6928ae52f45b90cd7fb"}, {file = "optimum-1.14.0.tar.gz", hash = "sha256:3b15d33b84f1cce483138f2ab202e17f39aa1330dbed1bd63e619d2230931b17"}, ] @@ -3619,7 +3608,7 @@ datasets = [ {version = ">=1.2.1", optional = true, markers = "extra == \"onnxruntime\""}, ] evaluate = {version = "*", optional = true, markers = "extra == \"onnxruntime\""} -huggingface-hub = ">=0.8.0" +huggingface_hub = ">=0.8.0" numpy = "*" onnx = {version = "*", optional = true, markers = "extra == \"onnxruntime\""} onnxruntime = {version = ">=1.11.0", optional = true, markers = "extra == \"onnxruntime\""} @@ -3710,7 +3699,7 @@ files = [ [package.dependencies] numpy = [ {version = ">=1.20.3", markers = "python_version < \"3.10\""}, - {version = ">=1.21.0", markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, + {version = ">=1.21.0", markers = "python_version >= \"3.10\""}, {version = ">=1.23.2", markers = "python_version >= \"3.11\""}, ] python-dateutil = ">=2.8.2" @@ -4127,7 +4116,7 @@ tests = ["pytest"] name = "pyarrow" version = "14.0.1" description = "Python library for Apache Arrow" -optional = false +optional = true python-versions = ">=3.8" files = [ {file = "pyarrow-14.0.1-cp310-cp310-macosx_10_14_x86_64.whl", hash = "sha256:96d64e5ba7dceb519a955e5eeb5c9adcfd63f73a56aea4722e2cc81364fc567a"}, @@ -4209,47 +4198,47 @@ files = [ [[package]] name = "pydantic" -version = "1.10.13" +version = "1.10.12" description = "Data validation and settings management using python type hints" optional = false python-versions = ">=3.7" files = [ - {file = "pydantic-1.10.13-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:efff03cc7a4f29d9009d1c96ceb1e7a70a65cfe86e89d34e4a5f2ab1e5693737"}, - {file = "pydantic-1.10.13-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:3ecea2b9d80e5333303eeb77e180b90e95eea8f765d08c3d278cd56b00345d01"}, - {file = "pydantic-1.10.13-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:1740068fd8e2ef6eb27a20e5651df000978edce6da6803c2bef0bc74540f9548"}, - {file = "pydantic-1.10.13-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:84bafe2e60b5e78bc64a2941b4c071a4b7404c5c907f5f5a99b0139781e69ed8"}, - {file = "pydantic-1.10.13-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:bc0898c12f8e9c97f6cd44c0ed70d55749eaf783716896960b4ecce2edfd2d69"}, - {file = "pydantic-1.10.13-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:654db58ae399fe6434e55325a2c3e959836bd17a6f6a0b6ca8107ea0571d2e17"}, - {file = "pydantic-1.10.13-cp310-cp310-win_amd64.whl", hash = "sha256:75ac15385a3534d887a99c713aa3da88a30fbd6204a5cd0dc4dab3d770b9bd2f"}, - {file = "pydantic-1.10.13-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:c553f6a156deb868ba38a23cf0df886c63492e9257f60a79c0fd8e7173537653"}, - {file = "pydantic-1.10.13-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:5e08865bc6464df8c7d61439ef4439829e3ab62ab1669cddea8dd00cd74b9ffe"}, - {file = "pydantic-1.10.13-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:e31647d85a2013d926ce60b84f9dd5300d44535a9941fe825dc349ae1f760df9"}, - {file = "pydantic-1.10.13-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:210ce042e8f6f7c01168b2d84d4c9eb2b009fe7bf572c2266e235edf14bacd80"}, - {file = "pydantic-1.10.13-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:8ae5dd6b721459bfa30805f4c25880e0dd78fc5b5879f9f7a692196ddcb5a580"}, - {file = "pydantic-1.10.13-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:f8e81fc5fb17dae698f52bdd1c4f18b6ca674d7068242b2aff075f588301bbb0"}, - {file = "pydantic-1.10.13-cp311-cp311-win_amd64.whl", hash = "sha256:61d9dce220447fb74f45e73d7ff3b530e25db30192ad8d425166d43c5deb6df0"}, - {file = "pydantic-1.10.13-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:4b03e42ec20286f052490423682016fd80fda830d8e4119f8ab13ec7464c0132"}, - {file = "pydantic-1.10.13-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:f59ef915cac80275245824e9d771ee939133be38215555e9dc90c6cb148aaeb5"}, - {file = "pydantic-1.10.13-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:5a1f9f747851338933942db7af7b6ee8268568ef2ed86c4185c6ef4402e80ba8"}, - {file = "pydantic-1.10.13-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:97cce3ae7341f7620a0ba5ef6cf043975cd9d2b81f3aa5f4ea37928269bc1b87"}, - {file = "pydantic-1.10.13-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:854223752ba81e3abf663d685f105c64150873cc6f5d0c01d3e3220bcff7d36f"}, - {file = "pydantic-1.10.13-cp37-cp37m-win_amd64.whl", hash = "sha256:b97c1fac8c49be29486df85968682b0afa77e1b809aff74b83081cc115e52f33"}, - {file = "pydantic-1.10.13-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:c958d053453a1c4b1c2062b05cd42d9d5c8eb67537b8d5a7e3c3032943ecd261"}, - {file = "pydantic-1.10.13-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:4c5370a7edaac06daee3af1c8b1192e305bc102abcbf2a92374b5bc793818599"}, - {file = "pydantic-1.10.13-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7d6f6e7305244bddb4414ba7094ce910560c907bdfa3501e9db1a7fd7eaea127"}, - {file = "pydantic-1.10.13-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:d3a3c792a58e1622667a2837512099eac62490cdfd63bd407993aaf200a4cf1f"}, - {file = "pydantic-1.10.13-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:c636925f38b8db208e09d344c7aa4f29a86bb9947495dd6b6d376ad10334fb78"}, - {file = "pydantic-1.10.13-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:678bcf5591b63cc917100dc50ab6caebe597ac67e8c9ccb75e698f66038ea953"}, - {file = "pydantic-1.10.13-cp38-cp38-win_amd64.whl", hash = "sha256:6cf25c1a65c27923a17b3da28a0bdb99f62ee04230c931d83e888012851f4e7f"}, - {file = "pydantic-1.10.13-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:8ef467901d7a41fa0ca6db9ae3ec0021e3f657ce2c208e98cd511f3161c762c6"}, - {file = "pydantic-1.10.13-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:968ac42970f57b8344ee08837b62f6ee6f53c33f603547a55571c954a4225691"}, - {file = "pydantic-1.10.13-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:9849f031cf8a2f0a928fe885e5a04b08006d6d41876b8bbd2fc68a18f9f2e3fd"}, - {file = "pydantic-1.10.13-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:56e3ff861c3b9c6857579de282ce8baabf443f42ffba355bf070770ed63e11e1"}, - {file = "pydantic-1.10.13-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:9f00790179497767aae6bcdc36355792c79e7bbb20b145ff449700eb076c5f96"}, - {file = "pydantic-1.10.13-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:75b297827b59bc229cac1a23a2f7a4ac0031068e5be0ce385be1462e7e17a35d"}, - {file = "pydantic-1.10.13-cp39-cp39-win_amd64.whl", hash = "sha256:e70ca129d2053fb8b728ee7d1af8e553a928d7e301a311094b8a0501adc8763d"}, - {file = "pydantic-1.10.13-py3-none-any.whl", hash = "sha256:b87326822e71bd5f313e7d3bfdc77ac3247035ac10b0c0618bd99dcf95b1e687"}, - {file = "pydantic-1.10.13.tar.gz", hash = "sha256:32c8b48dcd3b2ac4e78b0ba4af3a2c2eb6048cb75202f0ea7b34feb740efc340"}, + {file = "pydantic-1.10.12-cp310-cp310-macosx_10_9_x86_64.whl", hash = "sha256:a1fcb59f2f355ec350073af41d927bf83a63b50e640f4dbaa01053a28b7a7718"}, + {file = "pydantic-1.10.12-cp310-cp310-macosx_11_0_arm64.whl", hash = "sha256:b7ccf02d7eb340b216ec33e53a3a629856afe1c6e0ef91d84a4e6f2fb2ca70fe"}, + {file = "pydantic-1.10.12-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:8fb2aa3ab3728d950bcc885a2e9eff6c8fc40bc0b7bb434e555c215491bcf48b"}, + {file = "pydantic-1.10.12-cp310-cp310-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:771735dc43cf8383959dc9b90aa281f0b6092321ca98677c5fb6125a6f56d58d"}, + {file = "pydantic-1.10.12-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:ca48477862372ac3770969b9d75f1bf66131d386dba79506c46d75e6b48c1e09"}, + {file = "pydantic-1.10.12-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:a5e7add47a5b5a40c49b3036d464e3c7802f8ae0d1e66035ea16aa5b7a3923ed"}, + {file = "pydantic-1.10.12-cp310-cp310-win_amd64.whl", hash = "sha256:e4129b528c6baa99a429f97ce733fff478ec955513630e61b49804b6cf9b224a"}, + {file = "pydantic-1.10.12-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:b0d191db0f92dfcb1dec210ca244fdae5cbe918c6050b342d619c09d31eea0cc"}, + {file = "pydantic-1.10.12-cp311-cp311-macosx_11_0_arm64.whl", hash = "sha256:795e34e6cc065f8f498c89b894a3c6da294a936ee71e644e4bd44de048af1405"}, + {file = "pydantic-1.10.12-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:69328e15cfda2c392da4e713443c7dbffa1505bc9d566e71e55abe14c97ddc62"}, + {file = "pydantic-1.10.12-cp311-cp311-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2031de0967c279df0d8a1c72b4ffc411ecd06bac607a212892757db7462fc494"}, + {file = "pydantic-1.10.12-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:ba5b2e6fe6ca2b7e013398bc7d7b170e21cce322d266ffcd57cca313e54fb246"}, + {file = "pydantic-1.10.12-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:2a7bac939fa326db1ab741c9d7f44c565a1d1e80908b3797f7f81a4f86bc8d33"}, + {file = "pydantic-1.10.12-cp311-cp311-win_amd64.whl", hash = "sha256:87afda5539d5140cb8ba9e8b8c8865cb5b1463924d38490d73d3ccfd80896b3f"}, + {file = "pydantic-1.10.12-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:549a8e3d81df0a85226963611950b12d2d334f214436a19537b2efed61b7639a"}, + {file = "pydantic-1.10.12-cp37-cp37m-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:598da88dfa127b666852bef6d0d796573a8cf5009ffd62104094a4fe39599565"}, + {file = "pydantic-1.10.12-cp37-cp37m-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:ba5c4a8552bff16c61882db58544116d021d0b31ee7c66958d14cf386a5b5350"}, + {file = "pydantic-1.10.12-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:c79e6a11a07da7374f46970410b41d5e266f7f38f6a17a9c4823db80dadf4303"}, + {file = "pydantic-1.10.12-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:ab26038b8375581dc832a63c948f261ae0aa21f1d34c1293469f135fa92972a5"}, + {file = "pydantic-1.10.12-cp37-cp37m-win_amd64.whl", hash = "sha256:e0a16d274b588767602b7646fa05af2782576a6cf1022f4ba74cbb4db66f6ca8"}, + {file = "pydantic-1.10.12-cp38-cp38-macosx_10_9_x86_64.whl", hash = "sha256:6a9dfa722316f4acf4460afdf5d41d5246a80e249c7ff475c43a3a1e9d75cf62"}, + {file = "pydantic-1.10.12-cp38-cp38-macosx_11_0_arm64.whl", hash = "sha256:a73f489aebd0c2121ed974054cb2759af8a9f747de120acd2c3394cf84176ccb"}, + {file = "pydantic-1.10.12-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6b30bcb8cbfccfcf02acb8f1a261143fab622831d9c0989707e0e659f77a18e0"}, + {file = "pydantic-1.10.12-cp38-cp38-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:2fcfb5296d7877af406ba1547dfde9943b1256d8928732267e2653c26938cd9c"}, + {file = "pydantic-1.10.12-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:2f9a6fab5f82ada41d56b0602606a5506aab165ca54e52bc4545028382ef1c5d"}, + {file = "pydantic-1.10.12-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:dea7adcc33d5d105896401a1f37d56b47d443a2b2605ff8a969a0ed5543f7e33"}, + {file = "pydantic-1.10.12-cp38-cp38-win_amd64.whl", hash = "sha256:1eb2085c13bce1612da8537b2d90f549c8cbb05c67e8f22854e201bde5d98a47"}, + {file = "pydantic-1.10.12-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:ef6c96b2baa2100ec91a4b428f80d8f28a3c9e53568219b6c298c1125572ebc6"}, + {file = "pydantic-1.10.12-cp39-cp39-macosx_11_0_arm64.whl", hash = "sha256:6c076be61cd0177a8433c0adcb03475baf4ee91edf5a4e550161ad57fc90f523"}, + {file = "pydantic-1.10.12-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:2d5a58feb9a39f481eda4d5ca220aa8b9d4f21a41274760b9bc66bfd72595b86"}, + {file = "pydantic-1.10.12-cp39-cp39-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:e5f805d2d5d0a41633651a73fa4ecdd0b3d7a49de4ec3fadf062fe16501ddbf1"}, + {file = "pydantic-1.10.12-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:1289c180abd4bd4555bb927c42ee42abc3aee02b0fb2d1223fb7c6e5bef87dbe"}, + {file = "pydantic-1.10.12-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:5d1197e462e0364906cbc19681605cb7c036f2475c899b6f296104ad42b9f5fb"}, + {file = "pydantic-1.10.12-cp39-cp39-win_amd64.whl", hash = "sha256:fdbdd1d630195689f325c9ef1a12900524dceb503b00a987663ff4f58669b93d"}, + {file = "pydantic-1.10.12-py3-none-any.whl", hash = "sha256:b749a43aa51e32839c9d71dc67eb1e4221bb04af1033a32e3923d46f9effa942"}, + {file = "pydantic-1.10.12.tar.gz", hash = "sha256:0fe8a415cea8f340e7a9af9c54fc71a649b43e8ca3cc732986116b3cb135d303"}, ] [package.dependencies] @@ -4633,7 +4622,6 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, - {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -4641,15 +4629,8 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, - {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, - {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, - {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, - {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, - {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, - {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -4666,7 +4647,6 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, - {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -4674,7 +4654,6 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, - {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, @@ -5864,7 +5843,7 @@ files = [ ] [package.dependencies] -greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"aarch64\" or platform_machine == \"ppc64le\" or platform_machine == \"x86_64\" or platform_machine == \"amd64\" or platform_machine == \"AMD64\" or platform_machine == \"win32\" or platform_machine == \"WIN32\" or extra == \"asyncio\""} +greenlet = {version = "!=0.4.17", optional = true, markers = "platform_machine == \"win32\" or platform_machine == \"WIN32\" or platform_machine == \"AMD64\" or platform_machine == \"amd64\" or platform_machine == \"x86_64\" or platform_machine == \"ppc64le\" or platform_machine == \"aarch64\" or extra == \"asyncio\""} typing-extensions = ">=4.2.0" [package.extras] @@ -7170,7 +7149,7 @@ files = [ name = "yarl" version = "1.9.2" description = "Yet another URL library" -optional = false +optional = true python-versions = ">=3.7" files = [ {file = "yarl-1.9.2-cp310-cp310-macosx_10_9_universal2.whl", hash = "sha256:8c2ad583743d16ddbdf6bb14b5cd76bf43b0d0006e918809d5d4ddf7bde8dd82"}, @@ -7269,6 +7248,7 @@ docs = ["furo", "jaraco.packaging (>=9.3)", "jaraco.tidelift (>=1.4)", "rst.link testing = ["big-O", "jaraco.functools", "jaraco.itertools", "more-itertools", "pytest (>=6)", "pytest-black (>=0.3.7)", "pytest-checkdocs (>=2.4)", "pytest-cov", "pytest-enabler (>=2.2)", "pytest-ignore-flaky", "pytest-mypy (>=0.9.1)", "pytest-ruff"] [extras] +langchain = ["langchain"] local-models = ["optimum", "sentencepiece", "transformers"] postgres = ["asyncpg", "pgvector", "psycopg-binary"] query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "scikit-learn", "spacy"] @@ -7276,4 +7256,4 @@ query-tools = ["guidance", "jsonpath-ng", "lm-format-enforcer", "rank-bm25", "sc [metadata] lock-version = "2.0" python-versions = ">=3.8.1,<3.12" -content-hash = "59e20be47c4fa4d4f3a2dce6b2d1ec41e543a92918ca6aba7e68e56e9a4475fa" +content-hash = "f21da92d62393a9ec63be07b4e6563a8a4216567b6435686c468b2b236b4a411" diff --git a/pyproject.toml b/pyproject.toml index 4a5eb34473ddb8027cc9434a810945e53f647bcb..b1cb0071c31b3fe0bf86dce39be015dd1e2c15ba 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -38,14 +38,16 @@ name = "llama-index" packages = [{include = "llama_index"}] readme = "README.md" repository = "https://github.com/run-llama/llama_index" -version = "0.8.69.post2" +version = "0.9.0" [tool.poetry.dependencies] SQLAlchemy = {extras = ["asyncio"], version = ">=1.4.49"} +beautifulsoup4 = "^4.12.2" dataclasses-json = "^0.5.7" deprecated = ">=1.2.9.3" fsspec = ">=2023.5.0" -langchain = ">=0.0.303" +httpx = "*" +langchain = {optional = true, version = ">=0.0.303"} nest-asyncio = "^1.5.8" nltk = "^3.8.1" numpy = "*" @@ -72,6 +74,9 @@ spacy = {optional = true, version = "^3.7.1"} aiostream = "^0.5.2" [tool.poetry.extras] +langchain = [ + "langchain", +] local_models = [ "optimum", "sentencepiece", @@ -92,7 +97,6 @@ query_tools = [ ] [tool.poetry.group.dev.dependencies] -beautifulsoup4 = "^4.12.2" # needed for tests black = {extras = ["jupyter"], version = "<=23.9.1,>=23.7.0"} codespell = {extras = ["toml"], version = ">=v2.2.6"} google-generativeai = {python = ">=3.9,<3.12", version = "^0.2.1"} diff --git a/tests/chat_engine/test_condense_question.py b/tests/chat_engine/test_condense_question.py index b67f1ba2b8d5118eaab952d3344aaaa114409010..349fe686b328703c39651c6bf9adc0de7e6b43ca 100644 --- a/tests/chat_engine/test_condense_question.py +++ b/tests/chat_engine/test_condense_question.py @@ -1,10 +1,10 @@ from unittest.mock import Mock from llama_index.chat_engine.condense_question import CondenseQuestionChatEngine -from llama_index.indices.query.base import BaseQueryEngine -from llama_index.indices.service_context import ServiceContext +from llama_index.core import BaseQueryEngine from llama_index.llms.base import ChatMessage, MessageRole from llama_index.response.schema import Response +from llama_index.service_context import ServiceContext def test_condense_question_chat_engine( diff --git a/tests/chat_engine/test_simple.py b/tests/chat_engine/test_simple.py index c9d0b4241939f609d6eb976e35b20bf8d4cdffee..e84fcdbddebc57692ce664a39667e98b279bda5c 100644 --- a/tests/chat_engine/test_simple.py +++ b/tests/chat_engine/test_simple.py @@ -1,6 +1,6 @@ from llama_index.chat_engine.simple import SimpleChatEngine -from llama_index.indices.service_context import ServiceContext from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.service_context import ServiceContext def test_simple_chat_engine( diff --git a/tests/conftest.py b/tests/conftest.py index 9bdf83dabc734c5a47560138ddf51e78ce02a20d..5791c4d5d1ec2038cfe686af1ae4050bcbc3ea61 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -1,15 +1,15 @@ import os # import socket -from typing import Any, Optional +from typing import Any, List, Optional import openai import pytest -from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor.base import LLMPredictor from llama_index.llms.base import LLMMetadata from llama_index.llms.mock import MockLLM -from llama_index.text_splitter import SentenceSplitter, TokenTextSplitter +from llama_index.node_parser.text import SentenceSplitter, TokenTextSplitter +from llama_index.service_context import ServiceContext from tests.indices.vector_store.mock_services import MockEmbedding from tests.mock_utils.mock_predict import ( @@ -35,7 +35,9 @@ def allow_networking(monkeypatch: pytest.MonkeyPatch) -> None: def patch_token_text_splitter(monkeypatch: pytest.MonkeyPatch) -> None: monkeypatch.setattr(SentenceSplitter, "split_text", patch_token_splitter_newline) monkeypatch.setattr( - SentenceSplitter, "split_text_metadata_aware", patch_token_splitter_newline + SentenceSplitter, + "split_text_metadata_aware", + patch_token_splitter_newline, ) monkeypatch.setattr(TokenTextSplitter, "split_text", patch_token_splitter_newline) monkeypatch.setattr( @@ -123,3 +125,30 @@ class CachedOpenAIApiKeys: def __exit__(self, *exc: object) -> None: os.environ["OPENAI_API_KEY"] = str(self.api_env_variable_was) os.environ["OPENAI_API_TYPE"] = str(self.api_env_type_was) + openai.api_key = self.openai_api_key_was + openai.api_type = self.openai_api_type_was + + +def pytest_addoption(parser: pytest.Parser) -> None: + parser.addoption( + "--integration", + action="store_true", + default=False, + help="run integration tests", + ) + + +def pytest_configure(config: pytest.Config) -> None: + config.addinivalue_line("markers", "integration: mark test as integration") + + +def pytest_collection_modifyitems( + config: pytest.Config, items: List[pytest.Item] +) -> None: + if config.getoption("--integration"): + # --integration given in cli: do not skip integration tests + return + skip_integration = pytest.mark.skip(reason="need --integration option to run") + for item in items: + if "integration" in item.keywords: + item.add_marker(skip_integration) diff --git a/tests/evaluation/test_dataset_generation.py b/tests/evaluation/test_dataset_generation.py index d22b39a63152982267e331b9ef0244f5e31a314f..1bdd3c6d85abf7356e7857e935997c874842dc8b 100644 --- a/tests/evaluation/test_dataset_generation.py +++ b/tests/evaluation/test_dataset_generation.py @@ -1,10 +1,10 @@ """Test dataset generation.""" from llama_index.evaluation.dataset_generation import DatasetGenerator -from llama_index.indices.service_context import ServiceContext from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.schema import TextNode +from llama_index.service_context import ServiceContext def test_dataset_generation( diff --git a/tests/indices/document_summary/conftest.py b/tests/indices/document_summary/conftest.py index 76e0dda378dbf0e05f3c83efe0c17f59b1a69eac..3a05f996ccd26cb126f69ee0deea6317e7a48cfb 100644 --- a/tests/indices/document_summary/conftest.py +++ b/tests/indices/document_summary/conftest.py @@ -2,9 +2,9 @@ from typing import List import pytest from llama_index.indices.document_summary.base import DocumentSummaryIndex -from llama_index.indices.service_context import ServiceContext from llama_index.response_synthesizers import get_response_synthesizer from llama_index.schema import Document +from llama_index.service_context import ServiceContext from tests.mock_utils.mock_prompts import MOCK_REFINE_PROMPT, MOCK_TEXT_QA_PROMPT diff --git a/tests/indices/empty/test_base.py b/tests/indices/empty/test_base.py index b851964f5e240f402f5fb5343519efba410b5d78..1f0b1ba1f8a37ace7e06040bb428ace5ede2528e 100644 --- a/tests/indices/empty/test_base.py +++ b/tests/indices/empty/test_base.py @@ -2,7 +2,7 @@ from llama_index.data_structs.data_structs import EmptyIndexStruct from llama_index.indices.empty.base import EmptyIndex -from llama_index.indices.service_context import ServiceContext +from llama_index.service_context import ServiceContext def test_empty( diff --git a/tests/indices/keyword_table/test_base.py b/tests/indices/keyword_table/test_base.py index f8033186703ff3ae3410844aa48e461ea2a39430..bcd8f6a82c3e30bdd92aec7ad36541362ad2ab38 100644 --- a/tests/indices/keyword_table/test_base.py +++ b/tests/indices/keyword_table/test_base.py @@ -5,8 +5,8 @@ from unittest.mock import patch import pytest from llama_index.indices.keyword_table.simple_base import SimpleKeywordTableIndex -from llama_index.indices.service_context import ServiceContext from llama_index.schema import Document +from llama_index.service_context import ServiceContext from tests.mock_utils.mock_utils import mock_extract_keywords diff --git a/tests/indices/keyword_table/test_retrievers.py b/tests/indices/keyword_table/test_retrievers.py index 4cc81ecd327b929a250a2d0cdd797a0e458b398b..1b05cf79904186bac10e0da568ef4a2b037e94f6 100644 --- a/tests/indices/keyword_table/test_retrievers.py +++ b/tests/indices/keyword_table/test_retrievers.py @@ -2,9 +2,8 @@ from typing import List from unittest.mock import patch from llama_index.indices.keyword_table.simple_base import SimpleKeywordTableIndex -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext -from llama_index.schema import Document +from llama_index.schema import Document, QueryBundle +from llama_index.service_context import ServiceContext from tests.mock_utils.mock_utils import mock_extract_keywords diff --git a/tests/indices/knowledge_graph/test_base.py b/tests/indices/knowledge_graph/test_base.py index fb0b9a008f947aeb863d69daaec6ea6f2b4c972a..ec60c5388cf8654aab7b606cb50936d46e9cfa98 100644 --- a/tests/indices/knowledge_graph/test_base.py +++ b/tests/indices/knowledge_graph/test_base.py @@ -6,8 +6,8 @@ from unittest.mock import patch import pytest from llama_index.embeddings.base import BaseEmbedding from llama_index.indices.knowledge_graph.base import KnowledgeGraphIndex -from llama_index.indices.service_context import ServiceContext from llama_index.schema import Document, TextNode +from llama_index.service_context import ServiceContext from tests.mock_utils.mock_prompts import ( MOCK_KG_TRIPLET_EXTRACT_PROMPT, diff --git a/tests/indices/knowledge_graph/test_retrievers.py b/tests/indices/knowledge_graph/test_retrievers.py index c0e782c8c63f370d0d6e1907f371704dca7a2163..9260ec64dd4f75bd8052864bc677b72d8c8efe4f 100644 --- a/tests/indices/knowledge_graph/test_retrievers.py +++ b/tests/indices/knowledge_graph/test_retrievers.py @@ -4,9 +4,8 @@ from unittest.mock import patch from llama_index.graph_stores import SimpleGraphStore from llama_index.indices.knowledge_graph.base import KnowledgeGraphIndex from llama_index.indices.knowledge_graph.retrievers import KGTableRetriever -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext -from llama_index.schema import Document +from llama_index.schema import Document, QueryBundle +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext from tests.indices.knowledge_graph.test_base import MockEmbedding, mock_extract_triplets diff --git a/tests/indices/list/test_index.py b/tests/indices/list/test_index.py index b96023f986eb043cae35d49e7051797be68bee96..3ff7499906c7773332d76b5375c0dd02f38c02e7 100644 --- a/tests/indices/list/test_index.py +++ b/tests/indices/list/test_index.py @@ -2,10 +2,10 @@ from typing import Dict, List, Tuple -from llama_index.indices.base_retriever import BaseRetriever +from llama_index.core import BaseRetriever from llama_index.indices.list.base import ListRetrieverMode, SummaryIndex -from llama_index.indices.service_context import ServiceContext from llama_index.schema import BaseNode, Document +from llama_index.service_context import ServiceContext def test_build_list( diff --git a/tests/indices/list/test_retrievers.py b/tests/indices/list/test_retrievers.py index aa8b595c53bf53c04f6f14e3f1b3a10ab89daad6..5fcbb38200a2e631f4d298e4d54f7fcd24ca4835 100644 --- a/tests/indices/list/test_retrievers.py +++ b/tests/indices/list/test_retrievers.py @@ -3,10 +3,10 @@ from unittest.mock import patch from llama_index.indices.list.base import SummaryIndex from llama_index.indices.list.retrievers import SummaryIndexEmbeddingRetriever -from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor.base import LLMPredictor from llama_index.prompts import BasePromptTemplate from llama_index.schema import Document +from llama_index.service_context import ServiceContext from tests.indices.list.test_index import _get_embeddings diff --git a/tests/indices/query/query_transform/test_base.py b/tests/indices/query/query_transform/test_base.py index 64d32bdbc6941139537fec28019eabc795912a00..438acd468288b1b96ce256b3620cf85691996903 100644 --- a/tests/indices/query/query_transform/test_base.py +++ b/tests/indices/query/query_transform/test_base.py @@ -2,7 +2,7 @@ from llama_index.indices.query.query_transform.base import DecomposeQueryTransform -from llama_index.indices.service_context import ServiceContext +from llama_index.service_context import ServiceContext from tests.indices.query.query_transform.mock_utils import MOCK_DECOMPOSE_PROMPT diff --git a/tests/indices/query/test_compose.py b/tests/indices/query/test_compose.py index 4bda49729475c23d7fc14207a694fc7a7ef0d2c3..207d09bfb6e37223310c94115d0652e4bfd8ec2b 100644 --- a/tests/indices/query/test_compose.py +++ b/tests/indices/query/test_compose.py @@ -5,9 +5,9 @@ from typing import Dict, List from llama_index.indices.composability.graph import ComposableGraph from llama_index.indices.keyword_table.simple_base import SimpleKeywordTableIndex from llama_index.indices.list.base import SummaryIndex -from llama_index.indices.service_context import ServiceContext from llama_index.indices.tree.base import TreeIndex from llama_index.schema import Document +from llama_index.service_context import ServiceContext def test_recursive_query_list_tree( diff --git a/tests/indices/query/test_compose_vector.py b/tests/indices/query/test_compose_vector.py index 09988b4a990524f63d971b9a4f9df933982054c5..422327598b8c8c9ad8bd553f065eaf959434e5af 100644 --- a/tests/indices/query/test_compose_vector.py +++ b/tests/indices/query/test_compose_vector.py @@ -8,9 +8,9 @@ from llama_index.data_structs.data_structs import IndexStruct from llama_index.embeddings.base import BaseEmbedding from llama_index.indices.composability.graph import ComposableGraph from llama_index.indices.keyword_table.simple_base import SimpleKeywordTableIndex -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.schema import Document +from llama_index.service_context import ServiceContext from tests.indices.vector_store.utils import get_pinecone_storage_context from tests.mock_utils.mock_prompts import MOCK_QUERY_KEYWORD_EXTRACT_PROMPT diff --git a/tests/indices/query/test_query_bundle.py b/tests/indices/query/test_query_bundle.py index 04de984e25a98445e647e9db1c8083329e118bc3..2ee0c22876eeb9c858d88b72e453b713562e0886 100644 --- a/tests/indices/query/test_query_bundle.py +++ b/tests/indices/query/test_query_bundle.py @@ -5,9 +5,8 @@ from typing import Dict, List import pytest from llama_index.embeddings.base import BaseEmbedding from llama_index.indices.list.base import SummaryIndex -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext -from llama_index.schema import Document +from llama_index.schema import Document, QueryBundle +from llama_index.service_context import ServiceContext @pytest.fixture() diff --git a/tests/indices/response/test_response_builder.py b/tests/indices/response/test_response_builder.py index 316a4ceb659489f82d19dccd5eec88659b8b535e..a90e7aa4802fa8851844f70664fc47c53486666d 100644 --- a/tests/indices/response/test_response_builder.py +++ b/tests/indices/response/test_response_builder.py @@ -5,11 +5,11 @@ from typing import List from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.indices.prompt_helper import PromptHelper -from llama_index.indices.service_context import ServiceContext from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.response_synthesizers import ResponseMode, get_response_synthesizer from llama_index.schema import Document +from llama_index.service_context import ServiceContext from tests.indices.vector_store.mock_services import MockEmbedding from tests.mock_utils.mock_prompts import MOCK_REFINE_PROMPT, MOCK_TEXT_QA_PROMPT from tests.mock_utils.mock_utils import mock_tokenizer diff --git a/tests/indices/response/test_tree_summarize.py b/tests/indices/response/test_tree_summarize.py index a50b3f108bc0441aa48b8a8275189abea407dda4..7f28cab21828da4a5868cd4dc1b699e86b6bc8ee 100644 --- a/tests/indices/response/test_tree_summarize.py +++ b/tests/indices/response/test_tree_summarize.py @@ -5,10 +5,10 @@ from unittest.mock import Mock import pytest from llama_index.indices.prompt_helper import PromptHelper -from llama_index.indices.service_context import ServiceContext from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.response_synthesizers import TreeSummarize +from llama_index.service_context import ServiceContext from pydantic import BaseModel diff --git a/tests/indices/struct_store/test_base.py b/tests/indices/struct_store/test_base.py index 828562734dde661dbaab6422594e53c284469b6f..1e2d9febb25f103b6eef7b167d0409a6a5fb80bb 100644 --- a/tests/indices/struct_store/test_base.py +++ b/tests/indices/struct_store/test_base.py @@ -3,8 +3,6 @@ from typing import Any, Dict, List, Tuple from llama_index.indices.list.base import SummaryIndex -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.indices.struct_store.sql import ( SQLContextContainerBuilder, SQLStructStoreIndex, @@ -14,9 +12,11 @@ from llama_index.schema import ( BaseNode, Document, NodeRelationship, + QueryBundle, RelatedNodeInfo, TextNode, ) +from llama_index.service_context import ServiceContext from llama_index.utilities.sql_wrapper import SQLDatabase from sqlalchemy import ( Column, diff --git a/tests/indices/struct_store/test_json_query.py b/tests/indices/struct_store/test_json_query.py index ad4dd7dd47b9893e78ccc6e17e7318c595b7a45a..3b1bc4757cae778dee7939404707be712fab9b2d 100644 --- a/tests/indices/struct_store/test_json_query.py +++ b/tests/indices/struct_store/test_json_query.py @@ -6,10 +6,10 @@ from typing import Any, Dict, Generator, cast from unittest.mock import AsyncMock, MagicMock, patch import pytest -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.indices.struct_store.json_query import JSONQueryEngine, JSONType from llama_index.response.schema import Response +from llama_index.schema import QueryBundle +from llama_index.service_context import ServiceContext TEST_PARAMS = [ # synthesize_response, call_apredict diff --git a/tests/indices/struct_store/test_sql_query.py b/tests/indices/struct_store/test_sql_query.py index 2f6c4d8566b38fb4c69230856fce94217db008c0..77ec585df15134fe48282afd7118d2efd3c16f04 100644 --- a/tests/indices/struct_store/test_sql_query.py +++ b/tests/indices/struct_store/test_sql_query.py @@ -2,7 +2,6 @@ import asyncio from typing import Any, Dict, Tuple import pytest -from llama_index.indices.service_context import ServiceContext from llama_index.indices.struct_store.base import default_output_parser from llama_index.indices.struct_store.sql import SQLStructStoreIndex from llama_index.indices.struct_store.sql_query import ( @@ -11,6 +10,7 @@ from llama_index.indices.struct_store.sql_query import ( SQLStructStoreQueryEngine, ) from llama_index.schema import Document +from llama_index.service_context import ServiceContext from llama_index.utilities.sql_wrapper import SQLDatabase from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine from sqlalchemy.exc import OperationalError diff --git a/tests/indices/test_loading.py b/tests/indices/test_loading.py index 038813581e1e32f8aba88937c92c8642180ad526..ec239c6b12dbd6716e08db242e630de6a861d4b5 100644 --- a/tests/indices/test_loading.py +++ b/tests/indices/test_loading.py @@ -7,10 +7,10 @@ from llama_index.indices.loading import ( load_index_from_storage, load_indices_from_storage, ) -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine from llama_index.schema import BaseNode, Document +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.simple_docstore import SimpleDocumentStore from llama_index.storage.index_store.simple_index_store import SimpleIndexStore from llama_index.storage.storage_context import StorageContext diff --git a/tests/indices/test_loading_graph.py b/tests/indices/test_loading_graph.py index cc661c78e63aae24848874fc55e2ee5060c5df8b..0fd59d3cb5c427ce9492c1c6e0e08b332fab2ae5 100644 --- a/tests/indices/test_loading_graph.py +++ b/tests/indices/test_loading_graph.py @@ -4,9 +4,9 @@ from typing import List from llama_index.indices.composability.graph import ComposableGraph from llama_index.indices.list.base import SummaryIndex from llama_index.indices.loading import load_graph_from_storage -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.schema import Document +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext diff --git a/tests/indices/test_node_utils.py b/tests/indices/test_node_utils.py deleted file mode 100644 index 0e4d52e3a06a38d757da09d0243a7836bab445d7..0000000000000000000000000000000000000000 --- a/tests/indices/test_node_utils.py +++ /dev/null @@ -1,132 +0,0 @@ -"""Test node utils.""" - -from typing import List - -import pytest -import tiktoken -from llama_index.bridge.langchain import RecursiveCharacterTextSplitter -from llama_index.node_parser.node_utils import get_nodes_from_document -from llama_index.schema import ( - Document, - MetadataMode, - NodeRelationship, - ObjectType, - RelatedNodeInfo, -) -from llama_index.text_splitter import TokenTextSplitter - - -@pytest.fixture() -def text_splitter() -> TokenTextSplitter: - """Get text splitter.""" - return TokenTextSplitter(chunk_size=20, chunk_overlap=0) - - -@pytest.fixture() -def documents() -> List[Document]: - """Get documents.""" - # NOTE: one document for now - doc_text = ( - "Hello world.\n" - "This is a test.\n" - "This is another test.\n" - "This is a test v2." - ) - return [ - Document(text=doc_text, id_="test_doc_id", metadata={"test_key": "test_val"}) - ] - - -def test_get_nodes_from_document( - documents: List[Document], text_splitter: TokenTextSplitter -) -> None: - """Test get nodes from document have desired chunk size.""" - nodes = get_nodes_from_document( - documents[0], - text_splitter, - include_metadata=False, - ) - assert len(nodes) == 2 - actual_chunk_sizes = [ - len(text_splitter.tokenizer(node.get_content())) for node in nodes - ] - assert all( - chunk_size <= text_splitter.chunk_size for chunk_size in actual_chunk_sizes - ) - - -def test_get_nodes_from_document_with_metadata( - documents: List[Document], text_splitter: TokenTextSplitter -) -> None: - """Test get nodes from document with metadata have desired chunk size.""" - nodes = get_nodes_from_document( - documents[0], - text_splitter, - include_metadata=True, - ) - assert len(nodes) == 3 - actual_chunk_sizes = [ - len(text_splitter.tokenizer(node.get_content(metadata_mode=MetadataMode.ALL))) - for node in nodes - ] - assert all( - chunk_size <= text_splitter.chunk_size for chunk_size in actual_chunk_sizes - ) - assert all( - "test_key: test_val" in n.get_content(metadata_mode=MetadataMode.ALL) - for n in nodes - ) - - -def test_get_nodes_from_document_with_node_relationship( - documents: List[Document], text_splitter: TokenTextSplitter -) -> None: - """Test get nodes from document with node relationship have desired chunk size.""" - nodes = get_nodes_from_document( - documents[0], - text_splitter, - include_prev_next_rel=True, - ) - assert len(nodes) == 3 - - # check whether relationship.node_type is set properly - rel_node = nodes[0].relationships[NodeRelationship.SOURCE] - if isinstance(rel_node, RelatedNodeInfo): - assert rel_node.node_type == ObjectType.DOCUMENT - rel_node = nodes[0].relationships[NodeRelationship.NEXT] - if isinstance(rel_node, RelatedNodeInfo): - assert rel_node.node_type == ObjectType.TEXT - - # check whether next and provious node_id is set properly - rel_node = nodes[0].relationships[NodeRelationship.NEXT] - if isinstance(rel_node, RelatedNodeInfo): - assert rel_node.node_id == nodes[1].node_id - rel_node = nodes[1].relationships[NodeRelationship.NEXT] - if isinstance(rel_node, RelatedNodeInfo): - assert rel_node.node_id == nodes[2].node_id - rel_node = nodes[1].relationships[NodeRelationship.PREVIOUS] - if isinstance(rel_node, RelatedNodeInfo): - assert rel_node.node_id == nodes[0].node_id - rel_node = nodes[2].relationships[NodeRelationship.PREVIOUS] - if isinstance(rel_node, RelatedNodeInfo): - assert rel_node.node_id == nodes[1].node_id - - -def test_get_nodes_from_document_langchain_compatible( - documents: List[Document], -) -> None: - """Test get nodes from document have desired chunk size.""" - tokenizer = tiktoken.get_encoding("gpt2").encode - text_splitter = RecursiveCharacterTextSplitter.from_tiktoken_encoder( - chunk_size=20, chunk_overlap=0 - ) - nodes = get_nodes_from_document( - documents[0], - text_splitter, # type: ignore - include_metadata=False, - ) - assert len(nodes) == 2 - actual_chunk_sizes = [len(tokenizer(node.get_content())) for node in nodes] - assert all( - chunk_size <= text_splitter._chunk_size for chunk_size in actual_chunk_sizes - ) diff --git a/tests/indices/test_prompt_helper.py b/tests/indices/test_prompt_helper.py index 842d1e2fbe0641ce728d2501bc48dcb61b99b75b..c9a2c4e3ad1641cd1c89fbedd93b2258b1023718 100644 --- a/tests/indices/test_prompt_helper.py +++ b/tests/indices/test_prompt_helper.py @@ -5,10 +5,10 @@ from typing import Optional, Type, Union import pytest from llama_index.indices.prompt_helper import PromptHelper from llama_index.indices.tree.utils import get_numbered_text_from_nodes +from llama_index.node_parser.text.utils import truncate_text from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_utils import get_biggest_prompt, get_empty_prompt_txt from llama_index.schema import TextNode -from llama_index.text_splitter.utils import truncate_text from tests.mock_utils.mock_utils import mock_tokenizer diff --git a/tests/indices/test_service_context.py b/tests/indices/test_service_context.py index 8fdf61acdf174f5a3ea626fc236f818245ae8197..175d08a442d437339e6ff4b1533ad28871baf6cd 100644 --- a/tests/indices/test_service_context.py +++ b/tests/indices/test_service_context.py @@ -1,31 +1,28 @@ -from llama_index.indices.prompt_helper import PromptHelper -from llama_index.indices.service_context import ServiceContext -from llama_index.llms import MockLLM -from llama_index.node_parser import SimpleNodeParser -from llama_index.node_parser.extractors import ( - MetadataExtractor, +from typing import List + +from llama_index.extractors import ( QuestionsAnsweredExtractor, SummaryExtractor, TitleExtractor, ) -from llama_index.text_splitter import TokenTextSplitter +from llama_index.indices.prompt_helper import PromptHelper +from llama_index.llms import MockLLM +from llama_index.node_parser import SentenceSplitter +from llama_index.schema import TransformComponent +from llama_index.service_context import ServiceContext from llama_index.token_counter.mock_embed_model import MockEmbedding def test_service_context_serialize() -> None: - text_splitter = TokenTextSplitter(chunk_size=1, chunk_overlap=0) + extractors: List[TransformComponent] = [ + SummaryExtractor(), + QuestionsAnsweredExtractor(), + TitleExtractor(), + ] - metadata_extractor = MetadataExtractor( - extractors=[ - SummaryExtractor(), - QuestionsAnsweredExtractor(), - TitleExtractor(), - ] - ) + node_parser = SentenceSplitter(chunk_size=1, chunk_overlap=0) - node_parser = SimpleNodeParser.from_defaults( - text_splitter=text_splitter, metadata_extractor=metadata_extractor - ) + transformations: List[TransformComponent] = [node_parser, *extractors] llm = MockLLM(max_tokens=1) embed_model = MockEmbedding(embed_dim=1) @@ -35,7 +32,7 @@ def test_service_context_serialize() -> None: service_context = ServiceContext.from_defaults( llm=llm, embed_model=embed_model, - node_parser=node_parser, + transformations=transformations, prompt_helper=prompt_helper, ) @@ -43,23 +40,17 @@ def test_service_context_serialize() -> None: assert service_context_dict["llm"]["max_tokens"] == 1 assert service_context_dict["embed_model"]["embed_dim"] == 1 - assert service_context_dict["text_splitter"]["chunk_size"] == 1 - assert len(service_context_dict["extractors"]) == 3 assert service_context_dict["prompt_helper"]["context_window"] == 1 loaded_service_context = ServiceContext.from_dict(service_context_dict) assert isinstance(loaded_service_context.llm, MockLLM) assert isinstance(loaded_service_context.embed_model, MockEmbedding) - assert isinstance(loaded_service_context.node_parser, SimpleNodeParser) + assert isinstance(loaded_service_context.transformations[0], SentenceSplitter) assert isinstance(loaded_service_context.prompt_helper, PromptHelper) - assert isinstance( - loaded_service_context.node_parser.text_splitter, TokenTextSplitter - ) - assert loaded_service_context.node_parser.metadata_extractor is not None - assert len(loaded_service_context.node_parser.metadata_extractor.extractors) == 3 - assert loaded_service_context.node_parser.text_splitter.chunk_size == 1 + assert len(loaded_service_context.transformations) == 4 + assert loaded_service_context.transformations[0].chunk_size == 1 assert loaded_service_context.prompt_helper.context_window == 1 assert loaded_service_context.llm.max_tokens == 1 assert loaded_service_context.embed_model.embed_dim == 1 diff --git a/tests/indices/tree/test_embedding_retriever.py b/tests/indices/tree/test_embedding_retriever.py index 39a4905dde8ce4abf4be6519a6f8209a5bc83af9..fed842aea37a194c4f115911b42a59b8f1b5b678 100644 --- a/tests/indices/tree/test_embedding_retriever.py +++ b/tests/indices/tree/test_embedding_retriever.py @@ -5,13 +5,12 @@ from typing import Any, Dict, List from unittest.mock import patch import pytest -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.indices.tree.base import TreeIndex from llama_index.indices.tree.select_leaf_embedding_retriever import ( TreeSelectLeafEmbeddingRetriever, ) -from llama_index.schema import BaseNode, Document +from llama_index.schema import BaseNode, Document, QueryBundle +from llama_index.service_context import ServiceContext from tests.mock_utils.mock_prompts import ( MOCK_INSERT_PROMPT, diff --git a/tests/indices/tree/test_index.py b/tests/indices/tree/test_index.py index 629ca8d8d159e52473f2e1029990cbc2c547dc73..baa06e2da171785373d7be507baddad69c7a6f3f 100644 --- a/tests/indices/tree/test_index.py +++ b/tests/indices/tree/test_index.py @@ -4,9 +4,9 @@ from typing import Any, Dict, List, Optional from unittest.mock import patch from llama_index.data_structs.data_structs import IndexGraph -from llama_index.indices.service_context import ServiceContext from llama_index.indices.tree.base import TreeIndex from llama_index.schema import BaseNode, Document +from llama_index.service_context import ServiceContext from llama_index.storage.docstore import BaseDocumentStore diff --git a/tests/indices/tree/test_retrievers.py b/tests/indices/tree/test_retrievers.py index a7fc76087fa1cf396349df59a3fe6bf54bcbe74a..4f4ca5c60d3342687879ac8895a8f6efd99b088c 100644 --- a/tests/indices/tree/test_retrievers.py +++ b/tests/indices/tree/test_retrievers.py @@ -1,8 +1,8 @@ from typing import Dict, List -from llama_index.indices.service_context import ServiceContext from llama_index.indices.tree.base import TreeIndex from llama_index.schema import Document +from llama_index.service_context import ServiceContext def test_query( diff --git a/tests/indices/vector_store/test_deeplake.py b/tests/indices/vector_store/test_deeplake.py index ceed9a28aa589eade68ed4101e8f3d5d04dc43e5..b46a844d1aaf921684265da80cba6ffebd657a9e 100644 --- a/tests/indices/vector_store/test_deeplake.py +++ b/tests/indices/vector_store/test_deeplake.py @@ -3,9 +3,9 @@ from typing import List import pytest -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.schema import Document, TextNode +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext from llama_index.vector_stores import DeepLakeVectorStore diff --git a/tests/indices/vector_store/test_faiss.py b/tests/indices/vector_store/test_faiss.py index df18b21969a2624d6e386bcfd91714cb7b67505e..9ca1cdc97d44cffebabedbdfff58177bc467b8c2 100644 --- a/tests/indices/vector_store/test_faiss.py +++ b/tests/indices/vector_store/test_faiss.py @@ -4,9 +4,9 @@ from pathlib import Path from typing import List import pytest -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.schema import Document, TextNode +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext from llama_index.vector_stores.faiss import FaissVectorStore from llama_index.vector_stores.types import VectorStoreQuery diff --git a/tests/indices/vector_store/test_pinecone.py b/tests/indices/vector_store/test_pinecone.py index bbb111e3e1749711f4e2ae5762e177486c5f6677..b7c5b275fbc3e9b8c7c162859e3863acfcaddc11 100644 --- a/tests/indices/vector_store/test_pinecone.py +++ b/tests/indices/vector_store/test_pinecone.py @@ -3,9 +3,9 @@ from typing import List import pytest -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.schema import Document, TextNode +from llama_index.service_context import ServiceContext from tests.indices.vector_store.utils import get_pinecone_storage_context from tests.mock_utils.mock_utils import mock_tokenizer diff --git a/tests/indices/vector_store/test_retrievers.py b/tests/indices/vector_store/test_retrievers.py index a8bc10f42fbf02f1300d180fd4d49f2e9f8edf6a..e0a8d5ea673f858cb6ad34afc0c53a77abba8e88 100644 --- a/tests/indices/vector_store/test_retrievers.py +++ b/tests/indices/vector_store/test_retrievers.py @@ -1,10 +1,15 @@ from typing import List, cast import pytest -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store.base import VectorStoreIndex -from llama_index.schema import Document, NodeRelationship, RelatedNodeInfo, TextNode +from llama_index.schema import ( + Document, + NodeRelationship, + QueryBundle, + RelatedNodeInfo, + TextNode, +) +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext from llama_index.vector_stores.simple import SimpleVectorStore diff --git a/tests/indices/vector_store/test_simple.py b/tests/indices/vector_store/test_simple.py index 2ac5747d6c78bce28dcd24aa98489b1ff220d9c6..2ecb9a8348d0616028407dfb61a03b88ae258a57 100644 --- a/tests/indices/vector_store/test_simple.py +++ b/tests/indices/vector_store/test_simple.py @@ -2,9 +2,9 @@ from typing import Any, List, cast from llama_index.indices.loading import load_index_from_storage -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.schema import Document +from llama_index.service_context import ServiceContext from llama_index.storage.storage_context import StorageContext from llama_index.vector_stores.simple import SimpleVectorStore diff --git a/tests/ingestion/test_cache.py b/tests/ingestion/test_cache.py new file mode 100644 index 0000000000000000000000000000000000000000..b6ebfb0c6cdf21541afc3ebd84b438c20c7a5574 --- /dev/null +++ b/tests/ingestion/test_cache.py @@ -0,0 +1,47 @@ +from typing import Any, List + +from llama_index.ingestion import IngestionCache +from llama_index.ingestion.pipeline import get_transformation_hash +from llama_index.schema import BaseNode, TextNode, TransformComponent + + +class DummyTransform(TransformComponent): + def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]: + for node in nodes: + node.set_content(node.get_content() + "\nTESTTEST") + return nodes + + +def test_cache() -> None: + cache = IngestionCache() + transformation = DummyTransform() + + node = TextNode(text="dummy") + hash = get_transformation_hash([node], transformation) + + new_nodes = transformation([node]) + cache.put(hash, new_nodes) + + cache_hit = cache.get(hash) + assert cache_hit is not None + assert cache_hit[0].get_content() == new_nodes[0].get_content() + + new_hash = get_transformation_hash(new_nodes, transformation) + assert cache.get(new_hash) is None + + +def test_cache_clear() -> None: + cache = IngestionCache() + transformation = DummyTransform() + + node = TextNode(text="dummy") + hash = get_transformation_hash([node], transformation) + + new_nodes = transformation([node]) + cache.put(hash, new_nodes) + + cache_hit = cache.get(hash) + assert cache_hit is not None + + cache.clear() + assert cache.get(hash) is None diff --git a/tests/ingestion/test_pipeline.py b/tests/ingestion/test_pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..9d71fbb6e94835b20d81d5c5d42201142c560f8a --- /dev/null +++ b/tests/ingestion/test_pipeline.py @@ -0,0 +1,41 @@ +from llama_index.embeddings import OpenAIEmbedding +from llama_index.extractors import KeywordExtractor +from llama_index.ingestion.pipeline import IngestionPipeline +from llama_index.llms import MockLLM +from llama_index.node_parser import SentenceSplitter +from llama_index.readers import ReaderConfig, StringIterableReader +from llama_index.schema import Document + + +def test_build_pipeline() -> None: + pipeline = IngestionPipeline( + reader=ReaderConfig( + reader=StringIterableReader(), reader_kwargs={"texts": ["This is a test."]} + ), + documents=[Document.example()], + transformations=[ + SentenceSplitter(), + KeywordExtractor(llm=MockLLM()), + OpenAIEmbedding(api_key="fake"), + ], + ) + + assert len(pipeline.transformations) == 3 + + +def test_run_pipeline() -> None: + pipeline = IngestionPipeline( + reader=ReaderConfig( + reader=StringIterableReader(), reader_kwargs={"texts": ["This is a test."]} + ), + documents=[Document.example()], + transformations=[ + SentenceSplitter(), + KeywordExtractor(llm=MockLLM()), + ], + ) + + nodes = pipeline.run() + + assert len(nodes) == 2 + assert len(nodes[0].metadata) > 0 diff --git a/tests/llms/test_langchain.py b/tests/llms/test_langchain.py index 911bbd4cd3becdad46f17db3c8ff2e7678cc8af4..dae2cd6827db8ad0d6361dbb088f97d71f63dcaa 100644 --- a/tests/llms/test_langchain.py +++ b/tests/llms/test_langchain.py @@ -1,22 +1,33 @@ from typing import List import pytest -from llama_index.bridge.langchain import ( - AIMessage, - BaseMessage, - ChatOpenAI, - Cohere, - FakeListLLM, - FunctionMessage, - HumanMessage, - OpenAI, - SystemMessage, -) from llama_index.llms.base import ChatMessage, MessageRole -from llama_index.llms.langchain import LangChainLLM -from llama_index.llms.langchain_utils import from_lc_messages, to_lc_messages +try: + import cohere +except ImportError: + cohere = None # type: ignore + +try: + import langchain + from llama_index.bridge.langchain import ( + AIMessage, + BaseMessage, + ChatOpenAI, + Cohere, + FakeListLLM, + FunctionMessage, + HumanMessage, + OpenAI, + SystemMessage, + ) + from llama_index.llms.langchain import LangChainLLM + from llama_index.llms.langchain_utils import from_lc_messages, to_lc_messages +except ImportError: + langchain = None # type: ignore + +@pytest.mark.skipif(langchain is None, reason="langchain not installed") def test_basic() -> None: lc_llm = FakeListLLM(responses=["test response 1", "test response 2"]) llm = LangChainLLM(llm=lc_llm) @@ -28,6 +39,7 @@ def test_basic() -> None: llm.chat([message]) +@pytest.mark.skipif(langchain is None, reason="langchain not installed") def test_to_lc_messages() -> None: lc_messages: List[BaseMessage] = [ SystemMessage(content="test system message"), @@ -42,6 +54,7 @@ def test_to_lc_messages() -> None: assert messages[i].content == lc_messages[i].content +@pytest.mark.skipif(langchain is None, reason="langchain not installed") def test_from_lc_messages() -> None: messages = [ ChatMessage(content="test system message", role=MessageRole.SYSTEM), @@ -60,13 +73,9 @@ def test_from_lc_messages() -> None: assert messages[i].content == lc_messages[i].content -try: - import cohere -except ImportError: - cohere = None # type: ignore - - -@pytest.mark.skipif(cohere is None, reason="cohere not installed") +@pytest.mark.skipif( + cohere is None or langchain is None, reason="cohere or langchain not installed" +) def test_metadata_sets_model_name() -> None: chat_gpt = LangChainLLM( llm=ChatOpenAI(model="gpt-4-0613", openai_api_key="model-name-tests") diff --git a/tests/node_parser/metadata_extractor.py b/tests/node_parser/metadata_extractor.py index 0a1d91a5a34ec866c6794bc737f8cfd82931a7c9..71f3da720d41d4833ae4c7ed89ae326dbae91580 100644 --- a/tests/node_parser/metadata_extractor.py +++ b/tests/node_parser/metadata_extractor.py @@ -1,35 +1,33 @@ -from llama_index import Document -from llama_index.indices.service_context import ServiceContext -from llama_index.node_parser import SimpleNodeParser -from llama_index.node_parser.extractors import ( +from typing import List + +from llama_index.extractors import ( KeywordExtractor, - MetadataExtractor, QuestionsAnsweredExtractor, SummaryExtractor, TitleExtractor, ) +from llama_index.ingestion import run_transformations +from llama_index.node_parser import SentenceSplitter +from llama_index.schema import Document, TransformComponent +from llama_index.service_context import ServiceContext def test_metadata_extractor(mock_service_context: ServiceContext) -> None: - metadata_extractor = MetadataExtractor( - extractors=[ - TitleExtractor(nodes=5), - QuestionsAnsweredExtractor(questions=3), - SummaryExtractor(summaries=["prev", "self"]), - KeywordExtractor(keywords=10), - ], - ) + extractors: List[TransformComponent] = [ + TitleExtractor(nodes=5), + QuestionsAnsweredExtractor(questions=3), + SummaryExtractor(summaries=["prev", "self"]), + KeywordExtractor(keywords=10), + ] - node_parser = SimpleNodeParser.from_defaults( - metadata_extractor=metadata_extractor, - ) + node_parser: TransformComponent = SentenceSplitter() document = Document( text="sample text", metadata={"filename": "README.md", "category": "codebase"}, ) - nodes = node_parser.get_nodes_from_documents([document]) + nodes = run_transformations([document], [node_parser, *extractors]) assert "document_title" in nodes[0].metadata assert "questions_this_excerpt_can_answer" in nodes[0].metadata diff --git a/tests/objects/test_base.py b/tests/objects/test_base.py index ab4f0cddfae1c9b0696b101b93ab5ac406a3c4b2..949d4d4db8158093a55487ccb1a433d7cd1c53b5 100644 --- a/tests/objects/test_base.py +++ b/tests/objects/test_base.py @@ -1,10 +1,10 @@ """Test object index.""" from llama_index.indices.list.base import SummaryIndex -from llama_index.indices.service_context import ServiceContext from llama_index.objects.base import ObjectIndex from llama_index.objects.base_node_mapping import SimpleObjectNodeMapping from llama_index.objects.tool_node_mapping import SimpleToolNodeMapping +from llama_index.service_context import ServiceContext from llama_index.tools.function_tool import FunctionTool diff --git a/tests/output_parsers/test_base.py b/tests/output_parsers/test_base.py index bffd32cc6442dd8358e78d0589d3fb0d31dffa4c..770778bfa4621810fbb7ece53f965c0b0d2b8343 100644 --- a/tests/output_parsers/test_base.py +++ b/tests/output_parsers/test_base.py @@ -1,36 +1,45 @@ """Test Output parsers.""" -from llama_index.bridge.langchain import ( - BaseOutputParser as LCOutputParser, -) -from llama_index.bridge.langchain import ( - ResponseSchema, -) +import pytest from llama_index.output_parsers.langchain import LangchainOutputParser +try: + import langchain + from llama_index.bridge.langchain import ( + BaseOutputParser as LCOutputParser, + ) + from llama_index.bridge.langchain import ( + ResponseSchema, + ) +except ImportError: + langchain = None # type: ignore -class MockOutputParser(LCOutputParser): - """Mock output parser. - Similar to langchain's StructuredOutputParser, but better for testing. +@pytest.mark.skipif(langchain is None, reason="langchain not installed") +def test_lc_output_parser() -> None: + """Test langchain output parser.""" - """ + class MockOutputParser(LCOutputParser): + """Mock output parser. - response_schema: ResponseSchema + Similar to langchain's StructuredOutputParser, but better for testing. - def get_format_instructions(self) -> str: - """Get format instructions.""" - return f"{{ {self.response_schema.name}, {self.response_schema.description} }}" + """ - def parse(self, text: str) -> str: - """Parse the output of an LLM call.""" - # TODO: make this better - return text + response_schema: ResponseSchema + def get_format_instructions(self) -> str: + """Get format instructions.""" + return ( + f"{{ {self.response_schema.name}, {self.response_schema.description} }}" + ) + + def parse(self, text: str) -> str: + """Parse the output of an LLM call.""" + # TODO: make this better + return text -def test_lc_output_parser() -> None: - """Test langchain output parser.""" response_schema = ResponseSchema( name="Education", description="education experience", diff --git a/tests/playground/test_base.py b/tests/playground/test_base.py index 5857f20836000be2646e8ade0f84355fbf8e0c74..f086949a6a6a7ee80e98024d6cd26a4303f07d9b 100644 --- a/tests/playground/test_base.py +++ b/tests/playground/test_base.py @@ -5,11 +5,11 @@ from typing import List import pytest from llama_index.embeddings.base import BaseEmbedding from llama_index.indices.list.base import SummaryIndex -from llama_index.indices.service_context import ServiceContext from llama_index.indices.tree.base import TreeIndex from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.playground import DEFAULT_INDEX_CLASSES, DEFAULT_MODES, Playground from llama_index.schema import Document +from llama_index.service_context import ServiceContext class MockEmbedding(BaseEmbedding): diff --git a/tests/indices/postprocessor/__init__.py b/tests/postprocessor/__init__.py similarity index 100% rename from tests/indices/postprocessor/__init__.py rename to tests/postprocessor/__init__.py diff --git a/tests/indices/postprocessor/test_base.py b/tests/postprocessor/test_base.py similarity index 98% rename from tests/indices/postprocessor/test_base.py rename to tests/postprocessor/test_base.py index 955002fb51274937efad6ca7a8394a891b29ec01..e1305306c06428c1342b415e02a5503b17ccb333 100644 --- a/tests/indices/postprocessor/test_base.py +++ b/tests/postprocessor/test_base.py @@ -5,24 +5,24 @@ from pathlib import Path from typing import Dict, cast import pytest -from llama_index.indices.postprocessor.node import ( +from llama_index.postprocessor.node import ( KeywordNodePostprocessor, PrevNextNodePostprocessor, ) -from llama_index.indices.postprocessor.node_recency import ( +from llama_index.postprocessor.node_recency import ( EmbeddingRecencyPostprocessor, FixedRecencyPostprocessor, TimeWeightedPostprocessor, ) -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.schema import ( MetadataMode, NodeRelationship, NodeWithScore, + QueryBundle, RelatedNodeInfo, TextNode, ) +from llama_index.service_context import ServiceContext from llama_index.storage.docstore.simple_docstore import SimpleDocumentStore spacy_installed = bool(find_spec("spacy")) diff --git a/tests/indices/postprocessor/test_llm_rerank.py b/tests/postprocessor/test_llm_rerank.py similarity index 90% rename from tests/indices/postprocessor/test_llm_rerank.py rename to tests/postprocessor/test_llm_rerank.py index 2d3d08941edc8ee07901a80c56880cefabf662d4..07a79438d8a438201ecf2877e6c2783cf95935a2 100644 --- a/tests/indices/postprocessor/test_llm_rerank.py +++ b/tests/postprocessor/test_llm_rerank.py @@ -3,12 +3,11 @@ from typing import Any, List from unittest.mock import patch -from llama_index.indices.postprocessor.llm_rerank import LLMRerank -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.llm_predictor import LLMPredictor +from llama_index.postprocessor.llm_rerank import LLMRerank from llama_index.prompts import BasePromptTemplate -from llama_index.schema import BaseNode, NodeWithScore, TextNode +from llama_index.schema import BaseNode, NodeWithScore, QueryBundle, TextNode +from llama_index.service_context import ServiceContext def mock_llmpredictor_predict( diff --git a/tests/indices/postprocessor/test_longcontext_reorder.py b/tests/postprocessor/test_longcontext_reorder.py similarity index 94% rename from tests/indices/postprocessor/test_longcontext_reorder.py rename to tests/postprocessor/test_longcontext_reorder.py index 0d2a4e3a238ccea767ee79e54f32bcf4401902b0..18ac8fe57b9fd6abb9a95bbe91a94ad3ef5f68bf 100644 --- a/tests/indices/postprocessor/test_longcontext_reorder.py +++ b/tests/postprocessor/test_longcontext_reorder.py @@ -1,6 +1,6 @@ from typing import List -from llama_index.indices.postprocessor.node import LongContextReorder +from llama_index.postprocessor.node import LongContextReorder from llama_index.schema import Node, NodeWithScore diff --git a/tests/indices/postprocessor/test_metadata_replacement.py b/tests/postprocessor/test_metadata_replacement.py similarity index 85% rename from tests/indices/postprocessor/test_metadata_replacement.py rename to tests/postprocessor/test_metadata_replacement.py index 618d92a2845d848db71f4cab6321c78c62951888..97bb4a3557089cbd3528f5ae3ecc934719911510 100644 --- a/tests/indices/postprocessor/test_metadata_replacement.py +++ b/tests/postprocessor/test_metadata_replacement.py @@ -1,4 +1,4 @@ -from llama_index.indices.postprocessor import MetadataReplacementPostProcessor +from llama_index.postprocessor import MetadataReplacementPostProcessor from llama_index.schema import NodeWithScore, TextNode diff --git a/tests/indices/postprocessor/test_optimizer.py b/tests/postprocessor/test_optimizer.py similarity index 95% rename from tests/indices/postprocessor/test_optimizer.py rename to tests/postprocessor/test_optimizer.py index e147c82b188a91929d29a365049fe4bba1ae1f6d..ed7d8f59280a078bf975d9e42279d3ce981a41a6 100644 --- a/tests/indices/postprocessor/test_optimizer.py +++ b/tests/postprocessor/test_optimizer.py @@ -4,9 +4,8 @@ from typing import Any, List from unittest.mock import patch from llama_index.embeddings.openai import OpenAIEmbedding -from llama_index.indices.postprocessor.optimizer import SentenceEmbeddingOptimizer -from llama_index.indices.query.schema import QueryBundle -from llama_index.schema import NodeWithScore, TextNode +from llama_index.postprocessor.optimizer import SentenceEmbeddingOptimizer +from llama_index.schema import NodeWithScore, QueryBundle, TextNode def mock_tokenizer_fn(text: str) -> List[str]: diff --git a/tests/prompts/test_base.py b/tests/prompts/test_base.py index d91f858c6b1748b8531c044a42416e16dc3d6491..ec4b9f1e90ad971bb51857da6d8b244bff453a06 100644 --- a/tests/prompts/test_base.py +++ b/tests/prompts/test_base.py @@ -4,12 +4,8 @@ from typing import Any import pytest -from llama_index.bridge.langchain import BaseLanguageModel, FakeListLLM -from llama_index.bridge.langchain import ConditionalPromptSelector as LangchainSelector -from llama_index.bridge.langchain import PromptTemplate as LangchainTemplate from llama_index.llms import MockLLM from llama_index.llms.base import ChatMessage, MessageRole -from llama_index.llms.langchain import LangChainLLM from llama_index.prompts import ( ChatPromptTemplate, LangchainPromptTemplate, @@ -19,6 +15,17 @@ from llama_index.prompts import ( from llama_index.prompts.prompt_type import PromptType from llama_index.types import BaseOutputParser +try: + import langchain + from llama_index.bridge.langchain import BaseLanguageModel, FakeListLLM + from llama_index.bridge.langchain import ( + ConditionalPromptSelector as LangchainSelector, + ) + from llama_index.bridge.langchain import PromptTemplate as LangchainTemplate + from llama_index.llms.langchain import LangChainLLM +except ImportError: + langchain = None # type: ignore + class MockOutputParser(BaseOutputParser): """Mock output parser.""" @@ -140,6 +147,7 @@ def test_selector_template() -> None: ) +@pytest.mark.skipif(langchain is None, reason="langchain not installed") def test_langchain_template() -> None: lc_template = LangchainTemplate.from_template("hello {text} {foo}") template = LangchainPromptTemplate(lc_template) @@ -161,6 +169,7 @@ def test_langchain_template() -> None: assert template_2_partial.format(text2="world2") == "hello world2 bar" +@pytest.mark.skipif(langchain is None, reason="langchain not installed") def test_langchain_selector_template() -> None: lc_llm = FakeListLLM(responses=["test"]) mock_llm = LangChainLLM(llm=lc_llm) diff --git a/tests/query_engine/test_pandas.py b/tests/query_engine/test_pandas.py index a60822b184c946c24e5669a309d1505d9d485fdb..5e30d38534c669b51abacb4cbd40ed145ef9bf2a 100644 --- a/tests/query_engine/test_pandas.py +++ b/tests/query_engine/test_pandas.py @@ -3,9 +3,9 @@ from typing import Any, Dict, cast import pandas as pd -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.query_engine.pandas_query_engine import PandasQueryEngine +from llama_index.schema import QueryBundle +from llama_index.service_context import ServiceContext def test_pandas_query_engine(mock_service_context: ServiceContext) -> None: diff --git a/tests/query_engine/test_retriever_query_engine.py b/tests/query_engine/test_retriever_query_engine.py index 8df0755dce4ca97bf6327da5e037ad8c0cb44003..eedb032fec73ebaa3858e0fa690e7b9c5fe5316b 100644 --- a/tests/query_engine/test_retriever_query_engine.py +++ b/tests/query_engine/test_retriever_query_engine.py @@ -1,5 +1,4 @@ import pytest -from langchain.chat_models import ChatOpenAI from llama_index import ( Document, LLMPredictor, @@ -8,6 +7,7 @@ from llama_index import ( ) from llama_index.indices.tree.select_leaf_retriever import TreeSelectLeafRetriever from llama_index.llms import Anthropic +from llama_index.llms.openai import OpenAI from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine try: @@ -20,7 +20,7 @@ except ImportError: def test_query_engine_falls_back_to_inheriting_retrievers_service_context() -> None: documents = [Document(text="Hi")] gpt35turbo_predictor = LLMPredictor( - llm=ChatOpenAI( + llm=OpenAI( temperature=0, model_name="gpt-3.5-turbo-0613", streaming=True, diff --git a/tests/question_gen/test_guidance_generator.py b/tests/question_gen/test_guidance_generator.py index 707cb0888a3f8bc2f6c123920a94dbcd95ef2fd6..01c0a6d02e21c492f8c24368d696093d4e48bcc4 100644 --- a/tests/question_gen/test_guidance_generator.py +++ b/tests/question_gen/test_guidance_generator.py @@ -3,9 +3,9 @@ try: except ImportError: MockLLM = None # type: ignore import pytest -from llama_index.indices.query.schema import QueryBundle from llama_index.question_gen.guidance_generator import GuidanceQuestionGenerator from llama_index.question_gen.types import SubQuestion +from llama_index.schema import QueryBundle from llama_index.tools.types import ToolMetadata diff --git a/tests/question_gen/test_llm_generators.py b/tests/question_gen/test_llm_generators.py index c83943bf55725592f5d2754433677524a64d32f4..c74b837485fabbd81380b176883d6893e972d56c 100644 --- a/tests/question_gen/test_llm_generators.py +++ b/tests/question_gen/test_llm_generators.py @@ -1,7 +1,7 @@ -from llama_index.indices.query.schema import QueryBundle -from llama_index.indices.service_context import ServiceContext from llama_index.question_gen.llm_generators import LLMQuestionGenerator from llama_index.question_gen.types import SubQuestion +from llama_index.schema import QueryBundle +from llama_index.service_context import ServiceContext from llama_index.tools.types import ToolMetadata diff --git a/tests/response_synthesizers/test_refine.py b/tests/response_synthesizers/test_refine.py index 82631272fffe18bc907bef299a5e616383c0f25e..1088d7ad9842db8f2d253f51333fffd0c189482a 100644 --- a/tests/response_synthesizers/test_refine.py +++ b/tests/response_synthesizers/test_refine.py @@ -4,9 +4,9 @@ from typing import Any, Dict, Optional, Type, cast import pytest from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks import CallbackManager -from llama_index.indices.service_context import ServiceContext from llama_index.response_synthesizers import Refine from llama_index.response_synthesizers.refine import StructuredRefineResponse +from llama_index.service_context import ServiceContext from llama_index.types import BasePydanticProgram diff --git a/tests/selectors/test_llm_selectors.py b/tests/selectors/test_llm_selectors.py index 4dd8b01a75845b694ef4a0dee998e7aa853719d1..9cbd38630e86e0e040aec04b14a4bd5051dc76af 100644 --- a/tests/selectors/test_llm_selectors.py +++ b/tests/selectors/test_llm_selectors.py @@ -1,8 +1,8 @@ from unittest.mock import patch -from llama_index.indices.service_context import ServiceContext from llama_index.llms import CompletionResponse from llama_index.selectors.llm_selectors import LLMMultiSelector, LLMSingleSelector +from llama_index.service_context import ServiceContext from tests.mock_utils.mock_predict import _mock_single_select diff --git a/tests/test_utils.py b/tests/test_utils.py index 2bf8713a63c1c093adc1dc3b8ca9b2bd912e1788..d740dc732e3ec4c227a7c9af867ed6a5bcb5be65 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -10,7 +10,7 @@ from llama_index.utils import ( ErrorToRetry, _get_colored_text, get_color_mapping, - globals_helper, + get_tokenizer, iter_batch, print_text, retry_on_exceptions_with_backoff, @@ -24,7 +24,7 @@ def test_tokenizer() -> None: """ text = "hello world foo bar" - tokenizer = globals_helper.tokenizer + tokenizer = get_tokenizer() assert len(tokenizer(text)) == 4 diff --git a/tests/text_splitter/test_sentence_splitter.py b/tests/text_splitter/test_sentence_splitter.py index 928dcb30ab89c107bc267d73d58d8344affaa8c6..f6685acc5b9f6f6c5674128c5fc692dda82cfb17 100644 --- a/tests/text_splitter/test_sentence_splitter.py +++ b/tests/text_splitter/test_sentence_splitter.py @@ -1,5 +1,5 @@ import tiktoken -from llama_index.text_splitter import SentenceSplitter +from llama_index.node_parser.text import SentenceSplitter def test_paragraphs() -> None: @@ -26,7 +26,7 @@ def test_sentences() -> None: def test_chinese_text(chinese_text: str) -> None: splitter = SentenceSplitter(chunk_size=512, chunk_overlap=0) chunks = splitter.split_text(chinese_text) - assert len(chunks) == 3 + assert len(chunks) == 2 def test_contiguous_text(contiguous_text: str) -> None: diff --git a/tests/text_splitter/test_token_splitter.py b/tests/text_splitter/test_token_splitter.py index 9e6db64dba452f2577549e1e4e78e107b9fe31cc..f168f407a4c5d6ad4345402d152143de6622122d 100644 --- a/tests/text_splitter/test_token_splitter.py +++ b/tests/text_splitter/test_token_splitter.py @@ -1,7 +1,7 @@ """Test text splitter.""" import tiktoken -from llama_index.text_splitter import TokenTextSplitter -from llama_index.text_splitter.utils import truncate_text +from llama_index.node_parser.text import TokenTextSplitter +from llama_index.node_parser.text.utils import truncate_text def test_split_token() -> None: @@ -48,7 +48,7 @@ def test_split_long_token() -> None: def test_split_chinese(chinese_text: str) -> None: text_splitter = TokenTextSplitter(chunk_size=512, chunk_overlap=0) chunks = text_splitter.split_text(chinese_text) - assert len(chunks) == 3 + assert len(chunks) == 2 def test_contiguous_text(contiguous_text: str) -> None: diff --git a/tests/token_predictor/test_base.py b/tests/token_predictor/test_base.py index f2581062b46471a1b5bbcd8bd1aa290e0f1b4fbd..8d15fafab0f068a4330f22f1478fd7c950c2f4f9 100644 --- a/tests/token_predictor/test_base.py +++ b/tests/token_predictor/test_base.py @@ -5,11 +5,11 @@ from unittest.mock import patch from llama_index.indices.keyword_table.base import KeywordTableIndex from llama_index.indices.list.base import SummaryIndex -from llama_index.indices.service_context import ServiceContext from llama_index.indices.tree.base import TreeIndex from llama_index.llm_predictor.mock import MockLLMPredictor +from llama_index.node_parser import TokenTextSplitter from llama_index.schema import Document -from llama_index.text_splitter import TokenTextSplitter +from llama_index.service_context import ServiceContext from tests.mock_utils.mock_text_splitter import mock_token_splitter_newline diff --git a/tests/tools/test_base.py b/tests/tools/test_base.py index a010d2442ac1ef9dcc659fe705ad6e73c4886c72..a8330e13ffc2ca7bf0a0f94e2cedc136a5f55407 100644 --- a/tests/tools/test_base.py +++ b/tests/tools/test_base.py @@ -3,6 +3,11 @@ import pytest from llama_index.bridge.pydantic import BaseModel from llama_index.tools.function_tool import FunctionTool +try: + import langchain +except ImportError: + langchain = None # type: ignore + def tmp_function(x: int) -> str: return str(x) @@ -36,6 +41,13 @@ def test_function_tool() -> None: actual_schema = function_tool.metadata.fn_schema.schema() assert actual_schema["properties"]["x"]["type"] == "integer" + +@pytest.mark.skipif(langchain is None, reason="langchain not installed") +def test_function_tool_to_langchain() -> None: + function_tool = FunctionTool.from_defaults( + tmp_function, name="foo", description="bar" + ) + # test to langchain # NOTE: can't take in a function with int args langchain_tool = function_tool.to_langchain_tool() @@ -72,6 +84,14 @@ async def test_function_tool_async() -> None: assert str(function_tool(2)) == "2" assert str(await function_tool.acall(2)) == "async_2" + +@pytest.mark.skipif(langchain is None, reason="langchain not installed") +@pytest.mark.asyncio() +async def test_function_tool_async_langchain() -> None: + function_tool = FunctionTool.from_defaults( + fn=tmp_function, async_fn=async_tmp_function, name="foo", description="bar" + ) + # test to langchain # NOTE: can't take in a function with int args langchain_tool = function_tool.to_langchain_tool() @@ -112,6 +132,14 @@ async def test_function_tool_async_defaults() -> None: actual_schema = function_tool.metadata.fn_schema.schema() assert actual_schema["properties"]["x"]["type"] == "integer" + +@pytest.mark.skipif(langchain is None, reason="langchain not installed") +@pytest.mark.asyncio() +async def test_function_tool_async_defaults_langchain() -> None: + function_tool = FunctionTool.from_defaults( + fn=tmp_function, name="foo", description="bar" + ) + # test to langchain # NOTE: can't take in a function with int args langchain_tool = function_tool.to_langchain_tool() diff --git a/tests/tools/test_ondemand_loader.py b/tests/tools/test_ondemand_loader.py index f4548c817178625520f3fe4ab08a15c7d87d7e0d..e3e977cc4d9c0b77e26d6728f283073c45218f60 100644 --- a/tests/tools/test_ondemand_loader.py +++ b/tests/tools/test_ondemand_loader.py @@ -2,29 +2,32 @@ from typing import List +import pytest + +try: + import langchain +except ImportError: + langchain = None # type: ignore + from llama_index.bridge.pydantic import BaseModel -from llama_index.indices.service_context import ServiceContext from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.readers.string_iterable import StringIterableReader -from llama_index.schema import Document +from llama_index.service_context import ServiceContext from llama_index.tools.ondemand_loader_tool import OnDemandLoaderTool -def test_ondemand_loader_tool( - mock_service_context: ServiceContext, - documents: List[Document], -) -> None: - """Test ondemand loader.""" +class TestSchemaSpec(BaseModel): + """Test schema spec.""" - class TestSchemaSpec(BaseModel): - """Test schema spec.""" + texts: List[str] + query_str: str - texts: List[str] - query_str: str +@pytest.fixture() +def tool(mock_service_context: ServiceContext) -> OnDemandLoaderTool: # import most basic string reader reader = StringIterableReader() - tool = OnDemandLoaderTool.from_defaults( + return OnDemandLoaderTool.from_defaults( reader=reader, index_cls=VectorStoreIndex, index_kwargs={"service_context": mock_service_context}, @@ -32,9 +35,20 @@ def test_ondemand_loader_tool( description="ondemand_loader_tool_desc", fn_schema=TestSchemaSpec, ) + + +def test_ondemand_loader_tool( + tool: OnDemandLoaderTool, +) -> None: + """Test ondemand loader.""" response = tool(["Hello world."], query_str="What is?") assert str(response) == "What is?:Hello world." + +@pytest.mark.skipif(langchain is None, reason="langchain not installed") +def test_ondemand_loader_tool_langchain( + tool: OnDemandLoaderTool, +) -> None: # convert tool to structured langchain tool lc_tool = tool.to_langchain_structured_tool() assert lc_tool.args_schema == TestSchemaSpec diff --git a/tests/utilities/test_sql_wrapper.py b/tests/utilities/test_sql_wrapper.py index 5d96187e5f5643df6768234459879d67397b818c..a89b3d890fb754054e76a2e4e0373a6a0b1b3f79 100644 --- a/tests/utilities/test_sql_wrapper.py +++ b/tests/utilities/test_sql_wrapper.py @@ -2,7 +2,6 @@ from typing import Generator import pytest from llama_index.utilities.sql_wrapper import SQLDatabase -from pytest_mock import MockerFixture from sqlalchemy import Column, Integer, MetaData, String, Table, create_engine @@ -31,11 +30,12 @@ def test_init(sql_database: SQLDatabase) -> None: assert isinstance(sql_database.metadata_obj, MetaData) -# Test from_uri method -def test_from_uri(mocker: MockerFixture) -> None: - mocked = mocker.patch("llama_index.utilities.sql_wrapper.create_engine") - SQLDatabase.from_uri("sqlite:///:memory:") - mocked.assert_called_once_with("sqlite:///:memory:", **{}) +# NOTE: Test is failing after removing langchain for some reason. +# # Test from_uri method +# def test_from_uri(mocker: MockerFixture) -> None: +# mocked = mocker.patch("llama_index.utilities.sql_wrapper.create_engine") +# SQLDatabase.from_uri("sqlite:///:memory:") +# mocked.assert_called_once_with("sqlite:///:memory:", **{}) # Test get_table_columns method