From 6f3d4b51696e6882796656770e2d262b6c825f6b Mon Sep 17 00:00:00 2001 From: Logan <logan.markewich@live.com> Date: Mon, 11 Dec 2023 10:53:12 -0600 Subject: [PATCH] Merge LLM + LLMPredictor, reorganize types (#9388) --- CHANGELOG.md | 5 + benchmarks/agent/agent_utils.py | 2 +- benchmarks/struct_indices/spider/evaluate.py | 2 +- docs/api_reference/service_context.rst | 2 +- docs/changes/deprecated_terms.md | 2 +- docs/community/faq/llms.md | 6 +- docs/community/integrations/deepeval.md | 1 - .../agent/openai_agent_query_plan.ipynb | 1 - .../callbacks/HoneyHiveLlamaIndexTracer.ipynb | 21 +- .../callbacks/WandbCallbackHandler.ipynb | 19 +- .../citation/pdf_page_reference.ipynb | 2 +- .../City_Analysis-Decompose.ipynb | 2 +- .../City_Analysis-Unified-Query.ipynb | 5 +- .../city_analysis/City_Analysis.ipynb | 1 - .../PineconeDemo-CityAnalysis.ipynb | 1 - .../DeepLakeDemo-FinancialData.ipynb | 5 +- .../llms/SimpleIndexDemo-ChatGPT.ipynb | 1 - .../help_channel_dump_05_25_23.json | 4 +- .../help_channel_dump_06_02_23.json | 4 +- docs/examples/docstore/DocstoreDemo.ipynb | 2 +- docs/examples/docstore/FirestoreDemo.ipynb | 1 - .../examples/docstore/MongoDocstoreDemo.ipynb | 1 - .../RedisDocstoreIndexStoreDemo.ipynb | 1 - docs/examples/evaluation/Deepeval.ipynb | 1 - .../evaluation/QuestionGeneration.ipynb | 1 - docs/examples/evaluation/batch_eval.ipynb | 180 ++++++++-- docs/examples/evaluation/relevancy_eval.ipynb | 1 - .../gradient/gradient_structured.ipynb | 2 +- ...llm_judge_single_grading_correctness.ipynb | 3 +- .../pairwise/finetune_llm_judge.ipynb | 3 +- .../knowledge_graph/KnowledgeGraphDemo.ipynb | 1 - .../NebulaGraphKGIndexDemo.ipynb | 5 +- .../knowledge_graph/Neo4jKGIndexDemo.ipynb | 5 +- docs/examples/llm/Konko.ipynb | 2 +- docs/examples/llm/anyscale.ipynb | 2 +- docs/examples/llm/everlyai.ipynb | 2 +- docs/examples/llm/litellm.ipynb | 1 - docs/examples/llm/llm_predictor.ipynb | 17 +- docs/examples/llm/monsterapi.ipynb | 2 +- docs/examples/llm/openllm.ipynb | 2 +- docs/examples/llm/perplexity.ipynb | 2 +- docs/examples/llm/rungpt.ipynb | 2 +- docs/examples/llm/vertex.ipynb | 2 +- docs/examples/llm/vllm.ipynb | 2 +- .../low_level/response_synthesis.ipynb | 2 +- .../LLMReranker-Gatsby.ipynb | 1 - .../LLMReranker-Lyft-10k.ipynb | 1 - .../output_parsing/GuardrailsDemo.ipynb | 128 +++++--- .../LangchainOutputParserDemo.ipynb | 130 ++++---- .../output_parsing/evaporate_program.ipynb | 2 +- .../guidance_sub_question.ipynb | 1 - .../query_engine/JointQASummary.ipynb | 2 +- .../SQLAutoVectorQueryEngine.ipynb | 2 +- .../query_engine/SQLJoinQueryEngine.ipynb | 2 +- .../query_engine/citation_query_engine.ipynb | 1 - .../query_engine/flare_query_engine.ipynb | 1 - .../knowledge_graph_query_engine.ipynb | 2 - .../knowledge_graph_rag_query_engine.ipynb | 2 - .../pdf_tables/recursive_retriever.ipynb | 2 +- .../SimpleIndexDemo-multistep.ipynb | 12 +- ...City_Analysis-Decompose-KeywordTable.ipynb | 1 - .../vector_stores/SimpleIndexDemoMMR.ipynb | 1 - .../module_guides/models/llms/usage_custom.md | 1 - docs/module_guides/querying/output_parser.md | 61 ++-- .../supporting_modules/service_context.md | 1 - .../query_transformations.md | 10 +- .../evaluating/cost_analysis/root.md | 2 +- .../apps/fullstack_with_delphic.md | 2 +- .../putting_it_all_together/q_and_a.md | 2 +- .../q_and_a/terms_definitions_tutorial.md | 3 +- .../q_and_a/unified_query.md | 5 +- .../async/AsyncComposableIndicesSEC.ipynb | 2 +- examples/async/AsyncLLMPredictorDemo.ipynb | 6 +- examples/experimental/Evaporate.ipynb | 2 +- .../paul_graham_essay/GPT4Comparison.ipynb | 1 - .../test_wiki/TestNYC-Benchmark-GPT4.ipynb | 17 +- examples/test_wiki/TestNYC-Tree-GPT4.ipynb | 2 +- experimental/classifier/utils.py | 8 +- experimental/cli/configuration.py | 2 +- llama_index/agent/context_retriever_agent.py | 3 +- llama_index/agent/openai_agent.py | 3 +- llama_index/agent/openai_assistant_agent.py | 2 +- llama_index/agent/react/base.py | 3 +- llama_index/agent/react/formatter.py | 2 +- llama_index/agent/types.py | 2 +- llama_index/callbacks/finetuning_handler.py | 4 +- .../chat_engine/condense_plus_context.py | 16 +- llama_index/chat_engine/condense_question.py | 18 +- llama_index/chat_engine/context.py | 8 +- llama_index/chat_engine/simple.py | 8 +- llama_index/chat_engine/types.py | 2 +- llama_index/chat_engine/utils.py | 2 +- llama_index/evaluation/correctness.py | 2 +- llama_index/evaluation/guideline.py | 2 +- llama_index/evaluation/pairwise.py | 2 +- llama_index/extractors/interface.py | 9 +- llama_index/extractors/metadata_extractors.py | 75 ++--- .../finetuning/cross_encoders/dataset_gen.py | 2 +- llama_index/finetuning/openai/base.py | 2 +- llama_index/finetuning/types.py | 2 +- llama_index/indices/base.py | 2 +- .../indices/common/struct_store/base.py | 8 +- .../indices/common/struct_store/sql.py | 6 +- llama_index/indices/common_tree/base.py | 6 +- .../indices/document_summary/retrievers.py | 2 +- llama_index/indices/keyword_table/base.py | 4 +- .../indices/keyword_table/retrievers.py | 2 +- llama_index/indices/knowledge_graph/base.py | 2 +- .../indices/knowledge_graph/retrievers.py | 6 +- llama_index/indices/list/retrievers.py | 2 +- llama_index/indices/prompt_helper.py | 3 +- .../indices/query/query_transform/base.py | 30 +- .../query_transform/feedback_transform.py | 12 +- .../indices/struct_store/json_query.py | 8 +- llama_index/indices/struct_store/sql.py | 4 +- llama_index/indices/struct_store/sql_query.py | 6 +- .../indices/struct_store/sql_retriever.py | 4 +- llama_index/indices/tree/inserter.py | 8 +- .../indices/tree/select_leaf_retriever.py | 10 +- .../auto_retriever/auto_retriever.py | 4 +- llama_index/llm_predictor/base.py | 14 +- llama_index/llm_predictor/mock.py | 6 +- llama_index/llm_predictor/structured.py | 3 + llama_index/llm_predictor/utils.py | 55 ---- llama_index/llm_predictor/vellum/predictor.py | 3 + llama_index/llms/__init__.py | 24 +- llama_index/llms/ai21.py | 28 +- llama_index/llms/anthropic.py | 33 +- llama_index/llms/anthropic_utils.py | 2 +- llama_index/llms/anyscale.py | 17 +- llama_index/llms/anyscale_utils.py | 2 +- llama_index/llms/azure_openai.py | 15 +- llama_index/llms/base.py | 132 ++------ llama_index/llms/bedrock.py | 35 +- llama_index/llms/bedrock_utils.py | 10 +- llama_index/llms/clarifai.py | 21 +- llama_index/llms/cohere.py | 35 +- llama_index/llms/cohere_utils.py | 2 +- llama_index/llms/custom.py | 20 +- llama_index/llms/everlyai.py | 15 +- llama_index/llms/generic_utils.py | 2 +- llama_index/llms/gradient.py | 20 +- llama_index/llms/huggingface.py | 45 ++- llama_index/llms/konko.py | 36 +- llama_index/llms/konko_utils.py | 2 +- llama_index/llms/langchain.py | 24 +- llama_index/llms/langchain_utils.py | 2 +- llama_index/llms/litellm.py | 36 +- llama_index/llms/litellm_utils.py | 2 +- llama_index/llms/llama_api.py | 30 +- llama_index/llms/llama_cpp.py | 42 +-- llama_index/llms/llama_utils.py | 2 +- llama_index/llms/llm.py | 310 ++++++++++++++++++ llama_index/llms/loading.py | 2 +- llama_index/llms/localai.py | 21 +- llama_index/llms/mock.py | 23 +- llama_index/llms/monsterapi.py | 35 +- llama_index/llms/ollama.py | 71 ++-- llama_index/llms/openai.py | 34 +- llama_index/llms/openai_like.py | 2 +- llama_index/llms/openai_utils.py | 2 +- llama_index/llms/openllm.py | 39 ++- llama_index/llms/palm.py | 20 +- llama_index/llms/perplexity.py | 20 +- llama_index/llms/portkey.py | 32 +- llama_index/llms/portkey_utils.py | 2 +- llama_index/llms/predibase.py | 20 +- llama_index/llms/replicate.py | 70 ++-- llama_index/llms/rungpt.py | 20 +- llama_index/llms/types.py | 110 +++++++ llama_index/llms/utils.py | 2 +- llama_index/llms/vertex.py | 21 +- llama_index/llms/vertex_utils.py | 2 +- llama_index/llms/vllm.py | 53 ++- llama_index/llms/watsonx.py | 28 +- llama_index/llms/xinference.py | 21 +- llama_index/llms/xinference_utils.py | 2 +- llama_index/memory/chat_memory_buffer.py | 3 +- llama_index/memory/types.py | 3 +- llama_index/multi_modal_llms/base.py | 2 +- llama_index/multi_modal_llms/openai.py | 10 +- .../multi_modal_llms/replicate_multi_modal.py | 8 +- .../relational/unstructured_element.py | 3 +- llama_index/node_parser/text/token.py | 1 + llama_index/playground/base.py | 2 +- llama_index/postprocessor/llm_rerank.py | 2 +- llama_index/postprocessor/node.py | 3 +- llama_index/postprocessor/pii.py | 2 +- llama_index/program/llm_program.py | 2 +- llama_index/program/openai_program.py | 2 +- .../program/predefined/evaporate/extractor.py | 4 +- llama_index/program/utils.py | 2 +- llama_index/prompts/__init__.py | 2 +- llama_index/prompts/base.py | 71 ++-- llama_index/prompts/chat_prompts.py | 2 +- llama_index/prompts/lmformatenforcer_utils.py | 2 +- llama_index/prompts/utils.py | 4 +- .../query_engine/flare/answer_inserter.py | 2 +- llama_index/query_engine/flare/base.py | 2 +- .../knowledge_graph_query_engine.py | 4 +- .../query_engine/pandas_query_engine.py | 2 +- .../query_engine/sql_join_query_engine.py | 18 +- llama_index/question_gen/llm_generators.py | 12 +- llama_index/question_gen/openai_generator.py | 2 +- .../response_synthesizers/accumulate.py | 45 ++- .../response_synthesizers/generation.py | 8 +- llama_index/response_synthesizers/refine.py | 54 +-- .../response_synthesizers/simple_summarize.py | 8 +- .../response_synthesizers/tree_summarize.py | 157 +++++---- llama_index/schema.py | 4 +- llama_index/selectors/llm_selectors.py | 26 +- llama_index/selectors/utils.py | 4 +- llama_index/service_context.py | 12 +- llama_index/types.py | 4 +- tests/agent/openai/test_openai_agent.py | 2 +- tests/agent/react/test_react_agent.py | 4 +- .../chat_engine/test_condense_plus_context.py | 25 +- tests/chat_engine/test_condense_question.py | 2 +- tests/chat_engine/test_simple.py | 2 +- tests/conftest.py | 21 +- tests/indices/list/test_retrievers.py | 4 +- .../query/query_transform/test_base.py | 2 +- tests/indices/response/test_tree_summarize.py | 26 +- tests/indices/struct_store/test_json_query.py | 39 ++- tests/llms/test_anthropic.py | 2 +- tests/llms/test_anthropic_utils.py | 2 +- tests/llms/test_bedrock.py | 2 +- tests/llms/test_cohere.py | 2 +- tests/llms/test_custom.py | 4 +- tests/llms/test_konko.py | 2 +- tests/llms/test_langchain.py | 2 +- tests/llms/test_litellm.py | 2 +- tests/llms/test_llama_utils.py | 2 +- tests/llms/test_localai.py | 2 +- tests/llms/test_openai.py | 2 +- tests/llms/test_openai_like.py | 2 +- tests/llms/test_openai_utils.py | 2 +- tests/llms/test_palm.py | 2 +- tests/llms/test_rungpt.py | 4 +- tests/llms/test_vertex.py | 2 +- tests/llms/test_watsonx.py | 2 +- tests/llms/test_xinference.py | 2 +- tests/postprocessor/test_llm_rerank.py | 4 +- tests/program/test_llm_program.py | 2 +- tests/program/test_lmformatenforcer.py | 2 +- tests/program/test_multi_modal_llm_program.py | 2 +- tests/prompts/test_base.py | 2 +- .../test_retriever_query_engine.py | 35 +- tests/token_predictor/test_base.py | 6 +- 249 files changed, 2062 insertions(+), 1376 deletions(-) delete mode 100644 llama_index/llm_predictor/utils.py create mode 100644 llama_index/llms/llm.py create mode 100644 llama_index/llms/types.py diff --git a/CHANGELOG.md b/CHANGELOG.md index 6f5327167b..cbc06b381b 100644 --- a/CHANGELOG.md +++ b/CHANGELOG.md @@ -12,6 +12,11 @@ - Change more than one image input for Replicate Multi-modal models from error to warning (#9360) - Removed GPT-Licensed `aiostream` dependency (#9403) +### Breaking Changes + +- Updated the base `LLM` interface to match `LLMPredictor` (#9388) +- Deprecated `LLMPredictor` (#9388) + ## [0.9.13] - 2023-12-06 ### New Features diff --git a/benchmarks/agent/agent_utils.py b/benchmarks/agent/agent_utils.py index d35fa8635d..74a8625189 100644 --- a/benchmarks/agent/agent_utils.py +++ b/benchmarks/agent/agent_utils.py @@ -3,8 +3,8 @@ from typing import Dict, List, Type from llama_index.agent import OpenAIAgent, ReActAgent from llama_index.agent.types import BaseAgent from llama_index.llms import Anthropic, OpenAI -from llama_index.llms.base import LLM from llama_index.llms.llama_utils import messages_to_prompt +from llama_index.llms.llm import LLM from llama_index.llms.replicate import Replicate OPENAI_MODELS = [ diff --git a/benchmarks/struct_indices/spider/evaluate.py b/benchmarks/struct_indices/spider/evaluate.py index 0ebcf7d56a..a914d1a026 100644 --- a/benchmarks/struct_indices/spider/evaluate.py +++ b/benchmarks/struct_indices/spider/evaluate.py @@ -10,8 +10,8 @@ from spider_utils import create_indexes, load_examples from tqdm import tqdm from llama_index.indices.struct_store.sql import SQLQueryMode, SQLStructStoreIndex -from llama_index.llms.base import ChatMessage, MessageRole from llama_index.llms.openai import OpenAI +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.response.schema import Response logging.getLogger("root").setLevel(logging.WARNING) diff --git a/docs/api_reference/service_context.rst b/docs/api_reference/service_context.rst index 772e2c9d32..e0f88cc918 100644 --- a/docs/api_reference/service_context.rst +++ b/docs/api_reference/service_context.rst @@ -6,7 +6,7 @@ Service Context The service context container is a utility container for LlamaIndex index and query classes. The container contains the following objects that are commonly used for configuring every index and -query, such as the LLMPredictor (for configuring the LLM), +query, such as the LLM, the PromptHelper (for configuring input size/chunk size), the BaseEmbedding (for configuring the embedding model), and more. diff --git a/docs/changes/deprecated_terms.md b/docs/changes/deprecated_terms.md index b4e78f4971..def9fc8ee0 100644 --- a/docs/changes/deprecated_terms.md +++ b/docs/changes/deprecated_terms.md @@ -24,7 +24,7 @@ This has been renamed to `VectorStoreIndex`, but it is only a cosmetic change. P ## LLMPredictor -The `LLMPredictor` object is no longer intended to be used by users. Instead, you can setup an LLM directly and pass it into the `ServiceContext`. +The `LLMPredictor` object is no longer intended to be used by users. Instead, you can setup an LLM directly and pass it into the `ServiceContext`. THe `LLM` class itself has similar attributes and methods as the `LLMPredictor`. - [LLMs in LlamaIndex](/module_guides/models/llms.md) - [Setting LLMs in the ServiceContext](/module_guides/supporting_modules/service_context.md) diff --git a/docs/community/faq/llms.md b/docs/community/faq/llms.md index 58aff0e0ba..4202f33801 100644 --- a/docs/community/faq/llms.md +++ b/docs/community/faq/llms.md @@ -46,12 +46,12 @@ response = query_engine.query("Rest of your query... \nRespond in Italian") Alternatively: ```py -from llama_index import LLMPredictor, ServiceContext +from llama_index import ServiceContext from llama_index.llms import OpenAI -llm_predictor = LLMPredictor(system_prompt="Always respond in Italian.") +llm = OpenAI(system_prompt="Always respond in Italian.") -service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor) +service_context = ServiceContext.from_defaults(llm=llm) query_engine = load_index_from_storage( storage_context, service_context=service_context diff --git a/docs/community/integrations/deepeval.md b/docs/community/integrations/deepeval.md index d927efd6c4..ed8bc8ecdd 100644 --- a/docs/community/integrations/deepeval.md +++ b/docs/community/integrations/deepeval.md @@ -58,7 +58,6 @@ from llama_index import ( TreeIndex, VectorStoreIndex, SimpleDirectoryReader, - LLMPredictor, ServiceContext, Response, ) diff --git a/docs/examples/agent/openai_agent_query_plan.ipynb b/docs/examples/agent/openai_agent_query_plan.ipynb index 383fa4f1bc..e356c5a111 100644 --- a/docs/examples/agent/openai_agent_query_plan.ipynb +++ b/docs/examples/agent/openai_agent_query_plan.ipynb @@ -80,7 +80,6 @@ "source": [ "from llama_index import (\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", " GPTVectorStoreIndex,\n", ")\n", diff --git a/docs/examples/callbacks/HoneyHiveLlamaIndexTracer.ipynb b/docs/examples/callbacks/HoneyHiveLlamaIndexTracer.ipynb index bf81db607b..2117dfef15 100644 --- a/docs/examples/callbacks/HoneyHiveLlamaIndexTracer.ipynb +++ b/docs/examples/callbacks/HoneyHiveLlamaIndexTracer.ipynb @@ -118,16 +118,13 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.callbacks import CallbackManager, CBEventType\n", - "from llama_index.callbacks import LlamaDebugHandler, WandbCallbackHandler\n", + "from llama_index.callbacks import CallbackManager\n", + "from llama_index.callbacks import LlamaDebugHandler\n", "from llama_index import (\n", - " SummaryIndex,\n", - " GPTTreeIndex,\n", - " GPTVectorStoreIndex,\n", + " VectorStoreIndex,\n", " ServiceContext,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", - " GPTSimpleKeywordTableIndex,\n", + " SimpleKeywordTableIndex,\n", " StorageContext,\n", ")\n", "from llama_index.indices.composability import ComposableGraph\n", @@ -289,9 +286,7 @@ } ], "source": [ - "index = GPTVectorStoreIndex.from_documents(\n", - " docs, service_context=service_context\n", - ")" + "index = VectorStoreIndex.from_documents(docs, service_context=service_context)" ] }, { @@ -421,7 +416,7 @@ ], "source": [ "# build NYC index\n", - "nyc_index = GPTVectorStoreIndex.from_documents(\n", + "nyc_index = VectorStoreIndex.from_documents(\n", " nyc_documents,\n", " service_context=service_context,\n", " storage_context=storage_context,\n", @@ -450,7 +445,7 @@ ], "source": [ "# build essay index\n", - "essay_index = GPTVectorStoreIndex.from_documents(\n", + "essay_index = VectorStoreIndex.from_documents(\n", " essay_documents,\n", " service_context=service_context,\n", " storage_context=storage_context,\n", @@ -529,7 +524,7 @@ "from llama_index import StorageContext, load_graph_from_storage\n", "\n", "graph = ComposableGraph.from_indices(\n", - " GPTSimpleKeywordTableIndex,\n", + " SimpleKeywordTableIndex,\n", " [nyc_index, essay_index],\n", " index_summaries=[nyc_index_summary, essay_index_summary],\n", " max_keywords_per_chunk=50,\n", diff --git a/docs/examples/callbacks/WandbCallbackHandler.ipynb b/docs/examples/callbacks/WandbCallbackHandler.ipynb index 664a545f35..cfea37a36c 100644 --- a/docs/examples/callbacks/WandbCallbackHandler.ipynb +++ b/docs/examples/callbacks/WandbCallbackHandler.ipynb @@ -57,16 +57,13 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.callbacks import CallbackManager, CBEventType\n", + "from llama_index.callbacks import CallbackManager\n", "from llama_index.callbacks import LlamaDebugHandler, WandbCallbackHandler\n", "from llama_index import (\n", - " SummaryIndex,\n", - " GPTTreeIndex,\n", - " GPTVectorStoreIndex,\n", + " VectorStoreIndex,\n", " ServiceContext,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", - " GPTSimpleKeywordTableIndex,\n", + " SimpleKeywordTableIndex,\n", " StorageContext,\n", ")\n", "from llama_index.indices.composability import ComposableGraph\n", @@ -238,9 +235,7 @@ } ], "source": [ - "index = GPTVectorStoreIndex.from_documents(\n", - " docs, service_context=service_context\n", - ")" + "index = VectorStoreIndex.from_documents(docs, service_context=service_context)" ] }, { @@ -457,7 +452,7 @@ ], "source": [ "# build NYC index\n", - "nyc_index = GPTVectorStoreIndex.from_documents(\n", + "nyc_index = VectorStoreIndex.from_documents(\n", " nyc_documents,\n", " service_context=service_context,\n", " storage_context=storage_context,\n", @@ -493,7 +488,7 @@ ], "source": [ "# build essay index\n", - "essay_index = GPTVectorStoreIndex.from_documents(\n", + "essay_index = VectorStoreIndex.from_documents(\n", " essay_documents,\n", " service_context=service_context,\n", " storage_context=storage_context,\n", @@ -572,7 +567,7 @@ "from llama_index import StorageContext, load_graph_from_storage\n", "\n", "graph = ComposableGraph.from_indices(\n", - " GPTSimpleKeywordTableIndex,\n", + " SimpleKeywordTableIndex,\n", " [nyc_index, essay_index],\n", " index_summaries=[nyc_index_summary, essay_index_summary],\n", " max_keywords_per_chunk=50,\n", diff --git a/docs/examples/citation/pdf_page_reference.ipynb b/docs/examples/citation/pdf_page_reference.ipynb index 2c25638451..82d6fbb909 100644 --- a/docs/examples/citation/pdf_page_reference.ipynb +++ b/docs/examples/citation/pdf_page_reference.ipynb @@ -57,7 +57,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import LLMPredictor, ServiceContext\n", + "from llama_index import ServiceContext\n", "from llama_index.llms import OpenAI\n", "\n", "service_context = ServiceContext.from_defaults(\n", diff --git a/docs/examples/composable_indices/city_analysis/City_Analysis-Decompose.ipynb b/docs/examples/composable_indices/city_analysis/City_Analysis-Decompose.ipynb index d685ebd7ec..0659a82858 100644 --- a/docs/examples/composable_indices/city_analysis/City_Analysis-Decompose.ipynb +++ b/docs/examples/composable_indices/city_analysis/City_Analysis-Decompose.ipynb @@ -309,7 +309,7 @@ ")\n", "\n", "decompose_transform = DecomposeQueryTransform(\n", - " service_context.llm_predictor, verbose=True\n", + " service_context.llm, verbose=True\n", ")" ] }, diff --git a/docs/examples/composable_indices/city_analysis/City_Analysis-Unified-Query.ipynb b/docs/examples/composable_indices/city_analysis/City_Analysis-Unified-Query.ipynb index 800cae7b77..c1a806320b 100644 --- a/docs/examples/composable_indices/city_analysis/City_Analysis-Unified-Query.ipynb +++ b/docs/examples/composable_indices/city_analysis/City_Analysis-Unified-Query.ipynb @@ -371,11 +371,8 @@ "from llama_index.indices.query.query_transform.base import (\n", " DecomposeQueryTransform,\n", ")\n", - "from llama_index import LLMPredictor\n", "\n", - "decompose_transform = DecomposeQueryTransform(\n", - " LLMPredictor(llm=chatgpt), verbose=True\n", - ")" + "decompose_transform = DecomposeQueryTransform(llm=chatgpt, verbose=True)" ] }, { diff --git a/docs/examples/composable_indices/city_analysis/City_Analysis.ipynb b/docs/examples/composable_indices/city_analysis/City_Analysis.ipynb index 68255f8526..9f39f9a413 100644 --- a/docs/examples/composable_indices/city_analysis/City_Analysis.ipynb +++ b/docs/examples/composable_indices/city_analysis/City_Analysis.ipynb @@ -206,7 +206,6 @@ " SimpleKeywordTableIndex,\n", " SummaryIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "from llama_index.llms import OpenAI\n", diff --git a/docs/examples/composable_indices/city_analysis/PineconeDemo-CityAnalysis.ipynb b/docs/examples/composable_indices/city_analysis/PineconeDemo-CityAnalysis.ipynb index bc4e37269d..4a6e00c072 100644 --- a/docs/examples/composable_indices/city_analysis/PineconeDemo-CityAnalysis.ipynb +++ b/docs/examples/composable_indices/city_analysis/PineconeDemo-CityAnalysis.ipynb @@ -67,7 +67,6 @@ " VectorStoreIndex,\n", " SimpleKeywordTableIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "from llama_index.vector_stores import PineconeVectorStore\n", diff --git a/docs/examples/composable_indices/financial_data_analysis/DeepLakeDemo-FinancialData.ipynb b/docs/examples/composable_indices/financial_data_analysis/DeepLakeDemo-FinancialData.ipynb index 1c3bf18483..a955f829bc 100644 --- a/docs/examples/composable_indices/financial_data_analysis/DeepLakeDemo-FinancialData.ipynb +++ b/docs/examples/composable_indices/financial_data_analysis/DeepLakeDemo-FinancialData.ipynb @@ -152,7 +152,6 @@ " VectorStoreIndex,\n", " SimpleKeywordTableIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", " download_loader,\n", " Document,\n", @@ -818,7 +817,7 @@ ")\n", "\n", "decompose_transform = DecomposeQueryTransform(\n", - " service_context.llm_predictor, verbose=True\n", + " service_context.llm, verbose=True\n", ")" ] }, @@ -879,7 +878,7 @@ ")\n", "\n", "decompose_transform = DecomposeQueryTransform(\n", - " service_context.llm_predictor, verbose=True\n", + " service_context.llm, verbose=True\n", ")" ] }, diff --git a/docs/examples/customization/llms/SimpleIndexDemo-ChatGPT.ipynb b/docs/examples/customization/llms/SimpleIndexDemo-ChatGPT.ipynb index da846a6389..565b5cfefb 100644 --- a/docs/examples/customization/llms/SimpleIndexDemo-ChatGPT.ipynb +++ b/docs/examples/customization/llms/SimpleIndexDemo-ChatGPT.ipynb @@ -51,7 +51,6 @@ "from llama_index import (\n", " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "from llama_index.llms import OpenAI\n", diff --git a/docs/examples/discover_llamaindex/document_management/discord_dumps/help_channel_dump_05_25_23.json b/docs/examples/discover_llamaindex/document_management/discord_dumps/help_channel_dump_05_25_23.json index 54bb7a9d3a..70b35ad139 100644 --- a/docs/examples/discover_llamaindex/document_management/discord_dumps/help_channel_dump_05_25_23.json +++ b/docs/examples/discover_llamaindex/document_management/discord_dumps/help_channel_dump_05_25_23.json @@ -79388,7 +79388,7 @@ "timestampEdited": null, "callEndedTimestamp": null, "isPinned": false, - "content": "`\n for cur_text_chunk in text_chunks:\n if not self._streaming:\n (\n response,\n formatted_prompt,\n ) = self._service_context.llm_predictor.predict(\n refine_template,\n context_msg=cur_text_chunk,\n )\n else:\n response, formatted_prompt = self._service_context.llm_predictor.stream(\n refine_template,\n context_msg=cur_text_chunk,\n )\n self._log_prompt_and_response(\n formatted_prompt, response, log_prefix=\"Refined\"\n )\n`\nA code snippet in refine_response_single seems to have no effect on the subsequent llm query after being split into cur_text_chunk. It appears that the response is entirely based on the result of the last text chunk, and previous chunks are essentially discarded. I feel that there may be some issues with this logic, perhaps it's just my understanding problem. I hope you can clarify this for me, thank you.", + "content": "`\n for cur_text_chunk in text_chunks:\n if not self._streaming:\n (\n response,\n formatted_prompt,\n ) = self._service_context.llm.predict(\n refine_template,\n context_msg=cur_text_chunk,\n )\n else:\n response, formatted_prompt = self._service_context.llm.stream(\n refine_template,\n context_msg=cur_text_chunk,\n )\n self._log_prompt_and_response(\n formatted_prompt, response, log_prefix=\"Refined\"\n )\n`\nA code snippet in refine_response_single seems to have no effect on the subsequent llm query after being split into cur_text_chunk. It appears that the response is entirely based on the result of the last text chunk, and previous chunks are essentially discarded. I feel that there may be some issues with this logic, perhaps it's just my understanding problem. I hope you can clarify this for me, thank you.", "author": { "id": "937548610885791806", "name": "noequal", @@ -105504,7 +105504,7 @@ "timestampEdited": null, "callEndedTimestamp": null, "isPinned": false, - "content": "`index._service_context.llm_predictor.last_token_usage()`\n\n`index._service_context.embed_model.last_token_usage()`", + "content": "`index._service_context.llm.last_token_usage()`\n\n`index._service_context.embed_model.last_token_usage()`", "author": { "id": "334536717648265216", "name": "Logan M", diff --git a/docs/examples/discover_llamaindex/document_management/discord_dumps/help_channel_dump_06_02_23.json b/docs/examples/discover_llamaindex/document_management/discord_dumps/help_channel_dump_06_02_23.json index c29bb79f5d..b59cd12b67 100644 --- a/docs/examples/discover_llamaindex/document_management/discord_dumps/help_channel_dump_06_02_23.json +++ b/docs/examples/discover_llamaindex/document_management/discord_dumps/help_channel_dump_06_02_23.json @@ -79388,7 +79388,7 @@ "timestampEdited": null, "callEndedTimestamp": null, "isPinned": false, - "content": "`\n for cur_text_chunk in text_chunks:\n if not self._streaming:\n (\n response,\n formatted_prompt,\n ) = self._service_context.llm_predictor.predict(\n refine_template,\n context_msg=cur_text_chunk,\n )\n else:\n response, formatted_prompt = self._service_context.llm_predictor.stream(\n refine_template,\n context_msg=cur_text_chunk,\n )\n self._log_prompt_and_response(\n formatted_prompt, response, log_prefix=\"Refined\"\n )\n`\nA code snippet in refine_response_single seems to have no effect on the subsequent llm query after being split into cur_text_chunk. It appears that the response is entirely based on the result of the last text chunk, and previous chunks are essentially discarded. I feel that there may be some issues with this logic, perhaps it's just my understanding problem. I hope you can clarify this for me, thank you.", + "content": "`\n for cur_text_chunk in text_chunks:\n if not self._streaming:\n (\n response,\n formatted_prompt,\n ) = self._service_context.llm.predict(\n refine_template,\n context_msg=cur_text_chunk,\n )\n else:\n response, formatted_prompt = self._service_context.llm.stream(\n refine_template,\n context_msg=cur_text_chunk,\n )\n self._log_prompt_and_response(\n formatted_prompt, response, log_prefix=\"Refined\"\n )\n`\nA code snippet in refine_response_single seems to have no effect on the subsequent llm query after being split into cur_text_chunk. It appears that the response is entirely based on the result of the last text chunk, and previous chunks are essentially discarded. I feel that there may be some issues with this logic, perhaps it's just my understanding problem. I hope you can clarify this for me, thank you.", "author": { "id": "937548610885791806", "name": "noequal", @@ -105504,7 +105504,7 @@ "timestampEdited": null, "callEndedTimestamp": null, "isPinned": false, - "content": "`index._service_context.llm_predictor.last_token_usage()`\n\n`index._service_context.embed_model.last_token_usage()`", + "content": "`index._service_context.llm.last_token_usage()`\n\n`index._service_context.embed_model.last_token_usage()`", "author": { "id": "334536717648265216", "name": "Logan M", diff --git a/docs/examples/docstore/DocstoreDemo.ipynb b/docs/examples/docstore/DocstoreDemo.ipynb index 788c2a3970..df8eb22891 100644 --- a/docs/examples/docstore/DocstoreDemo.ipynb +++ b/docs/examples/docstore/DocstoreDemo.ipynb @@ -59,7 +59,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import SimpleDirectoryReader, ServiceContext, LLMPredictor\n", + "from llama_index import SimpleDirectoryReader, ServiceContext\n", "from llama_index import VectorStoreIndex, SummaryIndex, SimpleKeywordTableIndex\n", "from llama_index.composability import ComposableGraph\n", "from llama_index.llms import OpenAI" diff --git a/docs/examples/docstore/FirestoreDemo.ipynb b/docs/examples/docstore/FirestoreDemo.ipynb index f2c9a9343a..dfbfaddbe4 100644 --- a/docs/examples/docstore/FirestoreDemo.ipynb +++ b/docs/examples/docstore/FirestoreDemo.ipynb @@ -56,7 +56,6 @@ "from llama_index import (\n", " SimpleDirectoryReader,\n", " ServiceContext,\n", - " LLMPredictor,\n", " StorageContext,\n", ")\n", "from llama_index import VectorStoreIndex, SummaryIndex, SimpleKeywordTableIndex\n", diff --git a/docs/examples/docstore/MongoDocstoreDemo.ipynb b/docs/examples/docstore/MongoDocstoreDemo.ipynb index 2f45117a54..e4f5cc8555 100644 --- a/docs/examples/docstore/MongoDocstoreDemo.ipynb +++ b/docs/examples/docstore/MongoDocstoreDemo.ipynb @@ -63,7 +63,6 @@ "from llama_index import (\n", " SimpleDirectoryReader,\n", " ServiceContext,\n", - " LLMPredictor,\n", " StorageContext,\n", ")\n", "from llama_index import VectorStoreIndex, SummaryIndex, SimpleKeywordTableIndex\n", diff --git a/docs/examples/docstore/RedisDocstoreIndexStoreDemo.ipynb b/docs/examples/docstore/RedisDocstoreIndexStoreDemo.ipynb index 3ed734ba02..74b9a196a7 100644 --- a/docs/examples/docstore/RedisDocstoreIndexStoreDemo.ipynb +++ b/docs/examples/docstore/RedisDocstoreIndexStoreDemo.ipynb @@ -91,7 +91,6 @@ "from llama_index import (\n", " SimpleDirectoryReader,\n", " ServiceContext,\n", - " LLMPredictor,\n", " StorageContext,\n", ")\n", "from llama_index import VectorStoreIndex, SummaryIndex, SimpleKeywordTableIndex\n", diff --git a/docs/examples/evaluation/Deepeval.ipynb b/docs/examples/evaluation/Deepeval.ipynb index fe3a1ba316..17a525f6aa 100644 --- a/docs/examples/evaluation/Deepeval.ipynb +++ b/docs/examples/evaluation/Deepeval.ipynb @@ -100,7 +100,6 @@ " TreeIndex,\n", " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", " Response,\n", ")\n", diff --git a/docs/examples/evaluation/QuestionGeneration.ipynb b/docs/examples/evaluation/QuestionGeneration.ipynb index b723dc20fe..05a2371349 100644 --- a/docs/examples/evaluation/QuestionGeneration.ipynb +++ b/docs/examples/evaluation/QuestionGeneration.ipynb @@ -64,7 +64,6 @@ " SimpleDirectoryReader,\n", " VectorStoreIndex,\n", " ServiceContext,\n", - " LLMPredictor,\n", " Response,\n", ")\n", "from llama_index.llms import OpenAI" diff --git a/docs/examples/evaluation/batch_eval.ipynb b/docs/examples/evaluation/batch_eval.ipynb index 1ce5842cb9..8e3123ebe3 100644 --- a/docs/examples/evaluation/batch_eval.ipynb +++ b/docs/examples/evaluation/batch_eval.ipynb @@ -35,8 +35,8 @@ "import os\n", "import openai\n", "\n", - "os.environ[\"OPENAI_API_KEY\"] = \"YOUR_API_KEY\"\n", - "openai.api_key = os.environ[\"OPENAI_API_KEY\"]" + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"\n", + "# openai.api_key = os.environ[\"OPENAI_API_KEY\"]" ] }, { @@ -124,12 +124,126 @@ "First, we can generate some questions and then run evaluation on them." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "976e0a93", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Requirement already satisfied: spacy in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (3.7.2)\n", + "Requirement already satisfied: datasets in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (2.15.0)\n", + "Requirement already satisfied: span-marker in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (1.5.0)\n", + "Requirement already satisfied: scikit-learn in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (1.3.2)\n", + "Requirement already satisfied: spacy-legacy<3.1.0,>=3.0.11 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (3.0.12)\n", + "Requirement already satisfied: spacy-loggers<2.0.0,>=1.0.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (1.0.5)\n", + "Requirement already satisfied: murmurhash<1.1.0,>=0.28.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (1.0.10)\n", + "Requirement already satisfied: cymem<2.1.0,>=2.0.2 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (2.0.8)\n", + "Requirement already satisfied: preshed<3.1.0,>=3.0.2 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (3.0.9)\n", + "Requirement already satisfied: thinc<8.3.0,>=8.1.8 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (8.2.1)\n", + "Requirement already satisfied: wasabi<1.2.0,>=0.9.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (1.1.2)\n", + "Requirement already satisfied: srsly<3.0.0,>=2.4.3 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (2.4.8)\n", + "Requirement already satisfied: catalogue<2.1.0,>=2.0.6 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (2.0.10)\n", + "Requirement already satisfied: weasel<0.4.0,>=0.1.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (0.3.4)\n", + "Requirement already satisfied: typer<0.10.0,>=0.3.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (0.9.0)\n", + "Requirement already satisfied: smart-open<7.0.0,>=5.2.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (6.4.0)\n", + "Requirement already satisfied: tqdm<5.0.0,>=4.38.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (4.66.1)\n", + "Requirement already satisfied: requests<3.0.0,>=2.13.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (2.31.0)\n", + "Requirement already satisfied: pydantic!=1.8,!=1.8.1,<3.0.0,>=1.7.4 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (1.10.12)\n", + "Requirement already satisfied: jinja2 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (3.1.2)\n", + "Requirement already satisfied: setuptools in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (69.0.2)\n", + "Requirement already satisfied: packaging>=20.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (23.2)\n", + "Requirement already satisfied: langcodes<4.0.0,>=3.2.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (3.3.0)\n", + "Requirement already satisfied: numpy>=1.19.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from spacy) (1.24.4)\n", + "Requirement already satisfied: pyarrow>=8.0.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (14.0.1)\n", + "Requirement already satisfied: pyarrow-hotfix in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (0.6)\n", + "Requirement already satisfied: dill<0.3.8,>=0.3.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (0.3.7)\n", + "Requirement already satisfied: pandas in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (2.0.3)\n", + "Requirement already satisfied: xxhash in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (3.4.1)\n", + "Requirement already satisfied: multiprocess in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (0.70.15)\n", + "Requirement already satisfied: fsspec[http]<=2023.10.0,>=2023.1.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (2023.10.0)\n", + "Requirement already satisfied: aiohttp in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (3.9.1)\n", + "Requirement already satisfied: huggingface-hub>=0.18.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (0.19.4)\n", + "Requirement already satisfied: pyyaml>=5.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from datasets) (6.0.1)\n", + "Requirement already satisfied: torch in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from span-marker) (2.1.1)\n", + "Requirement already satisfied: accelerate in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from span-marker) (0.25.0)\n", + "Requirement already satisfied: transformers>=4.19.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from span-marker) (4.35.2)\n", + "Requirement already satisfied: evaluate in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from span-marker) (0.4.1)\n", + "Requirement already satisfied: seqeval in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from span-marker) (1.2.2)\n", + "Requirement already satisfied: scipy>=1.5.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from scikit-learn) (1.11.4)\n", + "Requirement already satisfied: joblib>=1.1.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from scikit-learn) (1.3.2)\n", + "Requirement already satisfied: threadpoolctl>=2.0.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from scikit-learn) (3.2.0)\n", + "Requirement already satisfied: attrs>=17.3.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from aiohttp->datasets) (23.1.0)\n", + "Requirement already satisfied: multidict<7.0,>=4.5 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from aiohttp->datasets) (6.0.4)\n", + "Requirement already satisfied: yarl<2.0,>=1.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.9.3)\n", + "Requirement already satisfied: frozenlist>=1.1.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.4.0)\n", + "Requirement already satisfied: aiosignal>=1.1.2 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from aiohttp->datasets) (1.3.1)\n", + "Requirement already satisfied: filelock in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from huggingface-hub>=0.18.0->datasets) (3.13.1)\n", + "Requirement already satisfied: typing-extensions>=3.7.4.3 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from huggingface-hub>=0.18.0->datasets) (4.8.0)\n", + "Requirement already satisfied: charset-normalizer<4,>=2 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from requests<3.0.0,>=2.13.0->spacy) (3.3.2)\n", + "Requirement already satisfied: idna<4,>=2.5 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from requests<3.0.0,>=2.13.0->spacy) (3.6)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from requests<3.0.0,>=2.13.0->spacy) (1.26.18)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from requests<3.0.0,>=2.13.0->spacy) (2023.11.17)\n", + "Requirement already satisfied: blis<0.8.0,>=0.7.8 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from thinc<8.3.0,>=8.1.8->spacy) (0.7.11)\n", + "Requirement already satisfied: confection<1.0.0,>=0.0.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from thinc<8.3.0,>=8.1.8->spacy) (0.1.4)\n", + "Requirement already satisfied: regex!=2019.12.17 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from transformers>=4.19.0->span-marker) (2023.10.3)\n", + "Requirement already satisfied: tokenizers<0.19,>=0.14 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from transformers>=4.19.0->span-marker) (0.15.0)\n", + "Requirement already satisfied: safetensors>=0.3.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from transformers>=4.19.0->span-marker) (0.4.1)\n", + "Requirement already satisfied: click<9.0.0,>=7.1.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from typer<0.10.0,>=0.3.0->spacy) (8.1.7)\n", + "Requirement already satisfied: cloudpathlib<0.17.0,>=0.7.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from weasel<0.4.0,>=0.1.0->spacy) (0.16.0)\n", + "Requirement already satisfied: psutil in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from accelerate->span-marker) (5.9.6)\n", + "Requirement already satisfied: sympy in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (1.12)\n", + "Requirement already satisfied: networkx in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (3.2.1)\n", + "Requirement already satisfied: nvidia-cuda-nvrtc-cu12==12.1.105 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-runtime-cu12==12.1.105 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (12.1.105)\n", + "Requirement already satisfied: nvidia-cuda-cupti-cu12==12.1.105 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (12.1.105)\n", + "Requirement already satisfied: nvidia-cudnn-cu12==8.9.2.26 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (8.9.2.26)\n", + "Requirement already satisfied: nvidia-cublas-cu12==12.1.3.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (12.1.3.1)\n", + "Requirement already satisfied: nvidia-cufft-cu12==11.0.2.54 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (11.0.2.54)\n", + "Requirement already satisfied: nvidia-curand-cu12==10.3.2.106 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (10.3.2.106)\n", + "Requirement already satisfied: nvidia-cusolver-cu12==11.4.5.107 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (11.4.5.107)\n", + "Requirement already satisfied: nvidia-cusparse-cu12==12.1.0.106 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (12.1.0.106)\n", + "Requirement already satisfied: nvidia-nccl-cu12==2.18.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (2.18.1)\n", + "Requirement already satisfied: nvidia-nvtx-cu12==12.1.105 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (12.1.105)\n", + "Requirement already satisfied: triton==2.1.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from torch->span-marker) (2.1.0)\n", + "Requirement already satisfied: nvidia-nvjitlink-cu12 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from nvidia-cusolver-cu12==11.4.5.107->torch->span-marker) (12.3.101)\n", + "Requirement already satisfied: responses<0.19 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from evaluate->span-marker) (0.18.0)\n", + "Requirement already satisfied: MarkupSafe>=2.0 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from jinja2->spacy) (2.1.3)\n", + "Requirement already satisfied: python-dateutil>=2.8.2 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from pandas->datasets) (2.8.2)\n", + "Requirement already satisfied: pytz>=2020.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from pandas->datasets) (2023.3.post1)\n", + "Requirement already satisfied: tzdata>=2022.1 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from pandas->datasets) (2023.3)\n", + "Requirement already satisfied: six>=1.5 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from python-dateutil>=2.8.2->pandas->datasets) (1.16.0)\n", + "Requirement already satisfied: mpmath>=0.19 in /home/loganm/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages (from sympy->torch->span-marker) (1.3.0)\n", + "\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m A new release of pip is available: \u001b[0m\u001b[31;49m23.2.1\u001b[0m\u001b[39;49m -> \u001b[0m\u001b[32;49m23.3.1\u001b[0m\n", + "\u001b[1m[\u001b[0m\u001b[34;49mnotice\u001b[0m\u001b[1;39;49m]\u001b[0m\u001b[39;49m To update, run: \u001b[0m\u001b[32;49mpip install --upgrade pip\u001b[0m\n" + ] + } + ], + "source": [ + "!pip install spacy datasets span-marker scikit-learn" + ] + }, { "cell_type": "code", "execution_count": null, "id": "e31e10e6", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "/home/loganm/llama_index_proper/llama_index/llama_index/evaluation/dataset_generation.py:187: DeprecationWarning: Call to deprecated class DatasetGenerator. (Deprecated in favor of `RagDatasetGenerator` which should be used instead.)\n", + " return cls(\n", + "/home/loganm/llama_index_proper/llama_index/llama_index/evaluation/dataset_generation.py:282: DeprecationWarning: Call to deprecated class QueryResponseDataset. (Deprecated in favor of `LabelledRagDataset` which should be used instead.)\n", + " return QueryResponseDataset(queries=queries, responses=responses_dict)\n" + ] + } + ], "source": [ "from llama_index.evaluation import DatasetGenerator\n", "\n", @@ -137,7 +251,7 @@ " documents, service_context=service_context\n", ")\n", "\n", - "questions = dataset_generator.generate_questions_from_nodes(num=25)" + "qas = dataset_generator.generate_dataset_from_nodes(num=3)" ] }, { @@ -165,7 +279,7 @@ ")\n", "\n", "eval_results = await runner.aevaluate_queries(\n", - " vector_index.as_query_engine(), queries=questions\n", + " vector_index.as_query_engine(), queries=qas.questions\n", ")\n", "\n", "# If we had ground-truth answers, we could also include the correctness evaluator like below.\n", @@ -174,17 +288,35 @@ "#\n", "\n", "# runner = BatchEvalRunner(\n", - "# {'faithfulness': faithfulness_gpt4, 'relevancy': relevancy_gpt4, 'correctness': correctness_gpt4},\n", - "# workers=8,\n", + "# {\"correctness\": correctness_gpt4},\n", + "# workers=8,\n", "# )\n", - "#\n", + "\n", "# eval_results = await runner.aevaluate_queries(\n", - "# vector_index.as_query_engine(),\n", - "# queries=questions,\n", - "# query_kwargs={'question': {'reference': 'ground-truth answer', ...}}\n", + "# vector_index.as_query_engine(),\n", + "# queries=qas.queries,\n", + "# reference=[qr[1] for qr in qas.qr_pairs],\n", "# )" ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "0eff6823", + "metadata": {}, + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "3\n" + ] + } + ], + "source": [ + "print(len([qr for qr in qas.qr_pairs]))" + ] + }, { "cell_type": "markdown", "id": "b256b98c", @@ -203,22 +335,22 @@ "name": "stdout", "output_type": "stream", "text": [ - "dict_keys(['faithfulness', 'relevancy'])\n", - "dict_keys(['query', 'contexts', 'response', 'passing', 'feedback', 'score'])\n", - "True\n", - "The population of New York City as of 2020 is 8,804,190.\n", - "[\"== Demographics ==\\n\\nNew York City is the most populous city in the United States, with 8,804,190 residents incorporating more immigration into the city than outmigration since the 2010 United States census. More than twice as many people live in New York City as compared to Los Angeles, the second-most populous U.S. city; and New York has more than three times the population of Chicago, the third-most populous U.S. city. New York City gained more residents between 2010 and 2020 (629,000) than any other U.S. city, and a greater amount than the total sum of the gains over the same decade of the next four largest U.S. cities, Los Angeles, Chicago, Houston, and Phoenix, Arizona combined. New York City's population is about 44% of New York State's population, and about 39% of the population of the New York metropolitan area. The majority of New York City residents in 2020 (5,141,538, or 58.4%) were living on Long Island, in Brooklyn, or in Queens. The New York City metropolitan statistical area, has the largest foreign-born population of any metropolitan region in the world. The New York region continues to be by far the leading metropolitan gateway for legal immigrants admitted into the United States, substantially exceeding the combined totals of Los Angeles and Miami.\\n\\n\\n=== Population density ===\\n\\nIn 2020, the city had an estimated population density of 29,302.37 inhabitants per square mile (11,313.71/km2), rendering it the nation's most densely populated of all larger municipalities (those with more than 100,000 residents), with several small cities (of fewer than 100,000) in adjacent Hudson County, New Jersey having greater density, as per the 2010 census. Geographically co-extensive with New York County, the borough of Manhattan's 2017 population density of 72,918 inhabitants per square mile (28,154/km2) makes it the highest of any county in the United States and higher than the density of any individual American city. The next three densest counties in the United States, placing second through fourth, are also New York boroughs: Brooklyn, the Bronx, and Queens respectively.\", \"New York, often called New York City or NYC, is the most populous city in the United States. With a 2020 population of 8,804,190 distributed over 300.46 square miles (778.2 km2), New York City is the most densely populated major city in the United States and more than twice as populous as Los Angeles, the nation's second-largest city. New York City is located at the southern tip of New York State. It constitutes the geographical and demographic center of both the Northeast megalopolis and the New York metropolitan area, the largest metropolitan area in the U.S. by both population and urban area. With over 20.1 million people in its metropolitan statistical area and 23.5 million in its combined statistical area as of 2020, New York is one of the world's most populous megacities, and over 58 million people live within 250 mi (400 km) of the city. New York City is a global cultural, financial, entertainment, and media center with a significant influence on commerce, health care and life sciences, research, technology, education, politics, tourism, dining, art, fashion, and sports. Home to the headquarters of the United Nations, New York is an important center for international diplomacy, and is sometimes described as the capital of the world.Situated on one of the world's largest natural harbors and extending into the Atlantic Ocean, New York City comprises five boroughs, each of which is coextensive with a respective county of the state of New York. The five boroughs, which were created in 1898 when local governments were consolidated into a single municipal entity, are: Brooklyn (in Kings County), Queens (in Queens County), Manhattan (in New York County), The Bronx (in Bronx County), and Staten Island (in Richmond County).As of 2021, the New York metropolitan area is the largest metropolitan economy in the world with a gross metropolitan product of over $2.4 trillion. If the New York metropolitan area were a sovereign state, it would have the eighth-largest economy in the world. New York City is an established safe haven for global investors. New York is home to the highest number of billionaires, individuals of ultra-high net worth (greater than US$30 million), and millionaires of any city in the world.\\nThe city and its metropolitan area constitute the premier gateway for legal immigration to the United States.\"]\n" + "dict_keys(['correctness'])\n", + "dict_keys(['query', 'contexts', 'response', 'passing', 'feedback', 'score', 'pairwise_source'])\n", + "False\n", + "The context information does not provide any information related to the query. Therefore, I cannot provide an answer based on the given context.\n", + "None\n" ] } ], "source": [ "print(eval_results.keys())\n", "\n", - "print(eval_results[\"faithfulness\"][0].dict().keys())\n", + "print(eval_results[\"correctness\"][0].dict().keys())\n", "\n", - "print(eval_results[\"faithfulness\"][0].passing)\n", - "print(eval_results[\"faithfulness\"][0].response)\n", - "print(eval_results[\"faithfulness\"][0].contexts)" + "print(eval_results[\"correctness\"][0].passing)\n", + "print(eval_results[\"correctness\"][0].response)\n", + "print(eval_results[\"correctness\"][0].contexts)" ] }, { @@ -257,12 +389,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "faithfulness Score: 1.0\n" + "correctness Score: 0.0\n" ] } ], "source": [ - "score = get_eval_results(\"faithfulness\", eval_results)" + "score = get_eval_results(\"correctness\", eval_results)" ] }, { @@ -286,9 +418,9 @@ ], "metadata": { "kernelspec": { - "display_name": "llama_index_v2", + "display_name": "llama-index-4a-wkI5X-py3.11", "language": "python", - "name": "llama_index_v2" + "name": "python3" }, "language_info": { "codemirror_mode": { diff --git a/docs/examples/evaluation/relevancy_eval.ipynb b/docs/examples/evaluation/relevancy_eval.ipynb index e44df8e784..652e9783a9 100644 --- a/docs/examples/evaluation/relevancy_eval.ipynb +++ b/docs/examples/evaluation/relevancy_eval.ipynb @@ -36,7 +36,6 @@ " TreeIndex,\n", " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", " Response,\n", ")\n", diff --git a/docs/examples/finetuning/gradient/gradient_structured.ipynb b/docs/examples/finetuning/gradient/gradient_structured.ipynb index fbb581771f..efee83cbfd 100644 --- a/docs/examples/finetuning/gradient/gradient_structured.ipynb +++ b/docs/examples/finetuning/gradient/gradient_structured.ipynb @@ -103,7 +103,7 @@ " is_chat_model=True,\n", ")\n", "# HACK: set chat model\n", - "# from llama_index.llms.base import LLMMetadata\n", + "# from llama_index.llms.types import LLMMetadata\n", "# gradient_llm.metadata = LLMMetadata(\n", "# context_window=1024,\n", "# num_output=gradient_llm.max_tokens or 20,\n", diff --git a/docs/examples/finetuning/llm_judge/correctness/finetune_llm_judge_single_grading_correctness.ipynb b/docs/examples/finetuning/llm_judge/correctness/finetune_llm_judge_single_grading_correctness.ipynb index bc383124b5..26334e8bff 100644 --- a/docs/examples/finetuning/llm_judge/correctness/finetune_llm_judge_single_grading_correctness.ipynb +++ b/docs/examples/finetuning/llm_judge/correctness/finetune_llm_judge_single_grading_correctness.ipynb @@ -257,14 +257,13 @@ " RetrieverQueryEngine,\n", ")\n", "from llama_index.llms import HuggingFaceInferenceAPI\n", - "from llama_index.llm_predictor import LLMPredictor\n", "\n", "llm = HuggingFaceInferenceAPI(\n", " model_name=\"meta-llama/Llama-2-7b-chat-hf\",\n", " context_window=2048, # to use refine\n", " token=HUGGING_FACE_TOKEN,\n", ")\n", - "context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))\n", + "context = ServiceContext.from_defaults(llm=llm)\n", "query_engine = RetrieverQueryEngine.from_args(\n", " retriever=the_retriever, service_context=context\n", ")" diff --git a/docs/examples/finetuning/llm_judge/pairwise/finetune_llm_judge.ipynb b/docs/examples/finetuning/llm_judge/pairwise/finetune_llm_judge.ipynb index c72f69131b..542413c8fd 100644 --- a/docs/examples/finetuning/llm_judge/pairwise/finetune_llm_judge.ipynb +++ b/docs/examples/finetuning/llm_judge/pairwise/finetune_llm_judge.ipynb @@ -402,7 +402,6 @@ " RetrieverQueryEngine,\n", ")\n", "from llama_index.llms import HuggingFaceInferenceAPI\n", - "from llama_index.llm_predictor import LLMPredictor\n", "\n", "\n", "def create_query_engine(\n", @@ -416,7 +415,7 @@ " context_window=2048, # to use refine\n", " token=HUGGING_FACE_TOKEN,\n", " )\n", - " context = ServiceContext.from_defaults(llm_predictor=LLMPredictor(llm=llm))\n", + " context = ServiceContext.from_defaults(llm=llm)\n", " return RetrieverQueryEngine.from_args(\n", " retriever=retriever, service_context=context\n", " )" diff --git a/docs/examples/index_structs/knowledge_graph/KnowledgeGraphDemo.ipynb b/docs/examples/index_structs/knowledge_graph/KnowledgeGraphDemo.ipynb index bb238dec75..b467fb951a 100644 --- a/docs/examples/index_structs/knowledge_graph/KnowledgeGraphDemo.ipynb +++ b/docs/examples/index_structs/knowledge_graph/KnowledgeGraphDemo.ipynb @@ -76,7 +76,6 @@ "source": [ "from llama_index import (\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", " KnowledgeGraphIndex,\n", ")\n", diff --git a/docs/examples/index_structs/knowledge_graph/NebulaGraphKGIndexDemo.ipynb b/docs/examples/index_structs/knowledge_graph/NebulaGraphKGIndexDemo.ipynb index 68fa6426c8..d64cd170e1 100644 --- a/docs/examples/index_structs/knowledge_graph/NebulaGraphKGIndexDemo.ipynb +++ b/docs/examples/index_structs/knowledge_graph/NebulaGraphKGIndexDemo.ipynb @@ -51,7 +51,6 @@ " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", " KnowledgeGraphIndex,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "\n", @@ -83,7 +82,6 @@ " api_base=openai.api_base,\n", " api_version=openai.api_version,\n", ")\n", - "llm_predictor = LLMPredictor(llm=llm)\n", "\n", "# You need to deploy your own embedding model as well as your own chat completion model\n", "embedding_model = OpenAIEmbedding(\n", @@ -96,7 +94,7 @@ ")\n", "\n", "service_context = ServiceContext.from_defaults(\n", - " llm_predictor=llm_predictor,\n", + " llm=llm,\n", " embed_model=embedding_model,\n", ")" ] @@ -128,7 +126,6 @@ "source": [ "from llama_index import (\n", " KnowledgeGraphIndex,\n", - " LLMPredictor,\n", " ServiceContext,\n", " SimpleDirectoryReader,\n", ")\n", diff --git a/docs/examples/index_structs/knowledge_graph/Neo4jKGIndexDemo.ipynb b/docs/examples/index_structs/knowledge_graph/Neo4jKGIndexDemo.ipynb index 82272d9a40..ad4710ad31 100644 --- a/docs/examples/index_structs/knowledge_graph/Neo4jKGIndexDemo.ipynb +++ b/docs/examples/index_structs/knowledge_graph/Neo4jKGIndexDemo.ipynb @@ -50,7 +50,6 @@ " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", " KnowledgeGraphIndex,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "\n", @@ -81,7 +80,6 @@ " \"api_version\": openai.api_version,\n", " },\n", ")\n", - "llm_predictor = LLMPredictor(llm=llm)\n", "\n", "# You need to deploy your own embedding model as well as your own chat completion model\n", "embedding_llm = OpenAIEmbedding(\n", @@ -94,7 +92,7 @@ ")\n", "\n", "service_context = ServiceContext.from_defaults(\n", - " llm_predictor=llm_predictor,\n", + " llm=llm,\n", " embed_model=embedding_llm,\n", ")" ] @@ -124,7 +122,6 @@ "source": [ "from llama_index import (\n", " KnowledgeGraphIndex,\n", - " LLMPredictor,\n", " ServiceContext,\n", " SimpleDirectoryReader,\n", ")\n", diff --git a/docs/examples/llm/Konko.ipynb b/docs/examples/llm/Konko.ipynb index 5259a1552a..d1bb6b0c7e 100644 --- a/docs/examples/llm/Konko.ipynb +++ b/docs/examples/llm/Konko.ipynb @@ -44,7 +44,7 @@ "outputs": [], "source": [ "from llama_index.llms import Konko\n", - "from llama_index.llms.base import ChatMessage" + "from llama_index.llms import ChatMessage" ] }, { diff --git a/docs/examples/llm/anyscale.ipynb b/docs/examples/llm/anyscale.ipynb index 7b12d0c98b..8ee0f60818 100644 --- a/docs/examples/llm/anyscale.ipynb +++ b/docs/examples/llm/anyscale.ipynb @@ -42,7 +42,7 @@ "outputs": [], "source": [ "from llama_index.llms import Anyscale\n", - "from llama_index.llms.base import ChatMessage" + "from llama_index.llms import ChatMessage" ] }, { diff --git a/docs/examples/llm/everlyai.ipynb b/docs/examples/llm/everlyai.ipynb index b99f50eb43..3c5062d2da 100644 --- a/docs/examples/llm/everlyai.ipynb +++ b/docs/examples/llm/everlyai.ipynb @@ -44,7 +44,7 @@ "outputs": [], "source": [ "from llama_index.llms import EverlyAI\n", - "from llama_index.llms.base import ChatMessage" + "from llama_index.llms import ChatMessage" ] }, { diff --git a/docs/examples/llm/litellm.ipynb b/docs/examples/llm/litellm.ipynb index d40af5d4fd..2049334b60 100755 --- a/docs/examples/llm/litellm.ipynb +++ b/docs/examples/llm/litellm.ipynb @@ -55,7 +55,6 @@ "source": [ "import os\n", "from llama_index.llms import LiteLLM, ChatMessage\n", - "from llama_index.llms.base import \n", "\n", "# set env variable\n", "os.environ[\"OPENAI_API_KEY\"] = \"your-api-key\"\n", diff --git a/docs/examples/llm/llm_predictor.ipynb b/docs/examples/llm/llm_predictor.ipynb index bc6fb4cdd6..dc352872d6 100644 --- a/docs/examples/llm/llm_predictor.ipynb +++ b/docs/examples/llm/llm_predictor.ipynb @@ -50,7 +50,7 @@ "outputs": [], "source": [ "from langchain.chat_models import ChatAnyscale, ChatOpenAI\n", - "from llama_index import LLMPredictor\n", + "from llama_index.llms import LangChainLLM\n", "from llama_index.prompts import PromptTemplate" ] }, @@ -61,7 +61,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm_predictor = LLMPredictor(ChatOpenAI())" + "llm = LangChainLLM(ChatOpenAI())" ] }, { @@ -71,7 +71,7 @@ "metadata": {}, "outputs": [], "source": [ - "stream = await llm_predictor.astream(PromptTemplate(\"Hi, write a short story\"))" + "stream = await llm.astream(PromptTemplate(\"Hi, write a short story\"))" ] }, { @@ -127,7 +127,7 @@ "outputs": [], "source": [ "## Test with ChatAnyscale\n", - "llm_predictor = LLMPredictor(ChatAnyscale())" + "llm = LangChainLLM(ChatAnyscale())" ] }, { @@ -145,7 +145,7 @@ } ], "source": [ - "stream = llm_predictor.stream(\n", + "stream = llm.stream(\n", " PromptTemplate(\"Hi, Which NFL team have most Super Bowl wins\")\n", ")\n", "for token in stream:\n", @@ -167,8 +167,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.llms import OpenAI\n", - "from llama_index import LLMPredictor" + "from llama_index.llms import OpenAI" ] }, { @@ -178,7 +177,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm_predictor = LLMPredictor(OpenAI())" + "llm = OpenAI()" ] }, { @@ -188,7 +187,7 @@ "metadata": {}, "outputs": [], "source": [ - "stream = await llm_predictor.astream(\"Hi, write a short story\")" + "stream = await llm.astream(\"Hi, write a short story\")" ] }, { diff --git a/docs/examples/llm/monsterapi.ipynb b/docs/examples/llm/monsterapi.ipynb index cc295e7f86..a38aea8154 100644 --- a/docs/examples/llm/monsterapi.ipynb +++ b/docs/examples/llm/monsterapi.ipynb @@ -160,7 +160,7 @@ } ], "source": [ - "from llama_index.llms.base import ChatMessage\n", + "from llama_index.llms import ChatMessage\n", "\n", "# Construct mock Chat history\n", "history_message = ChatMessage(\n", diff --git a/docs/examples/llm/openllm.ipynb b/docs/examples/llm/openllm.ipynb index 5076857c00..290a23c3fe 100644 --- a/docs/examples/llm/openllm.ipynb +++ b/docs/examples/llm/openllm.ipynb @@ -87,7 +87,7 @@ "from typing import List, Optional\n", "\n", "from llama_index.llms import OpenLLM, OpenLLMAPI\n", - "from llama_index.llms.base import ChatMessage" + "from llama_index.llms import ChatMessage" ] }, { diff --git a/docs/examples/llm/perplexity.ipynb b/docs/examples/llm/perplexity.ipynb index 868ec878bd..5ca5ba5bb1 100644 --- a/docs/examples/llm/perplexity.ipynb +++ b/docs/examples/llm/perplexity.ipynb @@ -65,7 +65,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.llms.base import ChatMessage\n", + "from llama_index.llms import ChatMessage\n", "\n", "messages_dict = [\n", " {\"role\": \"system\", \"content\": \"Be precise and concise.\"},\n", diff --git a/docs/examples/llm/rungpt.ipynb b/docs/examples/llm/rungpt.ipynb index 922d7b153a..4122438a09 100644 --- a/docs/examples/llm/rungpt.ipynb +++ b/docs/examples/llm/rungpt.ipynb @@ -122,7 +122,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.llms.base import ChatMessage, MessageRole\n", + "from llama_index.llms import ChatMessage, MessageRole\n", "from llama_index.llms.rungpt import RunGptLLM\n", "\n", "messages = [\n", diff --git a/docs/examples/llm/vertex.ipynb b/docs/examples/llm/vertex.ipynb index 6384f7ba0d..297550d6dc 100644 --- a/docs/examples/llm/vertex.ipynb +++ b/docs/examples/llm/vertex.ipynb @@ -60,7 +60,7 @@ ], "source": [ "from llama_index.llms.vertex import Vertex\n", - "from llama_index.llms.base import ChatMessage, MessageRole\n", + "from llama_index.llms import ChatMessage, MessageRole\n", "\n", "llm = Vertex(model=\"text-bison\", temperature=0, additional_kwargs={})\n", "llm.complete(\"Hello this is a sample text\").text" diff --git a/docs/examples/llm/vllm.ipynb b/docs/examples/llm/vllm.ipynb index 4ece58e24c..8cda81b3b3 100644 --- a/docs/examples/llm/vllm.ipynb +++ b/docs/examples/llm/vllm.ipynb @@ -491,7 +491,7 @@ "outputs": [], "source": [ "from llama_index.llms.vllm import VllmServer\n", - "from llama_index.llms.base import ChatMessage" + "from llama_index.llms import ChatMessage" ] }, { diff --git a/docs/examples/low_level/response_synthesis.ipynb b/docs/examples/low_level/response_synthesis.ipynb index 99e068809e..feb1c7a604 100644 --- a/docs/examples/low_level/response_synthesis.ipynb +++ b/docs/examples/low_level/response_synthesis.ipynb @@ -926,7 +926,7 @@ "outputs": [], "source": [ "from llama_index.retrievers import BaseRetriever\n", - "from llama_index.llms.base import LLM\n", + "from llama_index.llms.llm import LLM\n", "from dataclasses import dataclass\n", "from typing import Optional, List\n", "\n", diff --git a/docs/examples/node_postprocessor/LLMReranker-Gatsby.ipynb b/docs/examples/node_postprocessor/LLMReranker-Gatsby.ipynb index 7e9e6c6130..4089b1063c 100644 --- a/docs/examples/node_postprocessor/LLMReranker-Gatsby.ipynb +++ b/docs/examples/node_postprocessor/LLMReranker-Gatsby.ipynb @@ -41,7 +41,6 @@ " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", " ServiceContext,\n", - " LLMPredictor,\n", ")\n", "from llama_index.postprocessor import LLMRerank\n", "from llama_index.llms import OpenAI\n", diff --git a/docs/examples/node_postprocessor/LLMReranker-Lyft-10k.ipynb b/docs/examples/node_postprocessor/LLMReranker-Lyft-10k.ipynb index 00aec9e230..c4e1d08bbd 100644 --- a/docs/examples/node_postprocessor/LLMReranker-Lyft-10k.ipynb +++ b/docs/examples/node_postprocessor/LLMReranker-Lyft-10k.ipynb @@ -50,7 +50,6 @@ " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", " ServiceContext,\n", - " LLMPredictor,\n", ")\n", "from llama_index.postprocessor import LLMRerank\n", "\n", diff --git a/docs/examples/output_parsing/GuardrailsDemo.ipynb b/docs/examples/output_parsing/GuardrailsDemo.ipynb index e0e0ed10b4..5d32d49ca1 100644 --- a/docs/examples/output_parsing/GuardrailsDemo.ipynb +++ b/docs/examples/output_parsing/GuardrailsDemo.ipynb @@ -26,6 +26,16 @@ "If you're opening this Notebook on colab, you will probably need to install LlamaIndex 🦙." ] }, + { + "cell_type": "code", + "execution_count": null, + "id": "e716f66f", + "metadata": {}, + "outputs": [], + "source": [ + "!pip install guardrails-ai" + ] + }, { "attachments": {}, "cell_type": "markdown", @@ -40,7 +50,27 @@ "execution_count": null, "id": "649bea0c", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.\n", + "ERROR: could not open HSTS store at '/home/loganm/.wget-hsts'. HSTS will be disabled.\n", + "--2023-12-11 10:18:02-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 75042 (73K) [text/plain]\n", + "Saving to: ‘data/paul_graham/paul_graham_essay.txt’\n", + "\n", + "data/paul_graham/pa 100%[===================>] 73.28K --.-KB/s in 0.04s \n", + "\n", + "2023-12-11 10:18:02 (1.70 MB/s) - ‘data/paul_graham/paul_graham_essay.txt’ saved [75042/75042]\n", + "\n" + ] + } + ], "source": [ "!mkdir -p 'data/paul_graham/'\n", "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'" @@ -59,7 +89,35 @@ "execution_count": null, "id": "690a6918-7c75-4f95-9ccc-d2c4a1fe00d7", "metadata": {}, - "outputs": [], + "outputs": [ + { + "ename": "AttributeError", + "evalue": "module 'openai' has no attribute 'error'", + "output_type": "error", + "traceback": [ + "\u001b[0;31m---------------------------------------------------------------------------\u001b[0m", + "\u001b[0;31mAttributeError\u001b[0m Traceback (most recent call last)", + "\u001b[1;32m/home/loganm/llama_index_proper/llama_index/docs/examples/output_parsing/GuardrailsDemo.ipynb Cell 8\u001b[0m line \u001b[0;36m7\n\u001b[1;32m <a href='vscode-notebook-cell://wsl%2Bubuntu/home/loganm/llama_index_proper/llama_index/docs/examples/output_parsing/GuardrailsDemo.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=3'>4</a>\u001b[0m logging\u001b[39m.\u001b[39mbasicConfig(stream\u001b[39m=\u001b[39msys\u001b[39m.\u001b[39mstdout, level\u001b[39m=\u001b[39mlogging\u001b[39m.\u001b[39mINFO)\n\u001b[1;32m <a href='vscode-notebook-cell://wsl%2Bubuntu/home/loganm/llama_index_proper/llama_index/docs/examples/output_parsing/GuardrailsDemo.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=4'>5</a>\u001b[0m logging\u001b[39m.\u001b[39mgetLogger()\u001b[39m.\u001b[39maddHandler(logging\u001b[39m.\u001b[39mStreamHandler(stream\u001b[39m=\u001b[39msys\u001b[39m.\u001b[39mstdout))\n\u001b[0;32m----> <a href='vscode-notebook-cell://wsl%2Bubuntu/home/loganm/llama_index_proper/llama_index/docs/examples/output_parsing/GuardrailsDemo.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=6'>7</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m \u001b[39mimport\u001b[39;00m VectorStoreIndex, SimpleDirectoryReader\n\u001b[1;32m <a href='vscode-notebook-cell://wsl%2Bubuntu/home/loganm/llama_index_proper/llama_index/docs/examples/output_parsing/GuardrailsDemo.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=7'>8</a>\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mIPython\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mdisplay\u001b[39;00m \u001b[39mimport\u001b[39;00m Markdown, display\n\u001b[1;32m <a href='vscode-notebook-cell://wsl%2Bubuntu/home/loganm/llama_index_proper/llama_index/docs/examples/output_parsing/GuardrailsDemo.ipynb#W6sdnNjb2RlLXJlbW90ZQ%3D%3D?line=9'>10</a>\u001b[0m \u001b[39mimport\u001b[39;00m \u001b[39mos\u001b[39;00m\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/__init__.py:21\u001b[0m\n\u001b[1;32m 17\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39membeddings\u001b[39;00m \u001b[39mimport\u001b[39;00m OpenAIEmbedding\n\u001b[1;32m 19\u001b[0m \u001b[39m# indices\u001b[39;00m\n\u001b[1;32m 20\u001b[0m \u001b[39m# loading\u001b[39;00m\n\u001b[0;32m---> 21\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m \u001b[39mimport\u001b[39;00m (\n\u001b[1;32m 22\u001b[0m ComposableGraph,\n\u001b[1;32m 23\u001b[0m DocumentSummaryIndex,\n\u001b[1;32m 24\u001b[0m GPTDocumentSummaryIndex,\n\u001b[1;32m 25\u001b[0m GPTKeywordTableIndex,\n\u001b[1;32m 26\u001b[0m GPTKnowledgeGraphIndex,\n\u001b[1;32m 27\u001b[0m GPTListIndex,\n\u001b[1;32m 28\u001b[0m GPTRAKEKeywordTableIndex,\n\u001b[1;32m 29\u001b[0m GPTSimpleKeywordTableIndex,\n\u001b[1;32m 30\u001b[0m GPTTreeIndex,\n\u001b[1;32m 31\u001b[0m GPTVectorStoreIndex,\n\u001b[1;32m 32\u001b[0m KeywordTableIndex,\n\u001b[1;32m 33\u001b[0m KnowledgeGraphIndex,\n\u001b[1;32m 34\u001b[0m ListIndex,\n\u001b[1;32m 35\u001b[0m RAKEKeywordTableIndex,\n\u001b[1;32m 36\u001b[0m SimpleKeywordTableIndex,\n\u001b[1;32m 37\u001b[0m SummaryIndex,\n\u001b[1;32m 38\u001b[0m TreeIndex,\n\u001b[1;32m 39\u001b[0m VectorStoreIndex,\n\u001b[1;32m 40\u001b[0m load_graph_from_storage,\n\u001b[1;32m 41\u001b[0m load_index_from_storage,\n\u001b[1;32m 42\u001b[0m load_indices_from_storage,\n\u001b[1;32m 43\u001b[0m )\n\u001b[1;32m 45\u001b[0m \u001b[39m# structured\u001b[39;00m\n\u001b[1;32m 46\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcommon\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mstruct_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m SQLDocumentContextBuilder\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/__init__.py:29\u001b[0m\n\u001b[1;32m 27\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mlist\u001b[39;00m \u001b[39mimport\u001b[39;00m GPTListIndex, ListIndex, SummaryIndex\n\u001b[1;32m 28\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mlist\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m GPTListIndex, ListIndex, SummaryIndex\n\u001b[0;32m---> 29\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mloading\u001b[39;00m \u001b[39mimport\u001b[39;00m (\n\u001b[1;32m 30\u001b[0m load_graph_from_storage,\n\u001b[1;32m 31\u001b[0m load_index_from_storage,\n\u001b[1;32m 32\u001b[0m load_indices_from_storage,\n\u001b[1;32m 33\u001b[0m )\n\u001b[1;32m 34\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmanaged\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvectara\u001b[39;00m \u001b[39mimport\u001b[39;00m VectaraIndex\n\u001b[1;32m 35\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmulti_modal\u001b[39;00m \u001b[39mimport\u001b[39;00m MultiModalVectorStoreIndex\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/loading.py:6\u001b[0m\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m BaseIndex\n\u001b[1;32m 5\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mcomposability\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mgraph\u001b[39;00m \u001b[39mimport\u001b[39;00m ComposableGraph\n\u001b[0;32m----> 6\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mregistry\u001b[39;00m \u001b[39mimport\u001b[39;00m INDEX_STRUCT_TYPE_TO_INDEX_CLASS\n\u001b[1;32m 7\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mstorage\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mstorage_context\u001b[39;00m \u001b[39mimport\u001b[39;00m StorageContext\n\u001b[1;32m 9\u001b[0m logger \u001b[39m=\u001b[39m logging\u001b[39m.\u001b[39mgetLogger(\u001b[39m__name__\u001b[39m)\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/registry.py:12\u001b[0m\n\u001b[1;32m 10\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mknowledge_graph\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m KnowledgeGraphIndex\n\u001b[1;32m 11\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mlist\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m SummaryIndex\n\u001b[0;32m---> 12\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmulti_modal\u001b[39;00m \u001b[39mimport\u001b[39;00m MultiModalVectorStoreIndex\n\u001b[1;32m 13\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mstruct_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mpandas\u001b[39;00m \u001b[39mimport\u001b[39;00m PandasIndex\n\u001b[1;32m 14\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mstruct_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39msql\u001b[39;00m \u001b[39mimport\u001b[39;00m SQLStructStoreIndex\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/multi_modal/__init__.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m\"\"\"Vector-store based data structures.\"\"\"\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmulti_modal\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m MultiModalVectorStoreIndex\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mmulti_modal\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mretriever\u001b[39;00m \u001b[39mimport\u001b[39;00m MultiModalVectorIndexRetriever\n\u001b[1;32m 6\u001b[0m __all__ \u001b[39m=\u001b[39m [\n\u001b[1;32m 7\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mMultiModalVectorStoreIndex\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 8\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mMultiModalVectorIndexRetriever\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 9\u001b[0m ]\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/multi_modal/base.py:19\u001b[0m\n\u001b[1;32m 12\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39membeddings\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mutils\u001b[39;00m \u001b[39mimport\u001b[39;00m EmbedType, resolve_embed_model\n\u001b[1;32m 13\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mutils\u001b[39;00m \u001b[39mimport\u001b[39;00m (\n\u001b[1;32m 14\u001b[0m async_embed_image_nodes,\n\u001b[1;32m 15\u001b[0m async_embed_nodes,\n\u001b[1;32m 16\u001b[0m embed_image_nodes,\n\u001b[1;32m 17\u001b[0m embed_nodes,\n\u001b[1;32m 18\u001b[0m )\n\u001b[0;32m---> 19\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m VectorStoreIndex\n\u001b[1;32m 20\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mschema\u001b[39;00m \u001b[39mimport\u001b[39;00m BaseNode, ImageNode\n\u001b[1;32m 21\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mservice_context\u001b[39;00m \u001b[39mimport\u001b[39;00m ServiceContext\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/vector_store/__init__.py:4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m\"\"\"Vector-store based data structures.\"\"\"\u001b[39;00m\n\u001b[1;32m 3\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m GPTVectorStoreIndex, VectorStoreIndex\n\u001b[0;32m----> 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mretrievers\u001b[39;00m \u001b[39mimport\u001b[39;00m (\n\u001b[1;32m 5\u001b[0m VectorIndexAutoRetriever,\n\u001b[1;32m 6\u001b[0m VectorIndexRetriever,\n\u001b[1;32m 7\u001b[0m )\n\u001b[1;32m 9\u001b[0m __all__ \u001b[39m=\u001b[39m [\n\u001b[1;32m 10\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mVectorStoreIndex\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 11\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mVectorIndexRetriever\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[0;32m (...)\u001b[0m\n\u001b[1;32m 14\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mGPTVectorStoreIndex\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 15\u001b[0m ]\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/vector_store/retrievers/__init__.py:4\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mretrievers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mretriever\u001b[39;00m \u001b[39mimport\u001b[39;00m ( \u001b[39m# noqa: I001\u001b[39;00m\n\u001b[1;32m 2\u001b[0m VectorIndexRetriever,\n\u001b[1;32m 3\u001b[0m )\n\u001b[0;32m----> 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mretrievers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mauto_retriever\u001b[39;00m \u001b[39mimport\u001b[39;00m (\n\u001b[1;32m 5\u001b[0m VectorIndexAutoRetriever,\n\u001b[1;32m 6\u001b[0m )\n\u001b[1;32m 8\u001b[0m __all__ \u001b[39m=\u001b[39m [\n\u001b[1;32m 9\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mVectorIndexRetriever\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 10\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mVectorIndexAutoRetriever\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 11\u001b[0m ]\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/vector_store/retrievers/auto_retriever/__init__.py:1\u001b[0m\n\u001b[0;32m----> 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mretrievers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mauto_retriever\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mauto_retriever\u001b[39;00m \u001b[39mimport\u001b[39;00m (\n\u001b[1;32m 2\u001b[0m VectorIndexAutoRetriever,\n\u001b[1;32m 3\u001b[0m )\n\u001b[1;32m 5\u001b[0m __all__ \u001b[39m=\u001b[39m [\n\u001b[1;32m 6\u001b[0m \u001b[39m\"\u001b[39m\u001b[39mVectorIndexAutoRetriever\u001b[39m\u001b[39m\"\u001b[39m,\n\u001b[1;32m 7\u001b[0m ]\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py:9\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m VectorStoreIndex\n\u001b[1;32m 8\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mretrievers\u001b[39;00m \u001b[39mimport\u001b[39;00m VectorIndexRetriever\n\u001b[0;32m----> 9\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mretrievers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mauto_retriever\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39moutput_parser\u001b[39;00m \u001b[39mimport\u001b[39;00m (\n\u001b[1;32m 10\u001b[0m VectorStoreQueryOutputParser,\n\u001b[1;32m 11\u001b[0m )\n\u001b[1;32m 12\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mindices\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mvector_store\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mretrievers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mauto_retriever\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mprompts\u001b[39;00m \u001b[39mimport\u001b[39;00m (\n\u001b[1;32m 13\u001b[0m DEFAULT_VECTOR_STORE_QUERY_PROMPT_TMPL,\n\u001b[1;32m 14\u001b[0m VectorStoreQueryPrompt,\n\u001b[1;32m 15\u001b[0m )\n\u001b[1;32m 16\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39moutput_parsers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m OutputParserException, StructuredOutput\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/indices/vector_store/retrievers/auto_retriever/output_parser.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mtyping\u001b[39;00m \u001b[39mimport\u001b[39;00m Any\n\u001b[0;32m----> 3\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39moutput_parsers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mbase\u001b[39;00m \u001b[39mimport\u001b[39;00m StructuredOutput\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39moutput_parsers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mutils\u001b[39;00m \u001b[39mimport\u001b[39;00m parse_json_markdown\n\u001b[1;32m 5\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mtypes\u001b[39;00m \u001b[39mimport\u001b[39;00m BaseOutputParser\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/output_parsers/__init__.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m\"\"\"Output parsers.\"\"\"\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39moutput_parsers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mguardrails\u001b[39;00m \u001b[39mimport\u001b[39;00m GuardrailsOutputParser\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39moutput_parsers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mlangchain\u001b[39;00m \u001b[39mimport\u001b[39;00m LangchainOutputParser\n\u001b[1;32m 5\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mllama_index\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39moutput_parsers\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mpydantic\u001b[39;00m \u001b[39mimport\u001b[39;00m PydanticOutputParser\n", + "File \u001b[0;32m~/llama_index_proper/llama_index/llama_index/output_parsers/guardrails.py:9\u001b[0m\n\u001b[1;32m 6\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mdeprecated\u001b[39;00m \u001b[39mimport\u001b[39;00m deprecated\n\u001b[1;32m 8\u001b[0m \u001b[39mtry\u001b[39;00m:\n\u001b[0;32m----> 9\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mguardrails\u001b[39;00m \u001b[39mimport\u001b[39;00m Guard\n\u001b[1;32m 10\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mImportError\u001b[39;00m:\n\u001b[1;32m 11\u001b[0m Guard \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages/guardrails/__init__.py:3\u001b[0m\n\u001b[1;32m 1\u001b[0m \u001b[39m# Set up __init__.py so that users can do from guardrails import Response, Schema, etc.\u001b[39;00m\n\u001b[0;32m----> 3\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mguardrails\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mguard\u001b[39;00m \u001b[39mimport\u001b[39;00m Guard\n\u001b[1;32m 4\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mguardrails\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mllm_providers\u001b[39;00m \u001b[39mimport\u001b[39;00m PromptCallableBase\n\u001b[1;32m 5\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mguardrails\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mlogging_utils\u001b[39;00m \u001b[39mimport\u001b[39;00m configure_logging\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages/guardrails/guard.py:10\u001b[0m\n\u001b[1;32m 7\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39meliot\u001b[39;00m \u001b[39mimport\u001b[39;00m add_destinations, start_action\n\u001b[1;32m 8\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mpydantic\u001b[39;00m \u001b[39mimport\u001b[39;00m BaseModel\n\u001b[0;32m---> 10\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mguardrails\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mllm_providers\u001b[39;00m \u001b[39mimport\u001b[39;00m get_async_llm_ask, get_llm_ask\n\u001b[1;32m 11\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mguardrails\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mprompt\u001b[39;00m \u001b[39mimport\u001b[39;00m Instructions, Prompt\n\u001b[1;32m 12\u001b[0m \u001b[39mfrom\u001b[39;00m \u001b[39mguardrails\u001b[39;00m\u001b[39m.\u001b[39;00m\u001b[39mrail\u001b[39;00m \u001b[39mimport\u001b[39;00m Rail\n", + "File \u001b[0;32m~/.cache/pypoetry/virtualenvs/llama-index-4a-wkI5X-py3.11/lib/python3.11/site-packages/guardrails/llm_providers.py:24\u001b[0m\n\u001b[1;32m 19\u001b[0m \u001b[39mexcept\u001b[39;00m \u001b[39mImportError\u001b[39;00m:\n\u001b[1;32m 20\u001b[0m cohere \u001b[39m=\u001b[39m \u001b[39mNone\u001b[39;00m\n\u001b[1;32m 23\u001b[0m OPENAI_RETRYABLE_ERRORS \u001b[39m=\u001b[39m [\n\u001b[0;32m---> 24\u001b[0m openai\u001b[39m.\u001b[39;49merror\u001b[39m.\u001b[39mAPIConnectionError,\n\u001b[1;32m 25\u001b[0m openai\u001b[39m.\u001b[39merror\u001b[39m.\u001b[39mAPIError,\n\u001b[1;32m 26\u001b[0m openai\u001b[39m.\u001b[39merror\u001b[39m.\u001b[39mTryAgain,\n\u001b[1;32m 27\u001b[0m openai\u001b[39m.\u001b[39merror\u001b[39m.\u001b[39mTimeout,\n\u001b[1;32m 28\u001b[0m openai\u001b[39m.\u001b[39merror\u001b[39m.\u001b[39mRateLimitError,\n\u001b[1;32m 29\u001b[0m openai\u001b[39m.\u001b[39merror\u001b[39m.\u001b[39mServiceUnavailableError,\n\u001b[1;32m 30\u001b[0m ]\n\u001b[1;32m 31\u001b[0m RETRYABLE_ERRORS \u001b[39m=\u001b[39m \u001b[39mtuple\u001b[39m(OPENAI_RETRYABLE_ERRORS)\n\u001b[1;32m 34\u001b[0m \u001b[39mclass\u001b[39;00m \u001b[39mPromptCallableException\u001b[39;00m(\u001b[39mException\u001b[39;00m):\n", + "\u001b[0;31mAttributeError\u001b[0m: module 'openai' has no attribute 'error'" + ] + } + ], "source": [ "import logging\n", "import sys\n", @@ -70,9 +128,9 @@ "from llama_index import VectorStoreIndex, SimpleDirectoryReader\n", "from IPython.display import Markdown, display\n", "\n", - "import openai\n", + "import os\n", "\n", - "openai.api_key = \"<YOUR_OPENAI_API_KEY>\"" + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"" ] }, { @@ -122,18 +180,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.output_parsers import GuardrailsOutputParser\n", - "from llama_index.llm_predictor import StructuredLLMPredictor" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "057139d2-09e8-4b8d-83a1-a2356a1475a8", - "metadata": {}, - "outputs": [], - "source": [ - "llm_predictor = StructuredLLMPredictor()" + "from llama_index.output_parsers import GuardrailsOutputParser" ] }, { @@ -144,20 +191,6 @@ "**Define custom QA and Refine Prompts**\n" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "2833d086-d240-4798-b3c5-a83ac4593b0e", - "metadata": {}, - "outputs": [], - "source": [ - "from llama_index.prompts import PromptTemplate\n", - "from llama_index.prompts.default_prompts import (\n", - " DEFAULT_TEXT_QA_PROMPT_TMPL,\n", - " DEFAULT_REFINE_PROMPT_TMPL,\n", - ")" - ] - }, { "cell_type": "markdown", "id": "dba8513e", @@ -216,28 +249,16 @@ "metadata": {}, "outputs": [], "source": [ + "from llama_index.llms import OpenAI\n", + "\n", "# Create a guard object\n", "guard = gd.Guard.from_pydantic(output_class=BulletPoints, prompt=prompt)\n", "\n", "# Create output parse object\n", - "output_parser = GuardrailsOutputParser(guard, llm=llm_predictor.llm)" - ] - }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9b440d4-6fb4-46e6-973f-44207b432d3f", - "metadata": {}, - "outputs": [], - "source": [ - "# NOTE: we use the same output parser for both prompts, though you can choose to use different parsers\n", - "# NOTE: here we add formatting instructions to the prompts.\n", - "\n", - "fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL)\n", - "fmt_refine_tmpl = output_parser.format(DEFAULT_REFINE_PROMPT_TMPL)\n", + "output_parser = GuardrailsOutputParser(guard, llm=OpenAI())\n", "\n", - "qa_prompt = PromptTemplate(fmt_qa_tmpl, output_parser=output_parser)\n", - "refine_prompt = PromptTemplate(fmt_refine_tmpl, output_parser=output_parser)" + "# attach to an llm object\n", + "llm = OpenAI(output_parser=output_parser)" ] }, { @@ -281,7 +302,12 @@ } ], "source": [ + "from llama_index.prompts.default_prompts import (\n", + " DEFAULT_TEXT_QA_PROMPT_TMPL,\n", + ")\n", + "\n", "# take a look at the new QA template!\n", + "fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL)\n", "print(fmt_qa_tmpl)" ] }, @@ -311,10 +337,12 @@ } ], "source": [ + "from llama_index import ServiceContext\n", + "\n", + "ctx = ServiceContext.from_defaults(llm=llm)\n", + "\n", "query_engine = index.as_query_engine(\n", - " text_qa_template=qa_prompt,\n", - " refine_template=refine_prompt,\n", - " llm_predictor=llm_predictor,\n", + " service_context=ctx,\n", ")\n", "response = query_engine.query(\n", " \"What are the three items the author did growing up?\",\n", diff --git a/docs/examples/output_parsing/LangchainOutputParserDemo.ipynb b/docs/examples/output_parsing/LangchainOutputParserDemo.ipynb index 9c7ee10231..c375090164 100644 --- a/docs/examples/output_parsing/LangchainOutputParserDemo.ipynb +++ b/docs/examples/output_parsing/LangchainOutputParserDemo.ipynb @@ -32,7 +32,27 @@ "execution_count": null, "id": "b9635dc3", "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Will not apply HSTS. The HSTS database must be a regular and non-world-writable file.\n", + "ERROR: could not open HSTS store at '/home/loganm/.wget-hsts'. HSTS will be disabled.\n", + "--2023-12-11 10:24:04-- https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt\n", + "Resolving raw.githubusercontent.com (raw.githubusercontent.com)... 185.199.110.133, 185.199.109.133, 185.199.108.133, ...\n", + "Connecting to raw.githubusercontent.com (raw.githubusercontent.com)|185.199.110.133|:443... connected.\n", + "HTTP request sent, awaiting response... 200 OK\n", + "Length: 75042 (73K) [text/plain]\n", + "Saving to: ‘data/paul_graham/paul_graham_essay.txt’\n", + "\n", + "data/paul_graham/pa 100%[===================>] 73.28K --.-KB/s in 0.04s \n", + "\n", + "2023-12-11 10:24:04 (1.74 MB/s) - ‘data/paul_graham/paul_graham_essay.txt’ saved [75042/75042]\n", + "\n" + ] + } + ], "source": [ "!mkdir -p 'data/paul_graham/'\n", "!wget 'https://raw.githubusercontent.com/run-llama/llama_index/main/docs/examples/data/paul_graham/paul_graham_essay.txt' -O 'data/paul_graham/paul_graham_essay.txt'" @@ -60,7 +80,11 @@ "logging.getLogger().addHandler(logging.StreamHandler(stream=sys.stdout))\n", "\n", "from llama_index import VectorStoreIndex, SimpleDirectoryReader\n", - "from IPython.display import Markdown, display" + "from IPython.display import Markdown, display\n", + "\n", + "import os\n", + "\n", + "os.environ[\"OPENAI_API_KEY\"] = \"sk-...\"" ] }, { @@ -84,10 +108,12 @@ "name": "stdout", "output_type": "stream", "text": [ - "INFO:llama_index.token_counter.token_counter:> [build_index_from_documents] Total LLM token usage: 0 tokens\n", - "> [build_index_from_documents] Total LLM token usage: 0 tokens\n", - "INFO:llama_index.token_counter.token_counter:> [build_index_from_documents] Total embedding token usage: 18579 tokens\n", - "> [build_index_from_documents] Total embedding token usage: 18579 tokens\n" + "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", + "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", + "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", + "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", + "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", + "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n" ] } ], @@ -111,20 +137,9 @@ "outputs": [], "source": [ "from llama_index.output_parsers import LangchainOutputParser\n", - "from llama_index.llm_predictor import StructuredLLMPredictor\n", "from langchain.output_parsers import StructuredOutputParser, ResponseSchema" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "057139d2-09e8-4b8d-83a1-a2356a1475a8", - "metadata": {}, - "outputs": [], - "source": [ - "llm_predictor = StructuredLLMPredictor()" - ] - }, { "cell_type": "markdown", "id": "bc25edf7-9343-4e82-a3f1-eec4281a9371", @@ -133,20 +148,6 @@ "**Define custom QA and Refine Prompts**" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "2833d086-d240-4798-b3c5-a83ac4593b0e", - "metadata": {}, - "outputs": [], - "source": [ - "from llama_index.prompts import PromptTemplate\n", - "from llama_index.prompts.default_prompts import (\n", - " DEFAULT_TEXT_QA_PROMPT_TMPL,\n", - " DEFAULT_REFINE_PROMPT_TMPL,\n", - ")" - ] - }, { "cell_type": "code", "execution_count": null, @@ -181,23 +182,6 @@ "output_parser = LangchainOutputParser(lc_output_parser)" ] }, - { - "cell_type": "code", - "execution_count": null, - "id": "a9b440d4-6fb4-46e6-973f-44207b432d3f", - "metadata": {}, - "outputs": [], - "source": [ - "# NOTE: we use the same output parser for both prompts, though you can choose to use different parsers\n", - "# NOTE: here we add formatting instructions to the prompts.\n", - "\n", - "fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL)\n", - "fmt_refine_tmpl = output_parser.format(DEFAULT_REFINE_PROMPT_TMPL)\n", - "\n", - "qa_prompt = PromptTemplate(fmt_qa_tmpl, output_parser=output_parser)\n", - "refine_prompt = PromptTemplate(fmt_refine_tmpl, output_parser=output_parser)" - ] - }, { "cell_type": "code", "execution_count": null, @@ -208,14 +192,15 @@ "name": "stdout", "output_type": "stream", "text": [ - "Context information is below. \n", + "Context information is below.\n", "---------------------\n", "{context_str}\n", "---------------------\n", - "Given the context information and not prior knowledge, answer the question: {query_str}\n", + "Given the context information and not prior knowledge, answer the query.\n", + "Query: {query_str}\n", + "Answer: \n", "\n", - "\n", - "The output should be a markdown code snippet formatted in the following schema:\n", + "The output should be a markdown code snippet formatted in the following schema, including the leading and trailing \"```json\" and \"```\":\n", "\n", "```json\n", "{{\n", @@ -227,7 +212,12 @@ } ], "source": [ + "from llama_index.prompts.default_prompts import (\n", + " DEFAULT_TEXT_QA_PROMPT_TMPL,\n", + ")\n", + "\n", "# take a look at the new QA template!\n", + "fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL)\n", "print(fmt_qa_tmpl)" ] }, @@ -245,40 +235,32 @@ "id": "fb9cdf43-0f31-4c36-869b-df9fa50aebdb", "metadata": {}, "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:llama_index.token_counter.token_counter:> [query] Total LLM token usage: 609 tokens\n" - ] - }, { "name": "stdout", "output_type": "stream", "text": [ - "> [query] Total LLM token usage: 609 tokens\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "INFO:llama_index.token_counter.token_counter:> [query] Total embedding token usage: 11 tokens\n" + "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n", + "HTTP Request: POST https://api.openai.com/v1/embeddings \"HTTP/1.1 200 OK\"\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "> [query] Total embedding token usage: 11 tokens\n" + "INFO:httpx:HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n", + "HTTP Request: POST https://api.openai.com/v1/chat/completions \"HTTP/1.1 200 OK\"\n" ] } ], "source": [ + "from llama_index import ServiceContext\n", + "from llama_index.llms import OpenAI\n", + "\n", + "llm = OpenAI(output_parser=output_parser)\n", + "ctx = ServiceContext.from_defaults(llm=llm)\n", + "\n", "query_engine = index.as_query_engine(\n", - " text_qa_template=qa_prompt,\n", - " refine_template=refine_prompt,\n", - " llm_predictor=llm_predictor,\n", + " service_context=ctx,\n", ")\n", "response = query_engine.query(\n", " \"What are a few things the author did growing up?\",\n", @@ -295,7 +277,7 @@ "name": "stdout", "output_type": "stream", "text": [ - "{'Education': 'Before college, the author wrote short stories and experimented with programming on an IBM 1401.', 'Work': 'The author worked on writing and programming outside of school.'}\n" + "{'Education': 'The author did not plan to study programming in college, but initially planned to study philosophy.', 'Work': 'Growing up, the author worked on writing short stories and programming. They wrote simple games, a program to predict rocket heights, and a word processor.'}\n" ] } ], @@ -306,9 +288,9 @@ ], "metadata": { "kernelspec": { - "display_name": "llama_index", + "display_name": "llama-index-4a-wkI5X-py3.11", "language": "python", - "name": "llama_index" + "name": "python3" }, "language_info": { "codemirror_mode": { diff --git a/docs/examples/output_parsing/evaporate_program.ipynb b/docs/examples/output_parsing/evaporate_program.ipynb index be9dbafcd1..a2e68be329 100644 --- a/docs/examples/output_parsing/evaporate_program.ipynb +++ b/docs/examples/output_parsing/evaporate_program.ipynb @@ -60,7 +60,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import SimpleDirectoryReader, ServiceContext, LLMPredictor\n", + "from llama_index import SimpleDirectoryReader, ServiceContext\n", "from llama_index.program.predefined import (\n", " DFEvaporateProgram,\n", " EvaporateExtractor,\n", diff --git a/docs/examples/output_parsing/guidance_sub_question.ipynb b/docs/examples/output_parsing/guidance_sub_question.ipynb index 39e63135d3..9220523f93 100644 --- a/docs/examples/output_parsing/guidance_sub_question.ipynb +++ b/docs/examples/output_parsing/guidance_sub_question.ipynb @@ -204,7 +204,6 @@ "source": [ "from llama_index import (\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", " VectorStoreIndex,\n", ")\n", diff --git a/docs/examples/query_engine/JointQASummary.ipynb b/docs/examples/query_engine/JointQASummary.ipynb index a4e1fd48b8..68952479f2 100644 --- a/docs/examples/query_engine/JointQASummary.ipynb +++ b/docs/examples/query_engine/JointQASummary.ipynb @@ -71,7 +71,7 @@ "from llama_index.composability.joint_qa_summary import (\n", " QASummaryQueryEngineBuilder,\n", ")\n", - "from llama_index import SimpleDirectoryReader, ServiceContext, LLMPredictor\n", + "from llama_index import SimpleDirectoryReader, ServiceContext\n", "from llama_index.response.notebook_utils import display_response\n", "from llama_index.llms import OpenAI" ] diff --git a/docs/examples/query_engine/SQLAutoVectorQueryEngine.ipynb b/docs/examples/query_engine/SQLAutoVectorQueryEngine.ipynb index 33c692a01d..0c78a15289 100644 --- a/docs/examples/query_engine/SQLAutoVectorQueryEngine.ipynb +++ b/docs/examples/query_engine/SQLAutoVectorQueryEngine.ipynb @@ -170,7 +170,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import ServiceContext, LLMPredictor\n", + "from llama_index import ServiceContext\n", "from llama_index.storage import StorageContext\n", "from llama_index.vector_stores import PineconeVectorStore\n", "from llama_index.node_parser import TokenTextSplitter\n", diff --git a/docs/examples/query_engine/SQLJoinQueryEngine.ipynb b/docs/examples/query_engine/SQLJoinQueryEngine.ipynb index 07d19dc598..dff62f347f 100644 --- a/docs/examples/query_engine/SQLJoinQueryEngine.ipynb +++ b/docs/examples/query_engine/SQLJoinQueryEngine.ipynb @@ -137,7 +137,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import ServiceContext, LLMPredictor\n", + "from llama_index import ServiceContext\n", "from llama_index.storage import StorageContext\n", "from llama_index.vector_stores import PineconeVectorStore\n", "from llama_index.node_parser import TokenTextSplitter\n", diff --git a/docs/examples/query_engine/citation_query_engine.ipynb b/docs/examples/query_engine/citation_query_engine.ipynb index 034359448d..57b0b196c9 100644 --- a/docs/examples/query_engine/citation_query_engine.ipynb +++ b/docs/examples/query_engine/citation_query_engine.ipynb @@ -73,7 +73,6 @@ " SimpleDirectoryReader,\n", " StorageContext,\n", " load_index_from_storage,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")" ] diff --git a/docs/examples/query_engine/flare_query_engine.ipynb b/docs/examples/query_engine/flare_query_engine.ipynb index 983273bf43..85f4cc106c 100644 --- a/docs/examples/query_engine/flare_query_engine.ipynb +++ b/docs/examples/query_engine/flare_query_engine.ipynb @@ -62,7 +62,6 @@ " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", " StorageContext,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")" ] diff --git a/docs/examples/query_engine/knowledge_graph_query_engine.ipynb b/docs/examples/query_engine/knowledge_graph_query_engine.ipynb index e25557e9bb..912f121ac0 100644 --- a/docs/examples/query_engine/knowledge_graph_query_engine.ipynb +++ b/docs/examples/query_engine/knowledge_graph_query_engine.ipynb @@ -80,7 +80,6 @@ "\n", "from llama_index import (\n", " KnowledgeGraphIndex,\n", - " LLMPredictor,\n", " ServiceContext,\n", " SimpleDirectoryReader,\n", ")\n", @@ -113,7 +112,6 @@ " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", " KnowledgeGraphIndex,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "\n", diff --git a/docs/examples/query_engine/knowledge_graph_rag_query_engine.ipynb b/docs/examples/query_engine/knowledge_graph_rag_query_engine.ipynb index afca05e20c..8e710bd14a 100644 --- a/docs/examples/query_engine/knowledge_graph_rag_query_engine.ipynb +++ b/docs/examples/query_engine/knowledge_graph_rag_query_engine.ipynb @@ -88,7 +88,6 @@ "\n", "from llama_index import (\n", " KnowledgeGraphIndex,\n", - " LLMPredictor,\n", " ServiceContext,\n", " SimpleDirectoryReader,\n", ")\n", @@ -121,7 +120,6 @@ " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", " KnowledgeGraphIndex,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "\n", diff --git a/docs/examples/query_engine/pdf_tables/recursive_retriever.ipynb b/docs/examples/query_engine/pdf_tables/recursive_retriever.ipynb index 511a283197..77039150dd 100644 --- a/docs/examples/query_engine/pdf_tables/recursive_retriever.ipynb +++ b/docs/examples/query_engine/pdf_tables/recursive_retriever.ipynb @@ -34,7 +34,7 @@ "from llama_index import Document, SummaryIndex\n", "\n", "# https://en.wikipedia.org/wiki/The_World%27s_Billionaires\n", - "from llama_index import VectorStoreIndex, ServiceContext, LLMPredictor\n", + "from llama_index import VectorStoreIndex, ServiceContext\n", "from llama_index.query_engine import PandasQueryEngine, RetrieverQueryEngine\n", "from llama_index.retrievers import RecursiveRetriever\n", "from llama_index.schema import IndexNode\n", diff --git a/docs/examples/query_transformations/SimpleIndexDemo-multistep.ipynb b/docs/examples/query_transformations/SimpleIndexDemo-multistep.ipynb index 1540419935..18b69a28a7 100644 --- a/docs/examples/query_transformations/SimpleIndexDemo-multistep.ipynb +++ b/docs/examples/query_transformations/SimpleIndexDemo-multistep.ipynb @@ -82,7 +82,6 @@ "from llama_index import (\n", " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "from llama_index.llms import OpenAI\n", @@ -96,11 +95,11 @@ "metadata": {}, "outputs": [], "source": [ - "# LLM Predictor (gpt-3)\n", + "# LLM (gpt-3)\n", "gpt3 = OpenAI(temperature=0, model=\"text-davinci-003\")\n", "service_context_gpt3 = ServiceContext.from_defaults(llm=gpt3)\n", "\n", - "# LLMPredictor (gpt-4)\n", + "# LLM (gpt-4)\n", "gpt4 = OpenAI(temperature=0, model=\"gpt-4\")\n", "service_context_gpt4 = ServiceContext.from_defaults(llm=gpt4)" ] @@ -145,16 +144,13 @@ "from llama_index.indices.query.query_transform.base import (\n", " StepDecomposeQueryTransform,\n", ")\n", - "from llama_index import LLMPredictor\n", "\n", "# gpt-4\n", - "step_decompose_transform = StepDecomposeQueryTransform(\n", - " LLMPredictor(llm=gpt4), verbose=True\n", - ")\n", + "step_decompose_transform = StepDecomposeQueryTransform(llm=gpt4, verbose=True)\n", "\n", "# gpt-3\n", "step_decompose_transform_gpt3 = StepDecomposeQueryTransform(\n", - " LLMPredictor(llm=gpt3), verbose=True\n", + " llm=gpt3, verbose=True\n", ")" ] }, diff --git a/docs/examples/usecases/City_Analysis-Decompose-KeywordTable.ipynb b/docs/examples/usecases/City_Analysis-Decompose-KeywordTable.ipynb index 463cc42a8e..53d547b3b3 100644 --- a/docs/examples/usecases/City_Analysis-Decompose-KeywordTable.ipynb +++ b/docs/examples/usecases/City_Analysis-Decompose-KeywordTable.ipynb @@ -80,7 +80,6 @@ " SimpleKeywordTableIndex,\n", " SummaryIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "import requests" diff --git a/docs/examples/vector_stores/SimpleIndexDemoMMR.ipynb b/docs/examples/vector_stores/SimpleIndexDemoMMR.ipynb index 96efbb40d2..58e2f570e6 100644 --- a/docs/examples/vector_stores/SimpleIndexDemoMMR.ipynb +++ b/docs/examples/vector_stores/SimpleIndexDemoMMR.ipynb @@ -192,7 +192,6 @@ " VectorStoreIndex,\n", " SimpleDirectoryReader,\n", " ServiceContext,\n", - " LLMPredictor,\n", ")\n", "from llama_index.response.notebook_utils import display_source_node\n", "from llama_index.llms import OpenAI\n", diff --git a/docs/module_guides/models/llms/usage_custom.md b/docs/module_guides/models/llms/usage_custom.md index 75acbd1f9c..103ebe1f98 100644 --- a/docs/module_guides/models/llms/usage_custom.md +++ b/docs/module_guides/models/llms/usage_custom.md @@ -24,7 +24,6 @@ you may also plug in any LLM shown on Langchain's from llama_index import ( KeywordTableIndex, SimpleDirectoryReader, - LLMPredictor, ServiceContext, ) from llama_index.llms import OpenAI diff --git a/docs/module_guides/querying/output_parser.md b/docs/module_guides/querying/output_parser.md index 60fc812f63..7225674738 100644 --- a/docs/module_guides/querying/output_parser.md +++ b/docs/module_guides/querying/output_parser.md @@ -13,22 +13,12 @@ Guardrails is an open-source Python package for specification/validation/correct ```python from llama_index import VectorStoreIndex, SimpleDirectoryReader from llama_index.output_parsers import GuardrailsOutputParser -from llama_index.llm_predictor import StructuredLLMPredictor -from llama_index.prompts import PromptTemplate -from llama_index.prompts.default_prompts import ( - DEFAULT_TEXT_QA_PROMPT_TMPL, - DEFAULT_REFINE_PROMPT_TMPL, -) +from llama_index.llms import OpenAI # load documents, build index documents = SimpleDirectoryReader("../paul_graham_essay/data").load_data() index = VectorStoreIndex(documents, chunk_size=512) -llm_predictor = StructuredLLMPredictor() - - -# specify StructuredLLMPredictor -# this is a special LLMPredictor that allows for structured outputs # define query / output spec rail_spec = """ @@ -59,22 +49,18 @@ Query string here. # define output parser output_parser = GuardrailsOutputParser.from_rail_string( - rail_spec, llm=llm_predictor.llm + rail_spec, llm=OpenAI() ) -# format each prompt with output parser instructions -fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL) -fmt_refine_tmpl = output_parser.format(DEFAULT_REFINE_PROMPT_TMPL) - -qa_prompt = PromptTemplate(fmt_qa_tmpl, output_parser=output_parser) -refine_prompt = PromptTemplate(fmt_refine_tmpl, output_parser=output_parser) +# Attach output parser to LLM +llm = OpenAI(output_parser=output_parser) # obtain a structured response -query_engine = index.as_query_engine( - service_context=ServiceContext.from_defaults(llm_predictor=llm_predictor), - text_qa_template=qa_prompt, - refine_template=refine_prompt, -) +from llama_index import ServiceContext + +ctx = ServiceContext.from_defaults(llm=llm) + +query_engine = index.as_query_engine(service_context=ctx) response = query_engine.query( "What are the three items the author did growing up?", ) @@ -94,19 +80,13 @@ Langchain also offers output parsing modules that you can use within LlamaIndex. ```python from llama_index import VectorStoreIndex, SimpleDirectoryReader from llama_index.output_parsers import LangchainOutputParser -from llama_index.llm_predictor import StructuredLLMPredictor -from llama_index.prompts import PromptTemplate -from llama_index.prompts.default_prompts import ( - DEFAULT_TEXT_QA_PROMPT_TMPL, - DEFAULT_REFINE_PROMPT_TMPL, -) +from llama_index.llms import OpenAI from langchain.output_parsers import StructuredOutputParser, ResponseSchema # load documents, build index documents = SimpleDirectoryReader("../paul_graham_essay/data").load_data() index = VectorStoreIndex.from_documents(documents) -llm_predictor = StructuredLLMPredictor() # define output schema response_schemas = [ @@ -126,18 +106,15 @@ lc_output_parser = StructuredOutputParser.from_response_schemas( ) output_parser = LangchainOutputParser(lc_output_parser) -# format each prompt with output parser instructions -fmt_qa_tmpl = output_parser.format(DEFAULT_TEXT_QA_PROMPT_TMPL) -fmt_refine_tmpl = output_parser.format(DEFAULT_REFINE_PROMPT_TMPL) -qa_prompt = PromptTemplate(fmt_qa_tmpl, output_parser=output_parser) -refine_prompt = PromptTemplate(fmt_refine_tmpl, output_parser=output_parser) - -# query index -query_engine = index.as_query_engine( - service_context=ServiceContext.from_defaults(llm_predictor=llm_predictor), - text_qa_template=qa_prompt, - refine_template=refine_prompt, -) +# Attach output parser to LLM +llm = OpenAI(output_parser=output_parser) + +# obtain a structured response +from llama_index import ServiceContext + +ctx = ServiceContext.from_defaults(llm=llm) + +query_engine = index.as_query_engine(service_context=ctx) response = query_engine.query( "What are a few things the author did growing up?", ) diff --git a/docs/module_guides/supporting_modules/service_context.md b/docs/module_guides/supporting_modules/service_context.md index 1f36eca077..eff6ea8bf4 100644 --- a/docs/module_guides/supporting_modules/service_context.md +++ b/docs/module_guides/supporting_modules/service_context.md @@ -69,7 +69,6 @@ Here's a complete example that sets up all objects using their default settings: ```python from llama_index import ( ServiceContext, - LLMPredictor, OpenAIEmbedding, PromptHelper, ) diff --git a/docs/optimizing/advanced_retrieval/query_transformations.md b/docs/optimizing/advanced_retrieval/query_transformations.md index c559fe2fad..d10af98c38 100644 --- a/docs/optimizing/advanced_retrieval/query_transformations.md +++ b/docs/optimizing/advanced_retrieval/query_transformations.md @@ -66,14 +66,12 @@ Here's a corresponding example code snippet over a composed graph. ```python # Setting: a summary index composed over multiple vector indices -# llm_predictor_chatgpt corresponds to the ChatGPT LLM interface +# llm_chatgpt corresponds to the ChatGPT LLM interface from llama_index.indices.query.query_transform.base import ( DecomposeQueryTransform, ) -decompose_transform = DecomposeQueryTransform( - llm_predictor_chatgpt, verbose=True -) +decompose_transform = DecomposeQueryTransform(llm_chatgpt, verbose=True) # initialize indexes and graph ... @@ -117,9 +115,7 @@ from llama_index.indices.query.query_transform.base import ( ) # gpt-4 -step_decompose_transform = StepDecomposeQueryTransform( - llm_predictor, verbose=True -) +step_decompose_transform = StepDecomposeQueryTransform(llm, verbose=True) query_engine = index.as_query_engine() query_engine = MultiStepQueryEngine( diff --git a/docs/understanding/evaluating/cost_analysis/root.md b/docs/understanding/evaluating/cost_analysis/root.md index 119e91435e..fe0cac0ef0 100644 --- a/docs/understanding/evaluating/cost_analysis/root.md +++ b/docs/understanding/evaluating/cost_analysis/root.md @@ -81,7 +81,7 @@ You may also predict the token usage of embedding calls with `MockEmbedding`. from llama_index import ServiceContext, set_global_service_context from llama_index import MockEmbedding -# specify a MockLLMPredictor +# specify a MockEmbedding embed_model = MockEmbedding(embed_dim=1536) service_context = ServiceContext.from_defaults(embed_model=embed_model) diff --git a/docs/understanding/putting_it_all_together/apps/fullstack_with_delphic.md b/docs/understanding/putting_it_all_together/apps/fullstack_with_delphic.md index 1ddea915af..f4ff376c40 100644 --- a/docs/understanding/putting_it_all_together/apps/fullstack_with_delphic.md +++ b/docs/understanding/putting_it_all_together/apps/fullstack_with_delphic.md @@ -329,7 +329,7 @@ async def receive(self, text_data): To load the collection model, the `load_collection_model` function is used, which can be found in [`delphic/utils/collections.py`](https://github.com/JSv4/Delphic/blob/main/delphic/utils/collections.py). This function retrieves the collection object with the given collection ID, checks if a JSON file for the collection model -exists, and if not, creates one. Then, it sets up the `LLMPredictor` and `ServiceContext` before loading +exists, and if not, creates one. Then, it sets up the `LLM` and `ServiceContext` before loading the `VectorStoreIndex` using the cache file. ```python diff --git a/docs/understanding/putting_it_all_together/q_and_a.md b/docs/understanding/putting_it_all_together/q_and_a.md index ac6bee7ffc..942c968459 100644 --- a/docs/understanding/putting_it_all_together/q_and_a.md +++ b/docs/understanding/putting_it_all_together/q_and_a.md @@ -140,7 +140,7 @@ from llama_index.indices.query.query_transform.base import ( ) decompose_transform = DecomposeQueryTransform( - service_context.llm_predictor, verbose=True + service_context.llm, verbose=True ) ``` diff --git a/docs/understanding/putting_it_all_together/q_and_a/terms_definitions_tutorial.md b/docs/understanding/putting_it_all_together/q_and_a/terms_definitions_tutorial.md index 36f4789d3f..d4ceb76526 100644 --- a/docs/understanding/putting_it_all_together/q_and_a/terms_definitions_tutorial.md +++ b/docs/understanding/putting_it_all_together/q_and_a/terms_definitions_tutorial.md @@ -87,7 +87,6 @@ We can add the following functions to both initialize our LLM, as well as use it from llama_index import ( Document, SummaryIndex, - LLMPredictor, ServiceContext, load_index_from_storage, ) @@ -358,7 +357,7 @@ from llama_index.prompts import ( ChatPromptTemplate, ) from llama_index.prompts.utils import is_chat_model -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms import ChatMessage, MessageRole # Text QA templates DEFAULT_TEXT_QA_PROMPT_TMPL = ( diff --git a/docs/understanding/putting_it_all_together/q_and_a/unified_query.md b/docs/understanding/putting_it_all_together/q_and_a/unified_query.md index 8de57a3746..17f82eed99 100644 --- a/docs/understanding/putting_it_all_together/q_and_a/unified_query.md +++ b/docs/understanding/putting_it_all_together/q_and_a/unified_query.md @@ -151,14 +151,11 @@ An example is shown below. ```python # define decompose_transform -from llama_index import LLMPredictor from llama_index.indices.query.query_transform.base import ( DecomposeQueryTransform, ) -decompose_transform = DecomposeQueryTransform( - LLMPredictor(llm=llm_gpt4), verbose=True -) +decompose_transform = DecomposeQueryTransform(llm=llm_gpt4, verbose=True) # define custom query engines from llama_index.query_engine.transform_query_engine import ( diff --git a/examples/async/AsyncComposableIndicesSEC.ipynb b/examples/async/AsyncComposableIndicesSEC.ipynb index acbae9f749..b8275b6e8a 100644 --- a/examples/async/AsyncComposableIndicesSEC.ipynb +++ b/examples/async/AsyncComposableIndicesSEC.ipynb @@ -143,7 +143,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import SummaryIndex, LLMPredictor\n", + "from llama_index import SummaryIndex\n", "from llama_index.llms import OpenAI\n", "from llama_index.composability import ComposableGraph" ] diff --git a/examples/async/AsyncLLMPredictorDemo.ipynb b/examples/async/AsyncLLMPredictorDemo.ipynb index 3d66da8f49..09a8f458e1 100644 --- a/examples/async/AsyncLLMPredictorDemo.ipynb +++ b/examples/async/AsyncLLMPredictorDemo.ipynb @@ -5,7 +5,7 @@ "id": "43cea4f8", "metadata": {}, "source": [ - "# Async LLMPredictor Demo" + "# Async LLM Demo" ] }, { @@ -15,7 +15,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index.langchain_helpers.chain_wrapper import LLMPredictor\n", + "from llama_index.llms import OpenAI\n", "from llama_index.prompts.default_prompts import DEFAULT_SUMMARY_PROMPT\n", "import asyncio\n", "import time" @@ -44,7 +44,7 @@ "metadata": {}, "outputs": [], "source": [ - "llm = LLMPredictor()" + "llm = OpenAI()" ] }, { diff --git a/examples/experimental/Evaporate.ipynb b/examples/experimental/Evaporate.ipynb index b09d543fac..ea1f892c34 100644 --- a/examples/experimental/Evaporate.ipynb +++ b/examples/experimental/Evaporate.ipynb @@ -16,7 +16,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import SimpleDirectoryReader, ServiceContext, LLMPredictor\n", + "from llama_index import SimpleDirectoryReader, ServiceContext\n", "from llama_index.experimental.evaporate import EvaporateExtractor\n", "from llama_index.llms import OpenAI\n", "import requests" diff --git a/examples/paul_graham_essay/GPT4Comparison.ipynb b/examples/paul_graham_essay/GPT4Comparison.ipynb index 710996ed16..fae5c47be3 100644 --- a/examples/paul_graham_essay/GPT4Comparison.ipynb +++ b/examples/paul_graham_essay/GPT4Comparison.ipynb @@ -10,7 +10,6 @@ "from llama_index import (\n", " SummaryIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " ServiceContext,\n", ")\n", "from llama_index.response.notebook_utils import display_response\n", diff --git a/examples/test_wiki/TestNYC-Benchmark-GPT4.ipynb b/examples/test_wiki/TestNYC-Benchmark-GPT4.ipynb index 40ec9c491c..ac3d7ab364 100644 --- a/examples/test_wiki/TestNYC-Benchmark-GPT4.ipynb +++ b/examples/test_wiki/TestNYC-Benchmark-GPT4.ipynb @@ -52,14 +52,13 @@ "from llama_index import (\n", " TreeIndex,\n", " SimpleDirectoryReader,\n", - " LLMPredictor,\n", " VectorStoreIndex,\n", " SummaryIndex,\n", " PromptTemplate,\n", " ServiceContext,\n", ")\n", "from llama_index.indices.base import BaseIndex\n", - "from llama_index.llms.base import LLM\n", + "from llama_index.llms.llm import LLM\n", "from llama_index.llms import OpenAI\n", "from llama_index.response.schema import Response\n", "import pandas as pd\n", @@ -346,14 +345,12 @@ "metadata": {}, "outputs": [], "source": [ - "def analyze_outcome_llm_single(\n", - " outcome: TestOutcome, llm_predictor: LLMPredictor\n", - ") -> Tuple[bool, bool]:\n", + "def analyze_outcome_llm_single(outcome: TestOutcome, llm: LLM) -> Tuple[bool, bool]:\n", " try:\n", " source_text = outcome.response.source_nodes[0].text\n", " except:\n", " source_text = \"Failed to retrieve any context\"\n", - " result_str, _ = llm_predictor.predict(\n", + " result_str, _ = llm.predict(\n", " DEFAULT_EVAL_PROMPT,\n", " query_str=outcome.test.query,\n", " context_str=source_text,\n", @@ -363,13 +360,11 @@ " return is_answer_correct, is_context_relevant, result_str\n", "\n", "\n", - "def analyze_outcome_llm(\n", - " outcomes: List[TestOutcome], llm_predictor: LLMPredictor\n", - ") -> None:\n", + "def analyze_outcome_llm(outcomes: List[TestOutcome], llm: LLM) -> None:\n", " rows = []\n", " for outcome in outcomes:\n", " is_correct_response, is_correct_source, result_str = analyze_outcome_llm_single(\n", - " outcome, llm_predictor\n", + " outcome, llm\n", " )\n", " row = [outcome.test.query, is_correct_response, is_correct_source, result_str]\n", " rows.append(row)\n", @@ -434,7 +429,7 @@ "id": "5b2e7fdd", "metadata": {}, "source": [ - "# Create LLMPredictors" + "# Create LLMs" ] }, { diff --git a/examples/test_wiki/TestNYC-Tree-GPT4.ipynb b/examples/test_wiki/TestNYC-Tree-GPT4.ipynb index 7f8bc21a1f..a940438438 100644 --- a/examples/test_wiki/TestNYC-Tree-GPT4.ipynb +++ b/examples/test_wiki/TestNYC-Tree-GPT4.ipynb @@ -70,7 +70,7 @@ "metadata": {}, "outputs": [], "source": [ - "from llama_index import TreeIndex, SimpleDirectoryReader, LLMPredictor, ServiceContext\n", + "from llama_index import TreeIndex, SimpleDirectoryReader, ServiceContext\n", "from llama_index.logger import LlamaLogger\n", "from llama_index.llms import OpenAI" ] diff --git a/experimental/classifier/utils.py b/experimental/classifier/utils.py index 43b989de80..877dfd7adb 100644 --- a/experimental/classifier/utils.py +++ b/experimental/classifier/utils.py @@ -8,7 +8,7 @@ import pandas as pd from sklearn.model_selection import train_test_split from llama_index.indices.utils import extract_numbers_given_response -from llama_index.llm_predictor import LLMPredictor +from llama_index.llms import OpenAI from llama_index.prompts import BasePromptTemplate, PromptTemplate @@ -77,13 +77,11 @@ def get_eval_preds( train_prompt: BasePromptTemplate, train_str: str, eval_df: pd.DataFrame, n: int = 20 ) -> List: """Get eval preds.""" - llm_predictor = LLMPredictor() + llm = OpenAI() eval_preds = [] for i in range(n): eval_str = get_sorted_dict_str(eval_df.iloc[i].to_dict()) - response = llm_predictor.predict( - train_prompt, train_str=train_str, eval_str=eval_str - ) + response = llm.predict(train_prompt, train_str=train_str, eval_str=eval_str) pred = extract_float_given_response(response) print(f"Getting preds: {i}/{n}: {pred}") if pred is None: diff --git a/experimental/cli/configuration.py b/experimental/cli/configuration.py index a78cde46bf..33a72e240d 100644 --- a/experimental/cli/configuration.py +++ b/experimental/cli/configuration.py @@ -13,7 +13,7 @@ from llama_index.indices import SimpleKeywordTableIndex from llama_index.indices.base import BaseIndex from llama_index.indices.loading import load_index_from_storage from llama_index.llm_predictor import StructuredLLMPredictor -from llama_index.llms.base import LLM +from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.storage.storage_context import StorageContext diff --git a/llama_index/agent/context_retriever_agent.py b/llama_index/agent/context_retriever_agent.py index f2a463b14b..c54253ac32 100644 --- a/llama_index/agent/context_retriever_agent.py +++ b/llama_index/agent/context_retriever_agent.py @@ -12,9 +12,10 @@ from llama_index.chat_engine.types import ( AgentChatResponse, ) from llama_index.core import BaseRetriever -from llama_index.llms.base import LLM, ChatMessage +from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import is_function_calling_model +from llama_index.llms.types import ChatMessage from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.prompts import PromptTemplate from llama_index.schema import NodeWithScore diff --git a/llama_index/agent/openai_agent.py b/llama_index/agent/openai_agent.py index 2ea23ad605..9f7dd54fb3 100644 --- a/llama_index/agent/openai_agent.py +++ b/llama_index/agent/openai_agent.py @@ -18,9 +18,10 @@ from llama_index.chat_engine.types import ( ChatResponseMode, StreamingAgentChatResponse, ) -from llama_index.llms.base import LLM, ChatMessage, ChatResponse, MessageRole +from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import OpenAIToolCall +from llama_index.llms.types import ChatMessage, ChatResponse, MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.objects.base import ObjectRetriever from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool diff --git a/llama_index/agent/openai_assistant_agent.py b/llama_index/agent/openai_assistant_agent.py index d38f4cec54..81015b1b88 100644 --- a/llama_index/agent/openai_assistant_agent.py +++ b/llama_index/agent/openai_assistant_agent.py @@ -18,7 +18,7 @@ from llama_index.chat_engine.types import ( ChatResponseMode, StreamingAgentChatResponse, ) -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.tools import BaseTool, ToolOutput logger = logging.getLogger(__name__) diff --git a/llama_index/agent/react/base.py b/llama_index/agent/react/base.py index 18759c6c6a..84c484d37e 100644 --- a/llama_index/agent/react/base.py +++ b/llama_index/agent/react/base.py @@ -30,8 +30,9 @@ from llama_index.callbacks import ( trace_method, ) from llama_index.chat_engine.types import AgentChatResponse, StreamingAgentChatResponse -from llama_index.llms.base import LLM, ChatMessage, ChatResponse, MessageRole +from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI +from llama_index.llms.types import ChatMessage, ChatResponse, MessageRole from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer from llama_index.memory.types import BaseMemory from llama_index.objects.base import ObjectRetriever diff --git a/llama_index/agent/react/formatter.py b/llama_index/agent/react/formatter.py index e1733846b3..ab39d29fe5 100644 --- a/llama_index/agent/react/formatter.py +++ b/llama_index/agent/react/formatter.py @@ -6,7 +6,7 @@ from typing import List, Optional, Sequence from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER from llama_index.agent.react.types import BaseReasoningStep, ObservationReasoningStep from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.tools import BaseTool diff --git a/llama_index/agent/types.py b/llama_index/agent/types.py index 6595cf2bfa..422516796c 100644 --- a/llama_index/agent/types.py +++ b/llama_index/agent/types.py @@ -4,7 +4,7 @@ from typing import List, Optional from llama_index.callbacks import trace_method from llama_index.chat_engine.types import BaseChatEngine, StreamingAgentChatResponse from llama_index.core import BaseQueryEngine -from llama_index.llms.base import ChatMessage +from llama_index.llms.types import ChatMessage from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import QueryBundle diff --git a/llama_index/callbacks/finetuning_handler.py b/llama_index/callbacks/finetuning_handler.py index 811c7be818..577e1fe104 100644 --- a/llama_index/callbacks/finetuning_handler.py +++ b/llama_index/callbacks/finetuning_handler.py @@ -35,7 +35,7 @@ class BaseFinetuningHandler(BaseCallbackHandler): **kwargs: Any, ) -> str: """Run when an event starts and return id of event.""" - from llama_index.llms.base import ChatMessage, MessageRole + from llama_index.llms.types import ChatMessage, MessageRole if event_type == CBEventType.LLM: cur_messages = [] @@ -68,7 +68,7 @@ class BaseFinetuningHandler(BaseCallbackHandler): **kwargs: Any, ) -> None: """Run when an event ends.""" - from llama_index.llms.base import ChatMessage, MessageRole + from llama_index.llms.types import ChatMessage, MessageRole if ( event_type == CBEventType.LLM diff --git a/llama_index/chat_engine/condense_plus_context.py b/llama_index/chat_engine/condense_plus_context.py index 0399af7d2b..bc3b14ab4e 100644 --- a/llama_index/chat_engine/condense_plus_context.py +++ b/llama_index/chat_engine/condense_plus_context.py @@ -13,9 +13,9 @@ from llama_index.chat_engine.types import ( from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext -from llama_index.llm_predictor.base import LLMPredictor -from llama_index.llms.base import LLM, ChatMessage, MessageRole from llama_index.llms.generic_utils import messages_to_history_str +from llama_index.llms.llm import LLM +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts.base import PromptTemplate @@ -60,7 +60,6 @@ class CondensePlusContextChatEngine(BaseChatEngine): self, retriever: BaseRetriever, llm: LLM, - llm_predictor: LLMPredictor, memory: BaseMemory, context_prompt: Optional[str] = None, condense_prompt: Optional[str] = None, @@ -72,7 +71,6 @@ class CondensePlusContextChatEngine(BaseChatEngine): ): self._retriever = retriever self._llm = llm - self._llm_predictor = llm_predictor self._memory = memory self._context_prompt_template = ( context_prompt or DEFAULT_CONTEXT_PROMPT_TEMPLATE @@ -106,10 +104,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): ) -> "CondensePlusContextChatEngine": """Initialize a CondensePlusContextChatEngine from default parameters.""" service_context = service_context or ServiceContext.from_defaults() - if not isinstance(service_context.llm_predictor, LLMPredictor): - raise ValueError("llm_predictor must be a LLMPredictor instance") - llm_predictor = service_context.llm_predictor - llm = llm_predictor.llm + llm = service_context.llm chat_history = chat_history or [] memory = memory or ChatMemoryBuffer.from_defaults( chat_history=chat_history, token_limit=llm.metadata.context_window - 256 @@ -118,7 +113,6 @@ class CondensePlusContextChatEngine(BaseChatEngine): return cls( retriever=retriever, llm=llm, - llm_predictor=llm_predictor, memory=memory, context_prompt=context_prompt, condense_prompt=condense_prompt, @@ -139,7 +133,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): chat_history_str = messages_to_history_str(chat_history) logger.debug(chat_history_str) - return self._llm_predictor.predict( + return self._llm.predict( self._condense_prompt_template, question=latest_message, chat_history=chat_history_str, @@ -155,7 +149,7 @@ class CondensePlusContextChatEngine(BaseChatEngine): chat_history_str = messages_to_history_str(chat_history) logger.debug(chat_history_str) - return await self._llm_predictor.apredict( + return await self._llm.apredict( self._condense_prompt_template, question=latest_message, chat_history=chat_history_str, diff --git a/llama_index/chat_engine/condense_question.py b/llama_index/chat_engine/condense_question.py index e8c9500701..ef8f2f19a1 100644 --- a/llama_index/chat_engine/condense_question.py +++ b/llama_index/chat_engine/condense_question.py @@ -10,9 +10,9 @@ from llama_index.chat_engine.types import ( ) from llama_index.chat_engine.utils import response_gen_from_query_engine from llama_index.core import BaseQueryEngine -from llama_index.llm_predictor.base import LLMPredictor -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llm_predictor.base import LLMPredictorType from llama_index.llms.generic_utils import messages_to_history_str +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.response.schema import RESPONSE_TYPE, StreamingResponse @@ -51,14 +51,14 @@ class CondenseQuestionChatEngine(BaseChatEngine): query_engine: BaseQueryEngine, condense_question_prompt: BasePromptTemplate, memory: BaseMemory, - service_context: ServiceContext, + llm: LLMPredictorType, verbose: bool = False, callback_manager: Optional[CallbackManager] = None, ) -> None: self._query_engine = query_engine self._condense_question_prompt = condense_question_prompt self._memory = memory - self._service_context = service_context + self._llm = llm self._verbose = verbose self.callback_manager = callback_manager or CallbackManager([]) @@ -80,9 +80,7 @@ class CondenseQuestionChatEngine(BaseChatEngine): condense_question_prompt = condense_question_prompt or DEFAULT_PROMPT service_context = service_context or ServiceContext.from_defaults() - if not isinstance(service_context.llm_predictor, LLMPredictor): - raise ValueError("llm_predictor must be a LLMPredictor instance") - llm = service_context.llm_predictor.llm + llm = service_context.llm chat_history = chat_history or [] memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm) @@ -100,7 +98,7 @@ class CondenseQuestionChatEngine(BaseChatEngine): query_engine, condense_question_prompt, memory, - service_context, + llm, verbose=verbose, callback_manager=service_context.callback_manager, ) @@ -114,7 +112,7 @@ class CondenseQuestionChatEngine(BaseChatEngine): chat_history_str = messages_to_history_str(chat_history) logger.debug(chat_history_str) - return self._service_context.llm_predictor.predict( + return self._llm.predict( self._condense_question_prompt, question=last_message, chat_history=chat_history_str, @@ -129,7 +127,7 @@ class CondenseQuestionChatEngine(BaseChatEngine): chat_history_str = messages_to_history_str(chat_history) logger.debug(chat_history_str) - return await self._service_context.llm_predictor.apredict( + return await self._llm.apredict( self._condense_question_prompt, question=last_message, chat_history=chat_history_str, diff --git a/llama_index/chat_engine/context.py b/llama_index/chat_engine/context.py index 9758df9188..04b76f1363 100644 --- a/llama_index/chat_engine/context.py +++ b/llama_index/chat_engine/context.py @@ -10,8 +10,8 @@ from llama_index.chat_engine.types import ( ToolOutput, ) from llama_index.core import BaseRetriever -from llama_index.llm_predictor.base import LLMPredictor -from llama_index.llms.base import LLM, ChatMessage, MessageRole +from llama_index.llms.llm import LLM +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle @@ -68,9 +68,7 @@ class ContextChatEngine(BaseChatEngine): ) -> "ContextChatEngine": """Initialize a ContextChatEngine from default parameters.""" service_context = service_context or ServiceContext.from_defaults() - if not isinstance(service_context.llm_predictor, LLMPredictor): - raise ValueError("llm_predictor must be a LLMPredictor instance") - llm = service_context.llm_predictor.llm + llm = service_context.llm chat_history = chat_history or [] memory = memory or ChatMemoryBuffer.from_defaults( diff --git a/llama_index/chat_engine/simple.py b/llama_index/chat_engine/simple.py index 3109de0219..4e95aeb5da 100644 --- a/llama_index/chat_engine/simple.py +++ b/llama_index/chat_engine/simple.py @@ -8,8 +8,8 @@ from llama_index.chat_engine.types import ( BaseChatEngine, StreamingAgentChatResponse, ) -from llama_index.llm_predictor.base import LLMPredictor -from llama_index.llms.base import LLM, ChatMessage +from llama_index.llms.llm import LLM +from llama_index.llms.types import ChatMessage from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.service_context import ServiceContext @@ -46,9 +46,7 @@ class SimpleChatEngine(BaseChatEngine): ) -> "SimpleChatEngine": """Initialize a SimpleChatEngine from default parameters.""" service_context = service_context or ServiceContext.from_defaults() - if not isinstance(service_context.llm_predictor, LLMPredictor): - raise ValueError("llm_predictor must be a LLMPredictor instance") - llm = service_context.llm_predictor.llm + llm = service_context.llm chat_history = chat_history or [] memory = memory or memory_cls.from_defaults(chat_history=chat_history, llm=llm) diff --git a/llama_index/chat_engine/types.py b/llama_index/chat_engine/types.py index 19c21e1937..64ccad7998 100644 --- a/llama_index/chat_engine/types.py +++ b/llama_index/chat_engine/types.py @@ -7,7 +7,7 @@ from enum import Enum from threading import Event from typing import AsyncGenerator, Generator, List, Optional, Union -from llama_index.llms.base import ChatMessage, ChatResponseAsyncGen, ChatResponseGen +from llama_index.llms.types import ChatMessage, ChatResponseAsyncGen, ChatResponseGen from llama_index.memory import BaseMemory from llama_index.response.schema import Response, StreamingResponse from llama_index.schema import NodeWithScore diff --git a/llama_index/chat_engine/utils.py b/llama_index/chat_engine/utils.py index 44e25bac34..b33e8ff6be 100644 --- a/llama_index/chat_engine/utils.py +++ b/llama_index/chat_engine/utils.py @@ -1,4 +1,4 @@ -from llama_index.llms.base import ( +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, diff --git a/llama_index/evaluation/correctness.py b/llama_index/evaluation/correctness.py index 5079f682d5..7afe9dd662 100644 --- a/llama_index/evaluation/correctness.py +++ b/llama_index/evaluation/correctness.py @@ -126,7 +126,7 @@ class CorrectnessEvaluator(BaseEvaluator): print(query, response, reference, flush=True) raise ValueError("query, response, and reference must be provided") - eval_response = await self._service_context.llm_predictor.apredict( + eval_response = await self._service_context.llm.apredict( prompt=self._eval_template, query=query, generated_answer=response, diff --git a/llama_index/evaluation/guideline.py b/llama_index/evaluation/guideline.py index 67946011c1..71d8105c0c 100644 --- a/llama_index/evaluation/guideline.py +++ b/llama_index/evaluation/guideline.py @@ -102,7 +102,7 @@ class GuidelineEvaluator(BaseEvaluator): await asyncio.sleep(sleep_time_in_seconds) - eval_response = await self._service_context.llm_predictor.apredict( + eval_response = await self._service_context.llm.apredict( self._eval_template, query=query, response=response, diff --git a/llama_index/evaluation/pairwise.py b/llama_index/evaluation/pairwise.py index 885b335795..a271db4057 100644 --- a/llama_index/evaluation/pairwise.py +++ b/llama_index/evaluation/pairwise.py @@ -117,7 +117,7 @@ class PairwiseComparisonEvaluator(BaseEvaluator): reference: Optional[str], ) -> EvaluationResult: """Get evaluation result.""" - eval_response = await self._service_context.llm_predictor.apredict( + eval_response = await self._service_context.llm.apredict( prompt=self._eval_template, query=query, answer_1=response, diff --git a/llama_index/extractors/interface.py b/llama_index/extractors/interface.py index 44c11c05d7..375e4bf27d 100644 --- a/llama_index/extractors/interface.py +++ b/llama_index/extractors/interface.py @@ -54,7 +54,14 @@ class BaseExtractor(TransformComponent): from llama_index.llm_predictor.loading import load_predictor llm_predictor = load_predictor(llm_predictor) - data["llm_predictor"] = llm_predictor + data["llm_predictor"] = llm_predictor + + llm = data.get("llm", None) + if llm: + from llama_index.llms.loading import load_llm + + llm = load_llm(llm) + data["llm"] = llm return cls(**data) diff --git a/llama_index/extractors/metadata_extractors.py b/llama_index/extractors/metadata_extractors.py index 1184f9771b..a259250bf8 100644 --- a/llama_index/extractors/metadata_extractors.py +++ b/llama_index/extractors/metadata_extractors.py @@ -25,8 +25,9 @@ from typing import Any, Callable, Dict, Iterable, List, Optional, Sequence, cast from llama_index.async_utils import DEFAULT_NUM_WORKERS, run_jobs from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.extractors.interface import BaseExtractor -from llama_index.llm_predictor.base import LLMPredictor -from llama_index.llms.base import LLM +from llama_index.llm_predictor.base import LLMPredictorType +from llama_index.llms.llm import LLM +from llama_index.llms.utils import resolve_llm from llama_index.prompts import PromptTemplate from llama_index.schema import BaseNode, TextNode from llama_index.types import BasePydanticProgram @@ -47,7 +48,7 @@ class TitleExtractor(BaseExtractor): metadata field. Args: - llm_predictor (Optional[LLMPredictor]): LLM predictor + llm (Optional[LLM]): LLM nodes (int): number of nodes from front to use for title extraction node_template (str): template for node-level title clues extraction combine_template (str): template for combining node-level clues into @@ -55,9 +56,7 @@ class TitleExtractor(BaseExtractor): """ is_text_node_only: bool = False # can work for mixture of text and non-text nodes - llm_predictor: LLMPredictor = Field( - description="The LLMPredictor to use for generation." - ) + llm: LLMPredictorType = Field(description="The LLM to use for generation.") nodes: int = Field( default=5, description="The number of nodes to extract titles from.", @@ -76,7 +75,7 @@ class TitleExtractor(BaseExtractor): self, llm: Optional[LLM] = None, # TODO: llm_predictor arg is deprecated - llm_predictor: Optional[LLMPredictor] = None, + llm_predictor: Optional[LLMPredictorType] = None, nodes: int = 5, node_template: str = DEFAULT_TITLE_NODE_TEMPLATE, combine_template: str = DEFAULT_TITLE_COMBINE_TEMPLATE, @@ -87,13 +86,8 @@ class TitleExtractor(BaseExtractor): if nodes < 1: raise ValueError("num_nodes must be >= 1") - if llm is not None: - llm_predictor = LLMPredictor(llm=llm) - elif llm_predictor is None and llm is None: - llm_predictor = LLMPredictor() - super().__init__( - llm_predictor=llm_predictor, + llm=llm or llm_predictor or resolve_llm("default"), nodes=nodes, node_template=node_template, combine_template=combine_template, @@ -120,7 +114,7 @@ class TitleExtractor(BaseExtractor): return [] title_jobs = [ - self.llm_predictor.apredict( + self.llm.apredict( PromptTemplate(template=self.node_template), context_str=cast(TextNode, node).text, ) @@ -135,7 +129,7 @@ class TitleExtractor(BaseExtractor): lambda x, y: x + "," + y, title_candidates[1:], title_candidates[0] ) - title = await self.llm_predictor.apredict( + title = await self.llm.apredict( PromptTemplate(template=self.combine_template), context_str=titles, ) @@ -152,13 +146,11 @@ class KeywordExtractor(BaseExtractor): `excerpt_keywords` metadata field. Args: - llm_predictor (Optional[LLMPredictor]): LLM predictor + llm (Optional[LLM]): LLM keywords (int): number of keywords to extract """ - llm_predictor: LLMPredictor = Field( - description="The LLMPredictor to use for generation." - ) + llm: LLMPredictorType = Field(description="The LLM to use for generation.") keywords: int = Field( default=5, description="The number of keywords to extract.", gt=0 ) @@ -167,7 +159,7 @@ class KeywordExtractor(BaseExtractor): self, llm: Optional[LLM] = None, # TODO: llm_predictor arg is deprecated - llm_predictor: Optional[LLMPredictor] = None, + llm_predictor: Optional[LLMPredictorType] = None, keywords: int = 5, num_workers: int = DEFAULT_NUM_WORKERS, **kwargs: Any, @@ -176,13 +168,8 @@ class KeywordExtractor(BaseExtractor): if keywords < 1: raise ValueError("num_keywords must be >= 1") - if llm is not None: - llm_predictor = LLMPredictor(llm=llm) - elif llm_predictor is None and llm is None: - llm_predictor = LLMPredictor() - super().__init__( - llm_predictor=llm_predictor, + llm=llm or llm_predictor or resolve_llm("default"), keywords=keywords, num_workers=num_workers, **kwargs, @@ -198,7 +185,7 @@ class KeywordExtractor(BaseExtractor): return {} # TODO: figure out a good way to allow users to customize keyword template - keywords = await self.llm_predictor.apredict( + keywords = await self.llm.apredict( PromptTemplate( template=f"""\ {{context_str}}. Give {self.keywords} unique keywords for this \ @@ -242,15 +229,13 @@ class QuestionsAnsweredExtractor(BaseExtractor): Extracts `questions_this_excerpt_can_answer` metadata field. Args: - llm_predictor (Optional[LLMPredictor]): LLM predictor + llm (Optional[LLM]): LLM questions (int): number of questions to extract prompt_template (str): template for question extraction, embedding_only (bool): whether to use embedding only """ - llm_predictor: LLMPredictor = Field( - description="The LLMPredictor to use for generation." - ) + llm: LLMPredictorType = Field(description="The LLM to use for generation.") questions: int = Field( default=5, description="The number of questions to generate.", @@ -268,7 +253,7 @@ class QuestionsAnsweredExtractor(BaseExtractor): self, llm: Optional[LLM] = None, # TODO: llm_predictor arg is deprecated - llm_predictor: Optional[LLMPredictor] = None, + llm_predictor: Optional[LLMPredictorType] = None, questions: int = 5, prompt_template: str = DEFAULT_QUESTION_GEN_TMPL, embedding_only: bool = True, @@ -279,13 +264,8 @@ class QuestionsAnsweredExtractor(BaseExtractor): if questions < 1: raise ValueError("questions must be >= 1") - if llm is not None: - llm_predictor = LLMPredictor(llm=llm) - elif llm_predictor is None and llm is None: - llm_predictor = LLMPredictor() - super().__init__( - llm_predictor=llm_predictor, + llm=llm or llm_predictor or resolve_llm("default"), questions=questions, prompt_template=prompt_template, embedding_only=embedding_only, @@ -304,7 +284,7 @@ class QuestionsAnsweredExtractor(BaseExtractor): context_str = node.get_content(metadata_mode=self.metadata_mode) prompt = PromptTemplate(template=self.prompt_template) - questions = await self.llm_predictor.apredict( + questions = await self.llm.apredict( prompt, num_questions=self.questions, context_str=context_str ) @@ -338,14 +318,12 @@ class SummaryExtractor(BaseExtractor): metadata fields. Args: - llm_predictor (Optional[LLMPredictor]): LLM predictor + llm (Optional[LLM]): LLM summaries (List[str]): list of summaries to extract: 'self', 'prev', 'next' prompt_template (str): template for summary extraction """ - llm_predictor: LLMPredictor = Field( - description="The LLMPredictor to use for generation." - ) + llm: LLMPredictorType = Field(description="The LLM to use for generation.") summaries: List[str] = Field( description="List of summaries to extract: 'self', 'prev', 'next'" ) @@ -362,17 +340,12 @@ class SummaryExtractor(BaseExtractor): self, llm: Optional[LLM] = None, # TODO: llm_predictor arg is deprecated - llm_predictor: Optional[LLMPredictor] = None, + llm_predictor: Optional[LLMPredictorType] = None, summaries: List[str] = ["self"], prompt_template: str = DEFAULT_SUMMARY_EXTRACT_TEMPLATE, num_workers: int = DEFAULT_NUM_WORKERS, **kwargs: Any, ): - if llm is not None: - llm_predictor = LLMPredictor(llm=llm) - elif llm_predictor is None and llm is None: - llm_predictor = LLMPredictor() - # validation if not all(s in ["self", "prev", "next"] for s in summaries): raise ValueError("summaries must be one of ['self', 'prev', 'next']") @@ -381,7 +354,7 @@ class SummaryExtractor(BaseExtractor): self._next_summary = "next" in summaries super().__init__( - llm_predictor=llm_predictor, + llm=llm or llm_predictor or resolve_llm("default"), summaries=summaries, prompt_template=prompt_template, num_workers=num_workers, @@ -398,7 +371,7 @@ class SummaryExtractor(BaseExtractor): return "" context_str = node.get_content(metadata_mode=self.metadata_mode) - summary = await self.llm_predictor.apredict( + summary = await self.llm.apredict( PromptTemplate(template=self.prompt_template), context_str=context_str ) diff --git a/llama_index/finetuning/cross_encoders/dataset_gen.py b/llama_index/finetuning/cross_encoders/dataset_gen.py index 3abb04f383..a594c221ab 100644 --- a/llama_index/finetuning/cross_encoders/dataset_gen.py +++ b/llama_index/finetuning/cross_encoders/dataset_gen.py @@ -8,7 +8,7 @@ from tqdm.auto import tqdm from llama_index import VectorStoreIndex from llama_index.llms import ChatMessage, OpenAI -from llama_index.llms.base import LLM +from llama_index.llms.llm import LLM from llama_index.node_parser import TokenTextSplitter from llama_index.schema import Document, MetadataMode diff --git a/llama_index/finetuning/openai/base.py b/llama_index/finetuning/openai/base.py index 8d16936041..9ea0a3e785 100644 --- a/llama_index/finetuning/openai/base.py +++ b/llama_index/finetuning/openai/base.py @@ -13,7 +13,7 @@ from llama_index.callbacks import OpenAIFineTuningHandler from llama_index.finetuning.openai.validate_json import validate_json from llama_index.finetuning.types import BaseLLMFinetuneEngine from llama_index.llms import OpenAI -from llama_index.llms.base import LLM +from llama_index.llms.llm import LLM logger = logging.getLogger(__name__) diff --git a/llama_index/finetuning/types.py b/llama_index/finetuning/types.py index 730027da51..bc2106cd64 100644 --- a/llama_index/finetuning/types.py +++ b/llama_index/finetuning/types.py @@ -4,7 +4,7 @@ from abc import ABC, abstractmethod from typing import Any from llama_index.embeddings.base import BaseEmbedding -from llama_index.llms.base import LLM +from llama_index.llms.llm import LLM from llama_index.postprocessor import CohereRerank, SentenceTransformerRerank diff --git a/llama_index/indices/base.py b/llama_index/indices/base.py index c13f95471b..be79007aca 100644 --- a/llama_index/indices/base.py +++ b/llama_index/indices/base.py @@ -27,7 +27,7 @@ class BaseIndex(Generic[IS], ABC): nodes (List[Node]): List of nodes to index show_progress (bool): Whether to show tqdm progress bars. Defaults to False. service_context (ServiceContext): Service context container (contains - components like LLMPredictor, PromptHelper, etc.). + components like LLM, Embeddings, etc.). """ diff --git a/llama_index/indices/common/struct_store/base.py b/llama_index/indices/common/struct_store/base.py index 4437ec1009..b594d7dbf2 100644 --- a/llama_index/indices/common/struct_store/base.py +++ b/llama_index/indices/common/struct_store/base.py @@ -6,7 +6,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, cast from llama_index.callbacks.schema import CBEventType, EventPayload from llama_index.data_structs.table import StructDatapoint -from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.llm_predictor.base import LLMPredictorType from llama_index.node_parser.interface import TextSplitter from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompt_selectors import ( @@ -135,12 +135,12 @@ class BaseStructDatapointExtractor: def __init__( self, - llm_predictor: BaseLLMPredictor, + llm: LLMPredictorType, schema_extract_prompt: BasePromptTemplate, output_parser: OUTPUT_PARSER_TYPE, ) -> None: """Initialize params.""" - self._llm_predictor = llm_predictor + self._llm = llm self._schema_extract_prompt = schema_extract_prompt self._output_parser = output_parser @@ -195,7 +195,7 @@ class BaseStructDatapointExtractor: logger.info(f"> Adding chunk {i}: {fmt_text_chunk}") # if embedding specified in document, pass it to the Node schema_text = self._get_schema_text() - response_str = self._llm_predictor.predict( + response_str = self._llm.predict( self._schema_extract_prompt, text=text_chunk, schema=schema_text, diff --git a/llama_index/indices/common/struct_store/sql.py b/llama_index/indices/common/struct_store/sql.py index 939bc12a43..781a337757 100644 --- a/llama_index/indices/common/struct_store/sql.py +++ b/llama_index/indices/common/struct_store/sql.py @@ -9,7 +9,7 @@ from llama_index.indices.common.struct_store.base import ( OUTPUT_PARSER_TYPE, BaseStructDatapointExtractor, ) -from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.llm_predictor.base import LLMPredictorType from llama_index.prompts import BasePromptTemplate from llama_index.utilities.sql_wrapper import SQLDatabase @@ -19,7 +19,7 @@ class SQLStructDatapointExtractor(BaseStructDatapointExtractor): def __init__( self, - llm_predictor: BaseLLMPredictor, + llm: LLMPredictorType, schema_extract_prompt: BasePromptTemplate, output_parser: OUTPUT_PARSER_TYPE, sql_database: SQLDatabase, @@ -28,7 +28,7 @@ class SQLStructDatapointExtractor(BaseStructDatapointExtractor): ref_doc_id_column: Optional[str] = None, ) -> None: """Initialize params.""" - super().__init__(llm_predictor, schema_extract_prompt, output_parser) + super().__init__(llm, schema_extract_prompt, output_parser) self._sql_database = sql_database # currently the user must specify a table info if table_name is None and table is None: diff --git a/llama_index/indices/common_tree/base.py b/llama_index/indices/common_tree/base.py index f43986b1c0..b0a7acb663 100644 --- a/llama_index/indices/common_tree/base.py +++ b/llama_index/indices/common_tree/base.py @@ -149,7 +149,7 @@ class GPTTreeIndexBuilder: ) as event: if self._use_async: tasks = [ - self._service_context.llm_predictor.apredict( + self._service_context.llm.apredict( self.summary_prompt, context_str=text_chunk ) for text_chunk in text_chunks @@ -167,7 +167,7 @@ class GPTTreeIndexBuilder: desc="Generating summaries", ) summaries = [ - self._service_context.llm_predictor.predict( + self._service_context.llm.predict( self.summary_prompt, context_str=text_chunk ) for text_chunk in text_chunks_progress @@ -217,7 +217,7 @@ class GPTTreeIndexBuilder: desc="Generating summaries", ) tasks = [ - self._service_context.llm_predictor.apredict( + self._service_context.llm.apredict( self.summary_prompt, context_str=text_chunk ) for text_chunk in text_chunks_progress diff --git a/llama_index/indices/document_summary/retrievers.py b/llama_index/indices/document_summary/retrievers.py index 5c17522165..fda5c8a845 100644 --- a/llama_index/indices/document_summary/retrievers.py +++ b/llama_index/indices/document_summary/retrievers.py @@ -80,7 +80,7 @@ class DocumentSummaryIndexLLMRetriever(BaseRetriever): query_str = query_bundle.query_str fmt_batch_str = self._format_node_batch_fn(summary_nodes) # call each batch independently - raw_response = self._service_context.llm_predictor.predict( + raw_response = self._service_context.llm.predict( self._choice_select_prompt, context_str=fmt_batch_str, query_str=query_str, diff --git a/llama_index/indices/keyword_table/base.py b/llama_index/indices/keyword_table/base.py index fb6b17156b..02a031d674 100644 --- a/llama_index/indices/keyword_table/base.py +++ b/llama_index/indices/keyword_table/base.py @@ -219,7 +219,7 @@ class KeywordTableIndex(BaseKeywordTableIndex): def _extract_keywords(self, text: str) -> Set[str]: """Extract keywords from text.""" - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( self.keyword_extract_template, text=text, ) @@ -227,7 +227,7 @@ class KeywordTableIndex(BaseKeywordTableIndex): async def _async_extract_keywords(self, text: str) -> Set[str]: """Extract keywords from text.""" - response = await self._service_context.llm_predictor.apredict( + response = await self._service_context.llm.apredict( self.keyword_extract_template, text=text, ) diff --git a/llama_index/indices/keyword_table/retrievers.py b/llama_index/indices/keyword_table/retrievers.py index af216c3550..0d687b2fea 100644 --- a/llama_index/indices/keyword_table/retrievers.py +++ b/llama_index/indices/keyword_table/retrievers.py @@ -118,7 +118,7 @@ class KeywordTableGPTRetriever(BaseKeywordTableRetriever): def _get_keywords(self, query_str: str) -> List[str]: """Extract keywords.""" - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( self.query_keyword_extract_template, max_keywords=self.max_keywords_per_query, question=query_str, diff --git a/llama_index/indices/knowledge_graph/base.py b/llama_index/indices/knowledge_graph/base.py index faccad7833..00cc76d836 100644 --- a/llama_index/indices/knowledge_graph/base.py +++ b/llama_index/indices/knowledge_graph/base.py @@ -119,7 +119,7 @@ class KnowledgeGraphIndex(BaseIndex[KG]): def _llm_extract_triplets(self, text: str) -> List[Tuple[str, str, str]]: """Extract keywords from text.""" - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( self.kg_triple_extract_template, text=text, ) diff --git a/llama_index/indices/knowledge_graph/retrievers.py b/llama_index/indices/knowledge_graph/retrievers.py index 72fe5efe56..9ac9cc064e 100644 --- a/llama_index/indices/knowledge_graph/retrievers.py +++ b/llama_index/indices/knowledge_graph/retrievers.py @@ -124,7 +124,7 @@ class KGTableRetriever(BaseRetriever): def _get_keywords(self, query_str: str) -> List[str]: """Extract keywords.""" - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( self.query_keyword_extract_template, max_keywords=self.max_keywords_per_query, question=query_str, @@ -524,7 +524,7 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): if handle_fn is not None: enitities_fn = handle_fn(query_str) if handle_llm_prompt_template is not None: - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( handle_llm_prompt_template, max_keywords=max_items, question=query_str, @@ -574,7 +574,7 @@ class KnowledgeGraphRAGRetriever(BaseRetriever): if handle_fn is not None: enitities_fn = handle_fn(query_str) if handle_llm_prompt_template is not None: - response = await self._service_context.llm_predictor.apredict( + response = await self._service_context.llm.apredict( handle_llm_prompt_template, max_keywords=max_items, question=query_str, diff --git a/llama_index/indices/list/retrievers.py b/llama_index/indices/list/retrievers.py index c3f6dd6e96..4f92ee0763 100644 --- a/llama_index/indices/list/retrievers.py +++ b/llama_index/indices/list/retrievers.py @@ -178,7 +178,7 @@ class SummaryIndexLLMRetriever(BaseRetriever): query_str = query_bundle.query_str fmt_batch_str = self._format_node_batch_fn(nodes_batch) # call each batch independently - raw_response = self._service_context.llm_predictor.predict( + raw_response = self._service_context.llm.predict( self._choice_select_prompt, context_str=fmt_batch_str, query_str=query_str, diff --git a/llama_index/indices/prompt_helper.py b/llama_index/indices/prompt_helper.py index aedab1ba52..f2c5ca9efb 100644 --- a/llama_index/indices/prompt_helper.py +++ b/llama_index/indices/prompt_helper.py @@ -16,7 +16,8 @@ from typing import Callable, List, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS from llama_index.llm_predictor.base import LLMMetadata -from llama_index.llms.base import LLM, ChatMessage +from llama_index.llms.llm import LLM +from llama_index.llms.types import ChatMessage from llama_index.node_parser.text.token import TokenTextSplitter from llama_index.node_parser.text.utils import truncate_text from llama_index.prompts import ( diff --git a/llama_index/indices/query/query_transform/base.py b/llama_index/indices/query/query_transform/base.py index 6313cde6ee..bb4ee668fe 100644 --- a/llama_index/indices/query/query_transform/base.py +++ b/llama_index/indices/query/query_transform/base.py @@ -12,8 +12,8 @@ from llama_index.indices.query.query_transform.prompts import ( ImageOutputQueryTransformPrompt, StepDecomposeQueryTransformPrompt, ) -from llama_index.llm_predictor import LLMPredictor -from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.llm_predictor.base import LLMPredictorType +from llama_index.llms.utils import resolve_llm from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_HYDE_PROMPT from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType @@ -98,14 +98,14 @@ class HyDEQueryTransform(BaseQueryTransform): def __init__( self, - llm_predictor: Optional[BaseLLMPredictor] = None, + llm: Optional[LLMPredictorType] = None, hyde_prompt: Optional[BasePromptTemplate] = None, include_original: bool = True, ) -> None: """Initialize HyDEQueryTransform. Args: - llm_predictor (Optional[LLMPredictor]): LLM for generating + llm_predictor (Optional[LLM]): LLM for generating hypothetical documents hyde_prompt (Optional[BasePromptTemplate]): Custom prompt for HyDE include_original (bool): Whether to include original query @@ -113,7 +113,7 @@ class HyDEQueryTransform(BaseQueryTransform): """ super().__init__() - self._llm_predictor = llm_predictor or LLMPredictor() + self._llm = llm or resolve_llm("default") self._hyde_prompt = hyde_prompt or DEFAULT_HYDE_PROMPT self._include_original = include_original @@ -130,9 +130,7 @@ class HyDEQueryTransform(BaseQueryTransform): """Run query transform.""" # TODO: support generating multiple hypothetical docs query_str = query_bundle.query_str - hypothetical_doc = self._llm_predictor.predict( - self._hyde_prompt, context_str=query_str - ) + hypothetical_doc = self._llm.predict(self._hyde_prompt, context_str=query_str) embedding_strs = [hypothetical_doc] if self._include_original: embedding_strs.extend(query_bundle.embedding_strs) @@ -149,20 +147,20 @@ class DecomposeQueryTransform(BaseQueryTransform): Performs a single step transformation. Args: - llm_predictor (Optional[LLMPredictor]): LLM for generating + llm_predictor (Optional[LLM]): LLM for generating hypothetical documents """ def __init__( self, - llm_predictor: Optional[BaseLLMPredictor] = None, + llm: Optional[LLMPredictorType] = None, decompose_query_prompt: Optional[DecomposeQueryTransformPrompt] = None, verbose: bool = False, ) -> None: """Init params.""" super().__init__() - self._llm_predictor = llm_predictor or LLMPredictor() + self._llm = llm or resolve_llm("default") self._decompose_query_prompt = ( decompose_query_prompt or DEFAULT_DECOMPOSE_QUERY_TRANSFORM_PROMPT ) @@ -185,7 +183,7 @@ class DecomposeQueryTransform(BaseQueryTransform): # given the text from the index, we can use the query bundle to generate # a new query bundle query_str = query_bundle.query_str - new_query_str = self._llm_predictor.predict( + new_query_str = self._llm.predict( self._decompose_query_prompt, query_str=query_str, context_str=index_summary, @@ -251,20 +249,20 @@ class StepDecomposeQueryTransform(BaseQueryTransform): NOTE: doesn't work yet. Args: - llm_predictor (Optional[LLMPredictor]): LLM for generating + llm_predictor (Optional[LLM]): LLM for generating hypothetical documents """ def __init__( self, - llm_predictor: Optional[BaseLLMPredictor] = None, + llm: Optional[LLMPredictorType] = None, step_decompose_query_prompt: Optional[StepDecomposeQueryTransformPrompt] = None, verbose: bool = False, ) -> None: """Init params.""" super().__init__() - self._llm_predictor = llm_predictor or LLMPredictor() + self._llm = llm or resolve_llm("default") self._step_decompose_query_prompt = ( step_decompose_query_prompt or DEFAULT_STEP_DECOMPOSE_QUERY_TRANSFORM_PROMPT ) @@ -291,7 +289,7 @@ class StepDecomposeQueryTransform(BaseQueryTransform): # given the text from the index, we can use the query bundle to generate # a new query bundle query_str = query_bundle.query_str - new_query_str = self._llm_predictor.predict( + new_query_str = self._llm.predict( self._step_decompose_query_prompt, prev_reasoning=fmt_prev_reasoning, query_str=query_str, diff --git a/llama_index/indices/query/query_transform/feedback_transform.py b/llama_index/indices/query/query_transform/feedback_transform.py index 0e8342b054..950f284514 100644 --- a/llama_index/indices/query/query_transform/feedback_transform.py +++ b/llama_index/indices/query/query_transform/feedback_transform.py @@ -3,8 +3,8 @@ from typing import Dict, Optional from llama_index.evaluation.base import Evaluation from llama_index.indices.query.query_transform.base import BaseQueryTransform -from llama_index.llm_predictor import LLMPredictor -from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.llm_predictor.base import LLMPredictorType +from llama_index.llms.utils import resolve_llm from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType from llama_index.schema import QueryBundle @@ -30,7 +30,7 @@ class FeedbackQueryTransformation(BaseQueryTransform): Args: eval(Evaluation): An evaluation object. - llm_predictor(BaseLLMPredictor): An LLM predictor. + llm(LLM): An LLM. resynthesize_query(bool): Whether to resynthesize the query. resynthesis_prompt(BasePromptTemplate): A prompt for resynthesizing the query. @@ -38,12 +38,12 @@ class FeedbackQueryTransformation(BaseQueryTransform): def __init__( self, - llm_predictor: Optional[BaseLLMPredictor] = None, + llm: Optional[LLMPredictorType] = None, resynthesize_query: bool = False, resynthesis_prompt: Optional[BasePromptTemplate] = None, ) -> None: super().__init__() - self.llm_predictor = llm_predictor or LLMPredictor() + self.llm = llm or resolve_llm("default") self.should_resynthesize_query = resynthesize_query self.resynthesis_prompt = resynthesis_prompt or DEFAULT_RESYNTHESIS_PROMPT @@ -106,7 +106,7 @@ class FeedbackQueryTransformation(BaseQueryTransform): if feedback is None: return query_str else: - new_query_str = self.llm_predictor.predict( + new_query_str = self.llm.predict( self.resynthesis_prompt, query_str=query_str, response=response, diff --git a/llama_index/indices/struct_store/json_query.py b/llama_index/indices/struct_store/json_query.py index 943d6b68ff..353aff77ed 100644 --- a/llama_index/indices/struct_store/json_query.py +++ b/llama_index/indices/struct_store/json_query.py @@ -136,7 +136,7 @@ class JSONQueryEngine(BaseQueryEngine): """Answer a query.""" schema = self._get_schema_context() - json_path_response_str = self._service_context.llm_predictor.predict( + json_path_response_str = self._service_context.llm.predict( self._json_path_prompt, schema=schema, query_str=query_bundle.query_str, @@ -157,7 +157,7 @@ class JSONQueryEngine(BaseQueryEngine): print_text(f"> JSONPath Output: {json_path_output}\n") if self._synthesize_response: - response_str = self._service_context.llm_predictor.predict( + response_str = self._service_context.llm.predict( self._response_synthesis_prompt, query_str=query_bundle.query_str, json_schema=self._json_schema, @@ -176,7 +176,7 @@ class JSONQueryEngine(BaseQueryEngine): async def _aquery(self, query_bundle: QueryBundle) -> Response: schema = self._get_schema_context() - json_path_response_str = await self._service_context.llm_predictor.apredict( + json_path_response_str = await self._service_context.llm.apredict( self._json_path_prompt, schema=schema, query_str=query_bundle.query_str, @@ -197,7 +197,7 @@ class JSONQueryEngine(BaseQueryEngine): print_text(f"> JSONPath Output: {json_path_output}\n") if self._synthesize_response: - response_str = await self._service_context.llm_predictor.apredict( + response_str = await self._service_context.llm.apredict( self._response_synthesis_prompt, query_str=query_bundle.query_str, json_schema=self._json_schema, diff --git a/llama_index/indices/struct_store/sql.py b/llama_index/indices/struct_store/sql.py index 32ca4425a0..f59127669c 100644 --- a/llama_index/indices/struct_store/sql.py +++ b/llama_index/indices/struct_store/sql.py @@ -107,7 +107,7 @@ class SQLStructStoreIndex(BaseStructStoreIndex[SQLStructTable]): return index_struct else: data_extractor = SQLStructDatapointExtractor( - self._service_context.llm_predictor, + self._service_context.llm, self.schema_extract_prompt, self.output_parser, self.sql_database, @@ -127,7 +127,7 @@ class SQLStructStoreIndex(BaseStructStoreIndex[SQLStructTable]): def _insert(self, nodes: Sequence[BaseNode], **insert_kwargs: Any) -> None: """Insert a document.""" data_extractor = SQLStructDatapointExtractor( - self._service_context.llm_predictor, + self._service_context.llm, self.schema_extract_prompt, self.output_parser, self.sql_database, diff --git a/llama_index/indices/struct_store/sql_query.py b/llama_index/indices/struct_store/sql_query.py index a543d18d49..2b57021c15 100644 --- a/llama_index/indices/struct_store/sql_query.py +++ b/llama_index/indices/struct_store/sql_query.py @@ -192,7 +192,7 @@ class NLStructStoreQueryEngine(BaseQueryEngine): table_desc_str = self._get_table_context(query_bundle) logger.info(f"> Table desc str: {table_desc_str}") - response_str = self._service_context.llm_predictor.predict( + response_str = self._service_context.llm.predict( self._text_to_sql_prompt, query_str=query_bundle.query_str, schema=table_desc_str, @@ -207,7 +207,7 @@ class NLStructStoreQueryEngine(BaseQueryEngine): metadata["sql_query"] = sql_query_str if self._synthesize_response: - response_str = self._service_context.llm_predictor.predict( + response_str = self._service_context.llm.predict( self._response_synthesis_prompt, query_str=query_bundle.query_str, sql_query=sql_query_str, @@ -223,7 +223,7 @@ class NLStructStoreQueryEngine(BaseQueryEngine): table_desc_str = self._get_table_context(query_bundle) logger.info(f"> Table desc str: {table_desc_str}") - response_str = await self._service_context.llm_predictor.apredict( + response_str = await self._service_context.llm.apredict( self._text_to_sql_prompt, query_str=query_bundle.query_str, schema=table_desc_str, diff --git a/llama_index/indices/struct_store/sql_retriever.py b/llama_index/indices/struct_store/sql_retriever.py index 6ae50a0f5b..68ebfc66ef 100644 --- a/llama_index/indices/struct_store/sql_retriever.py +++ b/llama_index/indices/struct_store/sql_retriever.py @@ -265,7 +265,7 @@ class NLSQLRetriever(BaseRetriever, PromptMixin): table_desc_str = self._get_table_context(query_bundle) logger.info(f"> Table desc str: {table_desc_str}") - response_str = self._service_context.llm_predictor.predict( + response_str = self._service_context.llm.predict( self._text_to_sql_prompt, query_str=query_bundle.query_str, schema=table_desc_str, @@ -303,7 +303,7 @@ class NLSQLRetriever(BaseRetriever, PromptMixin): table_desc_str = self._get_table_context(query_bundle) logger.info(f"> Table desc str: {table_desc_str}") - response_str = await self._service_context.llm_predictor.apredict( + response_str = await self._service_context.llm.apredict( self._text_to_sql_prompt, query_str=query_bundle.query_str, schema=table_desc_str, diff --git a/llama_index/indices/tree/inserter.py b/llama_index/indices/tree/inserter.py index 1e8eb526e2..9b99b0a99d 100644 --- a/llama_index/indices/tree/inserter.py +++ b/llama_index/indices/tree/inserter.py @@ -75,7 +75,7 @@ class TreeIndexInserter: ) text_chunk1 = "\n".join(truncated_chunks) - summary1 = self._service_context.llm_predictor.predict( + summary1 = self._service_context.llm.predict( self.summary_prompt, context_str=text_chunk1 ) node1 = TextNode(text=summary1) @@ -88,7 +88,7 @@ class TreeIndexInserter: ], ) text_chunk2 = "\n".join(truncated_chunks) - summary2 = self._service_context.llm_predictor.predict( + summary2 = self._service_context.llm.predict( self.summary_prompt, context_str=text_chunk2 ) node2 = TextNode(text=summary2) @@ -134,7 +134,7 @@ class TreeIndexInserter: numbered_text = get_numbered_text_from_nodes( cur_graph_node_list, text_splitter=text_splitter ) - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( self.insert_prompt, new_chunk_text=node.get_content(metadata_mode=MetadataMode.LLM), num_chunks=len(cur_graph_node_list), @@ -166,7 +166,7 @@ class TreeIndexInserter: ], ) text_chunk = "\n".join(truncated_chunks) - new_summary = self._service_context.llm_predictor.predict( + new_summary = self._service_context.llm.predict( self.summary_prompt, context_str=text_chunk ) diff --git a/llama_index/indices/tree/select_leaf_retriever.py b/llama_index/indices/tree/select_leaf_retriever.py index 8606f61499..a61a3e5ae9 100644 --- a/llama_index/indices/tree/select_leaf_retriever.py +++ b/llama_index/indices/tree/select_leaf_retriever.py @@ -129,7 +129,7 @@ class TreeSelectLeafRetriever(BaseRetriever): return cur_response else: context_msg = selected_node.get_content(metadata_mode=MetadataMode.LLM) - cur_response = self._service_context.llm_predictor.predict( + cur_response = self._service_context.llm.predict( self._refine_template, query_str=query_str, existing_answer=prev_response, @@ -172,7 +172,7 @@ class TreeSelectLeafRetriever(BaseRetriever): cur_node_list, text_splitter=text_splitter ) - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( query_template, context_list=numbered_node_text, ) @@ -193,7 +193,7 @@ class TreeSelectLeafRetriever(BaseRetriever): cur_node_list, text_splitter=text_splitter ) - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( query_template_multiple, context_list=numbered_node_text, ) @@ -290,7 +290,7 @@ class TreeSelectLeafRetriever(BaseRetriever): cur_node_list, text_splitter=text_splitter ) - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( query_template, context_list=numbered_node_text, ) @@ -311,7 +311,7 @@ class TreeSelectLeafRetriever(BaseRetriever): cur_node_list, text_splitter=text_splitter ) - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( query_template_multiple, context_list=numbered_node_text, ) diff --git a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py index 4e25b5ca86..e8e8e33789 100644 --- a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py +++ b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py @@ -40,7 +40,7 @@ class VectorIndexAutoRetriever(BaseRetriever): parameters. prompt_template_str: custom prompt template string for LLM. Uses default template string if None. - service_context: service context containing reference to LLMPredictor. + service_context: service context containing reference to an LLM. Uses service context from index be default if None. similarity_top_k (int): number of top k results to return. max_top_k (int): @@ -89,7 +89,7 @@ class VectorIndexAutoRetriever(BaseRetriever): schema_str = VectorStoreQuerySpec.schema_json(indent=4) # call LLM - output = self._service_context.llm_predictor.predict( + output = self._service_context.llm.predict( self._prompt, schema_str=schema_str, info_str=info_str, diff --git a/llama_index/llm_predictor/base.py b/llama_index/llm_predictor/base.py index 79444c07d7..d807d221f6 100644 --- a/llama_index/llm_predictor/base.py +++ b/llama_index/llm_predictor/base.py @@ -3,20 +3,25 @@ import logging from abc import ABC, abstractmethod from collections import ChainMap -from typing import Any, Dict, List, Optional +from typing import Any, Dict, List, Optional, Union from typing_extensions import Self from llama_index.bridge.pydantic import BaseModel, PrivateAttr from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.llm_predictor.utils import ( +from llama_index.llms.llm import ( + LLM, astream_chat_response_to_tokens, astream_completion_response_to_tokens, stream_chat_response_to_tokens, stream_completion_response_to_tokens, ) -from llama_index.llms.base import LLM, ChatMessage, LLMMetadata, MessageRole +from llama_index.llms.types import ( + ChatMessage, + LLMMetadata, + MessageRole, +) from llama_index.llms.utils import LLMType, resolve_llm from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.schema import BaseComponent @@ -326,3 +331,6 @@ class LLMPredictor(BaseLLMPredictor): *messages, ] return messages + + +LLMPredictorType = Union[LLMPredictor, LLM] diff --git a/llama_index/llm_predictor/mock.py b/llama_index/llm_predictor/mock.py index 7ddaf99f4d..d3a971f18e 100644 --- a/llama_index/llm_predictor/mock.py +++ b/llama_index/llm_predictor/mock.py @@ -1,11 +1,14 @@ """Mock LLM Predictor.""" from typing import Any, Dict +from deprecated import deprecated + from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS from llama_index.llm_predictor.base import BaseLLMPredictor -from llama_index.llms.base import LLM, LLMMetadata +from llama_index.llms.llm import LLM +from llama_index.llms.types import LLMMetadata from llama_index.prompts.base import BasePromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.token_counter.utils import ( @@ -82,6 +85,7 @@ def _mock_knowledge_graph_triplet_extract(prompt_args: Dict, max_triplets: int) ) +@deprecated("MockLLMPredictor is deprecated. Use MockLLM instead.") class MockLLMPredictor(BaseLLMPredictor): """Mock LLM Predictor.""" diff --git a/llama_index/llm_predictor/structured.py b/llama_index/llm_predictor/structured.py index d9f86ce202..a3e07303a6 100644 --- a/llama_index/llm_predictor/structured.py +++ b/llama_index/llm_predictor/structured.py @@ -4,6 +4,8 @@ import logging from typing import Any, Optional +from deprecated import deprecated + from llama_index.llm_predictor.base import LLMPredictor from llama_index.prompts.base import BasePromptTemplate from llama_index.types import TokenGen @@ -11,6 +13,7 @@ from llama_index.types import TokenGen logger = logging.getLogger(__name__) +@deprecated("StructuredLLMPredictor is deprecated. Use llm.structured_predict().") class StructuredLLMPredictor(LLMPredictor): """Structured LLM predictor class. diff --git a/llama_index/llm_predictor/utils.py b/llama_index/llm_predictor/utils.py deleted file mode 100644 index 6c35125fbb..0000000000 --- a/llama_index/llm_predictor/utils.py +++ /dev/null @@ -1,55 +0,0 @@ -from llama_index.llms.base import ( - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponseAsyncGen, - CompletionResponseGen, -) -from llama_index.types import TokenAsyncGen, TokenGen - - -def stream_completion_response_to_tokens( - completion_response_gen: CompletionResponseGen, -) -> TokenGen: - """Convert a stream completion response to a stream of tokens.""" - - def gen() -> TokenGen: - for response in completion_response_gen: - yield response.delta or "" - - return gen() - - -def stream_chat_response_to_tokens( - chat_response_gen: ChatResponseGen, -) -> TokenGen: - """Convert a stream completion response to a stream of tokens.""" - - def gen() -> TokenGen: - for response in chat_response_gen: - yield response.delta or "" - - return gen() - - -async def astream_completion_response_to_tokens( - completion_response_gen: CompletionResponseAsyncGen, -) -> TokenAsyncGen: - """Convert a stream completion response to a stream of tokens.""" - - async def gen() -> TokenAsyncGen: - async for response in completion_response_gen: - yield response.delta or "" - - return gen() - - -async def astream_chat_response_to_tokens( - chat_response_gen: ChatResponseAsyncGen, -) -> TokenAsyncGen: - """Convert a stream completion response to a stream of tokens.""" - - async def gen() -> TokenAsyncGen: - async for response in chat_response_gen: - yield response.delta or "" - - return gen() diff --git a/llama_index/llm_predictor/vellum/predictor.py b/llama_index/llm_predictor/vellum/predictor.py index f3ebf20631..7c3287c49b 100644 --- a/llama_index/llm_predictor/vellum/predictor.py +++ b/llama_index/llm_predictor/vellum/predictor.py @@ -2,6 +2,8 @@ from __future__ import annotations from typing import Any, Tuple, cast +from deprecated import deprecated + from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload @@ -16,6 +18,7 @@ from llama_index.prompts import BasePromptTemplate from llama_index.types import TokenAsyncGen, TokenGen +@deprecated("VellumPredictor is deprecated and will be removed in a future release.") class VellumPredictor(BaseLLMPredictor): _callback_manager: CallbackManager = PrivateAttr(default_factory=CallbackManager) diff --git a/llama_index/llms/__init__.py b/llama_index/llms/__init__.py index 870e1cf2ff..901c0d24b1 100644 --- a/llama_index/llms/__init__.py +++ b/llama_index/llms/__init__.py @@ -2,18 +2,6 @@ from llama_index.llms.ai21 import AI21 from llama_index.llms.anthropic import Anthropic from llama_index.llms.anyscale import Anyscale from llama_index.llms.azure_openai import AzureOpenAI -from llama_index.llms.base import ( - LLM, - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, -) from llama_index.llms.bedrock import Bedrock from llama_index.llms.clarifai import Clarifai from llama_index.llms.cohere import Cohere @@ -25,6 +13,7 @@ from llama_index.llms.konko import Konko from llama_index.llms.langchain import LangChainLLM from llama_index.llms.litellm import LiteLLM from llama_index.llms.llama_cpp import LlamaCPP +from llama_index.llms.llm import LLM from llama_index.llms.localai import LOCALAI_DEFAULTS, LocalAI from llama_index.llms.mock import MockLLM from llama_index.llms.monsterapi import MonsterLLM @@ -37,6 +26,17 @@ from llama_index.llms.perplexity import Perplexity from llama_index.llms.portkey import Portkey from llama_index.llms.predibase import PredibaseLLM from llama_index.llms.replicate import Replicate +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) from llama_index.llms.vertex import Vertex from llama_index.llms.vllm import Vllm, VllmServer from llama_index.llms.watsonx import WatsonX diff --git a/llama_index/llms/ai21.py b/llama_index/llms/ai21.py index 305aadbcaf..0ed8216b69 100644 --- a/llama_index/llms/ai21.py +++ b/llama_index/llms/ai21.py @@ -1,23 +1,23 @@ -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.llms.ai21_utils import ai21_model_to_context_size -from llama_index.llms.base import ( +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import ( + completion_to_chat_decorator, + get_from_param_or_env, +) +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, CompletionResponse, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import ( - completion_to_chat_decorator, - get_from_param_or_env, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode class AI21(CustomLLM): @@ -41,6 +41,11 @@ class AI21(CustomLLM): temperature: Optional[float] = 0.1, additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: """Initialize params.""" try: @@ -63,6 +68,11 @@ class AI21(CustomLLM): temperature=temperature, additional_kwargs=additional_kwargs, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/anthropic.py b/llama_index/llms/anthropic.py index 0cf10ad7c8..86ceff3a58 100644 --- a/llama_index/llms/anthropic.py +++ b/llama_index/llms/anthropic.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager @@ -8,7 +8,17 @@ from llama_index.llms.anthropic_utils import ( messages_to_anthropic_prompt, ) from llama_index.llms.base import ( - LLM, + llm_chat_callback, + llm_completion_callback, +) +from llama_index.llms.generic_utils import ( + achat_to_completion_decorator, + astream_chat_to_completion_decorator, + chat_to_completion_decorator, + stream_chat_to_completion_decorator, +) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -18,15 +28,8 @@ from llama_index.llms.base import ( CompletionResponseGen, LLMMetadata, MessageRole, - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.generic_utils import ( - achat_to_completion_decorator, - astream_chat_to_completion_decorator, - chat_to_completion_decorator, - stream_chat_to_completion_decorator, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_ANTHROPIC_MODEL = "claude-2" DEFAULT_ANTHROPIC_MAX_TOKENS = 512 @@ -73,6 +76,11 @@ class Anthropic(LLM): api_key: Optional[str] = None, additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: try: import anthropic @@ -101,6 +109,11 @@ class Anthropic(LLM): max_retries=max_retries, model=model, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/anthropic_utils.py b/llama_index/llms/anthropic_utils.py index 7ee85c684d..f0904bd73a 100644 --- a/llama_index/llms/anthropic_utils.py +++ b/llama_index/llms/anthropic_utils.py @@ -1,6 +1,6 @@ from typing import Dict, Sequence -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole HUMAN_PREFIX = "\n\nHuman:" ASSISTANT_PREFIX = "\n\nAssistant:" diff --git a/llama_index/llms/anyscale.py b/llama_index/llms/anyscale.py index 714aa86858..d9404326d1 100644 --- a/llama_index/llms/anyscale.py +++ b/llama_index/llms/anyscale.py @@ -1,15 +1,14 @@ -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE from llama_index.llms.anyscale_utils import ( anyscale_modelname_to_contextsize, ) -from llama_index.llms.base import ( - LLMMetadata, -) from llama_index.llms.generic_utils import get_from_param_or_env from llama_index.llms.openai import OpenAI +from llama_index.llms.types import ChatMessage, LLMMetadata +from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1" DEFAULT_MODEL = "meta-llama/Llama-2-70b-chat-hf" @@ -26,6 +25,11 @@ class Anyscale(OpenAI): api_base: Optional[str] = DEFAULT_API_BASE, api_key: Optional[str] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: additional_kwargs = additional_kwargs or {} callback_manager = callback_manager or CallbackManager([]) @@ -42,6 +46,11 @@ class Anyscale(OpenAI): additional_kwargs=additional_kwargs, max_retries=max_retries, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/anyscale_utils.py b/llama_index/llms/anyscale_utils.py index 6c7953c8c2..334f2c2591 100644 --- a/llama_index/llms/anyscale_utils.py +++ b/llama_index/llms/anyscale_utils.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Sequence -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole LLAMA_MODELS = { "meta-llama/Llama-2-7b-chat-hf": 4096, diff --git a/llama_index/llms/azure_openai.py b/llama_index/llms/azure_openai.py index 5b7a2c2920..0ffa3d32f6 100644 --- a/llama_index/llms/azure_openai.py +++ b/llama_index/llms/azure_openai.py @@ -1,4 +1,4 @@ -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional, Sequence import httpx from openai import AsyncAzureOpenAI @@ -12,6 +12,8 @@ from llama_index.llms.openai_utils import ( refresh_openai_azuread_token, resolve_from_aliases, ) +from llama_index.llms.types import ChatMessage +from llama_index.types import BaseOutputParser, PydanticProgramMode class AzureOpenAI(OpenAI): @@ -77,6 +79,12 @@ class AzureOpenAI(OpenAI): deployment: Optional[str] = None, # custom httpx client http_client: Optional[httpx.Client] = None, + # base class + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> None: engine = resolve_from_aliases( @@ -109,6 +117,11 @@ class AzureOpenAI(OpenAI): use_azure_ad=use_azure_ad, api_version=api_version, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, **kwargs, ) diff --git a/llama_index/llms/base.py b/llama_index/llms/base.py index 8aee9c90f9..734143046f 100644 --- a/llama_index/llms/base.py +++ b/llama_index/llms/base.py @@ -1,120 +1,30 @@ import asyncio from abc import abstractmethod from contextlib import contextmanager -from enum import Enum -from typing import Any, AsyncGenerator, Callable, Generator, Optional, Sequence, cast - -from llama_index.bridge.pydantic import BaseModel, Field, validator +from typing import ( + Any, + AsyncGenerator, + Callable, + Generator, + Sequence, + cast, +) + +from llama_index.bridge.pydantic import Field, validator from llama_index.callbacks import CallbackManager, CBEventType, EventPayload -from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) from llama_index.schema import BaseComponent -class MessageRole(str, Enum): - """Message role.""" - - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - FUNCTION = "function" - TOOL = "tool" - - -# ===== Generic Model Input - Chat ===== -class ChatMessage(BaseModel): - """Chat message.""" - - role: MessageRole = MessageRole.USER - content: Optional[Any] = "" - additional_kwargs: dict = Field(default_factory=dict) - - def __str__(self) -> str: - return f"{self.role.value}: {self.content}" - - -# ===== Generic Model Output - Chat ===== -class ChatResponse(BaseModel): - """Chat response.""" - - message: ChatMessage - raw: Optional[dict] = None - delta: Optional[str] = None - additional_kwargs: dict = Field(default_factory=dict) - - def __str__(self) -> str: - return str(self.message) - - -ChatResponseGen = Generator[ChatResponse, None, None] -ChatResponseAsyncGen = AsyncGenerator[ChatResponse, None] - - -# ===== Generic Model Output - Completion ===== -class CompletionResponse(BaseModel): - """ - Completion response. - - Fields: - text: Text content of the response if not streaming, or if streaming, - the current extent of streamed text. - additional_kwargs: Additional information on the response(i.e. token - counts, function calling information). - raw: Optional raw JSON that was parsed to populate text, if relevant. - delta: New text that just streamed in (only relevant when streaming). - """ - - text: str - additional_kwargs: dict = Field(default_factory=dict) - raw: Optional[dict] = None - delta: Optional[str] = None - - def __str__(self) -> str: - return self.text - - -CompletionResponseGen = Generator[CompletionResponse, None, None] -CompletionResponseAsyncGen = AsyncGenerator[CompletionResponse, None] - - -class LLMMetadata(BaseModel): - context_window: int = Field( - default=DEFAULT_CONTEXT_WINDOW, - description=( - "Total number of tokens the model can be input and output for one response." - ), - ) - num_output: int = Field( - default=DEFAULT_NUM_OUTPUTS, - description="Number of tokens the model can output when generating a response.", - ) - is_chat_model: bool = Field( - default=False, - description=( - "Set True if the model exposes a chat interface (i.e. can be passed a" - " sequence of messages, rather than text), like OpenAI's" - " /v1/chat/completions endpoint." - ), - ) - is_function_calling_model: bool = Field( - default=False, - # SEE: https://openai.com/blog/function-calling-and-other-api-updates - description=( - "Set True if the model supports function calling messages, similar to" - " OpenAI's function calling API. For example, converting 'Email Anya to" - " see if she wants to get coffee next Friday' to a function call like" - " `send_email(to: string, body: string)`." - ), - ) - model_name: str = Field( - default="unknown", - description=( - "The model's name used for logging, testing, and sanity checking. For some" - " models this can be automatically discerned. For other models, like" - " locally loaded models, this must be manually specified." - ), - ) - - def llm_chat_callback() -> Callable: def wrap(f: Callable) -> Callable: @contextmanager @@ -366,7 +276,7 @@ def llm_completion_callback() -> Callable: return wrap -class LLM(BaseComponent): +class BaseLLM(BaseComponent): """LLM interface.""" callback_manager: CallbackManager = Field( diff --git a/llama_index/llms/bedrock.py b/llama_index/llms/bedrock.py index e92961f33d..327fc7d485 100644 --- a/llama_index/llms/bedrock.py +++ b/llama_index/llms/bedrock.py @@ -1,18 +1,9 @@ import json -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.llms.base import ( - LLM, - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, llm_chat_callback, llm_completion_callback, ) @@ -26,6 +17,18 @@ from llama_index.llms.bedrock_utils import ( get_text_from_response, stream_completion_to_chat_decorator, ) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) +from llama_index.types import BaseOutputParser, PydanticProgramMode class Bedrock(LLM): @@ -70,6 +73,11 @@ class Bedrock(LLM): max_retries: Optional[int] = 10, additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: if context_size is None and model not in BEDROCK_FOUNDATION_LLMS: raise ValueError( @@ -124,6 +132,11 @@ class Bedrock(LLM): max_retries=max_retries, additional_kwargs=additional_kwargs, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod @@ -136,7 +149,7 @@ class Bedrock(LLM): return LLMMetadata( context_window=self.context_size, num_output=self.max_tokens, - is_chat_model=True, + is_chat_model=False, model_name=self.model, ) diff --git a/llama_index/llms/bedrock_utils.py b/llama_index/llms/bedrock_utils.py index c6389c37e7..9bc756049c 100644 --- a/llama_index/llms/bedrock_utils.py +++ b/llama_index/llms/bedrock_utils.py @@ -9,7 +9,11 @@ from tenacity import ( wait_exponential, ) -from llama_index.llms.base import ( +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, @@ -17,10 +21,6 @@ from llama_index.llms.base import ( CompletionResponseGen, MessageRole, ) -from llama_index.llms.generic_utils import ( - completion_response_to_chat_response, - stream_completion_response_to_chat_response, -) HUMAN_PREFIX = "\n\nHuman:" ASSISTANT_PREFIX = "\n\nAssistant:" diff --git a/llama_index/llms/clarifai.py b/llama_index/llms/clarifai.py index aa4c23a5f8..28e5d3dad4 100644 --- a/llama_index/llms/clarifai.py +++ b/llama_index/llms/clarifai.py @@ -1,9 +1,13 @@ -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.llms.base import ( - LLM, + llm_chat_callback, + llm_completion_callback, +) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -12,9 +16,8 @@ from llama_index.llms.base import ( CompletionResponseAsyncGen, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode EXAMPLE_URL = "https://clarifai.com/anthropic/completion/models/claude-v2" @@ -41,6 +44,11 @@ class Clarifai(LLM): max_tokens: int = 512, additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ): try: from clarifai.client.model import Model @@ -81,6 +89,11 @@ class Clarifai(LLM): additional_kwargs=additional_kwargs, callback_manager=callback_manager, model_name=model_name, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/cohere.py b/llama_index/llms/cohere.py index e5de8acd40..2383a2eae0 100644 --- a/llama_index/llms/cohere.py +++ b/llama_index/llms/cohere.py @@ -1,19 +1,9 @@ import warnings -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.llms.base import ( - LLM, - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, llm_chat_callback, llm_completion_callback, ) @@ -24,6 +14,19 @@ from llama_index.llms.cohere_utils import ( completion_with_retry, messages_to_cohere_history, ) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) +from llama_index.types import BaseOutputParser, PydanticProgramMode class Cohere(LLM): @@ -50,6 +53,11 @@ class Cohere(LLM): api_key: Optional[str] = None, additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: try: import cohere @@ -72,6 +80,11 @@ class Cohere(LLM): model=model, callback_manager=callback_manager, max_tokens=max_tokens, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/cohere_utils.py b/llama_index/llms/cohere_utils.py index 183564be20..292102f510 100644 --- a/llama_index/llms/cohere_utils.py +++ b/llama_index/llms/cohere_utils.py @@ -9,7 +9,7 @@ from tenacity import ( wait_exponential, ) -from llama_index.llms.base import ChatMessage +from llama_index.llms.types import ChatMessage COMMAND_MODELS = { "command": 4096, diff --git a/llama_index/llms/custom.py b/llama_index/llms/custom.py index 90a70dbee6..48eee2aee4 100644 --- a/llama_index/llms/custom.py +++ b/llama_index/llms/custom.py @@ -1,13 +1,6 @@ from typing import Any, Sequence from llama_index.llms.base import ( - LLM, - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, llm_chat_callback, llm_completion_callback, ) @@ -15,13 +8,22 @@ from llama_index.llms.generic_utils import ( completion_to_chat_decorator, stream_completion_to_chat_decorator, ) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, +) class CustomLLM(LLM): """Simple abstract base class for custom LLMs. - Subclasses must implement the `__init__`, `complete`, - `stream_complete`, and `metadata` methods. + Subclasses must implement the `__init__`, `_complete`, + `_stream_complete`, and `metadata` methods. """ @llm_chat_callback() diff --git a/llama_index/llms/everlyai.py b/llama_index/llms/everlyai.py index 1ff6404b59..708b801db4 100644 --- a/llama_index/llms/everlyai.py +++ b/llama_index/llms/everlyai.py @@ -1,11 +1,12 @@ -from typing import Any, Dict, Optional +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE -from llama_index.llms.base import LLMMetadata from llama_index.llms.everlyai_utils import everlyai_modelname_to_contextsize from llama_index.llms.generic_utils import get_from_param_or_env from llama_index.llms.openai import OpenAI +from llama_index.llms.types import ChatMessage, LLMMetadata +from llama_index.types import BaseOutputParser, PydanticProgramMode EVERLYAI_API_BASE = "https://everlyai.xyz/hosted" DEFAULT_MODEL = "meta-llama/Llama-2-7b-chat-hf" @@ -21,6 +22,11 @@ class EverlyAI(OpenAI): max_retries: int = 10, api_key: Optional[str] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: additional_kwargs = additional_kwargs or {} callback_manager = callback_manager or CallbackManager([]) @@ -36,6 +42,11 @@ class EverlyAI(OpenAI): additional_kwargs=additional_kwargs, max_retries=max_retries, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/generic_utils.py b/llama_index/llms/generic_utils.py index fd501dc12d..3ad12c0c2b 100644 --- a/llama_index/llms/generic_utils.py +++ b/llama_index/llms/generic_utils.py @@ -1,7 +1,7 @@ import os from typing import Any, Awaitable, Callable, List, Optional, Sequence -from llama_index.llms.base import ( +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, diff --git a/llama_index/llms/gradient.py b/llama_index/llms/gradient.py index cadcdcf6ef..c2091523be 100644 --- a/llama_index/llms/gradient.py +++ b/llama_index/llms/gradient.py @@ -1,17 +1,19 @@ -from typing import Any, Optional +from typing import Any, Callable, Optional, Sequence from typing_extensions import override from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS -from llama_index.llms import ( +from llama_index.llms.base import llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.types import ( + ChatMessage, CompletionResponse, CompletionResponseGen, - CustomLLM, LLMMetadata, ) -from llama_index.llms.base import llm_completion_callback +from llama_index.types import BaseOutputParser, PydanticProgramMode class _BaseGradientLLM(CustomLLM): @@ -49,6 +51,11 @@ class _BaseGradientLLM(CustomLLM): workspace_id: Optional[str] = None, callback_manager: Optional[CallbackManager] = None, is_chat_model: bool = False, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> None: super().__init__( @@ -58,6 +65,11 @@ class _BaseGradientLLM(CustomLLM): workspace_id=workspace_id, callback_manager=callback_manager, is_chat_model=is_chat_model, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, **kwargs, ) try: diff --git a/llama_index/llms/huggingface.py b/llama_index/llms/huggingface.py index fb1fc67b8e..acd9fe9ea0 100644 --- a/llama_index/llms/huggingface.py +++ b/llama_index/llms/huggingface.py @@ -8,16 +8,7 @@ from llama_index.constants import ( DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS, ) -from llama_index.llms import ChatResponseAsyncGen, CompletionResponseAsyncGen from llama_index.llms.base import ( - LLM, - ChatMessage, - ChatResponse, - ChatResponseGen, - CompletionResponse, - CompletionResponseGen, - LLMMetadata, - MessageRole, llm_chat_callback, llm_completion_callback, ) @@ -29,7 +20,19 @@ from llama_index.llms.generic_utils import ( from llama_index.llms.generic_utils import ( messages_to_prompt as generic_messages_to_prompt, ) +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) from llama_index.prompts.base import PromptTemplate +from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_HUGGINGFACE_MODEL = "StabilityAI/stablelm-tuned-alpha-3b" if TYPE_CHECKING: @@ -73,8 +76,8 @@ class HuggingFaceLLM(CustomLLM): "The model card on HuggingFace should specify if this is needed." ), ) - query_wrapper_prompt: str = Field( - default="{query_str}", + query_wrapper_prompt: PromptTemplate = Field( + default=PromptTemplate("{query_str}"), description=( "The query wrapper prompt, containing the query placeholder. " "The model card on HuggingFace should specify if this is needed. " @@ -129,13 +132,11 @@ class HuggingFaceLLM(CustomLLM): _model: Any = PrivateAttr() _tokenizer: Any = PrivateAttr() _stopping_criteria: Any = PrivateAttr() - _messages_to_prompt: Callable = PrivateAttr() def __init__( self, context_window: int = DEFAULT_CONTEXT_WINDOW, max_new_tokens: int = DEFAULT_NUM_OUTPUTS, - system_prompt: str = "", query_wrapper_prompt: Union[str, PromptTemplate] = "{query_str}", tokenizer_name: str = DEFAULT_HUGGINGFACE_MODEL, model_name: str = DEFAULT_HUGGINGFACE_MODEL, @@ -148,8 +149,12 @@ class HuggingFaceLLM(CustomLLM): model_kwargs: Optional[dict] = None, generate_kwargs: Optional[dict] = None, is_chat_model: Optional[bool] = False, - messages_to_prompt: Optional[Callable] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: """Initialize params.""" try: @@ -215,8 +220,8 @@ class HuggingFaceLLM(CustomLLM): self._stopping_criteria = StoppingCriteriaList([StopOnTokens()]) - if isinstance(query_wrapper_prompt, PromptTemplate): - query_wrapper_prompt = query_wrapper_prompt.template + if isinstance(query_wrapper_prompt, str): + query_wrapper_prompt = PromptTemplate(query_wrapper_prompt) self._messages_to_prompt = ( messages_to_prompt or self._tokenizer_messages_to_prompt @@ -225,7 +230,6 @@ class HuggingFaceLLM(CustomLLM): super().__init__( context_window=context_window, max_new_tokens=max_new_tokens, - system_prompt=system_prompt, query_wrapper_prompt=query_wrapper_prompt, tokenizer_name=tokenizer_name, model_name=model_name, @@ -237,6 +241,11 @@ class HuggingFaceLLM(CustomLLM): generate_kwargs=generate_kwargs or {}, is_chat_model=is_chat_model, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod @@ -386,7 +395,7 @@ def chat_messages_to_conversational_kwargs( return kwargs -class HuggingFaceInferenceAPI(LLM): +class HuggingFaceInferenceAPI(CustomLLM): """ Wrapper on the Hugging Face's Inference API. diff --git a/llama_index/llms/konko.py b/llama_index/llms/konko.py index 10af19d63f..ecb0562585 100644 --- a/llama_index/llms/konko.py +++ b/llama_index/llms/konko.py @@ -3,19 +3,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE -from llama_index.llms.base import ( - LLM, - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback from llama_index.llms.generic_utils import ( achat_to_completion_decorator, acompletion_to_chat_decorator, @@ -35,6 +23,18 @@ from llama_index.llms.konko_utils import ( resolve_konko_credentials, to_openai_message_dicts, ) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) +from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_KONKO_MODEL = "meta-llama/Llama-2-13b-chat-hf" @@ -80,6 +80,11 @@ class Konko(LLM): api_base: Optional[str] = None, api_version: Optional[str] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> None: additional_kwargs = additional_kwargs or {} @@ -110,6 +115,11 @@ class Konko(LLM): api_type=api_type, api_version=api_version, api_base=api_base, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, **kwargs, ) diff --git a/llama_index/llms/konko_utils.py b/llama_index/llms/konko_utils.py index 6859f96990..a097aab4cb 100644 --- a/llama_index/llms/konko_utils.py +++ b/llama_index/llms/konko_utils.py @@ -11,8 +11,8 @@ from tenacity import ( ) from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import ChatMessage from llama_index.llms.generic_utils import get_from_param_or_env +from llama_index.llms.types import ChatMessage DEFAULT_KONKO_API_TYPE = "open_ai" DEFAULT_KONKO_API_BASE = "https://api.konko.ai/v1" diff --git a/llama_index/llms/langchain.py b/llama_index/llms/langchain.py index 145f60a5f9..b0f49ae9de 100644 --- a/llama_index/llms/langchain.py +++ b/llama_index/llms/langchain.py @@ -1,13 +1,14 @@ from threading import Thread -from typing import TYPE_CHECKING, Any, Generator, Optional, Sequence +from typing import TYPE_CHECKING, Any, Callable, Generator, Optional, Sequence if TYPE_CHECKING: from langchain.base_language import BaseLanguageModel from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.base import ( - LLM, +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -16,9 +17,8 @@ from llama_index.llms.base import ( CompletionResponseAsyncGen, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode class LangChainLLM(LLM): @@ -30,9 +30,21 @@ class LangChainLLM(LLM): self, llm: "BaseLanguageModel", callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: self._llm = llm - super().__init__(callback_manager=callback_manager) + super().__init__( + callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, + ) @classmethod def class_name(cls) -> str: diff --git a/llama_index/llms/langchain_utils.py b/llama_index/llms/langchain_utils.py index ef47ff9a4c..accc029b08 100644 --- a/llama_index/llms/langchain_utils.py +++ b/llama_index/llms/langchain_utils.py @@ -16,8 +16,8 @@ from llama_index.bridge.langchain import ( from llama_index.bridge.langchain import BaseMessage as LCMessage from llama_index.constants import AI21_J2_CONTEXT_WINDOW, COHERE_CONTEXT_WINDOW from llama_index.llms.anyscale_utils import anyscale_modelname_to_contextsize -from llama_index.llms.base import ChatMessage, LLMMetadata, MessageRole from llama_index.llms.openai_utils import openai_modelname_to_contextsize +from llama_index.llms.types import ChatMessage, LLMMetadata, MessageRole def is_chat_model(llm: BaseLanguageModel) -> bool: diff --git a/llama_index/llms/litellm.py b/llama_index/llms/litellm.py index b0acaaa8e0..e1524c6306 100644 --- a/llama_index/llms/litellm.py +++ b/llama_index/llms/litellm.py @@ -3,19 +3,7 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_TEMPERATURE -from llama_index.llms.base import ( - LLM, - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback from llama_index.llms.generic_utils import ( achat_to_completion_decorator, acompletion_to_chat_decorator, @@ -35,6 +23,18 @@ from llama_index.llms.litellm_utils import ( to_openai_message_dicts, validate_litellm_api_key, ) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) +from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo" @@ -77,6 +77,11 @@ class LiteLLM(LLM): api_type: Optional[str] = None, api_base: Optional[str] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> None: if "custom_llm_provider" in kwargs: @@ -103,6 +108,11 @@ class LiteLLM(LLM): additional_kwargs=additional_kwargs, max_retries=max_retries, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, **kwargs, ) diff --git a/llama_index/llms/litellm_utils.py b/llama_index/llms/litellm_utils.py index 1696367946..2af40dab67 100644 --- a/llama_index/llms/litellm_utils.py +++ b/llama_index/llms/litellm_utils.py @@ -11,7 +11,7 @@ from tenacity import ( ) from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import ChatMessage +from llama_index.llms.types import ChatMessage MISSING_API_KEY_ERROR_MESSAGE = """No API key found for LLM. E.g. to use openai Please set the OPENAI_API_KEY environment variable or \ diff --git a/llama_index/llms/llama_api.py b/llama_index/llms/llama_api.py index b8d879de0a..9f7e07e137 100644 --- a/llama_index/llms/llama_api.py +++ b/llama_index/llms/llama_api.py @@ -1,24 +1,24 @@ -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import ( +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import chat_to_completion_decorator +from llama_index.llms.openai_utils import ( + from_openai_message_dict, + to_openai_message_dicts, +) +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, CompletionResponse, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import chat_to_completion_decorator -from llama_index.llms.openai_utils import ( - from_openai_message_dict, - to_openai_message_dicts, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode class LlamaAPI(CustomLLM): @@ -39,6 +39,11 @@ class LlamaAPI(CustomLLM): additional_kwargs: Optional[Dict[str, Any]] = None, api_key: Optional[str] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: try: from llamaapi import LlamaAPI as Client @@ -56,6 +61,11 @@ class LlamaAPI(CustomLLM): max_tokens=max_tokens, additional_kwargs=additional_kwargs or {}, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/llama_cpp.py b/llama_index/llms/llama_cpp.py index 1ff5edec52..7ab0bd1cda 100644 --- a/llama_index/llms/llama_cpp.py +++ b/llama_index/llms/llama_cpp.py @@ -11,24 +11,21 @@ from llama_index.constants import ( DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE, ) -from llama_index.llms.base import ( +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, CompletionResponse, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import ( - completion_response_to_chat_response, - stream_completion_response_to_chat_response, -) -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode from llama_index.utils import get_cache_dir DEFAULT_LLAMA_CPP_GGML_MODEL = ( @@ -65,12 +62,6 @@ class LlamaCPP(CustomLLM): description="The maximum number of context tokens for the model.", gt=0, ) - messages_to_prompt: Callable = Field( - description="The function to convert messages to a prompt.", exclude=True - ) - completion_to_prompt: Callable = Field( - description="The function to convert a completion to a prompt.", exclude=True - ) generate_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Kwargs used for generation." ) @@ -91,12 +82,15 @@ class LlamaCPP(CustomLLM): temperature: float = DEFAULT_TEMPERATURE, max_new_tokens: int = DEFAULT_NUM_OUTPUTS, context_window: int = DEFAULT_CONTEXT_WINDOW, - messages_to_prompt: Optional[Callable] = None, - completion_to_prompt: Optional[Callable] = None, callback_manager: Optional[CallbackManager] = None, generate_kwargs: Optional[Dict[str, Any]] = None, model_kwargs: Optional[Dict[str, Any]] = None, verbose: bool = DEFAULT_LLAMA_CPP_MODEL_VERBOSITY, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: try: from llama_cpp import Llama @@ -135,9 +129,6 @@ class LlamaCPP(CustomLLM): self._model = Llama(model_path=model_path, **model_kwargs) model_path = model_path - messages_to_prompt = messages_to_prompt or generic_messages_to_prompt - completion_to_prompt = completion_to_prompt or (lambda x: x) - generate_kwargs = generate_kwargs or {} generate_kwargs.update( {"temperature": temperature, "max_tokens": max_new_tokens} @@ -149,12 +140,15 @@ class LlamaCPP(CustomLLM): temperature=temperature, context_window=context_window, max_new_tokens=max_new_tokens, - messages_to_prompt=messages_to_prompt, - completion_to_prompt=completion_to_prompt, callback_manager=callback_manager, generate_kwargs=generate_kwargs, model_kwargs=model_kwargs, verbose=verbose, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/llama_utils.py b/llama_index/llms/llama_utils.py index d0ad1d1b5e..2ee0e950b7 100644 --- a/llama_index/llms/llama_utils.py +++ b/llama_index/llms/llama_utils.py @@ -1,6 +1,6 @@ from typing import List, Optional, Sequence -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole BOS, EOS = "<s>", "</s>" B_INST, E_INST = "[INST]", "[/INST]" diff --git a/llama_index/llms/llm.py b/llama_index/llms/llm.py new file mode 100644 index 0000000000..6d8e0aa193 --- /dev/null +++ b/llama_index/llms/llm.py @@ -0,0 +1,310 @@ +from collections import ChainMap +from typing import Any, List, Optional, Protocol, Sequence, runtime_checkable + +from llama_index.bridge.pydantic import BaseModel, Field, validator +from llama_index.callbacks import CBEventType, EventPayload +from llama_index.llms.base import BaseLLM +from llama_index.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) +from llama_index.llms.types import ( + ChatMessage, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponseAsyncGen, + CompletionResponseGen, + MessageRole, +) +from llama_index.prompts import BasePromptTemplate, PromptTemplate +from llama_index.types import ( + BaseOutputParser, + PydanticProgramMode, + TokenAsyncGen, + TokenGen, +) + + +# NOTE: These two protocols are needed to appease mypy +@runtime_checkable +class MessagesToPromptType(Protocol): + def __call__(self, messages: Sequence[ChatMessage]) -> str: + pass + + +@runtime_checkable +class CompletionToPromptType(Protocol): + def __call__(self, prompt: str) -> str: + pass + + +def stream_completion_response_to_tokens( + completion_response_gen: CompletionResponseGen, +) -> TokenGen: + """Convert a stream completion response to a stream of tokens.""" + + def gen() -> TokenGen: + for response in completion_response_gen: + yield response.delta or "" + + return gen() + + +def stream_chat_response_to_tokens( + chat_response_gen: ChatResponseGen, +) -> TokenGen: + """Convert a stream completion response to a stream of tokens.""" + + def gen() -> TokenGen: + for response in chat_response_gen: + yield response.delta or "" + + return gen() + + +async def astream_completion_response_to_tokens( + completion_response_gen: CompletionResponseAsyncGen, +) -> TokenAsyncGen: + """Convert a stream completion response to a stream of tokens.""" + + async def gen() -> TokenAsyncGen: + async for response in completion_response_gen: + yield response.delta or "" + + return gen() + + +async def astream_chat_response_to_tokens( + chat_response_gen: ChatResponseAsyncGen, +) -> TokenAsyncGen: + """Convert a stream completion response to a stream of tokens.""" + + async def gen() -> TokenAsyncGen: + async for response in chat_response_gen: + yield response.delta or "" + + return gen() + + +class LLM(BaseLLM): + system_prompt: Optional[str] = Field(description="System prompt for LLM calls.") + messages_to_prompt: MessagesToPromptType = Field( + description="Function to convert a list of messages to an LLM prompt.", + default=generic_messages_to_prompt, + exclude=True, + ) + completion_to_prompt: CompletionToPromptType = Field( + description="Function to convert a completion to an LLM prompt.", + default=lambda x: x, + exclude=True, + ) + output_parser: Optional[BaseOutputParser] = Field( + description="Output parser to parse, validate, and correct errors programmatically.", + default=None, + exclude=True, + ) + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT + + # deprecated + query_wrapper_prompt: Optional[BasePromptTemplate] = Field( + description="Query wrapper prompt for LLM calls.", + default=None, + exclude=True, + ) + + @validator("messages_to_prompt", pre=True) + def set_messages_to_prompt( + cls, messages_to_prompt: Optional[MessagesToPromptType] + ) -> MessagesToPromptType: + return messages_to_prompt or generic_messages_to_prompt + + @validator("completion_to_prompt", pre=True) + def set_completion_to_prompt( + cls, completion_to_prompt: Optional[CompletionToPromptType] + ) -> CompletionToPromptType: + return completion_to_prompt or (lambda x: x) + + def _log_template_data( + self, prompt: BasePromptTemplate, **prompt_args: Any + ) -> None: + template_vars = { + k: v + for k, v in ChainMap(prompt.kwargs, prompt_args).items() + if k in prompt.template_vars + } + with self.callback_manager.event( + CBEventType.TEMPLATING, + payload={ + EventPayload.TEMPLATE: prompt.get_template(llm=self), + EventPayload.TEMPLATE_VARS: template_vars, + EventPayload.SYSTEM_PROMPT: self.system_prompt, + EventPayload.QUERY_WRAPPER_PROMPT: self.query_wrapper_prompt, + }, + ): + pass + + def _get_prompt(self, prompt: BasePromptTemplate, **prompt_args: Any) -> str: + formatted_prompt = prompt.format( + llm=self, + messages_to_prompt=self.messages_to_prompt, + completion_to_prompt=self.completion_to_prompt, + **prompt_args, + ) + if self.output_parser is not None: + formatted_prompt = self.output_parser.format(formatted_prompt) + return self._extend_prompt(formatted_prompt) + + def _get_messages( + self, prompt: BasePromptTemplate, **prompt_args: Any + ) -> List[ChatMessage]: + messages = prompt.format_messages(llm=self, **prompt_args) + if self.output_parser is not None: + messages = self.output_parser.format_messages(messages) + return self._extend_messages(messages) + + def structured_predict( + self, + output_cls: BaseModel, + prompt: PromptTemplate, + **prompt_args: Any, + ) -> BaseModel: + from llama_index.program.utils import get_program_for_llm + + program = get_program_for_llm( + output_cls, + prompt, + self, + pydantic_program_mode=self.pydantic_program_mode, + ) + + return program(**prompt_args) + + async def astructured_predict( + self, + output_cls: BaseModel, + prompt: PromptTemplate, + **prompt_args: Any, + ) -> BaseModel: + from llama_index.program.utils import get_program_for_llm + + program = get_program_for_llm( + output_cls, + prompt, + self, + pydantic_program_mode=self.pydantic_program_mode, + ) + + return await program.acall(**prompt_args) + + def _parse_output(self, output: str) -> str: + if self.output_parser is not None: + return str(self.output_parser.parse(output)) + + return output + + def predict( + self, + prompt: BasePromptTemplate, + **prompt_args: Any, + ) -> str: + """Predict.""" + self._log_template_data(prompt, **prompt_args) + + if self.metadata.is_chat_model: + messages = self._get_messages(prompt, **prompt_args) + chat_response = self.chat(messages) + output = chat_response.message.content or "" + else: + formatted_prompt = self._get_prompt(prompt, **prompt_args) + response = self.complete(formatted_prompt) + output = response.text + + return self._parse_output(output) + + def stream( + self, + prompt: BasePromptTemplate, + **prompt_args: Any, + ) -> TokenGen: + """Stream.""" + self._log_template_data(prompt, **prompt_args) + + if self.metadata.is_chat_model: + messages = self._get_messages(prompt, **prompt_args) + chat_response = self.stream_chat(messages) + stream_tokens = stream_chat_response_to_tokens(chat_response) + else: + formatted_prompt = self._get_prompt(prompt, **prompt_args) + stream_response = self.stream_complete(formatted_prompt) + stream_tokens = stream_completion_response_to_tokens(stream_response) + + if prompt.output_parser is not None or self.output_parser is not None: + raise NotImplementedError("Output parser is not supported for streaming.") + + return stream_tokens + + async def apredict( + self, + prompt: BasePromptTemplate, + **prompt_args: Any, + ) -> str: + """Async predict.""" + self._log_template_data(prompt, **prompt_args) + + if self.metadata.is_chat_model: + messages = self._get_messages(prompt, **prompt_args) + chat_response = await self.achat(messages) + output = chat_response.message.content or "" + else: + formatted_prompt = self._get_prompt(prompt, **prompt_args) + response = await self.acomplete(formatted_prompt) + output = response.text + + return self._parse_output(output) + + async def astream( + self, + prompt: BasePromptTemplate, + **prompt_args: Any, + ) -> TokenAsyncGen: + """Async stream.""" + self._log_template_data(prompt, **prompt_args) + + if self.metadata.is_chat_model: + messages = self._get_messages(prompt, **prompt_args) + chat_response = await self.astream_chat(messages) + stream_tokens = await astream_chat_response_to_tokens(chat_response) + else: + formatted_prompt = self._get_prompt(prompt, **prompt_args) + stream_response = await self.astream_complete(formatted_prompt) + stream_tokens = await astream_completion_response_to_tokens(stream_response) + + if prompt.output_parser is not None or self.output_parser is not None: + raise NotImplementedError("Output parser is not supported for streaming.") + + return stream_tokens + + def _extend_prompt( + self, + formatted_prompt: str, + ) -> str: + """Add system and query wrapper prompts to base prompt.""" + extended_prompt = formatted_prompt + + if self.system_prompt: + extended_prompt = self.system_prompt + "\n\n" + extended_prompt + + if self.query_wrapper_prompt: + extended_prompt = self.query_wrapper_prompt.format( + query_str=extended_prompt + ) + + return extended_prompt + + def _extend_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: + """Add system prompt to chat message list.""" + if self.system_prompt: + messages = [ + ChatMessage(role=MessageRole.SYSTEM, content=self.system_prompt), + *messages, + ] + return messages diff --git a/llama_index/llms/loading.py b/llama_index/llms/loading.py index 2cbb9e74ad..a91e352c2d 100644 --- a/llama_index/llms/loading.py +++ b/llama_index/llms/loading.py @@ -1,12 +1,12 @@ from typing import Dict, Type -from llama_index.llms.base import LLM from llama_index.llms.bedrock import Bedrock from llama_index.llms.custom import CustomLLM from llama_index.llms.gradient import GradientBaseModelLLM, GradientModelAdapterLLM from llama_index.llms.huggingface import HuggingFaceLLM from llama_index.llms.langchain import LangChainLLM from llama_index.llms.llama_cpp import LlamaCPP +from llama_index.llms.llm import LLM from llama_index.llms.mock import MockLLM from llama_index.llms.openai import OpenAI from llama_index.llms.palm import PaLM diff --git a/llama_index/llms/localai.py b/llama_index/llms/localai.py index d3baedf551..9f307254a6 100644 --- a/llama_index/llms/localai.py +++ b/llama_index/llms/localai.py @@ -7,14 +7,15 @@ Source: https://github.com/go-skynet/LocalAI import warnings from types import MappingProxyType -from typing import Any, Dict, Mapping, Optional +from typing import Any, Callable, Dict, Mapping, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.constants import DEFAULT_CONTEXT_WINDOW -from llama_index.llms.base import LLMMetadata from llama_index.llms.openai import OpenAI from llama_index.llms.openai_like import OpenAILike from llama_index.llms.openai_utils import is_function_calling_model +from llama_index.llms.types import ChatMessage, LLMMetadata +from llama_index.types import BaseOutputParser, PydanticProgramMode # Use these as kwargs for OpenAILike to connect to LocalAIs DEFAULT_LOCALAI_PORT = 8080 @@ -47,9 +48,23 @@ class LocalAI(OpenAI): self, api_key: Optional[str] = LOCALAI_DEFAULTS["api_key"], api_base: Optional[str] = LOCALAI_DEFAULTS["api_base"], + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> None: - super().__init__(api_key=api_key, api_base=api_base, **kwargs) + super().__init__( + api_key=api_key, + api_base=api_base, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, + **kwargs, + ) warnings.warn( ( f"{type(self).__name__} subclass is deprecated in favor of" diff --git a/llama_index/llms/mock.py b/llama_index/llms/mock.py index 41b9ffd9af..9e3cf32e2d 100644 --- a/llama_index/llms/mock.py +++ b/llama_index/llms/mock.py @@ -1,13 +1,15 @@ -from typing import Any, Optional +from typing import Any, Callable, Optional, Sequence from llama_index.callbacks import CallbackManager -from llama_index.llms.base import ( +from llama_index.llms.base import llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.types import ( + ChatMessage, CompletionResponse, CompletionResponseGen, LLMMetadata, - llm_completion_callback, ) -from llama_index.llms.custom import CustomLLM +from llama_index.types import PydanticProgramMode class MockLLM(CustomLLM): @@ -17,8 +19,19 @@ class MockLLM(CustomLLM): self, max_tokens: Optional[int] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, ) -> None: - super().__init__(max_tokens=max_tokens, callback_manager=callback_manager) + super().__init__( + max_tokens=max_tokens, + callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + ) @classmethod def class_name(cls) -> str: diff --git a/llama_index/llms/monsterapi.py b/llama_index/llms/monsterapi.py index c7f12759ee..0e21207cb0 100644 --- a/llama_index/llms/monsterapi.py +++ b/llama_index/llms/monsterapi.py @@ -3,19 +3,16 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import ( +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, CompletionResponse, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_MONSTER_TEMP = 0.75 @@ -40,13 +37,6 @@ class MonsterLLM(CustomLLM): gt=0, ) - messages_to_prompt: Callable = Field( - description="The function to convert messages to a prompt.", exclude=True - ) - completion_to_prompt: Callable = Field( - description="The function to convert a completion to a prompt.", exclude=True - ) - _client: Any = PrivateAttr() def __init__( @@ -57,14 +47,14 @@ class MonsterLLM(CustomLLM): temperature: float = DEFAULT_MONSTER_TEMP, context_window: int = DEFAULT_CONTEXT_WINDOW, callback_manager: Optional[CallbackManager] = None, - messages_to_prompt: Optional[Callable] = None, - completion_to_prompt: Optional[Callable] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: self._client, available_llms = self.initialize_client(monster_api_key) - _messages_to_prompt = messages_to_prompt or generic_messages_to_prompt - _completion_to_prompt = completion_to_prompt or (lambda x: x) - # Check if provided model is supported if model not in available_llms: error_message = ( @@ -82,8 +72,11 @@ class MonsterLLM(CustomLLM): temperature=temperature, context_window=context_window, callback_manager=callback_manager, - messages_to_prompt=_messages_to_prompt, - completion_to_prompt=_completion_to_prompt, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) def initialize_client(self, monster_api_key: Optional[str]) -> Any: diff --git a/llama_index/llms/ollama.py b/llama_index/llms/ollama.py index bcc37c32c9..078e2f7a3c 100644 --- a/llama_index/llms/ollama.py +++ b/llama_index/llms/ollama.py @@ -1,38 +1,35 @@ import json -from typing import Any, Callable, Dict, Iterator, Optional, Sequence +from typing import Any, Dict, Iterator, Sequence -from llama_index.bridge.pydantic import Field, PrivateAttr -from llama_index.callbacks import CallbackManager +from llama_index.bridge.pydantic import Field from llama_index.constants import ( DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS, - DEFAULT_TEMPERATURE, ) -from llama_index.llms.base import ( +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, CompletionResponse, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import ( - completion_response_to_chat_response, - stream_completion_response_to_chat_response, -) -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, ) class Ollama(CustomLLM): - base_url: str = Field(description="Base url the model is hosted under.") + base_url: str = Field( + default="http://localhost:11434", + description="Base url the model is hosted under.", + ) model: str = Field(description="The Ollama model to use.") temperature: float = Field( - default=DEFAULT_TEMPERATURE, + default=0.75, description="The temperature to use for sampling.", gte=0.0, lte=1.0, @@ -49,34 +46,6 @@ class Ollama(CustomLLM): default_factory=dict, description="Additional kwargs for the Ollama API." ) - _messages_to_prompt: Callable = PrivateAttr() - _completion_to_prompt: Callable = PrivateAttr() - - def __init__( - self, - model: str, - base_url: str = "http://localhost:11434", - temperature: float = 0.75, - additional_kwargs: Optional[Dict[str, Any]] = None, - context_window: int = DEFAULT_CONTEXT_WINDOW, - prompt_key: str = "prompt", - messages_to_prompt: Optional[Callable] = None, - completion_to_prompt: Optional[Callable] = None, - callback_manager: Optional[CallbackManager] = None, - ) -> None: - self._messages_to_prompt = messages_to_prompt or generic_messages_to_prompt - self._completion_to_prompt = completion_to_prompt or (lambda x: x) - - super().__init__( - model=model, - temperature=temperature, - base_url=base_url, - additional_kwargs=additional_kwargs or {}, - context_window=context_window, - prompt_key=prompt_key, - callback_manager=callback_manager, - ) - @classmethod def class_name(cls) -> str: return "Ollama_llm" @@ -112,16 +81,16 @@ class Ollama(CustomLLM): @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - prompt = self._messages_to_prompt(messages) - completion_response = self.complete(prompt, **kwargs) + prompt = self.messages_to_prompt(messages) + completion_response = self.complete(prompt, formatted=True, **kwargs) return completion_response_to_chat_response(completion_response) @llm_chat_callback() def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: - prompt = self._messages_to_prompt(messages) - completion_response = self.stream_complete(prompt, **kwargs) + prompt = self.messages_to_prompt(messages) + completion_response = self.stream_complete(prompt, formatted=True, **kwargs) return stream_completion_response_to_chat_response(completion_response) @llm_completion_callback() @@ -142,7 +111,9 @@ class Ollama(CustomLLM): "Please install requests with `pip install requests`" ) all_kwargs = self._get_all_kwargs(**kwargs) - prompt = self._completion_to_prompt(prompt) + + if not kwargs.get("formatted", False): + prompt = self.completion_to_prompt(prompt) response = requests.post( url=f"{self.base_url}/api/generate/", headers={"Content-Type": "application/json"}, diff --git a/llama_index/llms/openai.py b/llama_index/llms/openai.py index e2b35c9145..5a7ff6946b 100644 --- a/llama_index/llms/openai.py +++ b/llama_index/llms/openai.py @@ -27,16 +27,6 @@ from llama_index.constants import ( DEFAULT_TEMPERATURE, ) from llama_index.llms.base import ( - LLM, - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, llm_chat_callback, llm_completion_callback, ) @@ -50,6 +40,7 @@ from llama_index.llms.generic_utils import ( stream_chat_to_completion_decorator, stream_completion_to_chat_decorator, ) +from llama_index.llms.llm import LLM from llama_index.llms.openai_utils import ( from_openai_message, is_chat_model, @@ -58,6 +49,18 @@ from llama_index.llms.openai_utils import ( resolve_openai_credentials, to_openai_message_dicts, ) +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) +from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo" @@ -131,6 +134,12 @@ class OpenAI(LLM): callback_manager: Optional[CallbackManager] = None, default_headers: Optional[Dict[str, str]] = None, http_client: Optional[httpx.Client] = None, + # base class + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> None: additional_kwargs = additional_kwargs or {} @@ -154,6 +163,11 @@ class OpenAI(LLM): timeout=timeout, reuse_client=reuse_client, default_headers=default_headers, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, **kwargs, ) diff --git a/llama_index/llms/openai_like.py b/llama_index/llms/openai_like.py index 72bdd778bd..ced6bda325 100644 --- a/llama_index/llms/openai_like.py +++ b/llama_index/llms/openai_like.py @@ -2,8 +2,8 @@ from typing import Optional, Union from llama_index.bridge.pydantic import Field from llama_index.constants import DEFAULT_CONTEXT_WINDOW -from llama_index.llms.base import LLMMetadata from llama_index.llms.openai import OpenAI, Tokenizer +from llama_index.llms.types import LLMMetadata class OpenAILike(OpenAI): diff --git a/llama_index/llms/openai_utils.py b/llama_index/llms/openai_utils.py index 3ba53e4650..83365c77e6 100644 --- a/llama_index/llms/openai_utils.py +++ b/llama_index/llms/openai_utils.py @@ -20,8 +20,8 @@ from tenacity import ( from tenacity.stop import stop_base from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import ChatMessage from llama_index.llms.generic_utils import get_from_param_or_env +from llama_index.llms.types import ChatMessage DEFAULT_OPENAI_API_TYPE = "open_ai" DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1" diff --git a/llama_index/llms/openllm.py b/llama_index/llms/openllm.py index 5fcf83c546..4596ec11a2 100644 --- a/llama_index/llms/openllm.py +++ b/llama_index/llms/openllm.py @@ -13,16 +13,7 @@ from typing import ( from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms import ChatResponseAsyncGen from llama_index.llms.base import ( - LLM, - ChatMessage, - ChatResponse, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, llm_chat_callback, llm_completion_callback, ) @@ -32,6 +23,18 @@ from llama_index.llms.generic_utils import ( from llama_index.llms.generic_utils import ( messages_to_prompt as generic_messages_to_prompt, ) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) +from llama_index.types import PydanticProgramMode logger = logging.getLogger(__name__) @@ -58,9 +61,6 @@ class OpenLLM(LLM): prompt_template: Optional[str] = Field( description="Optional prompt template to pass for this LLM." ) - system_message: Optional[str] = Field( - description="Optional system message to pass for this LLM." - ) backend: Optional[Literal["vllm", "pt"]] = Field( description="Optional backend to pass for this LLM. By default, it will use vLLM if vLLM is available in local system. Otherwise, it will fallback to PyTorch." ) @@ -85,22 +85,22 @@ class OpenLLM(LLM): else: _llm: Any = PrivateAttr() - _messages_to_prompt: Callable[[Sequence[ChatMessage]], Any] = PrivateAttr() - def __init__( self, model_id: str, model_version: Optional[str] = None, model_tag: Optional[str] = None, prompt_template: Optional[str] = None, - system_message: Optional[str] = None, backend: Optional[Literal["vllm", "pt"]] = None, *args: Any, quantize: Optional[Literal["awq", "gptq", "int8", "int4", "squeezellm"]] = None, serialization: Literal["safetensors", "legacy"] = "safetensors", trust_remote_code: bool = False, callback_manager: Optional[CallbackManager] = None, - messages_to_prompt: Optional[Callable[..., Any]] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, **attrs: Any, ): try: @@ -114,7 +114,7 @@ class OpenLLM(LLM): model_version=model_version, model_tag=model_tag, prompt_template=prompt_template, - system_message=system_message, + system_message=system_prompt, backend=backend, quantize=quantize, serialisation=serialization, @@ -133,12 +133,15 @@ class OpenLLM(LLM): model_version=self._llm.revision, model_tag=str(self._llm.tag), prompt_template=prompt_template, - system_message=system_message, backend=self._llm.__llm_backend__, quantize=self._llm.quantise, serialization=self._llm._serialisation, trust_remote_code=self._llm.trust_remote_code, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, ) @classmethod diff --git a/llama_index/llms/palm.py b/llama_index/llms/palm.py index d907acb0e1..30b49b6e32 100644 --- a/llama_index/llms/palm.py +++ b/llama_index/llms/palm.py @@ -1,17 +1,19 @@ """Palm API.""" import os -from typing import Any, Optional +from typing import Any, Callable, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import ( +from llama_index.llms.base import llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.types import ( + ChatMessage, CompletionResponse, CompletionResponseGen, LLMMetadata, - llm_completion_callback, ) -from llama_index.llms.custom import CustomLLM +from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_PALM_MODEL = "models/text-bison-001" @@ -39,6 +41,11 @@ class PaLM(CustomLLM): model_name: Optional[str] = DEFAULT_PALM_MODEL, num_output: Optional[int] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, **generate_kwargs: Any, ) -> None: """Initialize params.""" @@ -71,6 +78,11 @@ class PaLM(CustomLLM): num_output=num_output, generate_kwargs=generate_kwargs, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/perplexity.py b/llama_index/llms/perplexity.py index aa5db405e1..dd36e6bb22 100644 --- a/llama_index/llms/perplexity.py +++ b/llama_index/llms/perplexity.py @@ -1,13 +1,14 @@ import json -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence import httpx import requests from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager -from llama_index.llms.base import ( - LLM, +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -16,9 +17,8 @@ from llama_index.llms.base import ( CompletionResponseAsyncGen, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode class Perplexity(LLM): @@ -60,6 +60,11 @@ class Perplexity(LLM): max_retries: int = 10, context_window: Optional[int] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, **kwargs: Any, ) -> None: additional_kwargs = additional_kwargs or {} @@ -79,6 +84,11 @@ class Perplexity(LLM): api_base=api_base, headers=headers, context_window=context_window, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, **kwargs, ) diff --git a/llama_index/llms/portkey.py b/llama_index/llms/portkey.py index 5acf56b481..48c92ca638 100644 --- a/llama_index/llms/portkey.py +++ b/llama_index/llms/portkey.py @@ -1,19 +1,10 @@ """ Portkey integration with Llama_index for enhanced monitoring. """ -from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union, cast +from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, cast from llama_index.bridge.pydantic import Field, PrivateAttr -from llama_index.llms.base import ( - ChatMessage, - ChatResponse, - ChatResponseGen, - CompletionResponse, - CompletionResponseGen, - LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback from llama_index.llms.custom import CustomLLM from llama_index.llms.generic_utils import ( chat_to_completion_decorator, @@ -27,6 +18,15 @@ from llama_index.llms.portkey_utils import ( get_llm, is_chat_model, ) +from llama_index.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + CompletionResponse, + CompletionResponseGen, + LLMMetadata, +) +from llama_index.types import BaseOutputParser, PydanticProgramMode if TYPE_CHECKING: from portkey import ( @@ -63,6 +63,11 @@ class Portkey(CustomLLM): mode: Union["Modes", "ModesLiteral"], api_key: Optional[str] = None, base_url: Optional[str] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: """ Initialize a Portkey instance. @@ -82,6 +87,11 @@ class Portkey(CustomLLM): super().__init__( base_url=base_url, api_key=api_key, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) if api_key is not None: portkey.api_key = api_key diff --git a/llama_index/llms/portkey_utils.py b/llama_index/llms/portkey_utils.py index b328ea402f..e23e6b5ee5 100644 --- a/llama_index/llms/portkey_utils.py +++ b/llama_index/llms/portkey_utils.py @@ -8,7 +8,6 @@ from typing import TYPE_CHECKING, List from llama_index.llms.anthropic import Anthropic from llama_index.llms.anthropic_utils import CLAUDE_MODELS -from llama_index.llms.base import LLMMetadata from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import ( AZURE_TURBO_MODELS, @@ -17,6 +16,7 @@ from llama_index.llms.openai_utils import ( GPT4_MODELS, TURBO_MODELS, ) +from llama_index.llms.types import LLMMetadata if TYPE_CHECKING: from portkey import ( diff --git a/llama_index/llms/predibase.py b/llama_index/llms/predibase.py index c7993044cb..38b86216f1 100644 --- a/llama_index/llms/predibase.py +++ b/llama_index/llms/predibase.py @@ -1,5 +1,5 @@ import os -from typing import Any, Optional +from typing import Any, Callable, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager @@ -8,13 +8,15 @@ from llama_index.constants import ( DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE, ) -from llama_index.llms.base import ( +from llama_index.llms.base import llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.types import ( + ChatMessage, CompletionResponse, CompletionResponseGen, LLMMetadata, - llm_completion_callback, ) -from llama_index.llms.custom import CustomLLM +from llama_index.types import BaseOutputParser, PydanticProgramMode class PredibaseLLM(CustomLLM): @@ -49,6 +51,11 @@ class PredibaseLLM(CustomLLM): temperature: float = DEFAULT_TEMPERATURE, context_window: int = DEFAULT_CONTEXT_WINDOW, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: predibase_api_key = ( predibase_api_key @@ -66,6 +73,11 @@ class PredibaseLLM(CustomLLM): temperature=temperature, context_window=context_window, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @staticmethod diff --git a/llama_index/llms/replicate.py b/llama_index/llms/replicate.py index 8283ffe540..16c8adae63 100644 --- a/llama_index/llms/replicate.py +++ b/llama_index/llms/replicate.py @@ -1,25 +1,20 @@ -from typing import Any, Callable, Dict, Optional, Sequence +from typing import Any, Dict, Sequence -from llama_index.bridge.pydantic import Field, PrivateAttr -from llama_index.callbacks import CallbackManager +from llama_index.bridge.pydantic import Field from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import ( +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, CompletionResponse, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import ( - completion_response_to_chat_response, - stream_completion_response_to_chat_response, -) -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, ) DEFAULT_REPLICATE_TEMP = 0.75 @@ -34,14 +29,16 @@ class Replicate(CustomLLM): lte=1.0, ) image: str = Field( - description="The image file for multimodal model to use. (optional)" + default="", description="The image file for multimodal model to use. (optional)" ) context_window: int = Field( default=DEFAULT_CONTEXT_WINDOW, description="The maximum number of context tokens for the model.", gt=0, ) - prompt_key: str = Field(description="The key to use for the prompt in API calls.") + prompt_key: str = Field( + default="prompt", description="The key to use for the prompt in API calls." + ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the Replicate API." ) @@ -49,36 +46,6 @@ class Replicate(CustomLLM): default=False, description="Whether the model is a chat model." ) - _messages_to_prompt: Callable = PrivateAttr() - _completion_to_prompt: Callable = PrivateAttr() - - def __init__( - self, - model: str, - temperature: float = DEFAULT_REPLICATE_TEMP, - image: Optional[str] = "", - additional_kwargs: Optional[Dict[str, Any]] = None, - context_window: int = DEFAULT_CONTEXT_WINDOW, - prompt_key: str = "prompt", - messages_to_prompt: Optional[Callable] = None, - completion_to_prompt: Optional[Callable] = None, - callback_manager: Optional[CallbackManager] = None, - is_chat_model: bool = False, - ) -> None: - self._messages_to_prompt = messages_to_prompt or generic_messages_to_prompt - self._completion_to_prompt = completion_to_prompt or (lambda x: x) - - super().__init__( - model=model, - temperature=temperature, - image=image, - additional_kwargs=additional_kwargs or {}, - context_window=context_window, - prompt_key=prompt_key, - callback_manager=callback_manager, - is_chat_model=is_chat_model, - ) - @classmethod def class_name(cls) -> str: return "Replicate_llm" @@ -116,16 +83,16 @@ class Replicate(CustomLLM): @llm_chat_callback() def chat(self, messages: Sequence[ChatMessage], **kwargs: Any) -> ChatResponse: - prompt = self._messages_to_prompt(messages) - completion_response = self.complete(prompt, **kwargs) + prompt = self.messages_to_prompt(messages) + completion_response = self.complete(prompt, formatted=True, **kwargs) return completion_response_to_chat_response(completion_response) @llm_chat_callback() def stream_chat( self, messages: Sequence[ChatMessage], **kwargs: Any ) -> ChatResponseGen: - prompt = self._messages_to_prompt(messages) - completion_response = self.stream_complete(prompt, **kwargs) + prompt = self.messages_to_prompt(messages) + completion_response = self.stream_complete(prompt, formatted=True, **kwargs) return stream_completion_response_to_chat_response(completion_response) @llm_completion_callback() @@ -146,7 +113,8 @@ class Replicate(CustomLLM): "Please install replicate with `pip install replicate`" ) - prompt = self._completion_to_prompt(prompt) + if not kwargs.get("formatted", False): + prompt = self.completion_to_prompt(prompt) input_dict = self._get_input_dict(prompt, **kwargs) response_iter = replicate.run(self.model, input=input_dict) diff --git a/llama_index/llms/rungpt.py b/llama_index/llms/rungpt.py index 65ff1e91e8..8351635325 100644 --- a/llama_index/llms/rungpt.py +++ b/llama_index/llms/rungpt.py @@ -1,11 +1,12 @@ import json -from typing import Any, Dict, List, Optional, Sequence, Tuple, Union +from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import ( - LLM, +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -15,9 +16,8 @@ from llama_index.llms.base import ( CompletionResponseGen, LLMMetadata, MessageRole, - llm_chat_callback, - llm_completion_callback, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_RUNGPT_MODEL = "rungpt" DEFAULT_RUNGPT_TEMP = 0.75 @@ -62,6 +62,11 @@ class RunGptLLM(LLM): context_window: int = DEFAULT_CONTEXT_WINDOW, additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ): if endpoint.startswith("http://"): base_url = endpoint @@ -76,6 +81,11 @@ class RunGptLLM(LLM): additional_kwargs=additional_kwargs or {}, callback_manager=callback_manager or CallbackManager([]), base_url=base_url, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/types.py b/llama_index/llms/types.py new file mode 100644 index 0000000000..9db785861d --- /dev/null +++ b/llama_index/llms/types.py @@ -0,0 +1,110 @@ +from enum import Enum +from typing import Any, AsyncGenerator, Generator, Optional + +from llama_index.bridge.pydantic import BaseModel, Field +from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS + + +class MessageRole(str, Enum): + """Message role.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + TOOL = "tool" + + +# ===== Generic Model Input - Chat ===== +class ChatMessage(BaseModel): + """Chat message.""" + + role: MessageRole = MessageRole.USER + content: Optional[Any] = "" + additional_kwargs: dict = Field(default_factory=dict) + + def __str__(self) -> str: + return f"{self.role.value}: {self.content}" + + +# ===== Generic Model Output - Chat ===== +class ChatResponse(BaseModel): + """Chat response.""" + + message: ChatMessage + raw: Optional[dict] = None + delta: Optional[str] = None + additional_kwargs: dict = Field(default_factory=dict) + + def __str__(self) -> str: + return str(self.message) + + +ChatResponseGen = Generator[ChatResponse, None, None] +ChatResponseAsyncGen = AsyncGenerator[ChatResponse, None] + + +# ===== Generic Model Output - Completion ===== +class CompletionResponse(BaseModel): + """ + Completion response. + + Fields: + text: Text content of the response if not streaming, or if streaming, + the current extent of streamed text. + additional_kwargs: Additional information on the response(i.e. token + counts, function calling information). + raw: Optional raw JSON that was parsed to populate text, if relevant. + delta: New text that just streamed in (only relevant when streaming). + """ + + text: str + additional_kwargs: dict = Field(default_factory=dict) + raw: Optional[dict] = None + delta: Optional[str] = None + + def __str__(self) -> str: + return self.text + + +CompletionResponseGen = Generator[CompletionResponse, None, None] +CompletionResponseAsyncGen = AsyncGenerator[CompletionResponse, None] + + +class LLMMetadata(BaseModel): + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description=( + "Total number of tokens the model can be input and output for one response." + ), + ) + num_output: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="Number of tokens the model can output when generating a response.", + ) + is_chat_model: bool = Field( + default=False, + description=( + "Set True if the model exposes a chat interface (i.e. can be passed a" + " sequence of messages, rather than text), like OpenAI's" + " /v1/chat/completions endpoint." + ), + ) + is_function_calling_model: bool = Field( + default=False, + # SEE: https://openai.com/blog/function-calling-and-other-api-updates + description=( + "Set True if the model supports function calling messages, similar to" + " OpenAI's function calling API. For example, converting 'Email Anya to" + " see if she wants to get coffee next Friday' to a function call like" + " `send_email(to: string, body: string)`." + ), + ) + model_name: str = Field( + default="unknown", + description=( + "The model's name used for logging, testing, and sanity checking. For some" + " models this can be automatically discerned. For other models, like" + " locally loaded models, this must be manually specified." + ), + ) diff --git a/llama_index/llms/utils.py b/llama_index/llms/utils.py index 53b1bcb3b1..076a97c496 100644 --- a/llama_index/llms/utils.py +++ b/llama_index/llms/utils.py @@ -3,9 +3,9 @@ from typing import TYPE_CHECKING, Optional, Union if TYPE_CHECKING: from langchain.base_language import BaseLanguageModel -from llama_index.llms.base import LLM from llama_index.llms.llama_cpp import LlamaCPP from llama_index.llms.llama_utils import completion_to_prompt, messages_to_prompt +from llama_index.llms.llm import LLM from llama_index.llms.mock import MockLLM from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import validate_openai_api_key diff --git a/llama_index/llms/vertex.py b/llama_index/llms/vertex.py index e74964c3d3..b6f2764a33 100644 --- a/llama_index/llms/vertex.py +++ b/llama_index/llms/vertex.py @@ -1,9 +1,13 @@ -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.llms.base import ( - LLM, + llm_chat_callback, + llm_completion_callback, +) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -13,8 +17,6 @@ from llama_index.llms.base import ( CompletionResponseGen, LLMMetadata, MessageRole, - llm_chat_callback, - llm_completion_callback, ) from llama_index.llms.vertex_utils import ( CHAT_MODELS, @@ -27,6 +29,7 @@ from llama_index.llms.vertex_utils import ( completion_with_retry, init_vertexai, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode class Vertex(LLM): @@ -60,6 +63,11 @@ class Vertex(LLM): iscode: bool = False, additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: init_vertexai(project=project, location=location, credentials=credentials) @@ -96,6 +104,11 @@ class Vertex(LLM): examples=examples, iscode=iscode, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/vertex_utils.py b/llama_index/llms/vertex_utils.py index 7f5e8fba57..e87ff3e4ba 100644 --- a/llama_index/llms/vertex_utils.py +++ b/llama_index/llms/vertex_utils.py @@ -12,7 +12,7 @@ from tenacity import ( wait_exponential, ) -from llama_index.llms.base import MessageRole +from llama_index.llms.types import MessageRole CHAT_MODELS = ["chat-bison", "chat-bison-32k", "chat-bison@001"] TEXT_MODELS = ["text-bison", "text-bison-32k", "text-bison@001"] diff --git a/llama_index/llms/vllm.py b/llama_index/llms/vllm.py index 0fd30f90df..f102596633 100644 --- a/llama_index/llms/vllm.py +++ b/llama_index/llms/vllm.py @@ -3,8 +3,16 @@ from typing import Any, Callable, Dict, List, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.base import ( - LLM, +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) +from llama_index.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -13,17 +21,9 @@ from llama_index.llms.base import ( CompletionResponseAsyncGen, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.generic_utils import ( - completion_response_to_chat_response, - stream_completion_response_to_chat_response, -) -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, ) from llama_index.llms.vllm_utils import get_response, post_http_request +from llama_index.types import BaseOutputParser, PydanticProgramMode class Vllm(LLM): @@ -100,14 +100,6 @@ class Vllm(LLM): description="The data type for the model weights and activations.", ) - messages_to_prompt: Callable = Field( - description="The function to convert messages to a prompt.", exclude=True - ) - - completion_to_prompt: Callable = Field( - description="The function to convert a completion to a prompt.", exclude=True - ) - download_dir: Optional[str] = Field( default=None, description="Directory to download and load the weights. (Default to the default cache dir of huggingface)", @@ -126,8 +118,8 @@ class Vllm(LLM): self, model: str = "facebook/opt-125m", temperature: float = 1.0, - tensor_parallel_size: Optional[int] = 1, - trust_remote_code: Optional[bool] = True, + tensor_parallel_size: int = 1, + trust_remote_code: bool = True, n: int = 1, best_of: Optional[int] = None, presence_penalty: float = 0.0, @@ -143,9 +135,12 @@ class Vllm(LLM): download_dir: Optional[str] = None, vllm_kwargs: Dict[str, Any] = {}, api_url: Optional[str] = "", - messages_to_prompt: Optional[Callable] = None, - completion_to_prompt: Optional[Callable] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: try: from vllm import LLM as VLLModel @@ -166,8 +161,6 @@ class Vllm(LLM): else: self._client = None callback_manager = callback_manager or CallbackManager([]) - messages_to_prompt = messages_to_prompt or generic_messages_to_prompt - completion_to_prompt = completion_to_prompt or (lambda x: x) super().__init__( model=model, temperature=temperature, @@ -184,10 +177,13 @@ class Vllm(LLM): logprobs=logprobs, dtype=dtype, download_dir=download_dir, - messages_to_prompt=messages_to_prompt, - completion_to_prompt=completion_to_prompt, vllm_kwargs=vllm_kwargs, api_url=api_url, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod @@ -304,6 +300,7 @@ class VllmServer(Vllm): completion_to_prompt: Optional[Callable] = None, vllm_kwargs: Dict[str, Any] = {}, callback_manager: Optional[CallbackManager] = None, + output_parser: Optional[BaseOutputParser] = None, ) -> None: self._client = None messages_to_prompt = messages_to_prompt or generic_messages_to_prompt @@ -331,6 +328,8 @@ class VllmServer(Vllm): completion_to_prompt=completion_to_prompt, vllm_kwargs=vllm_kwargs, api_url=api_url, + callback_manager=callback_manager, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/watsonx.py b/llama_index/llms/watsonx.py index 5f1ad84198..15c69392c0 100644 --- a/llama_index/llms/watsonx.py +++ b/llama_index/llms/watsonx.py @@ -1,9 +1,14 @@ -from typing import Any, Dict, Optional, Sequence +from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.base import ( - LLM, +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.generic_utils import ( + completion_to_chat_decorator, + stream_completion_to_chat_decorator, +) +from llama_index.llms.llm import LLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -12,18 +17,13 @@ from llama_index.llms.base import ( CompletionResponseAsyncGen, CompletionResponseGen, LLMMetadata, - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.generic_utils import ( - completion_to_chat_decorator, - stream_completion_to_chat_decorator, ) from llama_index.llms.watsonx_utils import ( WATSONX_MODELS, get_from_param_or_env_without_error, watsonx_model_to_context_size, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode class WatsonX(LLM): @@ -51,6 +51,11 @@ class WatsonX(LLM): temperature: Optional[float] = 0.1, additional_kwargs: Optional[Dict[str, Any]] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: """Initialize params.""" if model_id not in WATSONX_MODELS: @@ -94,6 +99,11 @@ class WatsonX(LLM): additional_kwargs=additional_kwargs, model_info=self._model.get_details(), callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) @classmethod diff --git a/llama_index/llms/xinference.py b/llama_index/llms/xinference.py index d0f15dd0bb..62c02e90f4 100644 --- a/llama_index/llms/xinference.py +++ b/llama_index/llms/xinference.py @@ -1,9 +1,14 @@ import warnings -from typing import Any, Dict, Optional, Sequence, Tuple +from typing import Any, Callable, Dict, Optional, Sequence, Tuple from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.llms.base import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.llms.custom import CustomLLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, @@ -11,14 +16,12 @@ from llama_index.llms.base import ( CompletionResponseGen, LLMMetadata, MessageRole, - llm_chat_callback, - llm_completion_callback, ) -from llama_index.llms.custom import CustomLLM from llama_index.llms.xinference_utils import ( xinference_message_to_history, xinference_modelname_to_contextsize, ) +from llama_index.types import BaseOutputParser, PydanticProgramMode # an approximation of the ratio between llama and GPT2 tokens TOKEN_RATIO = 2.5 @@ -50,6 +53,11 @@ class Xinference(CustomLLM): temperature: float = DEFAULT_XINFERENCE_TEMP, max_tokens: Optional[int] = None, callback_manager: Optional[CallbackManager] = None, + system_prompt: Optional[str] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + pydantic_program_mode: PydanticProgramMode = PydanticProgramMode.DEFAULT, + output_parser: Optional[BaseOutputParser] = None, ) -> None: generator, context_window, model_description = self.load_model( model_uid, endpoint @@ -71,6 +79,11 @@ class Xinference(CustomLLM): max_tokens=max_tokens, model_description=model_description, callback_manager=callback_manager, + system_prompt=system_prompt, + messages_to_prompt=messages_to_prompt, + completion_to_prompt=completion_to_prompt, + pydantic_program_mode=pydantic_program_mode, + output_parser=output_parser, ) def load_model(self, model_uid: str, endpoint: str) -> Tuple[Any, int, dict]: diff --git a/llama_index/llms/xinference_utils.py b/llama_index/llms/xinference_utils.py index 56ad201bc0..bc1be05157 100644 --- a/llama_index/llms/xinference_utils.py +++ b/llama_index/llms/xinference_utils.py @@ -2,7 +2,7 @@ from typing import Optional from typing_extensions import NotRequired, TypedDict -from llama_index.llms.base import ChatMessage +from llama_index.llms.types import ChatMessage XINFERENCE_MODEL_SIZES = { "baichuan": 2048, diff --git a/llama_index/memory/chat_memory_buffer.py b/llama_index/memory/chat_memory_buffer.py index fc6014368c..baa22c299d 100644 --- a/llama_index/memory/chat_memory_buffer.py +++ b/llama_index/memory/chat_memory_buffer.py @@ -1,7 +1,8 @@ from typing import Any, Callable, Dict, List, Optional, cast from llama_index.bridge.pydantic import Field, root_validator -from llama_index.llms.base import LLM, ChatMessage, MessageRole +from llama_index.llms.llm import LLM +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.memory.types import BaseMemory from llama_index.utils import GlobalsHelper diff --git a/llama_index/memory/types.py b/llama_index/memory/types.py index 5375f5d408..2c00253dfc 100644 --- a/llama_index/memory/types.py +++ b/llama_index/memory/types.py @@ -2,7 +2,8 @@ from abc import abstractmethod from typing import Any, List, Optional from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import LLM, ChatMessage +from llama_index.llms.llm import LLM +from llama_index.llms.types import ChatMessage class BaseMemory(BaseModel): diff --git a/llama_index/multi_modal_llms/base.py b/llama_index/multi_modal_llms/base.py index 9bd70c1e1a..fe537352ba 100644 --- a/llama_index/multi_modal_llms/base.py +++ b/llama_index/multi_modal_llms/base.py @@ -7,7 +7,7 @@ from llama_index.constants import ( DEFAULT_NUM_INPUT_FILES, DEFAULT_NUM_OUTPUTS, ) -from llama_index.llms.base import ( +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, diff --git a/llama_index/multi_modal_llms/openai.py b/llama_index/multi_modal_llms/openai.py index ca04fffdea..2228c349e9 100644 --- a/llama_index/multi_modal_llms/openai.py +++ b/llama_index/multi_modal_llms/openai.py @@ -16,7 +16,11 @@ from llama_index.constants import ( DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE, ) -from llama_index.llms.base import ( +from llama_index.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) +from llama_index.llms.openai_utils import from_openai_message, to_openai_message_dicts +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -26,10 +30,6 @@ from llama_index.llms.base import ( CompletionResponseGen, MessageRole, ) -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, -) -from llama_index.llms.openai_utils import from_openai_message, to_openai_message_dicts from llama_index.multi_modal_llms import ( MultiModalLLM, MultiModalLLMMetadata, diff --git a/llama_index/multi_modal_llms/replicate_multi_modal.py b/llama_index/multi_modal_llms/replicate_multi_modal.py index 23fe2a2f6e..b0ae63ca6a 100644 --- a/llama_index/multi_modal_llms/replicate_multi_modal.py +++ b/llama_index/multi_modal_llms/replicate_multi_modal.py @@ -4,7 +4,10 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import ( +from llama_index.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -13,9 +16,6 @@ from llama_index.llms.base import ( CompletionResponseAsyncGen, CompletionResponseGen, ) -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, -) from llama_index.multi_modal_llms import ( MultiModalLLM, MultiModalLLMMetadata, diff --git a/llama_index/node_parser/relational/unstructured_element.py b/llama_index/node_parser/relational/unstructured_element.py index ad601def64..ca4415e57a 100644 --- a/llama_index/node_parser/relational/unstructured_element.py +++ b/llama_index/node_parser/relational/unstructured_element.py @@ -7,7 +7,8 @@ from tqdm import tqdm from llama_index.bridge.pydantic import BaseModel, Field, ValidationError from llama_index.callbacks.base import CallbackManager -from llama_index.llms.openai import LLM, OpenAI +from llama_index.llms.llm import LLM +from llama_index.llms.openai import OpenAI from llama_index.node_parser.interface import NodeParser from llama_index.response.schema import PydanticResponse from llama_index.schema import BaseNode, Document, IndexNode, TextNode diff --git a/llama_index/node_parser/text/token.py b/llama_index/node_parser/text/token.py index f6bff59fd2..12a9e12e0c 100644 --- a/llama_index/node_parser/text/token.py +++ b/llama_index/node_parser/text/token.py @@ -104,6 +104,7 @@ class TokenTextSplitter(MetadataAwareTextSplitter): """Split text into chunks, reserving space required for metadata str.""" metadata_len = len(self._tokenizer(metadata_str)) + DEFAULT_METADATA_FORMAT_LEN effective_chunk_size = self.chunk_size - metadata_len + print(effective_chunk_size, flush=True) if effective_chunk_size <= 0: raise ValueError( f"Metadata length ({metadata_len}) is longer than chunk size " diff --git a/llama_index/playground/base.py b/llama_index/playground/base.py index 89fa2259c3..4f51efa817 100644 --- a/llama_index/playground/base.py +++ b/llama_index/playground/base.py @@ -11,7 +11,7 @@ from llama_index.indices.base import BaseIndex from llama_index.indices.list.base import ListRetrieverMode, SummaryIndex from llama_index.indices.tree.base import TreeIndex, TreeRetrieverMode from llama_index.indices.vector_store import VectorStoreIndex -from llama_index.llm_predictor import LLMPredictor +from llama_index.llm_predictor.base import LLMPredictor from llama_index.schema import Document from llama_index.utils import get_color_mapping, print_text diff --git a/llama_index/postprocessor/llm_rerank.py b/llama_index/postprocessor/llm_rerank.py index 000d10aad7..63f30825aa 100644 --- a/llama_index/postprocessor/llm_rerank.py +++ b/llama_index/postprocessor/llm_rerank.py @@ -84,7 +84,7 @@ class LLMRerank(BaseNodePostprocessor): query_str = query_bundle.query_str fmt_batch_str = self._format_node_batch_fn(nodes_batch) # call each batch independently - raw_response = self.service_context.llm_predictor.predict( + raw_response = self.service_context.llm.predict( self.choice_select_prompt, context_str=fmt_batch_str, query_str=query_str, diff --git a/llama_index/postprocessor/node.py b/llama_index/postprocessor/node.py index a0b8d9b66b..a75ac53c4e 100644 --- a/llama_index/postprocessor/node.py +++ b/llama_index/postprocessor/node.py @@ -268,7 +268,6 @@ class AutoPrevNextNodePostprocessor(BaseNodePostprocessor): Args: docstore (BaseDocumentStore): The document store. - llm_predictor (LLMPredictor): The LLM predictor. num_nodes (int): The number of nodes to return (default: 1) infer_prev_next_tmpl (str): The template to use for inference. Required fields are {context_str} and {query_str}. @@ -319,7 +318,7 @@ class AutoPrevNextNodePostprocessor(BaseNodePostprocessor): all_nodes: Dict[str, NodeWithScore] = {} for node in nodes: all_nodes[node.node.node_id] = node - # use response builder instead of llm_predictor directly + # use response builder instead of llm directly # to be more robust to handling long context response_builder = get_response_synthesizer( service_context=self.service_context, diff --git a/llama_index/postprocessor/pii.py b/llama_index/postprocessor/pii.py index 83eb2d6ada..ae30b6b726 100644 --- a/llama_index/postprocessor/pii.py +++ b/llama_index/postprocessor/pii.py @@ -65,7 +65,7 @@ class PIINodePostprocessor(BaseNodePostprocessor): "Return the mapping in JSON." ) - response = self.service_context.llm_predictor.predict( + response = self.service_context.llm.predict( pii_prompt, context_str=text, query_str=task_str ) splits = response.split("Output Mapping:") diff --git a/llama_index/program/llm_program.py b/llama_index/program/llm_program.py index 13123da24f..878c96c267 100644 --- a/llama_index/program/llm_program.py +++ b/llama_index/program/llm_program.py @@ -1,7 +1,7 @@ from typing import Any, Optional, Type, cast from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import LLM +from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.output_parsers.pydantic import PydanticOutputParser from llama_index.prompts.base import BasePromptTemplate, PromptTemplate diff --git a/llama_index/program/openai_program.py b/llama_index/program/openai_program.py index e4e2adfc68..764ba93f21 100644 --- a/llama_index/program/openai_program.py +++ b/llama_index/program/openai_program.py @@ -2,7 +2,7 @@ import logging from typing import Any, Dict, Generator, List, Optional, Tuple, Type, Union, cast from llama_index.agent.openai_agent import resolve_tool_choice -from llama_index.llms.base import LLM +from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import OpenAIToolCall, to_openai_tool from llama_index.program.llm_prompt_program import BaseLLMFunctionProgram diff --git a/llama_index/program/predefined/evaporate/extractor.py b/llama_index/program/predefined/evaporate/extractor.py index bf8afb7411..8bbc4925ed 100644 --- a/llama_index/program/predefined/evaporate/extractor.py +++ b/llama_index/program/predefined/evaporate/extractor.py @@ -130,8 +130,8 @@ class EvaporateExtractor: """ field2count: dict = defaultdict(int) for node in nodes: - llm_predictor = self._service_context.llm_predictor - result = llm_predictor.predict( + llm = self._service_context.llm + result = llm.predict( self._schema_id_prompt, topic=topic, chunk=node.get_content(metadata_mode=MetadataMode.LLM), diff --git a/llama_index/program/utils.py b/llama_index/program/utils.py index f188f7e5fa..df8393ee40 100644 --- a/llama_index/program/utils.py +++ b/llama_index/program/utils.py @@ -3,7 +3,7 @@ from typing import Any, List, Type from llama_index.bridge.pydantic import BaseModel, Field, create_model -from llama_index.llms.base import LLM +from llama_index.llms.llm import LLM from llama_index.output_parsers.pydantic import PydanticOutputParser from llama_index.prompts.base import PromptTemplate from llama_index.types import BasePydanticProgram, PydanticProgramMode diff --git a/llama_index/prompts/__init__.py b/llama_index/prompts/__init__.py index 3f120cd5a6..3de7bf26dc 100644 --- a/llama_index/prompts/__init__.py +++ b/llama_index/prompts/__init__.py @@ -1,6 +1,6 @@ """Prompt class.""" -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.prompts.base import ( BasePromptTemplate, ChatPromptTemplate, diff --git a/llama_index/prompts/base.py b/llama_index/prompts/base.py index 86fde16447..18f0532b25 100644 --- a/llama_index/prompts/base.py +++ b/llama_index/prompts/base.py @@ -3,7 +3,7 @@ from abc import ABC, abstractmethod from copy import deepcopy -from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Tuple +from typing import TYPE_CHECKING, Any, Callable, Dict, List, Optional, Sequence, Tuple from llama_index.bridge.pydantic import Field @@ -13,8 +13,14 @@ if TYPE_CHECKING: ConditionalPromptSelector as LangchainSelector, ) from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import LLM, ChatMessage -from llama_index.llms.generic_utils import messages_to_prompt, prompt_to_messages +from llama_index.llms.base import BaseLLM +from llama_index.llms.generic_utils import ( + messages_to_prompt as default_messages_to_prompt, +) +from llama_index.llms.generic_utils import ( + prompt_to_messages, +) +from llama_index.llms.types import ChatMessage from llama_index.prompts.prompt_type import PromptType from llama_index.prompts.utils import get_template_vars from llama_index.types import BaseOutputParser @@ -88,17 +94,17 @@ class BasePromptTemplate(BaseModel, ABC): ... @abstractmethod - def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: + def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: ... @abstractmethod def format_messages( - self, llm: Optional[LLM] = None, **kwargs: Any + self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: ... @abstractmethod - def get_template(self, llm: Optional[LLM] = None) -> str: + def get_template(self, llm: Optional[BaseLLM] = None) -> str: ... @@ -147,7 +153,12 @@ class PromptTemplate(BasePromptTemplate): self.output_parser = output_parser return prompt - def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: + def format( + self, + llm: Optional[BaseLLM] = None, + completion_to_prompt: Optional[Callable[[str], str]] = None, + **kwargs: Any, + ) -> str: """Format the prompt into a string.""" del llm # unused all_kwargs = { @@ -157,19 +168,24 @@ class PromptTemplate(BasePromptTemplate): mapped_all_kwargs = self._map_all_vars(all_kwargs) prompt = self.template.format(**mapped_all_kwargs) + if self.output_parser is not None: prompt = self.output_parser.format(prompt) + + if completion_to_prompt is not None: + prompt = completion_to_prompt(prompt) + return prompt def format_messages( - self, llm: Optional[LLM] = None, **kwargs: Any + self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: """Format the prompt into a list of chat messages.""" del llm # unused prompt = self.format(**kwargs) return prompt_to_messages(prompt) - def get_template(self, llm: Optional[LLM] = None) -> str: + def get_template(self, llm: Optional[BaseLLM] = None) -> str: return self.template @@ -209,13 +225,22 @@ class ChatPromptTemplate(BasePromptTemplate): prompt.kwargs.update(kwargs) return prompt - def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: + def format( + self, + llm: Optional[BaseLLM] = None, + messages_to_prompt: Optional[Callable[[Sequence[ChatMessage]], str]] = None, + **kwargs: Any, + ) -> str: del llm # unused messages = self.format_messages(**kwargs) - return messages_to_prompt(messages) + + if messages_to_prompt is not None: + return messages_to_prompt(messages) + + return default_messages_to_prompt(messages) def format_messages( - self, llm: Optional[LLM] = None, **kwargs: Any + self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: del llm # unused """Format the prompt into a list of chat messages.""" @@ -245,21 +270,21 @@ class ChatPromptTemplate(BasePromptTemplate): return messages - def get_template(self, llm: Optional[LLM] = None) -> str: - return messages_to_prompt(self.message_templates) + def get_template(self, llm: Optional[BaseLLM] = None) -> str: + return default_messages_to_prompt(self.message_templates) class SelectorPromptTemplate(BasePromptTemplate): default_template: BasePromptTemplate conditionals: Optional[ - List[Tuple[Callable[[LLM], bool], BasePromptTemplate]] + List[Tuple[Callable[[BaseLLM], bool], BasePromptTemplate]] ] = None def __init__( self, default_template: BasePromptTemplate, conditionals: Optional[ - List[Tuple[Callable[[LLM], bool], BasePromptTemplate]] + List[Tuple[Callable[[BaseLLM], bool], BasePromptTemplate]] ] = None, ): metadata = default_template.metadata @@ -275,7 +300,7 @@ class SelectorPromptTemplate(BasePromptTemplate): output_parser=output_parser, ) - def select(self, llm: Optional[LLM] = None) -> BasePromptTemplate: + def select(self, llm: Optional[BaseLLM] = None) -> BasePromptTemplate: # ensure output parser is up to date self.default_template.output_parser = self.output_parser @@ -304,19 +329,19 @@ class SelectorPromptTemplate(BasePromptTemplate): default_template=default_template, conditionals=conditionals ) - def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: + def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: """Format the prompt into a string.""" prompt = self.select(llm=llm) return prompt.format(**kwargs) def format_messages( - self, llm: Optional[LLM] = None, **kwargs: Any + self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: """Format the prompt into a list of chat messages.""" prompt = self.select(llm=llm) return prompt.format_messages(**kwargs) - def get_template(self, llm: Optional[LLM] = None) -> str: + def get_template(self, llm: Optional[BaseLLM] = None) -> str: prompt = self.select(llm=llm) return prompt.get_template(llm=llm) @@ -392,7 +417,7 @@ class LangchainPromptTemplate(BasePromptTemplate): lc_prompt.selector = lc_selector return lc_prompt - def format(self, llm: Optional[LLM] = None, **kwargs: Any) -> str: + def format(self, llm: Optional[BaseLLM] = None, **kwargs: Any) -> str: """Format the prompt into a string.""" from llama_index.llms.langchain import LangChainLLM @@ -414,7 +439,7 @@ class LangchainPromptTemplate(BasePromptTemplate): return lc_template.format(**mapped_kwargs) def format_messages( - self, llm: Optional[LLM] = None, **kwargs: Any + self, llm: Optional[BaseLLM] = None, **kwargs: Any ) -> List[ChatMessage]: """Format the prompt into a list of chat messages.""" from llama_index.llms.langchain import LangChainLLM @@ -439,7 +464,7 @@ class LangchainPromptTemplate(BasePromptTemplate): lc_messages = lc_prompt_value.to_messages() return from_lc_messages(lc_messages) - def get_template(self, llm: Optional[LLM] = None) -> str: + def get_template(self, llm: Optional[BaseLLM] = None) -> str: from llama_index.llms.langchain import LangChainLLM if llm is not None: diff --git a/llama_index/prompts/chat_prompts.py b/llama_index/prompts/chat_prompts.py index d6d41afc7e..3fb8551031 100644 --- a/llama_index/prompts/chat_prompts.py +++ b/llama_index/prompts/chat_prompts.py @@ -1,6 +1,6 @@ """Prompts for ChatGPT.""" -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.prompts.base import ChatPromptTemplate # text qa prompt diff --git a/llama_index/prompts/lmformatenforcer_utils.py b/llama_index/prompts/lmformatenforcer_utils.py index 8b9f25d70c..c34618dc6d 100644 --- a/llama_index/prompts/lmformatenforcer_utils.py +++ b/llama_index/prompts/lmformatenforcer_utils.py @@ -1,9 +1,9 @@ from contextlib import contextmanager from typing import TYPE_CHECKING, Callable, Iterator -from llama_index.llms.base import LLM from llama_index.llms.huggingface import HuggingFaceLLM from llama_index.llms.llama_cpp import LlamaCPP +from llama_index.llms.llm import LLM if TYPE_CHECKING: from lmformatenforcer import CharacterLevelParser diff --git a/llama_index/prompts/utils.py b/llama_index/prompts/utils.py index b81c8c6496..bf129ea8c6 100644 --- a/llama_index/prompts/utils.py +++ b/llama_index/prompts/utils.py @@ -1,7 +1,7 @@ from string import Formatter from typing import List -from llama_index.llms.base import LLM +from llama_index.llms.base import BaseLLM def get_template_vars(template_str: str) -> List[str]: @@ -16,5 +16,5 @@ def get_template_vars(template_str: str) -> List[str]: return variables -def is_chat_model(llm: LLM) -> bool: +def is_chat_model(llm: BaseLLM) -> bool: return llm.metadata.is_chat_model diff --git a/llama_index/query_engine/flare/answer_inserter.py b/llama_index/query_engine/flare/answer_inserter.py index 1d996e6109..2434b910e1 100644 --- a/llama_index/query_engine/flare/answer_inserter.py +++ b/llama_index/query_engine/flare/answer_inserter.py @@ -173,7 +173,7 @@ class LLMLookaheadAnswerInserter(BaseLookaheadAnswerInserter): for query_task, answer in zip(query_tasks, answers): query_answer_pairs += f"Query: {query_task.query_str}\nAnswer: {answer}\n" - return self._service_context.llm_predictor.predict( + return self._service_context.llm.predict( self._answer_insert_prompt, lookahead_response=response, query_answer_pairs=query_answer_pairs, diff --git a/llama_index/query_engine/flare/base.py b/llama_index/query_engine/flare/base.py index bd429b75a5..c83473f8e2 100644 --- a/llama_index/query_engine/flare/base.py +++ b/llama_index/query_engine/flare/base.py @@ -193,7 +193,7 @@ class FLAREInstructQueryEngine(BaseQueryEngine): # e.g. # The colors on the flag of Ghana have the following meanings. Red is # for [Search(Ghana flag meaning)],... - lookahead_resp = self._service_context.llm_predictor.predict( + lookahead_resp = self._service_context.llm.predict( self._instruct_prompt, query_str=query_bundle.query_str, existing_answer=cur_response, diff --git a/llama_index/query_engine/knowledge_graph_query_engine.py b/llama_index/query_engine/knowledge_graph_query_engine.py index afee41f0f5..0c156d9736 100644 --- a/llama_index/query_engine/knowledge_graph_query_engine.py +++ b/llama_index/query_engine/knowledge_graph_query_engine.py @@ -183,7 +183,7 @@ class KnowledgeGraphQueryEngine(BaseQueryEngine): """Generate a Graph Store Query from a query bundle.""" # Get the query engine query string - graph_store_query: str = self._service_context.llm_predictor.predict( + graph_store_query: str = self._service_context.llm.predict( self._graph_query_synthesis_prompt, query_str=query_str, schema=self._graph_schema, @@ -195,7 +195,7 @@ class KnowledgeGraphQueryEngine(BaseQueryEngine): """Generate a Graph Store Query from a query bundle.""" # Get the query engine query string - graph_store_query: str = await self._service_context.llm_predictor.apredict( + graph_store_query: str = await self._service_context.llm.apredict( self._graph_query_synthesis_prompt, query_str=query_str, schema=self._graph_schema, diff --git a/llama_index/query_engine/pandas_query_engine.py b/llama_index/query_engine/pandas_query_engine.py index 862dd5f4ed..a6ebb95307 100644 --- a/llama_index/query_engine/pandas_query_engine.py +++ b/llama_index/query_engine/pandas_query_engine.py @@ -157,7 +157,7 @@ class PandasQueryEngine(BaseQueryEngine): """Answer a query.""" context = self._get_table_context() - pandas_response_str = self._service_context.llm_predictor.predict( + pandas_response_str = self._service_context.llm.predict( self._pandas_prompt, df_str=context, query_str=query_bundle.query_str, diff --git a/llama_index/query_engine/sql_join_query_engine.py b/llama_index/query_engine/sql_join_query_engine.py index 30d06ee2b5..faf5821fad 100644 --- a/llama_index/query_engine/sql_join_query_engine.py +++ b/llama_index/query_engine/sql_join_query_engine.py @@ -10,8 +10,8 @@ from llama_index.indices.struct_store.sql_query import ( BaseSQLTableQueryEngine, NLSQLTableQueryEngine, ) -from llama_index.llm_predictor import LLMPredictor -from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.llm_predictor.base import LLMPredictorType +from llama_index.llms.utils import resolve_llm from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.response.schema import RESPONSE_TYPE, Response @@ -110,7 +110,7 @@ class SQLAugmentQueryTransform(BaseQueryTransform): after augmenting with SQL results. Args: - llm_predictor (LLMPredictor): LLM predictor to use for query transformation. + llm (LLM): LLM to use for query transformation. sql_augment_transform_prompt (BasePromptTemplate): PromptTemplate to use for query transformation. check_stop_parser (Optional[Callable[[str], bool]]): Check stop function. @@ -119,12 +119,12 @@ class SQLAugmentQueryTransform(BaseQueryTransform): def __init__( self, - llm_predictor: Optional[BaseLLMPredictor] = None, + llm: Optional[LLMPredictorType] = None, sql_augment_transform_prompt: Optional[BasePromptTemplate] = None, check_stop_parser: Optional[Callable[[QueryBundle], bool]] = None, ) -> None: """Initialize params.""" - self._llm_predictor = llm_predictor or LLMPredictor() + self._llm = llm or resolve_llm("default") self._sql_augment_transform_prompt = ( sql_augment_transform_prompt or DEFAULT_SQL_AUGMENT_TRANSFORM_PROMPT @@ -145,7 +145,7 @@ class SQLAugmentQueryTransform(BaseQueryTransform): query_str = query_bundle.query_str sql_query = metadata["sql_query"] sql_query_response = metadata["sql_query_response"] - new_query_str = self._llm_predictor.predict( + new_query_str = self._llm.predict( self._sql_augment_transform_prompt, query_str=query_str, sql_query_str=sql_query, @@ -224,9 +224,7 @@ class SQLJoinQueryEngine(BaseQueryEngine): ) self._sql_augment_query_transform = ( sql_augment_query_transform - or SQLAugmentQueryTransform( - llm_predictor=self._service_context.llm_predictor - ) + or SQLAugmentQueryTransform(llm=self._service_context.llm) ) self._use_sql_join_synthesis = use_sql_join_synthesis self._verbose = verbose @@ -284,7 +282,7 @@ class SQLJoinQueryEngine(BaseQueryEngine): print_text(f"query engine response: {other_response}\n", color="pink") logger.info(f"> query engine response: {other_response}") - response_str = self._service_context.llm_predictor.predict( + response_str = self._service_context.llm.predict( self._sql_join_synthesis_prompt, query_str=query_bundle.query_str, sql_query_str=sql_query, diff --git a/llama_index/question_gen/llm_generators.py b/llama_index/question_gen/llm_generators.py index 18b68fd9fa..068cfabc46 100644 --- a/llama_index/question_gen/llm_generators.py +++ b/llama_index/question_gen/llm_generators.py @@ -1,6 +1,6 @@ from typing import List, Optional, Sequence, cast -from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.llm_predictor.base import LLMPredictorType from llama_index.output_parsers.base import StructuredOutput from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType @@ -20,10 +20,10 @@ from llama_index.types import BaseOutputParser class LLMQuestionGenerator(BaseQuestionGenerator): def __init__( self, - llm_predictor: BaseLLMPredictor, + llm: LLMPredictorType, prompt: BasePromptTemplate, ) -> None: - self._llm_predictor = llm_predictor + self._llm = llm self._prompt = prompt if self._prompt.output_parser is None: @@ -47,7 +47,7 @@ class LLMQuestionGenerator(BaseQuestionGenerator): output_parser=output_parser, prompt_type=PromptType.SUB_QUESTION, ) - return cls(service_context.llm_predictor, prompt) + return cls(service_context.llm, prompt) def _get_prompts(self) -> PromptDictType: """Get prompts.""" @@ -63,7 +63,7 @@ class LLMQuestionGenerator(BaseQuestionGenerator): ) -> List[SubQuestion]: tools_str = build_tools_text(tools) query_str = query.query_str - prediction = self._llm_predictor.predict( + prediction = self._llm.predict( prompt=self._prompt, tools_str=tools_str, query_str=query_str, @@ -79,7 +79,7 @@ class LLMQuestionGenerator(BaseQuestionGenerator): ) -> List[SubQuestion]: tools_str = build_tools_text(tools) query_str = query.query_str - prediction = await self._llm_predictor.apredict( + prediction = await self._llm.apredict( prompt=self._prompt, tools_str=tools_str, query_str=query_str, diff --git a/llama_index/question_gen/openai_generator.py b/llama_index/question_gen/openai_generator.py index c461d32c1c..8b96bbaf90 100644 --- a/llama_index/question_gen/openai_generator.py +++ b/llama_index/question_gen/openai_generator.py @@ -1,6 +1,6 @@ from typing import List, Optional, Sequence, cast -from llama_index.llms.base import LLM +from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.program.openai_program import OpenAIPydanticProgram from llama_index.prompts.mixin import PromptDictType diff --git a/llama_index/response_synthesizers/accumulate.py b/llama_index/response_synthesizers/accumulate.py index b56fc21f2d..6ff0fd357e 100644 --- a/llama_index/response_synthesizers/accumulate.py +++ b/llama_index/response_synthesizers/accumulate.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, List, Optional, Sequence +from typing import Any, Callable, List, Optional, Sequence from llama_index.async_utils import run_async_tasks from llama_index.prompts import BasePromptTemplate @@ -114,18 +114,35 @@ class Accumulate(BaseSynthesizer): text_qa_template, [text_chunk] ) - predictor = ( - self._service_context.llm_predictor.apredict - if use_async - else self._service_context.llm_predictor.predict - ) + predictor: Callable + if self._output_cls is None: + predictor = ( + self._service_context.llm.apredict + if use_async + else self._service_context.llm.predict + ) - return [ - predictor( - text_qa_template, - context_str=cur_text_chunk, - output_cls=self._output_cls, - **response_kwargs, + return [ + predictor( + text_qa_template, + context_str=cur_text_chunk, + **response_kwargs, + ) + for cur_text_chunk in text_chunks + ] + else: + predictor = ( + self._service_context.llm.astructured_predict + if use_async + else self._service_context.llm.structured_predict ) - for cur_text_chunk in text_chunks - ] + + return [ + predictor( + self._output_cls, + text_qa_template, + context_str=cur_text_chunk, + **response_kwargs, + ) + for cur_text_chunk in text_chunks + ] diff --git a/llama_index/response_synthesizers/generation.py b/llama_index/response_synthesizers/generation.py index 825c282f75..128b958b09 100644 --- a/llama_index/response_synthesizers/generation.py +++ b/llama_index/response_synthesizers/generation.py @@ -37,13 +37,13 @@ class Generation(BaseSynthesizer): del text_chunks if not self._streaming: - return await self._service_context.llm_predictor.apredict( + return await self._service_context.llm.apredict( self._input_prompt, query_str=query_str, **response_kwargs, ) else: - return self._service_context.llm_predictor.stream( + return self._service_context.llm.stream( self._input_prompt, query_str=query_str, **response_kwargs, @@ -59,13 +59,13 @@ class Generation(BaseSynthesizer): del text_chunks if not self._streaming: - return self._service_context.llm_predictor.predict( + return self._service_context.llm.predict( self._input_prompt, query_str=query_str, **response_kwargs, ) else: - return self._service_context.llm_predictor.stream( + return self._service_context.llm.stream( self._input_prompt, query_str=query_str, **response_kwargs, diff --git a/llama_index/response_synthesizers/refine.py b/llama_index/response_synthesizers/refine.py index b031a758ed..9b575269f8 100644 --- a/llama_index/response_synthesizers/refine.py +++ b/llama_index/response_synthesizers/refine.py @@ -3,7 +3,7 @@ from typing import Any, Callable, Generator, Optional, Sequence, Type, cast from llama_index.bridge.pydantic import BaseModel, Field, ValidationError from llama_index.indices.utils import truncate_text -from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.llm_predictor.base import LLMPredictorType from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.default_prompt_selectors import ( DEFAULT_REFINE_PROMPT_SEL, @@ -41,26 +41,45 @@ class DefaultRefineProgram(BasePydanticProgram): query_satisfied=True. In effect, doesn't do any answer filtering. """ - def __init__(self, prompt: BasePromptTemplate, llm_predictor: BaseLLMPredictor): + def __init__( + self, prompt: BasePromptTemplate, llm: LLMPredictorType, output_cls: BaseModel + ): self._prompt = prompt - self._llm_predictor = llm_predictor + self._llm = llm + self._output_cls = output_cls @property def output_cls(self) -> Type[BaseModel]: return StructuredRefineResponse def __call__(self, *args: Any, **kwds: Any) -> StructuredRefineResponse: - answer = self._llm_predictor.predict( - self._prompt, - **kwds, - ) + if self._output_cls is not None: + answer = self._llm.structured_predict( + self._output_cls, + self._prompt, + **kwds, + ) + answer = answer.json() + else: + answer = self._llm.predict( + self._prompt, + **kwds, + ) return StructuredRefineResponse(answer=answer, query_satisfied=True) async def acall(self, *args: Any, **kwds: Any) -> StructuredRefineResponse: - answer = await self._llm_predictor.apredict( - self._prompt, - **kwds, - ) + if self._output_cls is not None: + answer = await self._llm.astructured_predict( + self._output_cls, + self._prompt, + **kwds, + ) + answer = answer.json() + else: + answer = await self._llm.apredict( + self._prompt, + **kwds, + ) return StructuredRefineResponse(answer=answer, query_satisfied=True) @@ -155,7 +174,8 @@ class Refine(BaseSynthesizer): else: return DefaultRefineProgram( prompt=prompt, - llm_predictor=self._service_context.llm_predictor, + llm=self._service_context.llm, + output_cls=self._output_cls, ) def _give_response_single( @@ -181,7 +201,6 @@ class Refine(BaseSynthesizer): StructuredRefineResponse, program( context_str=cur_text_chunk, - output_cls=self._output_cls, **response_kwargs, ), ) @@ -193,10 +212,9 @@ class Refine(BaseSynthesizer): f"Validation error on structured response: {e}", exc_info=True ) elif response is None and self._streaming: - response = self._service_context.llm_predictor.stream( + response = self._service_context.llm.stream( text_qa_template, context_str=cur_text_chunk, - output_cls=self._output_cls, **response_kwargs, ) query_satisfied = True @@ -265,7 +283,6 @@ class Refine(BaseSynthesizer): StructuredRefineResponse, program( context_msg=cur_text_chunk, - output_cls=self._output_cls, **response_kwargs, ), ) @@ -285,10 +302,9 @@ class Refine(BaseSynthesizer): query_str=query_str, existing_answer=response ) - response = self._service_context.llm_predictor.stream( + response = self._service_context.llm.stream( refine_template, context_msg=cur_text_chunk, - output_cls=self._output_cls, **response_kwargs, ) @@ -371,7 +387,6 @@ class Refine(BaseSynthesizer): try: structured_response = await program.acall( context_msg=cur_text_chunk, - output_cls=self._output_cls, **response_kwargs, ) structured_response = cast( @@ -414,7 +429,6 @@ class Refine(BaseSynthesizer): try: structured_response = await program.acall( context_str=cur_text_chunk, - output_cls=self._output_cls, **response_kwargs, ) structured_response = cast( diff --git a/llama_index/response_synthesizers/simple_summarize.py b/llama_index/response_synthesizers/simple_summarize.py index 0930729a27..07f5d2db4a 100644 --- a/llama_index/response_synthesizers/simple_summarize.py +++ b/llama_index/response_synthesizers/simple_summarize.py @@ -42,13 +42,13 @@ class SimpleSummarize(BaseSynthesizer): response: RESPONSE_TEXT_TYPE if not self._streaming: - response = await self._service_context.llm_predictor.apredict( + response = await self._service_context.llm.apredict( text_qa_template, context_str=node_text, **response_kwargs, ) else: - response = self._service_context.llm_predictor.stream( + response = self._service_context.llm.stream( text_qa_template, context_str=node_text, **response_kwargs, @@ -76,13 +76,13 @@ class SimpleSummarize(BaseSynthesizer): response: RESPONSE_TEXT_TYPE if not self._streaming: - response = self._service_context.llm_predictor.predict( + response = self._service_context.llm.predict( text_qa_template, context_str=node_text, **kwargs, ) else: - response = self._service_context.llm_predictor.stream( + response = self._service_context.llm.stream( text_qa_template, context_str=node_text, **kwargs, diff --git a/llama_index/response_synthesizers/tree_summarize.py b/llama_index/response_synthesizers/tree_summarize.py index 773726d70f..b85a80349f 100644 --- a/llama_index/response_synthesizers/tree_summarize.py +++ b/llama_index/response_synthesizers/tree_summarize.py @@ -1,5 +1,5 @@ import asyncio -from typing import Any, List, Optional, Sequence +from typing import Any, Optional, Sequence from llama_index.async_utils import run_async_tasks from llama_index.prompts import BasePromptTemplate @@ -70,37 +70,54 @@ class TreeSummarize(BaseSynthesizer): if len(text_chunks) == 1: response: RESPONSE_TEXT_TYPE if self._streaming: - response = self._service_context.llm_predictor.stream( + response = self._service_context.llm.stream( summary_template, context_str=text_chunks[0], **response_kwargs ) else: - response = await self._service_context.llm_predictor.apredict( - summary_template, - output_cls=self._output_cls, - context_str=text_chunks[0], - **response_kwargs, - ) + if self._output_cls is None: + response = await self._service_context.llm.apredict( + summary_template, + context_str=text_chunks[0], + **response_kwargs, + ) + else: + response = await self._service_context.llm.astructured_predict( + self._output_cls, + summary_template, + context_str=text_chunks[0], + **response_kwargs, + ) # return pydantic object if output_cls is specified - return ( - response - if self._output_cls is None - else self._output_cls.parse_raw(response) - ) + return response else: # summarize each chunk - tasks = [ - self._service_context.llm_predictor.apredict( - summary_template, - output_cls=self._output_cls, - context_str=text_chunk, - **response_kwargs, - ) - for text_chunk in text_chunks - ] + if self._output_cls is None: + tasks = [ + self._service_context.llm.apredict( + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + else: + tasks = [ + self._service_context.llm.astructured_predict( + self._output_cls, + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] - summaries: List[str] = await asyncio.gather(*tasks) + summary_responses = await asyncio.gather(*tasks) + if self._output_cls is not None: + summaries = [summary.json() for summary in summary_responses] + else: + summaries = summary_responses # recursively summarize the summaries return await self.aget_response( @@ -129,48 +146,76 @@ class TreeSummarize(BaseSynthesizer): if len(text_chunks) == 1: response: RESPONSE_TEXT_TYPE if self._streaming: - response = self._service_context.llm_predictor.stream( + response = self._service_context.llm.stream( summary_template, context_str=text_chunks[0], **response_kwargs ) else: - response = self._service_context.llm_predictor.predict( - summary_template, - output_cls=self._output_cls, - context_str=text_chunks[0], - **response_kwargs, - ) - - # return pydantic object if output_cls is specified - return ( - response - if self._output_cls is None - else self._output_cls.parse_raw(response) - ) - - else: - # summarize each chunk - if self._use_async: - tasks = [ - self._service_context.llm_predictor.apredict( + if self._output_cls is None: + response = self._service_context.llm.predict( summary_template, - output_cls=self._output_cls, - context_str=text_chunk, + context_str=text_chunks[0], **response_kwargs, ) - for text_chunk in text_chunks - ] - - summaries: List[str] = run_async_tasks(tasks) - else: - summaries = [ - self._service_context.llm_predictor.predict( + else: + response = self._service_context.llm.structured_predict( + self._output_cls, summary_template, - output_cls=self._output_cls, - context_str=text_chunk, + context_str=text_chunks[0], **response_kwargs, ) - for text_chunk in text_chunks - ] + + return response + + else: + # summarize each chunk + if self._use_async: + if self._output_cls is None: + tasks = [ + self._service_context.llm.apredict( + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + else: + tasks = [ + self._service_context.llm.astructured_predict( + self._output_cls, + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + + summary_responses = run_async_tasks(tasks) + + if self._output_cls is not None: + summaries = [summary.json() for summary in summary_responses] + else: + summaries = summary_responses + else: + if self._output_cls is None: + summaries = [ + self._service_context.llm.predict( + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + else: + summaries = [ + self._service_context.llm.structured_predict( + self._output_cls, + summary_template, + context_str=text_chunk, + **response_kwargs, + ) + for text_chunk in text_chunks + ] + summaries = [summary.json() for summary in summaries] # recursively summarize the summaries return self.get_response( diff --git a/llama_index/schema.py b/llama_index/schema.py index d595fe6245..aa6cf7a468 100644 --- a/llama_index/schema.py +++ b/llama_index/schema.py @@ -71,9 +71,11 @@ class BaseComponent(BaseModel): # remove local functions keys_to_remove = [] - for key in state["__dict__"]: + for key, val in state["__dict__"].items(): if key.endswith("_fn"): keys_to_remove.append(key) + if "function <lambda>" in str(val): + keys_to_remove.append(key) for key in keys_to_remove: state["__dict__"].pop(key, None) diff --git a/llama_index/selectors/llm_selectors.py b/llama_index/selectors/llm_selectors.py index e4a2425491..eb3cf6945c 100644 --- a/llama_index/selectors/llm_selectors.py +++ b/llama_index/selectors/llm_selectors.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Sequence, cast -from llama_index.llm_predictor.base import BaseLLMPredictor +from llama_index.llm_predictor.base import LLMPredictorType from llama_index.output_parsers.base import StructuredOutput from llama_index.output_parsers.selection import Answer, SelectionOutputParser from llama_index.prompts.mixin import PromptDictType @@ -47,16 +47,16 @@ class LLMSingleSelector(BaseSelector): LLM-based selector that chooses one out of many options. Args: - llm_predictor (BaseLLMPredictor): An LLM predictor. + LLM (LLM): An LLM. prompt (SingleSelectPrompt): A LLM prompt for selecting one out of many options. """ def __init__( self, - llm_predictor: BaseLLMPredictor, + llm: LLMPredictorType, prompt: SingleSelectPrompt, ) -> None: - self._llm_predictor = llm_predictor + self._llm = llm self._prompt = prompt if self._prompt.output_parser is None: @@ -80,7 +80,7 @@ class LLMSingleSelector(BaseSelector): output_parser=output_parser, prompt_type=PromptType.SINGLE_SELECT, ) - return cls(service_context.llm_predictor, prompt) + return cls(service_context.llm, prompt) def _get_prompts(self) -> Dict[str, Any]: """Get prompts.""" @@ -98,7 +98,7 @@ class LLMSingleSelector(BaseSelector): choices_text = _build_choices_text(choices) # predict - prediction = self._llm_predictor.predict( + prediction = self._llm.predict( prompt=self._prompt, num_choices=len(choices), context_list=choices_text, @@ -117,7 +117,7 @@ class LLMSingleSelector(BaseSelector): choices_text = _build_choices_text(choices) # predict - prediction = await self._llm_predictor.apredict( + prediction = await self._llm.apredict( prompt=self._prompt, num_choices=len(choices), context_list=choices_text, @@ -136,18 +136,18 @@ class LLMMultiSelector(BaseSelector): LLM-based selector that chooses multiple out of many options. Args: - llm_predictor (LLMPredictor): An LLM predictor. + llm (LLM): An LLM. prompt (SingleSelectPrompt): A LLM prompt for selecting multiple out of many options. """ def __init__( self, - llm_predictor: BaseLLMPredictor, + llm: LLMPredictorType, prompt: MultiSelectPrompt, max_outputs: Optional[int] = None, ) -> None: - self._llm_predictor = llm_predictor + self._llm = llm self._prompt = prompt self._max_outputs = max_outputs @@ -175,7 +175,7 @@ class LLMMultiSelector(BaseSelector): output_parser=output_parser, prompt_type=PromptType.MULTI_SELECT, ) - return cls(service_context.llm_predictor, prompt, max_outputs) + return cls(service_context.llm, prompt, max_outputs) def _get_prompts(self) -> Dict[str, Any]: """Get prompts.""" @@ -193,7 +193,7 @@ class LLMMultiSelector(BaseSelector): context_list = _build_choices_text(choices) max_outputs = self._max_outputs or len(choices) - prediction = self._llm_predictor.predict( + prediction = self._llm.predict( prompt=self._prompt, num_choices=len(choices), max_outputs=max_outputs, @@ -212,7 +212,7 @@ class LLMMultiSelector(BaseSelector): context_list = _build_choices_text(choices) max_outputs = self._max_outputs or len(choices) - prediction = await self._llm_predictor.apredict( + prediction = await self._llm.apredict( prompt=self._prompt, num_choices=len(choices), max_outputs=max_outputs, diff --git a/llama_index/selectors/utils.py b/llama_index/selectors/utils.py index c9fbbb5a81..651d7b2aec 100644 --- a/llama_index/selectors/utils.py +++ b/llama_index/selectors/utils.py @@ -17,13 +17,13 @@ def get_selector_from_context( if is_multi: try: - llm = service_context.llm_predictor.llm + llm = service_context.llm selector = PydanticMultiSelector.from_defaults(llm=llm) # type: ignore except ValueError: selector = LLMMultiSelector.from_defaults(service_context=service_context) else: try: - llm = service_context.llm_predictor.llm + llm = service_context.llm selector = PydanticSingleSelector.from_defaults(llm=llm) # type: ignore except ValueError: selector = LLMSingleSelector.from_defaults(service_context=service_context) diff --git a/llama_index/service_context.py b/llama_index/service_context.py index 2a9006d572..c4378f2489 100644 --- a/llama_index/service_context.py +++ b/llama_index/service_context.py @@ -10,7 +10,7 @@ from llama_index.embeddings.utils import EmbedType, resolve_embed_model from llama_index.indices.prompt_helper import PromptHelper from llama_index.llm_predictor import LLMPredictor from llama_index.llm_predictor.base import BaseLLMPredictor, LLMMetadata -from llama_index.llms.base import LLM +from llama_index.llms.llm import LLM from llama_index.llms.utils import LLMType, resolve_llm from llama_index.logger import LlamaLogger from llama_index.node_parser.interface import NodeParser, TextSplitter @@ -164,6 +164,14 @@ class ServiceContext: if llm_predictor is not None: raise ValueError("Cannot specify both llm and llm_predictor") llm = resolve_llm(llm) + llm.system_prompt = llm.system_prompt or system_prompt + llm.query_wrapper_prompt = llm.query_wrapper_prompt or query_wrapper_prompt + llm.pydantic_program_mode = ( + llm.pydantic_program_mode or pydantic_program_mode + ) + + if llm_predictor is not None: + print("LLMPredictor is deprecated, please use LLM instead.") llm_predictor = llm_predictor or LLMPredictor( llm=llm, pydantic_program_mode=pydantic_program_mode ) @@ -311,8 +319,6 @@ class ServiceContext: @property def llm(self) -> LLM: - if not isinstance(self.llm_predictor, LLMPredictor): - raise ValueError("llm_predictor must be an instance of LLMPredictor") return self.llm_predictor.llm @property diff --git a/llama_index/types.py b/llama_index/types.py index c0c9b65efe..e454b18e8e 100644 --- a/llama_index/types.py +++ b/llama_index/types.py @@ -14,7 +14,7 @@ from typing import ( ) from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole Model = TypeVar("Model", bound=BaseModel) @@ -33,9 +33,9 @@ class BaseOutputParser(Protocol): def parse(self, output: str) -> Any: """Parse, validate, and correct errors programmatically.""" - @abstractmethod def format(self, query: str) -> str: """Format a query with structured output formatting instructions.""" + return query def format_messages(self, messages: List[ChatMessage]) -> List[ChatMessage]: """Format a list of messages with structured output formatting instructions.""" diff --git a/tests/agent/openai/test_openai_agent.py b/tests/agent/openai/test_openai_agent.py index 36fda875a5..a6859dac74 100644 --- a/tests/agent/openai/test_openai_agent.py +++ b/tests/agent/openai/test_openai_agent.py @@ -4,9 +4,9 @@ from unittest.mock import MagicMock, patch import pytest from llama_index.agent.openai_agent import OpenAIAgent, call_tool_with_error_handling from llama_index.chat_engine.types import AgentChatResponse -from llama_index.llms.base import ChatMessage, ChatResponse from llama_index.llms.mock import MockLLM from llama_index.llms.openai import OpenAI +from llama_index.llms.types import ChatMessage, ChatResponse from llama_index.tools.function_tool import FunctionTool from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage diff --git a/tests/agent/react/test_react_agent.py b/tests/agent/react/test_react_agent.py index 75e21f00b8..308d55a966 100644 --- a/tests/agent/react/test_react_agent.py +++ b/tests/agent/react/test_react_agent.py @@ -5,13 +5,13 @@ import pytest from llama_index.agent.react.base import ReActAgent from llama_index.bridge.pydantic import PrivateAttr from llama_index.chat_engine.types import AgentChatResponse, StreamingAgentChatResponse -from llama_index.llms.base import ( +from llama_index.llms.mock import MockLLM +from llama_index.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, MessageRole, ) -from llama_index.llms.mock import MockLLM from llama_index.tools.function_tool import FunctionTool diff --git a/tests/chat_engine/test_condense_plus_context.py b/tests/chat_engine/test_condense_plus_context.py index 6d61d61642..4f246e437e 100644 --- a/tests/chat_engine/test_condense_plus_context.py +++ b/tests/chat_engine/test_condense_plus_context.py @@ -1,15 +1,24 @@ from typing import Any, List -from unittest.mock import Mock +from unittest.mock import Mock, patch from llama_index.chat_engine.condense_plus_context import CondensePlusContextChatEngine from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.service_context import ServiceContext -from llama_index.llm_predictor.base import LLMPredictor +from llama_index.llms.mock import MockLLM from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer -from llama_index.prompts.base import BasePromptTemplate +from llama_index.prompts import BasePromptTemplate from llama_index.schema import NodeWithScore, TextNode +def override_predict(self: Any, prompt: BasePromptTemplate, **prompt_args: Any) -> str: + return prompt.format(**prompt_args) + + +@patch.object( + MockLLM, + "predict", + override_predict, +) def test_condense_plus_context_chat_engine( mock_service_context: ServiceContext, ) -> None: @@ -39,13 +48,6 @@ def test_condense_plus_context_chat_engine( mock_retriever.retrieve.side_effect = override_retrieve - mock_llm_predictor = Mock(spec=LLMPredictor) - - def override_predict(prompt: BasePromptTemplate, **prompt_args: Any) -> str: - return prompt.format(**prompt_args) - - mock_llm_predictor.predict.side_effect = override_predict - context_prompt = "Context information: {context_str}" condense_prompt = ( @@ -56,8 +58,7 @@ def test_condense_plus_context_chat_engine( engine = CondensePlusContextChatEngine( retriever=mock_retriever, - llm=mock_service_context.llm, - llm_predictor=mock_llm_predictor, + llm=MockLLM(), memory=ChatMemoryBuffer.from_defaults( chat_history=[], llm=mock_service_context.llm ), diff --git a/tests/chat_engine/test_condense_question.py b/tests/chat_engine/test_condense_question.py index 349fe686b3..fb249a5a3d 100644 --- a/tests/chat_engine/test_condense_question.py +++ b/tests/chat_engine/test_condense_question.py @@ -2,7 +2,7 @@ from unittest.mock import Mock from llama_index.chat_engine.condense_question import CondenseQuestionChatEngine from llama_index.core import BaseQueryEngine -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.response.schema import Response from llama_index.service_context import ServiceContext diff --git a/tests/chat_engine/test_simple.py b/tests/chat_engine/test_simple.py index e84fcdbdde..fa6e191b25 100644 --- a/tests/chat_engine/test_simple.py +++ b/tests/chat_engine/test_simple.py @@ -1,5 +1,5 @@ from llama_index.chat_engine.simple import SimpleChatEngine -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.service_context import ServiceContext diff --git a/tests/conftest.py b/tests/conftest.py index 5791c4d5d1..5d6b5e2d15 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -6,8 +6,8 @@ from typing import Any, List, Optional import openai import pytest from llama_index.llm_predictor.base import LLMPredictor -from llama_index.llms.base import LLMMetadata from llama_index.llms.mock import MockLLM +from llama_index.llms.types import LLMMetadata from llama_index.node_parser.text import SentenceSplitter, TokenTextSplitter from llama_index.service_context import ServiceContext @@ -68,10 +68,27 @@ def patch_llm_predictor(monkeypatch: pytest.MonkeyPatch) -> None: LLMMetadata(), ) + monkeypatch.setattr( + MockLLM, + "predict", + patch_llmpredictor_predict, + ) + monkeypatch.setattr( + MockLLM, + "apredict", + patch_llmpredictor_apredict, + ) + monkeypatch.setattr( + MockLLM, + "metadata", + LLMMetadata(), + ) + @pytest.fixture() def mock_service_context( - patch_token_text_splitter: Any, patch_llm_predictor: Any + patch_token_text_splitter: Any, + patch_llm_predictor: Any, ) -> ServiceContext: return ServiceContext.from_defaults(embed_model=MockEmbedding()) diff --git a/tests/indices/list/test_retrievers.py b/tests/indices/list/test_retrievers.py index 5fcbb38200..903c85a73d 100644 --- a/tests/indices/list/test_retrievers.py +++ b/tests/indices/list/test_retrievers.py @@ -3,7 +3,7 @@ from unittest.mock import patch from llama_index.indices.list.base import SummaryIndex from llama_index.indices.list.retrievers import SummaryIndexEmbeddingRetriever -from llama_index.llm_predictor.base import LLMPredictor +from llama_index.llms.mock import MockLLM from llama_index.prompts import BasePromptTemplate from llama_index.schema import Document from llama_index.service_context import ServiceContext @@ -55,7 +55,7 @@ def mock_llmpredictor_predict( @patch.object( - LLMPredictor, + MockLLM, "predict", mock_llmpredictor_predict, ) diff --git a/tests/indices/query/query_transform/test_base.py b/tests/indices/query/query_transform/test_base.py index 438acd4682..b1c71c0fcf 100644 --- a/tests/indices/query/query_transform/test_base.py +++ b/tests/indices/query/query_transform/test_base.py @@ -11,7 +11,7 @@ def test_decompose_query_transform(mock_service_context: ServiceContext) -> None """Test decompose query transform.""" query_transform = DecomposeQueryTransform( decompose_query_prompt=MOCK_DECOMPOSE_PROMPT, - llm_predictor=mock_service_context.llm_predictor, + llm=mock_service_context.llm, ) query_str = "What is?" diff --git a/tests/indices/response/test_tree_summarize.py b/tests/indices/response/test_tree_summarize.py index 8130df5246..93432e2fa4 100644 --- a/tests/indices/response/test_tree_summarize.py +++ b/tests/indices/response/test_tree_summarize.py @@ -1,11 +1,13 @@ """Test tree summarize.""" -from typing import List, Sequence -from unittest.mock import Mock +from typing import Any, List, Sequence +from unittest.mock import Mock, patch import pytest from llama_index.bridge.pydantic import BaseModel from llama_index.indices.prompt_helper import PromptHelper +from llama_index.llm_predictor import LLMPredictor +from llama_index.llms.mock import MockLLM from llama_index.prompts.base import PromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.response_synthesizers import TreeSummarize @@ -53,11 +55,19 @@ def test_tree_summarize(mock_service_context_merge_chunks: ServiceContext) -> No assert str(response) == "Text chunk 1\nText chunk 2\nText chunk 3\nText chunk 4" +class TestModel(BaseModel): + hello: str + + +def mock_return_class(*args: Any, **kwargs: Any) -> TestModel: + return TestModel(hello="Test Chunk 5") + + +@patch.object(MockLLM, "structured_predict", mock_return_class) def test_tree_summarize_output_cls( mock_service_context_merge_chunks: ServiceContext, ) -> None: - class TestModel(BaseModel): - hello: str + mock_service_context_merge_chunks.llm_predictor = LLMPredictor(MockLLM()) mock_summary_prompt_tmpl = "{context_str}{query_str}" mock_summary_prompt = PromptTemplate( @@ -71,9 +81,7 @@ def test_tree_summarize_output_cls( '{"hello":"Test Chunk 3"}', '{"hello":"Test Chunk 4"}', ] - response_rtr = {"hello": "Test Chunk 5"} - TestModel.parse_raw = Mock(name="parse_raw") # type: ignore - TestModel.parse_raw.return_value = response_rtr + response_dict = {"hello": "Test Chunk 5"} # test sync tree_summarize = TreeSummarize( @@ -83,8 +91,8 @@ def test_tree_summarize_output_cls( ) full_response = "\n".join(texts) response = tree_summarize.get_response(text_chunks=texts, query_str=query_str) - TestModel.parse_raw.assert_called_once_with(full_response) - assert response == response_rtr + assert isinstance(response, TestModel) + assert response.dict() == response_dict def test_tree_summarize_use_async( diff --git a/tests/indices/struct_store/test_json_query.py b/tests/indices/struct_store/test_json_query.py index 3b1bc4757c..a84b13ad5c 100644 --- a/tests/indices/struct_store/test_json_query.py +++ b/tests/indices/struct_store/test_json_query.py @@ -2,11 +2,14 @@ import asyncio import json -from typing import Any, Dict, Generator, cast -from unittest.mock import AsyncMock, MagicMock, patch +from typing import Any, Dict, cast +from unittest.mock import patch import pytest from llama_index.indices.struct_store.json_query import JSONQueryEngine, JSONType +from llama_index.llm_predictor import LLMPredictor +from llama_index.llms.mock import MockLLM +from llama_index.prompts.base import BasePromptTemplate from llama_index.response.schema import Response from llama_index.schema import QueryBundle from llama_index.service_context import ServiceContext @@ -21,23 +24,35 @@ TEST_PARAMS = [ TEST_LLM_OUTPUT = "test_llm_output" -@pytest.fixture() -def mock_json_service_ctx( - mock_service_context: ServiceContext, -) -> Generator[ServiceContext, None, None]: - with patch.object(mock_service_context, "llm_predictor") as mock_llm_predictor: - mock_llm_predictor.apredict = AsyncMock(return_value=TEST_LLM_OUTPUT) - mock_llm_predictor.predict = MagicMock(return_value=TEST_LLM_OUTPUT) - yield mock_service_context +def mock_predict(self: Any, prompt: BasePromptTemplate, **prompt_args: Any) -> str: + return TEST_LLM_OUTPUT + + +async def amock_predict( + self: Any, prompt: BasePromptTemplate, **prompt_args: Any +) -> str: + return TEST_LLM_OUTPUT @pytest.mark.parametrize(("synthesize_response", "call_apredict"), TEST_PARAMS) +@patch.object( + MockLLM, + "predict", + mock_predict, +) +@patch.object( + MockLLM, + "apredict", + amock_predict, +) def test_json_query_engine( synthesize_response: bool, call_apredict: bool, - mock_json_service_ctx: ServiceContext, + mock_service_context: ServiceContext, ) -> None: """Test GPTNLJSONQueryEngine.""" + mock_service_context.llm_predictor = LLMPredictor(MockLLM()) + # Test on some sample data json_val = cast(JSONType, {}) json_schema = cast(JSONType, {}) @@ -53,7 +68,7 @@ def test_json_query_engine( query_engine = JSONQueryEngine( json_value=json_val, json_schema=json_schema, - service_context=mock_json_service_ctx, + service_context=mock_service_context, output_processor=test_output_processor, verbose=True, synthesize_response=synthesize_response, diff --git a/tests/llms/test_anthropic.py b/tests/llms/test_anthropic.py index 5366d3fc39..c7386ffbd4 100644 --- a/tests/llms/test_anthropic.py +++ b/tests/llms/test_anthropic.py @@ -1,6 +1,6 @@ import pytest from llama_index.llms.anthropic import Anthropic -from llama_index.llms.base import ChatMessage +from llama_index.llms.types import ChatMessage try: import anthropic diff --git a/tests/llms/test_anthropic_utils.py b/tests/llms/test_anthropic_utils.py index ed3d978568..76b4dce62c 100644 --- a/tests/llms/test_anthropic_utils.py +++ b/tests/llms/test_anthropic_utils.py @@ -3,7 +3,7 @@ from llama_index.llms.anthropic_utils import ( anthropic_modelname_to_contextsize, messages_to_anthropic_prompt, ) -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole def test_messages_to_anthropic_prompt() -> None: diff --git a/tests/llms/test_bedrock.py b/tests/llms/test_bedrock.py index f462ce5b9a..169efc742b 100644 --- a/tests/llms/test_bedrock.py +++ b/tests/llms/test_bedrock.py @@ -5,7 +5,7 @@ from typing import Any, Generator from botocore.response import StreamingBody from botocore.stub import Stubber from llama_index.llms import Bedrock -from llama_index.llms.base import ChatMessage +from llama_index.llms.types import ChatMessage from pytest import MonkeyPatch diff --git a/tests/llms/test_cohere.py b/tests/llms/test_cohere.py index 99d3afe3cc..1d65c83a36 100644 --- a/tests/llms/test_cohere.py +++ b/tests/llms/test_cohere.py @@ -1,7 +1,7 @@ from typing import Any import pytest -from llama_index.llms.base import ChatMessage +from llama_index.llms.types import ChatMessage from pytest import MonkeyPatch try: diff --git a/tests/llms/test_custom.py b/tests/llms/test_custom.py index cd957cb98f..3cd79eca33 100644 --- a/tests/llms/test_custom.py +++ b/tests/llms/test_custom.py @@ -1,12 +1,12 @@ from typing import Any -from llama_index.llms.base import ( +from llama_index.llms.custom import CustomLLM +from llama_index.llms.types import ( ChatMessage, CompletionResponse, CompletionResponseGen, LLMMetadata, ) -from llama_index.llms.custom import CustomLLM class TestLLM(CustomLLM): diff --git a/tests/llms/test_konko.py b/tests/llms/test_konko.py index 019494ac60..8b62dd0b61 100644 --- a/tests/llms/test_konko.py +++ b/tests/llms/test_konko.py @@ -1,8 +1,8 @@ from typing import Any, Generator import pytest -from llama_index.llms.base import ChatMessage from llama_index.llms.konko import Konko +from llama_index.llms.types import ChatMessage from pytest import MonkeyPatch try: diff --git a/tests/llms/test_langchain.py b/tests/llms/test_langchain.py index dae2cd6827..15b1c03f58 100644 --- a/tests/llms/test_langchain.py +++ b/tests/llms/test_langchain.py @@ -1,7 +1,7 @@ from typing import List import pytest -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole try: import cohere diff --git a/tests/llms/test_litellm.py b/tests/llms/test_litellm.py index 3d6643f18d..8786f7b506 100644 --- a/tests/llms/test_litellm.py +++ b/tests/llms/test_litellm.py @@ -6,8 +6,8 @@ except ImportError: litellm = None # type: ignore import pytest -from llama_index.llms.base import ChatMessage from llama_index.llms.litellm import LiteLLM +from llama_index.llms.types import ChatMessage from pytest import MonkeyPatch from tests.conftest import CachedOpenAIApiKeys diff --git a/tests/llms/test_llama_utils.py b/tests/llms/test_llama_utils.py index 84326c8f6b..b8587d7a5d 100644 --- a/tests/llms/test_llama_utils.py +++ b/tests/llms/test_llama_utils.py @@ -1,7 +1,6 @@ from typing import Sequence import pytest -from llama_index.llms.base import ChatMessage, MessageRole from llama_index.llms.llama_utils import ( B_INST, B_SYS, @@ -13,6 +12,7 @@ from llama_index.llms.llama_utils import ( completion_to_prompt, messages_to_prompt, ) +from llama_index.llms.types import ChatMessage, MessageRole @pytest.fixture() diff --git a/tests/llms/test_localai.py b/tests/llms/test_localai.py index e8e93e7ff9..eda548c0ab 100644 --- a/tests/llms/test_localai.py +++ b/tests/llms/test_localai.py @@ -2,7 +2,7 @@ from unittest.mock import MagicMock, patch import pytest from llama_index.llms import LocalAI -from llama_index.llms.base import ChatMessage +from llama_index.llms.types import ChatMessage from openai.types import Completion, CompletionChoice from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index 9921739e81..ebc42b2092 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -3,8 +3,8 @@ from typing import Any, AsyncGenerator, Generator from unittest.mock import AsyncMock, MagicMock, patch import pytest -from llama_index.llms.base import ChatMessage from llama_index.llms.openai import OpenAI +from llama_index.llms.types import ChatMessage from openai.types.chat.chat_completion import ( ChatCompletion, ChatCompletionMessage, diff --git a/tests/llms/test_openai_like.py b/tests/llms/test_openai_like.py index 463a2d97b9..99a96f6f44 100644 --- a/tests/llms/test_openai_like.py +++ b/tests/llms/test_openai_like.py @@ -2,8 +2,8 @@ from typing import List from unittest.mock import MagicMock, call, patch from llama_index.llms import LOCALAI_DEFAULTS, OpenAILike -from llama_index.llms.base import ChatMessage, MessageRole from llama_index.llms.openai import Tokenizer +from llama_index.llms.types import ChatMessage, MessageRole from openai.types import Completion, CompletionChoice from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage diff --git a/tests/llms/test_openai_utils.py b/tests/llms/test_openai_utils.py index d5dcd76c28..9a7005f5c4 100644 --- a/tests/llms/test_openai_utils.py +++ b/tests/llms/test_openai_utils.py @@ -1,12 +1,12 @@ from typing import List import pytest -from llama_index.llms.base import ChatMessage, MessageRole from llama_index.llms.openai_utils import ( from_openai_message_dicts, from_openai_messages, to_openai_message_dicts, ) +from llama_index.llms.types import ChatMessage, MessageRole from openai.types.chat.chat_completion_assistant_message_param import ( FunctionCall as FunctionCallParam, ) diff --git a/tests/llms/test_palm.py b/tests/llms/test_palm.py index 6068008b59..c145f24ba6 100644 --- a/tests/llms/test_palm.py +++ b/tests/llms/test_palm.py @@ -36,8 +36,8 @@ sys.modules["google.generativeai"] = MockPalmPackage() from typing import Any -from llama_index.llms.base import CompletionResponse from llama_index.llms.palm import PaLM +from llama_index.llms.types import CompletionResponse @pytest.mark.skipif( diff --git a/tests/llms/test_rungpt.py b/tests/llms/test_rungpt.py index ea286272a2..475e719c8a 100644 --- a/tests/llms/test_rungpt.py +++ b/tests/llms/test_rungpt.py @@ -2,11 +2,11 @@ from typing import Any, Dict, Generator, List from unittest.mock import MagicMock, patch import pytest -from llama_index.llms.base import ( +from llama_index.llms.rungpt import RunGptLLM +from llama_index.llms.types import ( ChatMessage, MessageRole, ) -from llama_index.llms.rungpt import RunGptLLM try: import sseclient diff --git a/tests/llms/test_vertex.py b/tests/llms/test_vertex.py index 2f93a47f63..4d6f006863 100644 --- a/tests/llms/test_vertex.py +++ b/tests/llms/test_vertex.py @@ -1,5 +1,5 @@ import pytest -from llama_index.llms.base import CompletionResponse +from llama_index.llms.types import CompletionResponse from llama_index.llms.vertex import Vertex from llama_index.llms.vertex_utils import init_vertexai diff --git a/tests/llms/test_watsonx.py b/tests/llms/test_watsonx.py index 0176e501b9..990028bc91 100644 --- a/tests/llms/test_watsonx.py +++ b/tests/llms/test_watsonx.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Generator, Optional from unittest.mock import MagicMock import pytest -from llama_index.llms.base import ChatMessage +from llama_index.llms.types import ChatMessage try: import ibm_watson_machine_learning diff --git a/tests/llms/test_xinference.py b/tests/llms/test_xinference.py index 9787961140..3c2000746a 100644 --- a/tests/llms/test_xinference.py +++ b/tests/llms/test_xinference.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Generator, Iterator, List, Mapping, Sequence, Tuple, Union import pytest -from llama_index.llms.base import ( +from llama_index.llms.types import ( ChatMessage, ChatResponse, CompletionResponse, diff --git a/tests/postprocessor/test_llm_rerank.py b/tests/postprocessor/test_llm_rerank.py index 07a79438d8..f50216c575 100644 --- a/tests/postprocessor/test_llm_rerank.py +++ b/tests/postprocessor/test_llm_rerank.py @@ -3,7 +3,7 @@ from typing import Any, List from unittest.mock import patch -from llama_index.llm_predictor import LLMPredictor +from llama_index.llms.mock import MockLLM from llama_index.postprocessor.llm_rerank import LLMRerank from llama_index.prompts import BasePromptTemplate from llama_index.schema import BaseNode, NodeWithScore, QueryBundle, TextNode @@ -42,7 +42,7 @@ def mock_format_node_batch_fn(nodes: List[BaseNode]) -> str: @patch.object( - LLMPredictor, + MockLLM, "predict", mock_llmpredictor_predict, ) diff --git a/tests/program/test_llm_program.py b/tests/program/test_llm_program.py index 5b03af7117..ae8d4dcab1 100644 --- a/tests/program/test_llm_program.py +++ b/tests/program/test_llm_program.py @@ -4,7 +4,7 @@ import json from unittest.mock import MagicMock from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import ( +from llama_index.llms.types import ( ChatMessage, ChatResponse, CompletionResponse, diff --git a/tests/program/test_lmformatenforcer.py b/tests/program/test_lmformatenforcer.py index daecf975fd..9b9468c3d6 100644 --- a/tests/program/test_lmformatenforcer.py +++ b/tests/program/test_lmformatenforcer.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock import pytest from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import CompletionResponse from llama_index.llms.huggingface import HuggingFaceLLM +from llama_index.llms.types import CompletionResponse from llama_index.program.lmformatenforcer_program import LMFormatEnforcerPydanticProgram has_lmformatenforcer = find_spec("lmformatenforcer") is not None diff --git a/tests/program/test_multi_modal_llm_program.py b/tests/program/test_multi_modal_llm_program.py index c0f7d1977a..7d1fe9b848 100644 --- a/tests/program/test_multi_modal_llm_program.py +++ b/tests/program/test_multi_modal_llm_program.py @@ -5,7 +5,7 @@ from typing import Sequence from unittest.mock import MagicMock from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.base import ( +from llama_index.llms.types import ( CompletionResponse, ) from llama_index.multi_modal_llms import MultiModalLLMMetadata diff --git a/tests/prompts/test_base.py b/tests/prompts/test_base.py index ec4b9f1e90..00993b6f61 100644 --- a/tests/prompts/test_base.py +++ b/tests/prompts/test_base.py @@ -5,7 +5,7 @@ from typing import Any import pytest from llama_index.llms import MockLLM -from llama_index.llms.base import ChatMessage, MessageRole +from llama_index.llms.types import ChatMessage, MessageRole from llama_index.prompts import ( ChatPromptTemplate, LangchainPromptTemplate, diff --git a/tests/query_engine/test_retriever_query_engine.py b/tests/query_engine/test_retriever_query_engine.py index eedb032fec..431837d9cb 100644 --- a/tests/query_engine/test_retriever_query_engine.py +++ b/tests/query_engine/test_retriever_query_engine.py @@ -1,7 +1,6 @@ import pytest from llama_index import ( Document, - LLMPredictor, ServiceContext, TreeIndex, ) @@ -19,16 +18,14 @@ except ImportError: @pytest.mark.skipif(anthropic is None, reason="anthropic not installed") def test_query_engine_falls_back_to_inheriting_retrievers_service_context() -> None: documents = [Document(text="Hi")] - gpt35turbo_predictor = LLMPredictor( - llm=OpenAI( - temperature=0, - model_name="gpt-3.5-turbo-0613", - streaming=True, - openai_api_key="test-test-test", - ), + gpt35turbo_predictor = OpenAI( + temperature=0, + model_name="gpt-3.5-turbo-0613", + streaming=True, + openai_api_key="test-test-test", ) gpt35_sc = ServiceContext.from_defaults( - llm_predictor=gpt35turbo_predictor, + llm=gpt35turbo_predictor, chunk_size=512, ) @@ -37,21 +34,21 @@ def test_query_engine_falls_back_to_inheriting_retrievers_service_context() -> N query_engine = RetrieverQueryEngine(retriever=retriever) assert ( - retriever._service_context.llm_predictor.metadata.model_name - == gpt35turbo_predictor._llm.metadata.model_name + retriever._service_context.llm.metadata.model_name + == gpt35turbo_predictor.metadata.model_name ) assert ( - query_engine._response_synthesizer.service_context.llm_predictor.metadata.model_name - == retriever._service_context.llm_predictor.metadata.model_name + query_engine._response_synthesizer.service_context.llm.metadata.model_name + == retriever._service_context.llm.metadata.model_name ) assert ( query_engine._response_synthesizer.service_context == retriever._service_context ) documents = [Document(text="Hi")] - claude_predictor = LLMPredictor(llm=Anthropic(model="claude-2")) + claude_predictor = Anthropic(model="claude-2") claude_sc = ServiceContext.from_defaults( - llm_predictor=claude_predictor, + llm=claude_predictor, chunk_size=512, ) @@ -60,12 +57,12 @@ def test_query_engine_falls_back_to_inheriting_retrievers_service_context() -> N query_engine = RetrieverQueryEngine(retriever=retriever) assert ( - retriever._service_context.llm_predictor.metadata.model_name - == claude_predictor._llm.metadata.model_name + retriever._service_context.llm.metadata.model_name + == claude_predictor.metadata.model_name ) assert ( - query_engine._response_synthesizer.service_context.llm_predictor.metadata.model_name - == retriever._service_context.llm_predictor.metadata.model_name + query_engine._response_synthesizer.service_context.llm.metadata.model_name + == retriever._service_context.llm.metadata.model_name ) assert ( query_engine._response_synthesizer.service_context == retriever._service_context diff --git a/tests/token_predictor/test_base.py b/tests/token_predictor/test_base.py index 8d15fafab0..1a397eea60 100644 --- a/tests/token_predictor/test_base.py +++ b/tests/token_predictor/test_base.py @@ -6,7 +6,7 @@ from unittest.mock import patch from llama_index.indices.keyword_table.base import KeywordTableIndex from llama_index.indices.list.base import SummaryIndex from llama_index.indices.tree.base import TreeIndex -from llama_index.llm_predictor.mock import MockLLMPredictor +from llama_index.llms.mock import MockLLM from llama_index.node_parser import TokenTextSplitter from llama_index.schema import Document from llama_index.service_context import ServiceContext @@ -26,8 +26,8 @@ def test_token_predictor(mock_split: Any) -> None: "This is a test v2." ) document = Document(text=doc_text) - llm_predictor = MockLLMPredictor(max_tokens=256) - service_context = ServiceContext.from_defaults(llm_predictor=llm_predictor) + llm = MockLLM(max_tokens=256) + service_context = ServiceContext.from_defaults(llm=llm) # test tree index index = TreeIndex.from_documents([document], service_context=service_context) -- GitLab