From 6c261c5f6b8d2f9cd32f3975e7d2124e6eb25912 Mon Sep 17 00:00:00 2001 From: Jerry Liu <jerryjliu98@gmail.com> Date: Fri, 5 Jan 2024 20:16:02 -0800 Subject: [PATCH] move more classes to core (#9871) --- benchmarks/embeddings/bench_embeddings.py | 2 +- benchmarks/struct_indices/spider/evaluate.py | 4 +- docs/api_reference/response.rst | 2 +- .../query_transform_cookbook.ipynb | 2 +- llama_index/__init__.py | 6 +- .../agent/legacy/context_retriever_agent.py | 4 +- llama_index/agent/legacy/openai_agent.py | 2 +- llama_index/agent/legacy/react/base.py | 2 +- llama_index/agent/openai/step.py | 2 +- llama_index/agent/openai_assistant_agent.py | 2 +- llama_index/agent/react/base.py | 2 +- llama_index/agent/react/formatter.py | 2 +- llama_index/agent/react/step.py | 2 +- llama_index/agent/react_multimodal/step.py | 2 +- llama_index/agent/types.py | 6 +- llama_index/agent/utils.py | 2 +- llama_index/callbacks/finetuning_handler.py | 4 +- .../chat_engine/condense_plus_context.py | 2 +- llama_index/chat_engine/condense_question.py | 6 +- llama_index/chat_engine/context.py | 4 +- llama_index/chat_engine/simple.py | 2 +- llama_index/chat_engine/types.py | 8 +- llama_index/chat_engine/utils.py | 2 +- llama_index/core/__init__.py | 13 - llama_index/core/base_query_engine.py | 4 +- llama_index/core/base_retriever.py | 7 +- llama_index/core/embeddings/__init__.py | 0 llama_index/core/embeddings/base.py | 354 +++++++++++++++++ llama_index/core/llms/__init__.py | 0 llama_index/core/llms/types.py | 110 ++++++ llama_index/core/response/__init__.py | 0 llama_index/core/response/schema.py | 142 +++++++ llama_index/embeddings/adapter.py | 3 +- llama_index/embeddings/azure_openai.py | 2 +- llama_index/embeddings/base.py | 365 +----------------- llama_index/embeddings/bedrock.py | 7 +- llama_index/embeddings/clarifai.py | 3 +- llama_index/embeddings/clip.py | 6 +- llama_index/embeddings/cohereai.py | 2 +- llama_index/embeddings/gemini.py | 2 +- llama_index/embeddings/google.py | 2 +- llama_index/embeddings/google_palm.py | 2 +- llama_index/embeddings/gradient.py | 2 +- llama_index/embeddings/huggingface.py | 2 +- llama_index/embeddings/huggingface_optimum.py | 2 +- llama_index/embeddings/instructor.py | 2 +- llama_index/embeddings/jinaai.py | 2 +- llama_index/embeddings/langchain.py | 2 +- llama_index/embeddings/mistralai.py | 2 +- llama_index/embeddings/multi_modal_base.py | 2 +- .../embeddings/text_embeddings_inference.py | 2 +- llama_index/evaluation/base.py | 2 +- llama_index/evaluation/batch_runner.py | 4 +- llama_index/evaluation/benchmarks/beir.py | 2 +- llama_index/evaluation/benchmarks/hotpotqa.py | 3 +- llama_index/evaluation/eval_utils.py | 2 +- llama_index/evaluation/retrieval/evaluator.py | 2 +- llama_index/evaluation/semantic_similarity.py | 2 +- llama_index/indices/base.py | 3 +- llama_index/indices/base_retriever.py | 2 +- llama_index/indices/composability/graph.py | 2 +- llama_index/indices/document_summary/base.py | 4 +- .../indices/document_summary/retrievers.py | 2 +- llama_index/indices/empty/base.py | 3 +- llama_index/indices/empty/retrievers.py | 2 +- llama_index/indices/keyword_table/base.py | 2 +- .../indices/keyword_table/rake_base.py | 2 +- .../indices/keyword_table/retrievers.py | 2 +- .../indices/keyword_table/simple_base.py | 2 +- llama_index/indices/knowledge_graph/base.py | 2 +- .../indices/knowledge_graph/retrievers.py | 2 +- llama_index/indices/list/base.py | 2 +- llama_index/indices/list/retrievers.py | 2 +- llama_index/indices/managed/base.py | 2 +- .../indices/managed/colbert_index/base.py | 2 +- .../managed/colbert_index/retriever.py | 2 +- llama_index/indices/managed/vectara/base.py | 3 +- llama_index/indices/managed/vectara/query.py | 5 +- .../indices/managed/vectara/retriever.py | 2 +- llama_index/indices/managed/zilliz/base.py | 2 +- .../indices/managed/zilliz/retriever.py | 2 +- llama_index/indices/multi_modal/base.py | 3 +- llama_index/indices/multi_modal/retriever.py | 2 +- llama_index/indices/prompt_helper.py | 2 +- llama_index/indices/query/base.py | 2 +- llama_index/indices/query/embedding_utils.py | 2 +- .../indices/query/query_transform/base.py | 2 +- .../indices/struct_store/json_query.py | 4 +- llama_index/indices/struct_store/pandas.py | 3 +- llama_index/indices/struct_store/sql.py | 3 +- llama_index/indices/struct_store/sql_query.py | 4 +- .../indices/struct_store/sql_retriever.py | 2 +- .../indices/tree/all_leaf_retriever.py | 2 +- llama_index/indices/tree/base.py | 2 +- .../indices/tree/select_leaf_retriever.py | 4 +- .../indices/tree/tree_root_retriever.py | 2 +- llama_index/indices/vector_store/base.py | 2 +- .../auto_retriever/auto_retriever.py | 2 +- .../vector_store/retrievers/retriever.py | 2 +- llama_index/langchain_helpers/agents/tools.py | 4 +- llama_index/llama_dataset/base.py | 2 +- llama_index/llama_dataset/generator.py | 2 +- llama_index/llama_dataset/rag.py | 2 +- llama_index/llm_predictor/base.py | 10 +- llama_index/llm_predictor/mock.py | 2 +- llama_index/llms/__init__.py | 22 +- llama_index/llms/ai21.py | 16 +- llama_index/llms/anthropic.py | 22 +- llama_index/llms/anthropic_utils.py | 2 +- llama_index/llms/anyscale.py | 2 +- llama_index/llms/anyscale_utils.py | 2 +- llama_index/llms/azure_openai.py | 2 +- llama_index/llms/base.py | 2 +- llama_index/llms/bedrock.py | 20 +- llama_index/llms/bedrock_utils.py | 2 +- llama_index/llms/clarifai.py | 12 +- llama_index/llms/cohere.py | 22 +- llama_index/llms/cohere_utils.py | 2 +- llama_index/llms/custom.py | 16 +- llama_index/llms/everlyai.py | 2 +- llama_index/llms/gemini.py | 16 +- llama_index/llms/gemini_utils.py | 2 +- llama_index/llms/generic_utils.py | 2 +- llama_index/llms/gradient.py | 6 +- llama_index/llms/huggingface.py | 22 +- llama_index/llms/konko.py | 20 +- llama_index/llms/konko_utils.py | 2 +- llama_index/llms/langchain.py | 14 +- llama_index/llms/langchain_utils.py | 2 +- llama_index/llms/litellm.py | 20 +- llama_index/llms/litellm_utils.py | 2 +- llama_index/llms/llama_api.py | 16 +- llama_index/llms/llama_cpp.py | 14 +- llama_index/llms/llama_utils.py | 2 +- llama_index/llms/llm.py | 10 +- llama_index/llms/localai.py | 2 +- llama_index/llms/mistral.py | 26 +- llama_index/llms/mock.py | 6 +- llama_index/llms/monsterapi.py | 6 +- llama_index/llms/ollama.py | 6 +- llama_index/llms/openai.py | 22 +- llama_index/llms/openai_like.py | 2 +- llama_index/llms/openai_utils.py | 2 +- llama_index/llms/openllm.py | 20 +- llama_index/llms/openrouter.py | 2 +- llama_index/llms/palm.py | 6 +- llama_index/llms/perplexity.py | 6 +- llama_index/llms/portkey.py | 16 +- llama_index/llms/portkey_utils.py | 2 +- llama_index/llms/predibase.py | 6 +- llama_index/llms/replicate.py | 14 +- llama_index/llms/rungpt.py | 6 +- llama_index/llms/types.py | 139 ++----- llama_index/llms/vertex.py | 12 +- llama_index/llms/vertex_utils.py | 2 +- llama_index/llms/vllm.py | 20 +- llama_index/llms/watsonx.py | 14 +- llama_index/llms/xinference.py | 12 +- llama_index/llms/xinference_utils.py | 2 +- llama_index/memory/chat_memory_buffer.py | 1 + llama_index/memory/types.py | 2 +- llama_index/multi_modal_llms/base.py | 2 +- llama_index/multi_modal_llms/gemini.py | 14 +- llama_index/multi_modal_llms/openai.py | 18 +- .../multi_modal_llms/replicate_multi_modal.py | 8 +- .../node_parser/relational/base_element.py | 2 +- llama_index/objects/base.py | 2 +- llama_index/prompts/__init__.py | 2 +- llama_index/prompts/base.py | 2 +- llama_index/prompts/chat_prompts.py | 2 +- llama_index/query_engine/__init__.py | 2 +- .../query_engine/citation_query_engine.py | 5 +- .../query_engine/cogniswitch_query_engine.py | 4 +- llama_index/query_engine/custom.py | 4 +- llama_index/query_engine/flare/base.py | 4 +- .../query_engine/graph_query_engine.py | 4 +- .../knowledge_graph_query_engine.py | 4 +- llama_index/query_engine/multi_modal.py | 2 +- .../query_engine/multistep_query_engine.py | 4 +- .../query_engine/pandas_query_engine.py | 4 +- .../query_engine/retriever_query_engine.py | 5 +- .../query_engine/retry_query_engine.py | 4 +- .../query_engine/retry_source_query_engine.py | 4 +- .../query_engine/router_query_engine.py | 15 +- .../query_engine/sql_join_query_engine.py | 4 +- .../query_engine/sub_question_query_engine.py | 4 +- .../query_engine/transform_query_engine.py | 4 +- llama_index/readers/make_com/wrapper.py | 2 +- llama_index/response/__init__.py | 2 +- llama_index/response/notebook_utils.py | 2 +- llama_index/response/pprint_utils.py | 2 +- llama_index/response/schema.py | 148 +------ llama_index/response_synthesizers/base.py | 4 +- .../google/generativeai/base.py | 2 +- llama_index/retrievers/__init__.py | 3 +- .../retrievers/auto_merging_retriever.py | 2 +- llama_index/retrievers/bm25_retriever.py | 2 +- llama_index/retrievers/pathway_retriever.py | 2 +- llama_index/retrievers/recursive_retriever.py | 3 +- llama_index/retrievers/router_retriever.py | 2 +- llama_index/retrievers/transform_retriever.py | 2 +- llama_index/retrievers/you_retriever.py | 2 +- llama_index/schema.py | 4 + llama_index/service_context.py | 17 +- llama_index/tools/query_engine.py | 2 +- llama_index/tools/retriever_tool.py | 2 +- llama_index/types.py | 2 +- tests/agent/openai/test_openai_agent.py | 2 +- tests/agent/react/test_react_agent.py | 4 +- tests/chat_engine/test_condense_question.py | 6 +- tests/chat_engine/test_simple.py | 2 +- tests/conftest.py | 2 +- tests/embeddings/test_base.py | 2 +- tests/evaluation/test_base.py | 2 +- tests/indices/list/test_index.py | 2 +- tests/indices/managed/test_google.py | 2 +- tests/indices/struct_store/test_json_query.py | 2 +- tests/llms/test_anthropic.py | 2 +- tests/llms/test_anthropic_utils.py | 2 +- tests/llms/test_bedrock.py | 2 +- tests/llms/test_cohere.py | 2 +- tests/llms/test_custom.py | 4 +- tests/llms/test_gradient.py | 2 +- tests/llms/test_konko.py | 2 +- tests/llms/test_langchain.py | 2 +- tests/llms/test_litellm.py | 2 +- tests/llms/test_llama_utils.py | 2 +- tests/llms/test_localai.py | 2 +- tests/llms/test_openai.py | 2 +- tests/llms/test_openai_like.py | 2 +- tests/llms/test_openai_utils.py | 2 +- tests/llms/test_palm.py | 2 +- tests/llms/test_rungpt.py | 4 +- tests/llms/test_vertex.py | 2 +- tests/llms/test_watsonx.py | 2 +- tests/llms/test_xinference.py | 2 +- tests/program/test_llm_program.py | 2 +- tests/program/test_lmformatenforcer.py | 2 +- tests/program/test_multi_modal_llm_program.py | 2 +- tests/prompts/test_base.py | 2 +- .../test_cogniswitch_query_engine.py | 2 +- tests/query_engine/test_pandas.py | 2 +- 242 files changed, 1199 insertions(+), 1119 deletions(-) create mode 100644 llama_index/core/embeddings/__init__.py create mode 100644 llama_index/core/embeddings/base.py create mode 100644 llama_index/core/llms/__init__.py create mode 100644 llama_index/core/llms/types.py create mode 100644 llama_index/core/response/__init__.py create mode 100644 llama_index/core/response/schema.py diff --git a/benchmarks/embeddings/bench_embeddings.py b/benchmarks/embeddings/bench_embeddings.py index 1d0320ecb..1ea5ea45b 100644 --- a/benchmarks/embeddings/bench_embeddings.py +++ b/benchmarks/embeddings/bench_embeddings.py @@ -5,8 +5,8 @@ from typing import Callable, List, Optional, Tuple import pandas as pd from llama_index import SimpleDirectoryReader +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding from llama_index.embeddings import OpenAIEmbedding, resolve_embed_model -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding def generate_strings(num_strings: int = 100, string_length: int = 10) -> List[str]: diff --git a/benchmarks/struct_indices/spider/evaluate.py b/benchmarks/struct_indices/spider/evaluate.py index a914d1a02..3d65777b6 100644 --- a/benchmarks/struct_indices/spider/evaluate.py +++ b/benchmarks/struct_indices/spider/evaluate.py @@ -9,10 +9,10 @@ from typing import Dict, List, Optional from spider_utils import create_indexes, load_examples from tqdm import tqdm +from llama_index.core.llms.types import ChatMessage, MessageRole +from llama_index.core.response.schema import Response from llama_index.indices.struct_store.sql import SQLQueryMode, SQLStructStoreIndex from llama_index.llms.openai import OpenAI -from llama_index.llms.types import ChatMessage, MessageRole -from llama_index.response.schema import Response logging.getLogger("root").setLevel(logging.WARNING) diff --git a/docs/api_reference/response.rst b/docs/api_reference/response.rst index c5e98aca7..78b33f041 100644 --- a/docs/api_reference/response.rst +++ b/docs/api_reference/response.rst @@ -3,6 +3,6 @@ Response ================= -.. automodule:: llama_index.response.schema +.. automodule:: llama_index.core.response.schema :members: :inherited-members: diff --git a/docs/examples/query_transformations/query_transform_cookbook.ipynb b/docs/examples/query_transformations/query_transform_cookbook.ipynb index 5155306bc..7fc0cf0c4 100644 --- a/docs/examples/query_transformations/query_transform_cookbook.ipynb +++ b/docs/examples/query_transformations/query_transform_cookbook.ipynb @@ -600,7 +600,7 @@ "from llama_index.agent.react.formatter import ReActChatFormatter\n", "from llama_index.agent.react.output_parser import ReActOutputParser\n", "from llama_index.tools import FunctionTool\n", - "from llama_index.llms.types import ChatMessage" + "from llama_index.core.llms.types import ChatMessage" ] }, { diff --git a/llama_index/__init__.py b/llama_index/__init__.py index 773409264..0fb73b191 100644 --- a/llama_index/__init__.py +++ b/llama_index/__init__.py @@ -11,6 +11,9 @@ from typing import Callable, Optional # import global eval handler from llama_index.callbacks.global_handlers import set_global_handler + +# response +from llama_index.core.response.schema import Response from llama_index.data_structs.struct_type import IndexStructType # embeddings @@ -63,9 +66,6 @@ from llama_index.prompts import ( ) from llama_index.readers import SimpleDirectoryReader, download_loader -# response -from llama_index.response.schema import Response - # Response Synthesizer from llama_index.response_synthesizers.factory import get_response_synthesizer from llama_index.schema import Document, QueryBundle diff --git a/llama_index/agent/legacy/context_retriever_agent.py b/llama_index/agent/legacy/context_retriever_agent.py index 1636e31a6..2a8c2c031 100644 --- a/llama_index/agent/legacy/context_retriever_agent.py +++ b/llama_index/agent/legacy/context_retriever_agent.py @@ -11,11 +11,11 @@ from llama_index.callbacks import CallbackManager from llama_index.chat_engine.types import ( AgentChatResponse, ) -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever +from llama_index.core.llms.types import ChatMessage from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import is_function_calling_model -from llama_index.llms.types import ChatMessage from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.prompts import PromptTemplate from llama_index.schema import NodeWithScore diff --git a/llama_index/agent/legacy/openai_agent.py b/llama_index/agent/legacy/openai_agent.py index 28564c584..04de84966 100644 --- a/llama_index/agent/legacy/openai_agent.py +++ b/llama_index/agent/legacy/openai_agent.py @@ -19,10 +19,10 @@ from llama_index.chat_engine.types import ( ChatResponseMode, StreamingAgentChatResponse, ) +from llama_index.core.llms.types import ChatMessage, ChatResponse, MessageRole from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import OpenAIToolCall -from llama_index.llms.types import ChatMessage, ChatResponse, MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.objects.base import ObjectRetriever from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool diff --git a/llama_index/agent/legacy/react/base.py b/llama_index/agent/legacy/react/base.py index a86a1c443..e3a597279 100644 --- a/llama_index/agent/legacy/react/base.py +++ b/llama_index/agent/legacy/react/base.py @@ -30,10 +30,10 @@ from llama_index.callbacks import ( trace_method, ) from llama_index.chat_engine.types import AgentChatResponse, StreamingAgentChatResponse +from llama_index.core.llms.types import MessageRole from llama_index.llms.base import ChatMessage, ChatResponse from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI -from llama_index.llms.types import MessageRole from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer from llama_index.memory.types import BaseMemory from llama_index.objects.base import ObjectRetriever diff --git a/llama_index/agent/openai/step.py b/llama_index/agent/openai/step.py index 397559a5b..7e1b712c5 100644 --- a/llama_index/agent/openai/step.py +++ b/llama_index/agent/openai/step.py @@ -27,11 +27,11 @@ from llama_index.chat_engine.types import ( ChatResponseMode, StreamingAgentChatResponse, ) +from llama_index.core.llms.types import MessageRole from llama_index.llms.base import ChatMessage, ChatResponse from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import OpenAIToolCall -from llama_index.llms.types import MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.memory.types import BaseMemory from llama_index.objects.base import ObjectRetriever diff --git a/llama_index/agent/openai_assistant_agent.py b/llama_index/agent/openai_assistant_agent.py index c3a932cf5..15213c74a 100644 --- a/llama_index/agent/openai_assistant_agent.py +++ b/llama_index/agent/openai_assistant_agent.py @@ -19,7 +19,7 @@ from llama_index.chat_engine.types import ( ChatResponseMode, StreamingAgentChatResponse, ) -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.tools import BaseTool, ToolOutput, adapt_to_async_tool logger = logging.getLogger(__name__) diff --git a/llama_index/agent/react/base.py b/llama_index/agent/react/base.py index 89dd1e60b..731b1e2f2 100644 --- a/llama_index/agent/react/base.py +++ b/llama_index/agent/react/base.py @@ -23,9 +23,9 @@ from llama_index.agent.runner.base import AgentRunner from llama_index.callbacks import ( CallbackManager, ) +from llama_index.core.llms.types import ChatMessage from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI -from llama_index.llms.types import ChatMessage from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer from llama_index.memory.types import BaseMemory from llama_index.objects.base import ObjectRetriever diff --git a/llama_index/agent/react/formatter.py b/llama_index/agent/react/formatter.py index ab39d29fe..f00c21426 100644 --- a/llama_index/agent/react/formatter.py +++ b/llama_index/agent/react/formatter.py @@ -6,7 +6,7 @@ from typing import List, Optional, Sequence from llama_index.agent.react.prompts import REACT_CHAT_SYSTEM_HEADER from llama_index.agent.react.types import BaseReasoningStep, ObservationReasoningStep from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.tools import BaseTool diff --git a/llama_index/agent/react/step.py b/llama_index/agent/react/step.py index 4d33ed475..4b855b783 100644 --- a/llama_index/agent/react/step.py +++ b/llama_index/agent/react/step.py @@ -41,10 +41,10 @@ from llama_index.chat_engine.types import ( AgentChatResponse, StreamingAgentChatResponse, ) +from llama_index.core.llms.types import MessageRole from llama_index.llms.base import ChatMessage, ChatResponse from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI -from llama_index.llms.types import MessageRole from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer from llama_index.memory.types import BaseMemory from llama_index.objects.base import ObjectRetriever diff --git a/llama_index/agent/react_multimodal/step.py b/llama_index/agent/react_multimodal/step.py index b23066410..c961540ff 100644 --- a/llama_index/agent/react_multimodal/step.py +++ b/llama_index/agent/react_multimodal/step.py @@ -36,8 +36,8 @@ from llama_index.chat_engine.types import ( AGENT_CHAT_RESPONSE_TYPE, AgentChatResponse, ) +from llama_index.core.llms.types import MessageRole from llama_index.llms.base import ChatMessage, ChatResponse -from llama_index.llms.types import MessageRole from llama_index.memory.chat_memory_buffer import ChatMemoryBuffer from llama_index.memory.types import BaseMemory from llama_index.multi_modal_llms.base import MultiModalLLM diff --git a/llama_index/agent/types.py b/llama_index/agent/types.py index f8b5b2a27..08630b95f 100644 --- a/llama_index/agent/types.py +++ b/llama_index/agent/types.py @@ -6,11 +6,11 @@ from typing import Any, Dict, List, Optional from llama_index.bridge.pydantic import BaseModel, Field from llama_index.callbacks import trace_method from llama_index.chat_engine.types import BaseChatEngine, StreamingAgentChatResponse -from llama_index.core import BaseQueryEngine -from llama_index.llms.types import ChatMessage +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.llms.types import ChatMessage +from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.memory.types import BaseMemory from llama_index.prompts.mixin import PromptDictType, PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import QueryBundle diff --git a/llama_index/agent/utils.py b/llama_index/agent/utils.py index d41dc1cf0..b95e86f22 100644 --- a/llama_index/agent/utils.py +++ b/llama_index/agent/utils.py @@ -2,8 +2,8 @@ from llama_index.agent.types import TaskStep +from llama_index.core.llms.types import MessageRole from llama_index.llms.base import ChatMessage -from llama_index.llms.types import MessageRole from llama_index.memory import BaseMemory diff --git a/llama_index/callbacks/finetuning_handler.py b/llama_index/callbacks/finetuning_handler.py index 577e1fe10..288a1235b 100644 --- a/llama_index/callbacks/finetuning_handler.py +++ b/llama_index/callbacks/finetuning_handler.py @@ -35,7 +35,7 @@ class BaseFinetuningHandler(BaseCallbackHandler): **kwargs: Any, ) -> str: """Run when an event starts and return id of event.""" - from llama_index.llms.types import ChatMessage, MessageRole + from llama_index.core.llms.types import ChatMessage, MessageRole if event_type == CBEventType.LLM: cur_messages = [] @@ -68,7 +68,7 @@ class BaseFinetuningHandler(BaseCallbackHandler): **kwargs: Any, ) -> None: """Run when an event ends.""" - from llama_index.llms.types import ChatMessage, MessageRole + from llama_index.core.llms.types import ChatMessage, MessageRole if ( event_type == CBEventType.LLM diff --git a/llama_index/chat_engine/condense_plus_context.py b/llama_index/chat_engine/condense_plus_context.py index bc3b14ab4..620a15ef1 100644 --- a/llama_index/chat_engine/condense_plus_context.py +++ b/llama_index/chat_engine/condense_plus_context.py @@ -10,12 +10,12 @@ from llama_index.chat_engine.types import ( StreamingAgentChatResponse, ToolOutput, ) +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.indices.base_retriever import BaseRetriever from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext from llama_index.llms.generic_utils import messages_to_history_str from llama_index.llms.llm import LLM -from llama_index.llms.types import ChatMessage, MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts.base import PromptTemplate diff --git a/llama_index/chat_engine/condense_question.py b/llama_index/chat_engine/condense_question.py index ef8f2f19a..27430eaf8 100644 --- a/llama_index/chat_engine/condense_question.py +++ b/llama_index/chat_engine/condense_question.py @@ -9,13 +9,13 @@ from llama_index.chat_engine.types import ( StreamingAgentChatResponse, ) from llama_index.chat_engine.utils import response_gen_from_query_engine -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.llms.types import ChatMessage, MessageRole +from llama_index.core.response.schema import RESPONSE_TYPE, StreamingResponse from llama_index.llm_predictor.base import LLMPredictorType from llama_index.llms.generic_utils import messages_to_history_str -from llama_index.llms.types import ChatMessage, MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.prompts.base import BasePromptTemplate, PromptTemplate -from llama_index.response.schema import RESPONSE_TYPE, StreamingResponse from llama_index.service_context import ServiceContext from llama_index.tools import ToolOutput diff --git a/llama_index/chat_engine/context.py b/llama_index/chat_engine/context.py index 04b76f136..694b5bac4 100644 --- a/llama_index/chat_engine/context.py +++ b/llama_index/chat_engine/context.py @@ -9,9 +9,9 @@ from llama_index.chat_engine.types import ( StreamingAgentChatResponse, ToolOutput, ) -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.llms.llm import LLM -from llama_index.llms.types import ChatMessage, MessageRole from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.schema import MetadataMode, NodeWithScore, QueryBundle diff --git a/llama_index/chat_engine/simple.py b/llama_index/chat_engine/simple.py index 4e95aeb5d..a9ea59cb1 100644 --- a/llama_index/chat_engine/simple.py +++ b/llama_index/chat_engine/simple.py @@ -8,8 +8,8 @@ from llama_index.chat_engine.types import ( BaseChatEngine, StreamingAgentChatResponse, ) +from llama_index.core.llms.types import ChatMessage from llama_index.llms.llm import LLM -from llama_index.llms.types import ChatMessage from llama_index.memory import BaseMemory, ChatMemoryBuffer from llama_index.service_context import ServiceContext diff --git a/llama_index/chat_engine/types.py b/llama_index/chat_engine/types.py index 64ccad799..84d645264 100644 --- a/llama_index/chat_engine/types.py +++ b/llama_index/chat_engine/types.py @@ -7,9 +7,13 @@ from enum import Enum from threading import Event from typing import AsyncGenerator, Generator, List, Optional, Union -from llama_index.llms.types import ChatMessage, ChatResponseAsyncGen, ChatResponseGen +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponseAsyncGen, + ChatResponseGen, +) +from llama_index.core.response.schema import Response, StreamingResponse from llama_index.memory import BaseMemory -from llama_index.response.schema import Response, StreamingResponse from llama_index.schema import NodeWithScore from llama_index.tools import ToolOutput diff --git a/llama_index/chat_engine/utils.py b/llama_index/chat_engine/utils.py index b33e8ff6b..a85336e2e 100644 --- a/llama_index/chat_engine/utils.py +++ b/llama_index/chat_engine/utils.py @@ -1,4 +1,4 @@ -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, diff --git a/llama_index/core/__init__.py b/llama_index/core/__init__.py index bd2300c87..e69de29bb 100644 --- a/llama_index/core/__init__.py +++ b/llama_index/core/__init__.py @@ -1,13 +0,0 @@ -from llama_index.core.base_auto_retriever import BaseAutoRetriever -from llama_index.core.base_multi_modal_retriever import MultiModalRetriever -from llama_index.core.base_query_engine import BaseQueryEngine -from llama_index.core.base_retriever import BaseRetriever -from llama_index.core.image_retriever import BaseImageRetriever - -__all__ = [ - "BaseRetriever", - "BaseAutoRetriever", - "BaseQueryEngine", - "MultiModalRetriever", - "BaseImageRetriever", -] diff --git a/llama_index/core/base_query_engine.py b/llama_index/core/base_query_engine.py index c7546b79f..934b37314 100644 --- a/llama_index/core/base_query_engine.py +++ b/llama_index/core/base_query_engine.py @@ -5,14 +5,16 @@ from abc import abstractmethod from typing import Any, Dict, List, Optional, Sequence from llama_index.callbacks.base import CallbackManager +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.prompts.mixin import PromptDictType, PromptMixin -from llama_index.response.schema import RESPONSE_TYPE from llama_index.schema import NodeWithScore, QueryBundle, QueryType logger = logging.getLogger(__name__) class BaseQueryEngine(PromptMixin): + """Base query engine.""" + def __init__(self, callback_manager: Optional[CallbackManager]) -> None: self.callback_manager = callback_manager or CallbackManager([]) diff --git a/llama_index/core/base_retriever.py b/llama_index/core/base_retriever.py index 9e6871c54..e36607c4e 100644 --- a/llama_index/core/base_retriever.py +++ b/llama_index/core/base_retriever.py @@ -4,17 +4,16 @@ from typing import List, Optional from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.indices.query.schema import QueryBundle, QueryType -from llama_index.indices.service_context import ServiceContext from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType -from llama_index.schema import NodeWithScore +from llama_index.schema import NodeWithScore, QueryBundle, QueryType +from llama_index.service_context import ServiceContext class BaseRetriever(PromptMixin): """Base retriever.""" def __init__(self, callback_manager: Optional[CallbackManager] = None) -> None: - self.callback_manager = callback_manager or CallbackManager() + callback_manager = callback_manager or CallbackManager() def _check_callback_manager(self) -> None: """Check callback manager.""" diff --git a/llama_index/core/embeddings/__init__.py b/llama_index/core/embeddings/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/llama_index/core/embeddings/base.py b/llama_index/core/embeddings/base.py new file mode 100644 index 000000000..474422dfd --- /dev/null +++ b/llama_index/core/embeddings/base.py @@ -0,0 +1,354 @@ +"""Base embeddings file.""" + +import asyncio +from abc import abstractmethod +from enum import Enum +from typing import Any, Callable, Coroutine, List, Optional, Tuple + +import numpy as np + +from llama_index.bridge.pydantic import Field, validator +from llama_index.callbacks.base import CallbackManager +from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.constants import ( + DEFAULT_EMBED_BATCH_SIZE, +) +from llama_index.schema import BaseNode, MetadataMode, TransformComponent +from llama_index.utils import get_tqdm_iterable + +# TODO: change to numpy array +Embedding = List[float] + + +class SimilarityMode(str, Enum): + """Modes for similarity/distance.""" + + DEFAULT = "cosine" + DOT_PRODUCT = "dot_product" + EUCLIDEAN = "euclidean" + + +def mean_agg(embeddings: List[Embedding]) -> Embedding: + """Mean aggregation for embeddings.""" + return list(np.array(embeddings).mean(axis=0)) + + +def similarity( + embedding1: Embedding, + embedding2: Embedding, + mode: SimilarityMode = SimilarityMode.DEFAULT, +) -> float: + """Get embedding similarity.""" + if mode == SimilarityMode.EUCLIDEAN: + # Using -euclidean distance as similarity to achieve same ranking order + return -float(np.linalg.norm(np.array(embedding1) - np.array(embedding2))) + elif mode == SimilarityMode.DOT_PRODUCT: + return np.dot(embedding1, embedding2) + else: + product = np.dot(embedding1, embedding2) + norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2) + return product / norm + + +class BaseEmbedding(TransformComponent): + """Base class for embeddings.""" + + model_name: str = Field( + default="unknown", description="The name of the embedding model." + ) + embed_batch_size: int = Field( + default=DEFAULT_EMBED_BATCH_SIZE, + description="The batch size for embedding calls.", + gt=0, + lte=2048, + ) + callback_manager: CallbackManager = Field( + default_factory=lambda: CallbackManager([]), exclude=True + ) + + class Config: + arbitrary_types_allowed = True + + @validator("callback_manager", pre=True) + def _validate_callback_manager( + cls, v: Optional[CallbackManager] + ) -> CallbackManager: + if v is None: + return CallbackManager([]) + return v + + @abstractmethod + def _get_query_embedding(self, query: str) -> Embedding: + """ + Embed the input query synchronously. + + Subclasses should implement this method. Reference get_query_embedding's + docstring for more information. + """ + + @abstractmethod + async def _aget_query_embedding(self, query: str) -> Embedding: + """ + Embed the input query asynchronously. + + Subclasses should implement this method. Reference get_query_embedding's + docstring for more information. + """ + + def get_query_embedding(self, query: str) -> Embedding: + """ + Embed the input query. + + When embedding a query, depending on the model, a special instruction + can be prepended to the raw query string. For example, "Represent the + question for retrieving supporting documents: ". If you're curious, + other examples of predefined instructions can be found in + embeddings/huggingface_utils.py. + """ + with self.callback_manager.event( + CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} + ) as event: + query_embedding = self._get_query_embedding(query) + + event.on_end( + payload={ + EventPayload.CHUNKS: [query], + EventPayload.EMBEDDINGS: [query_embedding], + }, + ) + return query_embedding + + async def aget_query_embedding(self, query: str) -> Embedding: + """Get query embedding.""" + with self.callback_manager.event( + CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} + ) as event: + query_embedding = await self._aget_query_embedding(query) + + event.on_end( + payload={ + EventPayload.CHUNKS: [query], + EventPayload.EMBEDDINGS: [query_embedding], + }, + ) + return query_embedding + + def get_agg_embedding_from_queries( + self, + queries: List[str], + agg_fn: Optional[Callable[..., Embedding]] = None, + ) -> Embedding: + """Get aggregated embedding from multiple queries.""" + query_embeddings = [self.get_query_embedding(query) for query in queries] + agg_fn = agg_fn or mean_agg + return agg_fn(query_embeddings) + + async def aget_agg_embedding_from_queries( + self, + queries: List[str], + agg_fn: Optional[Callable[..., Embedding]] = None, + ) -> Embedding: + """Async get aggregated embedding from multiple queries.""" + query_embeddings = [await self.aget_query_embedding(query) for query in queries] + agg_fn = agg_fn or mean_agg + return agg_fn(query_embeddings) + + @abstractmethod + def _get_text_embedding(self, text: str) -> Embedding: + """ + Embed the input text synchronously. + + Subclasses should implement this method. Reference get_text_embedding's + docstring for more information. + """ + + async def _aget_text_embedding(self, text: str) -> Embedding: + """ + Embed the input text asynchronously. + + Subclasses can implement this method if there is a true async + implementation. Reference get_text_embedding's docstring for more + information. + """ + # Default implementation just falls back on _get_text_embedding + return self._get_text_embedding(text) + + def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: + """ + Embed the input sequence of text synchronously. + + Subclasses can implement this method if batch queries are supported. + """ + # Default implementation just loops over _get_text_embedding + return [self._get_text_embedding(text) for text in texts] + + async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]: + """ + Embed the input sequence of text asynchronously. + + Subclasses can implement this method if batch queries are supported. + """ + return await asyncio.gather( + *[self._aget_text_embedding(text) for text in texts] + ) + + def get_text_embedding(self, text: str) -> Embedding: + """ + Embed the input text. + + When embedding text, depending on the model, a special instruction + can be prepended to the raw text string. For example, "Represent the + document for retrieval: ". If you're curious, other examples of + predefined instructions can be found in embeddings/huggingface_utils.py. + """ + with self.callback_manager.event( + CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} + ) as event: + text_embedding = self._get_text_embedding(text) + + event.on_end( + payload={ + EventPayload.CHUNKS: [text], + EventPayload.EMBEDDINGS: [text_embedding], + } + ) + + return text_embedding + + async def aget_text_embedding(self, text: str) -> Embedding: + """Async get text embedding.""" + with self.callback_manager.event( + CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} + ) as event: + text_embedding = await self._aget_text_embedding(text) + + event.on_end( + payload={ + EventPayload.CHUNKS: [text], + EventPayload.EMBEDDINGS: [text_embedding], + } + ) + + return text_embedding + + def get_text_embedding_batch( + self, + texts: List[str], + show_progress: bool = False, + **kwargs: Any, + ) -> List[Embedding]: + """Get a list of text embeddings, with batching.""" + cur_batch: List[str] = [] + result_embeddings: List[Embedding] = [] + + queue_with_progress = enumerate( + get_tqdm_iterable(texts, show_progress, "Generating embeddings") + ) + + for idx, text in queue_with_progress: + cur_batch.append(text) + if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size: + # flush + with self.callback_manager.event( + CBEventType.EMBEDDING, + payload={EventPayload.SERIALIZED: self.to_dict()}, + ) as event: + embeddings = self._get_text_embeddings(cur_batch) + result_embeddings.extend(embeddings) + event.on_end( + payload={ + EventPayload.CHUNKS: cur_batch, + EventPayload.EMBEDDINGS: embeddings, + }, + ) + cur_batch = [] + + return result_embeddings + + async def aget_text_embedding_batch( + self, texts: List[str], show_progress: bool = False + ) -> List[Embedding]: + """Asynchronously get a list of text embeddings, with batching.""" + cur_batch: List[str] = [] + callback_payloads: List[Tuple[str, List[str]]] = [] + result_embeddings: List[Embedding] = [] + embeddings_coroutines: List[Coroutine] = [] + for idx, text in enumerate(texts): + cur_batch.append(text) + if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size: + # flush + event_id = self.callback_manager.on_event_start( + CBEventType.EMBEDDING, + payload={EventPayload.SERIALIZED: self.to_dict()}, + ) + callback_payloads.append((event_id, cur_batch)) + embeddings_coroutines.append(self._aget_text_embeddings(cur_batch)) + cur_batch = [] + + # flatten the results of asyncio.gather, which is a list of embeddings lists + nested_embeddings = [] + if show_progress: + try: + from tqdm.auto import tqdm + + nested_embeddings = [ + await f + for f in tqdm( + asyncio.as_completed(embeddings_coroutines), + total=len(embeddings_coroutines), + desc="Generating embeddings", + ) + ] + except ImportError: + nested_embeddings = await asyncio.gather(*embeddings_coroutines) + else: + nested_embeddings = await asyncio.gather(*embeddings_coroutines) + + result_embeddings = [ + embedding for embeddings in nested_embeddings for embedding in embeddings + ] + + for (event_id, text_batch), embeddings in zip( + callback_payloads, nested_embeddings + ): + self.callback_manager.on_event_end( + CBEventType.EMBEDDING, + payload={ + EventPayload.CHUNKS: text_batch, + EventPayload.EMBEDDINGS: embeddings, + }, + event_id=event_id, + ) + + return result_embeddings + + def similarity( + self, + embedding1: Embedding, + embedding2: Embedding, + mode: SimilarityMode = SimilarityMode.DEFAULT, + ) -> float: + """Get embedding similarity.""" + return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode) + + def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]: + embeddings = self.get_text_embedding_batch( + [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes], + **kwargs, + ) + + for node, embedding in zip(nodes, embeddings): + node.embedding = embedding + + return nodes + + async def acall(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]: + embeddings = await self.aget_text_embedding_batch( + [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes], + **kwargs, + ) + + for node, embedding in zip(nodes, embeddings): + node.embedding = embedding + + return nodes diff --git a/llama_index/core/llms/__init__.py b/llama_index/core/llms/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/llama_index/core/llms/types.py b/llama_index/core/llms/types.py new file mode 100644 index 000000000..9db785861 --- /dev/null +++ b/llama_index/core/llms/types.py @@ -0,0 +1,110 @@ +from enum import Enum +from typing import Any, AsyncGenerator, Generator, Optional + +from llama_index.bridge.pydantic import BaseModel, Field +from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS + + +class MessageRole(str, Enum): + """Message role.""" + + SYSTEM = "system" + USER = "user" + ASSISTANT = "assistant" + FUNCTION = "function" + TOOL = "tool" + + +# ===== Generic Model Input - Chat ===== +class ChatMessage(BaseModel): + """Chat message.""" + + role: MessageRole = MessageRole.USER + content: Optional[Any] = "" + additional_kwargs: dict = Field(default_factory=dict) + + def __str__(self) -> str: + return f"{self.role.value}: {self.content}" + + +# ===== Generic Model Output - Chat ===== +class ChatResponse(BaseModel): + """Chat response.""" + + message: ChatMessage + raw: Optional[dict] = None + delta: Optional[str] = None + additional_kwargs: dict = Field(default_factory=dict) + + def __str__(self) -> str: + return str(self.message) + + +ChatResponseGen = Generator[ChatResponse, None, None] +ChatResponseAsyncGen = AsyncGenerator[ChatResponse, None] + + +# ===== Generic Model Output - Completion ===== +class CompletionResponse(BaseModel): + """ + Completion response. + + Fields: + text: Text content of the response if not streaming, or if streaming, + the current extent of streamed text. + additional_kwargs: Additional information on the response(i.e. token + counts, function calling information). + raw: Optional raw JSON that was parsed to populate text, if relevant. + delta: New text that just streamed in (only relevant when streaming). + """ + + text: str + additional_kwargs: dict = Field(default_factory=dict) + raw: Optional[dict] = None + delta: Optional[str] = None + + def __str__(self) -> str: + return self.text + + +CompletionResponseGen = Generator[CompletionResponse, None, None] +CompletionResponseAsyncGen = AsyncGenerator[CompletionResponse, None] + + +class LLMMetadata(BaseModel): + context_window: int = Field( + default=DEFAULT_CONTEXT_WINDOW, + description=( + "Total number of tokens the model can be input and output for one response." + ), + ) + num_output: int = Field( + default=DEFAULT_NUM_OUTPUTS, + description="Number of tokens the model can output when generating a response.", + ) + is_chat_model: bool = Field( + default=False, + description=( + "Set True if the model exposes a chat interface (i.e. can be passed a" + " sequence of messages, rather than text), like OpenAI's" + " /v1/chat/completions endpoint." + ), + ) + is_function_calling_model: bool = Field( + default=False, + # SEE: https://openai.com/blog/function-calling-and-other-api-updates + description=( + "Set True if the model supports function calling messages, similar to" + " OpenAI's function calling API. For example, converting 'Email Anya to" + " see if she wants to get coffee next Friday' to a function call like" + " `send_email(to: string, body: string)`." + ), + ) + model_name: str = Field( + default="unknown", + description=( + "The model's name used for logging, testing, and sanity checking. For some" + " models this can be automatically discerned. For other models, like" + " locally loaded models, this must be manually specified." + ), + ) diff --git a/llama_index/core/response/__init__.py b/llama_index/core/response/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/llama_index/core/response/schema.py b/llama_index/core/response/schema.py new file mode 100644 index 000000000..1834b6ccf --- /dev/null +++ b/llama_index/core/response/schema.py @@ -0,0 +1,142 @@ +"""Response schema.""" + +from dataclasses import dataclass, field +from typing import Any, Dict, List, Optional, Union + +from llama_index.bridge.pydantic import BaseModel +from llama_index.schema import NodeWithScore +from llama_index.types import TokenGen +from llama_index.utils import truncate_text + + +@dataclass +class Response: + """Response object. + + Returned if streaming=False. + + Attributes: + response: The response text. + + """ + + response: Optional[str] + source_nodes: List[NodeWithScore] = field(default_factory=list) + metadata: Optional[Dict[str, Any]] = None + + def __str__(self) -> str: + """Convert to string representation.""" + return self.response or "None" + + def get_formatted_sources(self, length: int = 100) -> str: + """Get formatted sources text.""" + texts = [] + for source_node in self.source_nodes: + fmt_text_chunk = truncate_text(source_node.node.get_content(), length) + doc_id = source_node.node.node_id or "None" + source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}" + texts.append(source_text) + return "\n\n".join(texts) + + +@dataclass +class PydanticResponse: + """PydanticResponse object. + + Returned if streaming=False. + + Attributes: + response: The response text. + + """ + + response: Optional[BaseModel] + source_nodes: List[NodeWithScore] = field(default_factory=list) + metadata: Optional[Dict[str, Any]] = None + + def __str__(self) -> str: + """Convert to string representation.""" + return self.response.json() if self.response else "None" + + def __getattr__(self, name: str) -> Any: + """Get attribute, but prioritize the pydantic response object.""" + if self.response is not None and name in self.response.dict(): + return getattr(self.response, name) + else: + return None + + def get_formatted_sources(self, length: int = 100) -> str: + """Get formatted sources text.""" + texts = [] + for source_node in self.source_nodes: + fmt_text_chunk = truncate_text(source_node.node.get_content(), length) + doc_id = source_node.node.node_id or "None" + source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}" + texts.append(source_text) + return "\n\n".join(texts) + + def get_response(self) -> Response: + """Get a standard response object.""" + response_txt = self.response.json() if self.response else "None" + return Response(response_txt, self.source_nodes, self.metadata) + + +@dataclass +class StreamingResponse: + """StreamingResponse object. + + Returned if streaming=True. + + Attributes: + response_gen: The response generator. + + """ + + response_gen: TokenGen + source_nodes: List[NodeWithScore] = field(default_factory=list) + metadata: Optional[Dict[str, Any]] = None + response_txt: Optional[str] = None + + def __str__(self) -> str: + """Convert to string representation.""" + if self.response_txt is None and self.response_gen is not None: + response_txt = "" + for text in self.response_gen: + response_txt += text + self.response_txt = response_txt + return self.response_txt or "None" + + def get_response(self) -> Response: + """Get a standard response object.""" + if self.response_txt is None and self.response_gen is not None: + response_txt = "" + for text in self.response_gen: + response_txt += text + self.response_txt = response_txt + return Response(self.response_txt, self.source_nodes, self.metadata) + + def print_response_stream(self) -> None: + """Print the response stream.""" + if self.response_txt is None and self.response_gen is not None: + response_txt = "" + for text in self.response_gen: + print(text, end="", flush=True) + response_txt += text + self.response_txt = response_txt + else: + print(self.response_txt) + + def get_formatted_sources(self, length: int = 100, trim_text: int = True) -> str: + """Get formatted sources text.""" + texts = [] + for source_node in self.source_nodes: + fmt_text_chunk = source_node.node.get_content() + if trim_text: + fmt_text_chunk = truncate_text(fmt_text_chunk, length) + node_id = source_node.node.node_id or "None" + source_text = f"> Source (Node id: {node_id}): {fmt_text_chunk}" + texts.append(source_text) + return "\n\n".join(texts) + + +RESPONSE_TYPE = Union[Response, StreamingResponse, PydanticResponse] diff --git a/llama_index/embeddings/adapter.py b/llama_index/embeddings/adapter.py index c21fe4be2..5e7b9bfae 100644 --- a/llama_index/embeddings/adapter.py +++ b/llama_index/embeddings/adapter.py @@ -5,7 +5,8 @@ from typing import Any, List, Optional, Type, cast from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE +from llama_index.core.embeddings.base import BaseEmbedding from llama_index.utils import infer_torch_device logger = logging.getLogger(__name__) diff --git a/llama_index/embeddings/azure_openai.py b/llama_index/embeddings/azure_openai.py index 8c7910df6..efb96a4e6 100644 --- a/llama_index/embeddings/azure_openai.py +++ b/llama_index/embeddings/azure_openai.py @@ -5,7 +5,7 @@ from openai import AsyncAzureOpenAI, AzureOpenAI from llama_index.bridge.pydantic import Field, PrivateAttr, root_validator from llama_index.callbacks.base import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE +from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE from llama_index.embeddings.openai import ( OpenAIEmbedding, OpenAIEmbeddingMode, diff --git a/llama_index/embeddings/base.py b/llama_index/embeddings/base.py index 474422dfd..97028437c 100644 --- a/llama_index/embeddings/base.py +++ b/llama_index/embeddings/base.py @@ -1,354 +1,23 @@ -"""Base embeddings file.""" +"""Base embeddings file. -import asyncio -from abc import abstractmethod -from enum import Enum -from typing import Any, Callable, Coroutine, List, Optional, Tuple +Maintain for backwards compatibility. -import numpy as np +""" -from llama_index.bridge.pydantic import Field, validator -from llama_index.callbacks.base import CallbackManager -from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.constants import ( +from llama_index.core.embeddings.base import ( DEFAULT_EMBED_BATCH_SIZE, + BaseEmbedding, + Embedding, + SimilarityMode, + mean_agg, + similarity, ) -from llama_index.schema import BaseNode, MetadataMode, TransformComponent -from llama_index.utils import get_tqdm_iterable -# TODO: change to numpy array -Embedding = List[float] - - -class SimilarityMode(str, Enum): - """Modes for similarity/distance.""" - - DEFAULT = "cosine" - DOT_PRODUCT = "dot_product" - EUCLIDEAN = "euclidean" - - -def mean_agg(embeddings: List[Embedding]) -> Embedding: - """Mean aggregation for embeddings.""" - return list(np.array(embeddings).mean(axis=0)) - - -def similarity( - embedding1: Embedding, - embedding2: Embedding, - mode: SimilarityMode = SimilarityMode.DEFAULT, -) -> float: - """Get embedding similarity.""" - if mode == SimilarityMode.EUCLIDEAN: - # Using -euclidean distance as similarity to achieve same ranking order - return -float(np.linalg.norm(np.array(embedding1) - np.array(embedding2))) - elif mode == SimilarityMode.DOT_PRODUCT: - return np.dot(embedding1, embedding2) - else: - product = np.dot(embedding1, embedding2) - norm = np.linalg.norm(embedding1) * np.linalg.norm(embedding2) - return product / norm - - -class BaseEmbedding(TransformComponent): - """Base class for embeddings.""" - - model_name: str = Field( - default="unknown", description="The name of the embedding model." - ) - embed_batch_size: int = Field( - default=DEFAULT_EMBED_BATCH_SIZE, - description="The batch size for embedding calls.", - gt=0, - lte=2048, - ) - callback_manager: CallbackManager = Field( - default_factory=lambda: CallbackManager([]), exclude=True - ) - - class Config: - arbitrary_types_allowed = True - - @validator("callback_manager", pre=True) - def _validate_callback_manager( - cls, v: Optional[CallbackManager] - ) -> CallbackManager: - if v is None: - return CallbackManager([]) - return v - - @abstractmethod - def _get_query_embedding(self, query: str) -> Embedding: - """ - Embed the input query synchronously. - - Subclasses should implement this method. Reference get_query_embedding's - docstring for more information. - """ - - @abstractmethod - async def _aget_query_embedding(self, query: str) -> Embedding: - """ - Embed the input query asynchronously. - - Subclasses should implement this method. Reference get_query_embedding's - docstring for more information. - """ - - def get_query_embedding(self, query: str) -> Embedding: - """ - Embed the input query. - - When embedding a query, depending on the model, a special instruction - can be prepended to the raw query string. For example, "Represent the - question for retrieving supporting documents: ". If you're curious, - other examples of predefined instructions can be found in - embeddings/huggingface_utils.py. - """ - with self.callback_manager.event( - CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} - ) as event: - query_embedding = self._get_query_embedding(query) - - event.on_end( - payload={ - EventPayload.CHUNKS: [query], - EventPayload.EMBEDDINGS: [query_embedding], - }, - ) - return query_embedding - - async def aget_query_embedding(self, query: str) -> Embedding: - """Get query embedding.""" - with self.callback_manager.event( - CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} - ) as event: - query_embedding = await self._aget_query_embedding(query) - - event.on_end( - payload={ - EventPayload.CHUNKS: [query], - EventPayload.EMBEDDINGS: [query_embedding], - }, - ) - return query_embedding - - def get_agg_embedding_from_queries( - self, - queries: List[str], - agg_fn: Optional[Callable[..., Embedding]] = None, - ) -> Embedding: - """Get aggregated embedding from multiple queries.""" - query_embeddings = [self.get_query_embedding(query) for query in queries] - agg_fn = agg_fn or mean_agg - return agg_fn(query_embeddings) - - async def aget_agg_embedding_from_queries( - self, - queries: List[str], - agg_fn: Optional[Callable[..., Embedding]] = None, - ) -> Embedding: - """Async get aggregated embedding from multiple queries.""" - query_embeddings = [await self.aget_query_embedding(query) for query in queries] - agg_fn = agg_fn or mean_agg - return agg_fn(query_embeddings) - - @abstractmethod - def _get_text_embedding(self, text: str) -> Embedding: - """ - Embed the input text synchronously. - - Subclasses should implement this method. Reference get_text_embedding's - docstring for more information. - """ - - async def _aget_text_embedding(self, text: str) -> Embedding: - """ - Embed the input text asynchronously. - - Subclasses can implement this method if there is a true async - implementation. Reference get_text_embedding's docstring for more - information. - """ - # Default implementation just falls back on _get_text_embedding - return self._get_text_embedding(text) - - def _get_text_embeddings(self, texts: List[str]) -> List[Embedding]: - """ - Embed the input sequence of text synchronously. - - Subclasses can implement this method if batch queries are supported. - """ - # Default implementation just loops over _get_text_embedding - return [self._get_text_embedding(text) for text in texts] - - async def _aget_text_embeddings(self, texts: List[str]) -> List[Embedding]: - """ - Embed the input sequence of text asynchronously. - - Subclasses can implement this method if batch queries are supported. - """ - return await asyncio.gather( - *[self._aget_text_embedding(text) for text in texts] - ) - - def get_text_embedding(self, text: str) -> Embedding: - """ - Embed the input text. - - When embedding text, depending on the model, a special instruction - can be prepended to the raw text string. For example, "Represent the - document for retrieval: ". If you're curious, other examples of - predefined instructions can be found in embeddings/huggingface_utils.py. - """ - with self.callback_manager.event( - CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} - ) as event: - text_embedding = self._get_text_embedding(text) - - event.on_end( - payload={ - EventPayload.CHUNKS: [text], - EventPayload.EMBEDDINGS: [text_embedding], - } - ) - - return text_embedding - - async def aget_text_embedding(self, text: str) -> Embedding: - """Async get text embedding.""" - with self.callback_manager.event( - CBEventType.EMBEDDING, payload={EventPayload.SERIALIZED: self.to_dict()} - ) as event: - text_embedding = await self._aget_text_embedding(text) - - event.on_end( - payload={ - EventPayload.CHUNKS: [text], - EventPayload.EMBEDDINGS: [text_embedding], - } - ) - - return text_embedding - - def get_text_embedding_batch( - self, - texts: List[str], - show_progress: bool = False, - **kwargs: Any, - ) -> List[Embedding]: - """Get a list of text embeddings, with batching.""" - cur_batch: List[str] = [] - result_embeddings: List[Embedding] = [] - - queue_with_progress = enumerate( - get_tqdm_iterable(texts, show_progress, "Generating embeddings") - ) - - for idx, text in queue_with_progress: - cur_batch.append(text) - if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size: - # flush - with self.callback_manager.event( - CBEventType.EMBEDDING, - payload={EventPayload.SERIALIZED: self.to_dict()}, - ) as event: - embeddings = self._get_text_embeddings(cur_batch) - result_embeddings.extend(embeddings) - event.on_end( - payload={ - EventPayload.CHUNKS: cur_batch, - EventPayload.EMBEDDINGS: embeddings, - }, - ) - cur_batch = [] - - return result_embeddings - - async def aget_text_embedding_batch( - self, texts: List[str], show_progress: bool = False - ) -> List[Embedding]: - """Asynchronously get a list of text embeddings, with batching.""" - cur_batch: List[str] = [] - callback_payloads: List[Tuple[str, List[str]]] = [] - result_embeddings: List[Embedding] = [] - embeddings_coroutines: List[Coroutine] = [] - for idx, text in enumerate(texts): - cur_batch.append(text) - if idx == len(texts) - 1 or len(cur_batch) == self.embed_batch_size: - # flush - event_id = self.callback_manager.on_event_start( - CBEventType.EMBEDDING, - payload={EventPayload.SERIALIZED: self.to_dict()}, - ) - callback_payloads.append((event_id, cur_batch)) - embeddings_coroutines.append(self._aget_text_embeddings(cur_batch)) - cur_batch = [] - - # flatten the results of asyncio.gather, which is a list of embeddings lists - nested_embeddings = [] - if show_progress: - try: - from tqdm.auto import tqdm - - nested_embeddings = [ - await f - for f in tqdm( - asyncio.as_completed(embeddings_coroutines), - total=len(embeddings_coroutines), - desc="Generating embeddings", - ) - ] - except ImportError: - nested_embeddings = await asyncio.gather(*embeddings_coroutines) - else: - nested_embeddings = await asyncio.gather(*embeddings_coroutines) - - result_embeddings = [ - embedding for embeddings in nested_embeddings for embedding in embeddings - ] - - for (event_id, text_batch), embeddings in zip( - callback_payloads, nested_embeddings - ): - self.callback_manager.on_event_end( - CBEventType.EMBEDDING, - payload={ - EventPayload.CHUNKS: text_batch, - EventPayload.EMBEDDINGS: embeddings, - }, - event_id=event_id, - ) - - return result_embeddings - - def similarity( - self, - embedding1: Embedding, - embedding2: Embedding, - mode: SimilarityMode = SimilarityMode.DEFAULT, - ) -> float: - """Get embedding similarity.""" - return similarity(embedding1=embedding1, embedding2=embedding2, mode=mode) - - def __call__(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]: - embeddings = self.get_text_embedding_batch( - [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes], - **kwargs, - ) - - for node, embedding in zip(nodes, embeddings): - node.embedding = embedding - - return nodes - - async def acall(self, nodes: List[BaseNode], **kwargs: Any) -> List[BaseNode]: - embeddings = await self.aget_text_embedding_batch( - [node.get_content(metadata_mode=MetadataMode.EMBED) for node in nodes], - **kwargs, - ) - - for node, embedding in zip(nodes, embeddings): - node.embedding = embedding - - return nodes +__all__ = [ + "BaseEmbedding", + "similarity", + "SimilarityMode", + "DEFAULT_EMBED_BATCH_SIZE", + "mean_agg", + "Embedding", +] diff --git a/llama_index/embeddings/bedrock.py b/llama_index/embeddings/bedrock.py index a90352aaa..d9f3c7c00 100644 --- a/llama_index/embeddings/bedrock.py +++ b/llama_index/embeddings/bedrock.py @@ -6,11 +6,8 @@ from typing import Any, Dict, List, Literal, Optional from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks.base import CallbackManager -from llama_index.embeddings.base import ( - DEFAULT_EMBED_BATCH_SIZE, - BaseEmbedding, - Embedding, -) +from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE +from llama_index.core.embeddings.base import BaseEmbedding, Embedding class PROVIDERS(str, Enum): diff --git a/llama_index/embeddings/clarifai.py b/llama_index/embeddings/clarifai.py index e77bfd2df..3f2c459c2 100644 --- a/llama_index/embeddings/clarifai.py +++ b/llama_index/embeddings/clarifai.py @@ -3,7 +3,8 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE +from llama_index.core.embeddings.base import BaseEmbedding logger = logging.getLogger(__name__) diff --git a/llama_index/embeddings/clip.py b/llama_index/embeddings/clip.py index 9f905cd6e..1c20bb86a 100644 --- a/llama_index/embeddings/clip.py +++ b/llama_index/embeddings/clip.py @@ -2,10 +2,8 @@ import logging from typing import Any, List from llama_index.bridge.pydantic import Field, PrivateAttr -from llama_index.embeddings.base import ( - DEFAULT_EMBED_BATCH_SIZE, - Embedding, -) +from llama_index.constants import DEFAULT_EMBED_BATCH_SIZE +from llama_index.core.embeddings.base import Embedding from llama_index.embeddings.multi_modal_base import MultiModalEmbedding from llama_index.schema import ImageType diff --git a/llama_index/embeddings/cohereai.py b/llama_index/embeddings/cohereai.py index 94883fdb3..1fd4f19ed 100644 --- a/llama_index/embeddings/cohereai.py +++ b/llama_index/embeddings/cohereai.py @@ -3,7 +3,7 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding # Enums for validation and type safety diff --git a/llama_index/embeddings/gemini.py b/llama_index/embeddings/gemini.py index b335528a4..553a2ea68 100644 --- a/llama_index/embeddings/gemini.py +++ b/llama_index/embeddings/gemini.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks.base import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding class GeminiEmbedding(BaseEmbedding): diff --git a/llama_index/embeddings/google.py b/llama_index/embeddings/google.py index 64770062d..ef9142a2f 100644 --- a/llama_index/embeddings/google.py +++ b/llama_index/embeddings/google.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding # Google Universal Sentence Encode v5 DEFAULT_HANDLE = "https://tfhub.dev/google/universal-sentence-encoder-large/5" diff --git a/llama_index/embeddings/google_palm.py b/llama_index/embeddings/google_palm.py index 8e01f8fba..7fc3df38b 100644 --- a/llama_index/embeddings/google_palm.py +++ b/llama_index/embeddings/google_palm.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks.base import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding class GooglePaLMEmbedding(BaseEmbedding): diff --git a/llama_index/embeddings/gradient.py b/llama_index/embeddings/gradient.py index 21c607dc5..bc620492e 100644 --- a/llama_index/embeddings/gradient.py +++ b/llama_index/embeddings/gradient.py @@ -2,7 +2,7 @@ import logging from typing import Any, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr -from llama_index.embeddings.base import ( +from llama_index.core.embeddings.base import ( DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding, Embedding, diff --git a/llama_index/embeddings/huggingface.py b/llama_index/embeddings/huggingface.py index b9e4ccc03..b4f348f9d 100644 --- a/llama_index/embeddings/huggingface.py +++ b/llama_index/embeddings/huggingface.py @@ -3,7 +3,7 @@ from typing import TYPE_CHECKING, Any, List, Optional, Sequence, Union from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import ( +from llama_index.core.embeddings.base import ( DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding, Embedding, diff --git a/llama_index/embeddings/huggingface_optimum.py b/llama_index/embeddings/huggingface_optimum.py index 341f060f9..73f0a48eb 100644 --- a/llama_index/embeddings/huggingface_optimum.py +++ b/llama_index/embeddings/huggingface_optimum.py @@ -2,7 +2,7 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding from llama_index.embeddings.huggingface_utils import format_query, format_text from llama_index.utils import infer_torch_device diff --git a/llama_index/embeddings/instructor.py b/llama_index/embeddings/instructor.py index 513ee3fa4..7cf01c445 100644 --- a/llama_index/embeddings/instructor.py +++ b/llama_index/embeddings/instructor.py @@ -2,7 +2,7 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding from llama_index.embeddings.huggingface_utils import ( DEFAULT_INSTRUCT_MODEL, get_query_instruct_for_model_name, diff --git a/llama_index/embeddings/jinaai.py b/llama_index/embeddings/jinaai.py index ef07f5da0..8a4ed5253 100644 --- a/llama_index/embeddings/jinaai.py +++ b/llama_index/embeddings/jinaai.py @@ -6,7 +6,7 @@ import requests from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks.base import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding from llama_index.llms.generic_utils import get_from_param_or_env MAX_BATCH_SIZE = 2048 diff --git a/llama_index/embeddings/langchain.py b/llama_index/embeddings/langchain.py index 7fda89b84..2318abe8d 100644 --- a/llama_index/embeddings/langchain.py +++ b/llama_index/embeddings/langchain.py @@ -4,7 +4,7 @@ from typing import TYPE_CHECKING, List, Optional from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding if TYPE_CHECKING: from llama_index.bridge.langchain import Embeddings as LCEmbeddings diff --git a/llama_index/embeddings/mistralai.py b/llama_index/embeddings/mistralai.py index 2bd444859..05943cf9f 100644 --- a/llama_index/embeddings/mistralai.py +++ b/llama_index/embeddings/mistralai.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks.base import CallbackManager -from llama_index.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding +from llama_index.core.embeddings.base import DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding from llama_index.llms.generic_utils import get_from_param_or_env diff --git a/llama_index/embeddings/multi_modal_base.py b/llama_index/embeddings/multi_modal_base.py index 276063ca0..c3adf485f 100644 --- a/llama_index/embeddings/multi_modal_base.py +++ b/llama_index/embeddings/multi_modal_base.py @@ -5,7 +5,7 @@ from abc import abstractmethod from typing import Coroutine, List, Tuple from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.embeddings.base import ( +from llama_index.core.embeddings.base import ( BaseEmbedding, Embedding, ) diff --git a/llama_index/embeddings/text_embeddings_inference.py b/llama_index/embeddings/text_embeddings_inference.py index 48ffd861a..ad1a48a2f 100644 --- a/llama_index/embeddings/text_embeddings_inference.py +++ b/llama_index/embeddings/text_embeddings_inference.py @@ -2,7 +2,7 @@ from typing import Callable, List, Optional, Union from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager -from llama_index.embeddings.base import ( +from llama_index.core.embeddings.base import ( DEFAULT_EMBED_BATCH_SIZE, BaseEmbedding, Embedding, diff --git a/llama_index/evaluation/base.py b/llama_index/evaluation/base.py index 3f2023a0f..9ddf00521 100644 --- a/llama_index/evaluation/base.py +++ b/llama_index/evaluation/base.py @@ -4,8 +4,8 @@ from abc import abstractmethod from typing import Any, Optional, Sequence from llama_index.bridge.pydantic import BaseModel, Field +from llama_index.core.response.schema import Response from llama_index.prompts.mixin import PromptMixin, PromptMixinType -from llama_index.response.schema import Response class EvaluationResult(BaseModel): diff --git a/llama_index/evaluation/batch_runner.py b/llama_index/evaluation/batch_runner.py index 306480b71..4b4cb2e6f 100644 --- a/llama_index/evaluation/batch_runner.py +++ b/llama_index/evaluation/batch_runner.py @@ -2,9 +2,9 @@ import asyncio from typing import Any, Dict, List, Optional, Sequence, Tuple, cast from llama_index.async_utils import asyncio_module -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.evaluation.base import BaseEvaluator, EvaluationResult -from llama_index.response.schema import RESPONSE_TYPE, Response async def eval_response_worker( diff --git a/llama_index/evaluation/benchmarks/beir.py b/llama_index/evaluation/benchmarks/beir.py index 6bab13b3c..5751f5123 100644 --- a/llama_index/evaluation/benchmarks/beir.py +++ b/llama_index/evaluation/benchmarks/beir.py @@ -4,7 +4,7 @@ from typing import Callable, Dict, List, Optional import tqdm -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.schema import Document, QueryBundle from llama_index.utils import get_cache_dir diff --git a/llama_index/evaluation/benchmarks/hotpotqa.py b/llama_index/evaluation/benchmarks/hotpotqa.py index 2d5ff6bb6..4e7e2cb01 100644 --- a/llama_index/evaluation/benchmarks/hotpotqa.py +++ b/llama_index/evaluation/benchmarks/hotpotqa.py @@ -9,7 +9,8 @@ from typing import Any, Dict, List, Optional, Tuple import requests import tqdm -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine from llama_index.schema import NodeWithScore, QueryBundle, TextNode from llama_index.utils import get_cache_dir diff --git a/llama_index/evaluation/eval_utils.py b/llama_index/evaluation/eval_utils.py index 0d401741e..f9432d0d6 100644 --- a/llama_index/evaluation/eval_utils.py +++ b/llama_index/evaluation/eval_utils.py @@ -12,7 +12,7 @@ import numpy as np import pandas as pd from llama_index.async_utils import asyncio_module -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.evaluation.base import EvaluationResult diff --git a/llama_index/evaluation/retrieval/evaluator.py b/llama_index/evaluation/retrieval/evaluator.py index 7174d80e7..e8b24d308 100644 --- a/llama_index/evaluation/retrieval/evaluator.py +++ b/llama_index/evaluation/retrieval/evaluator.py @@ -3,7 +3,7 @@ from typing import Any, List, Sequence, Tuple from llama_index.bridge.pydantic import Field -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.evaluation.retrieval.base import ( BaseRetrievalEvaluator, RetrievalEvalMode, diff --git a/llama_index/evaluation/semantic_similarity.py b/llama_index/evaluation/semantic_similarity.py index c77f2fa08..393b7866a 100644 --- a/llama_index/evaluation/semantic_similarity.py +++ b/llama_index/evaluation/semantic_similarity.py @@ -1,6 +1,6 @@ from typing import Any, Callable, Optional, Sequence -from llama_index.embeddings.base import SimilarityMode, similarity +from llama_index.core.embeddings.base import SimilarityMode, similarity from llama_index.evaluation.base import BaseEvaluator, EvaluationResult from llama_index.prompts.mixin import PromptDictType from llama_index.service_context import ServiceContext diff --git a/llama_index/indices/base.py b/llama_index/indices/base.py index be79007ac..c904d657b 100644 --- a/llama_index/indices/base.py +++ b/llama_index/indices/base.py @@ -4,7 +4,8 @@ from abc import ABC, abstractmethod from typing import Any, Dict, Generic, List, Optional, Sequence, Type, TypeVar, cast from llama_index.chat_engine.types import BaseChatEngine, ChatMode -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexStruct from llama_index.ingestion import run_transformations from llama_index.llms.openai import OpenAI diff --git a/llama_index/indices/base_retriever.py b/llama_index/indices/base_retriever.py index 22087ac26..0cad9e778 100644 --- a/llama_index/indices/base_retriever.py +++ b/llama_index/indices/base_retriever.py @@ -1,5 +1,5 @@ # for backwards compatibility -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever __all__ = [ "BaseRetriever", diff --git a/llama_index/indices/composability/graph.py b/llama_index/indices/composability/graph.py index d7e5e14c3..c3e522d6d 100644 --- a/llama_index/indices/composability/graph.py +++ b/llama_index/indices/composability/graph.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional, Sequence, Type, cast -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.data_structs.data_structs import IndexStruct from llama_index.indices.base import BaseIndex from llama_index.schema import IndexNode, NodeRelationship, ObjectType, RelatedNodeInfo diff --git a/llama_index/indices/document_summary/base.py b/llama_index/indices/document_summary/base.py index 1bf97e20a..79e67b506 100644 --- a/llama_index/indices/document_summary/base.py +++ b/llama_index/indices/document_summary/base.py @@ -10,11 +10,11 @@ from collections import defaultdict from enum import Enum from typing import Any, Dict, Optional, Sequence, Union, cast -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever +from llama_index.core.response.schema import Response from llama_index.data_structs.document_summary import IndexDocumentSummary from llama_index.indices.base import BaseIndex from llama_index.indices.utils import embed_nodes -from llama_index.response.schema import Response from llama_index.response_synthesizers import ( BaseSynthesizer, ResponseMode, diff --git a/llama_index/indices/document_summary/retrievers.py b/llama_index/indices/document_summary/retrievers.py index fda5c8a84..cbb46d74a 100644 --- a/llama_index/indices/document_summary/retrievers.py +++ b/llama_index/indices/document_summary/retrievers.py @@ -8,7 +8,7 @@ import logging from typing import Any, Callable, List, Optional from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.document_summary.base import DocumentSummaryIndex from llama_index.indices.utils import ( default_format_node_batch_fn, diff --git a/llama_index/indices/empty/base.py b/llama_index/indices/empty/base.py index 6f74184f4..295a56bb7 100644 --- a/llama_index/indices/empty/base.py +++ b/llama_index/indices/empty/base.py @@ -7,7 +7,8 @@ pure LLM calls. from typing import Any, Dict, Optional, Sequence -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import EmptyIndexStruct from llama_index.indices.base import BaseIndex from llama_index.schema import BaseNode diff --git a/llama_index/indices/empty/retrievers.py b/llama_index/indices/empty/retrievers.py index e79532bc5..19d0eb3db 100644 --- a/llama_index/indices/empty/retrievers.py +++ b/llama_index/indices/empty/retrievers.py @@ -2,7 +2,7 @@ from typing import Any, List, Optional from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.empty.base import EmptyIndex from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_SIMPLE_INPUT_PROMPT diff --git a/llama_index/indices/keyword_table/base.py b/llama_index/indices/keyword_table/base.py index 02a031d67..885243c8c 100644 --- a/llama_index/indices/keyword_table/base.py +++ b/llama_index/indices/keyword_table/base.py @@ -13,7 +13,7 @@ from enum import Enum from typing import Any, Dict, Optional, Sequence, Set, Union from llama_index.async_utils import run_async_tasks -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import KeywordTable from llama_index.indices.base import BaseIndex from llama_index.indices.keyword_table.utils import extract_keywords_given_response diff --git a/llama_index/indices/keyword_table/rake_base.py b/llama_index/indices/keyword_table/rake_base.py index b4188e731..5b5a8c1f9 100644 --- a/llama_index/indices/keyword_table/rake_base.py +++ b/llama_index/indices/keyword_table/rake_base.py @@ -6,7 +6,7 @@ Similar to KeywordTableIndex, but uses RAKE instead of GPT. from typing import Any, Set, Union -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.keyword_table.base import ( BaseKeywordTableIndex, KeywordTableRetrieverMode, diff --git a/llama_index/indices/keyword_table/retrievers.py b/llama_index/indices/keyword_table/retrievers.py index 0d687b2fe..05480c051 100644 --- a/llama_index/indices/keyword_table/retrievers.py +++ b/llama_index/indices/keyword_table/retrievers.py @@ -5,7 +5,7 @@ from collections import defaultdict from typing import Any, Dict, List, Optional from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.keyword_table.base import BaseKeywordTableIndex from llama_index.indices.keyword_table.utils import ( extract_keywords_given_response, diff --git a/llama_index/indices/keyword_table/simple_base.py b/llama_index/indices/keyword_table/simple_base.py index f54a57866..c296d9633 100644 --- a/llama_index/indices/keyword_table/simple_base.py +++ b/llama_index/indices/keyword_table/simple_base.py @@ -7,7 +7,7 @@ technique that doesn't involve GPT - just uses regex. from typing import Any, Set, Union -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.keyword_table.base import ( BaseKeywordTableIndex, KeywordTableRetrieverMode, diff --git a/llama_index/indices/knowledge_graph/base.py b/llama_index/indices/knowledge_graph/base.py index 00cc76d83..eb7dd65c9 100644 --- a/llama_index/indices/knowledge_graph/base.py +++ b/llama_index/indices/knowledge_graph/base.py @@ -8,7 +8,7 @@ import logging from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple from llama_index.constants import GRAPH_STORE_KEY -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import KG from llama_index.graph_stores.simple import SimpleGraphStore from llama_index.graph_stores.types import GraphStore diff --git a/llama_index/indices/knowledge_graph/retrievers.py b/llama_index/indices/knowledge_graph/retrievers.py index 9ac9cc064..d7118083a 100644 --- a/llama_index/indices/knowledge_graph/retrievers.py +++ b/llama_index/indices/knowledge_graph/retrievers.py @@ -5,7 +5,7 @@ from enum import Enum from typing import Any, Callable, Dict, List, Optional, Set, Tuple from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.keyword_table.utils import extract_keywords_given_response from llama_index.indices.knowledge_graph.base import KnowledgeGraphIndex from llama_index.indices.query.embedding_utils import get_top_k_embeddings diff --git a/llama_index/indices/list/base.py b/llama_index/indices/list/base.py index c4e3321c2..f09f63d6f 100644 --- a/llama_index/indices/list/base.py +++ b/llama_index/indices/list/base.py @@ -8,7 +8,7 @@ in sequence in order to answer a given query. from enum import Enum from typing import Any, Dict, Optional, Sequence, Union -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexList from llama_index.indices.base import BaseIndex from llama_index.schema import BaseNode diff --git a/llama_index/indices/list/retrievers.py b/llama_index/indices/list/retrievers.py index 4f92ee076..5548e4fdd 100644 --- a/llama_index/indices/list/retrievers.py +++ b/llama_index/indices/list/retrievers.py @@ -3,7 +3,7 @@ import logging from typing import Any, Callable, List, Optional, Tuple from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.list.base import SummaryIndex from llama_index.indices.query.embedding_utils import get_top_k_embeddings from llama_index.indices.utils import ( diff --git a/llama_index/indices/managed/base.py b/llama_index/indices/managed/base.py index d192f6d30..92d2475ca 100644 --- a/llama_index/indices/managed/base.py +++ b/llama_index/indices/managed/base.py @@ -6,7 +6,7 @@ An index that that is built on top of a managed service. from abc import ABC, abstractmethod from typing import Any, Dict, Optional, Sequence, Type -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexDict from llama_index.indices.base import BaseIndex, IndexType from llama_index.schema import BaseNode, Document diff --git a/llama_index/indices/managed/colbert_index/base.py b/llama_index/indices/managed/colbert_index/base.py index e0ad967f1..d808f297a 100644 --- a/llama_index/indices/managed/colbert_index/base.py +++ b/llama_index/indices/managed/colbert_index/base.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Optional, Sequence -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexDict from llama_index.indices.base import BaseIndex from llama_index.schema import BaseNode, NodeWithScore diff --git a/llama_index/indices/managed/colbert_index/retriever.py b/llama_index/indices/managed/colbert_index/retriever.py index 199dfa786..c3e0d0431 100644 --- a/llama_index/indices/managed/colbert_index/retriever.py +++ b/llama_index/indices/managed/colbert_index/retriever.py @@ -2,7 +2,7 @@ from typing import Any, Dict, List, Optional from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.schema import NodeWithScore, QueryBundle from llama_index.vector_stores.types import MetadataFilters diff --git a/llama_index/indices/managed/vectara/base.py b/llama_index/indices/managed/vectara/base.py index adc1467f6..7b51e813e 100644 --- a/llama_index/indices/managed/vectara/base.py +++ b/llama_index/indices/managed/vectara/base.py @@ -13,7 +13,8 @@ from typing import Any, Dict, List, Optional, Sequence, Type import requests -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexDict, IndexStructType from llama_index.indices.managed.base import BaseManagedIndex, IndexType from llama_index.schema import BaseNode, Document, MetadataMode, TextNode diff --git a/llama_index/indices/managed/vectara/query.py b/llama_index/indices/managed/vectara/query.py index faf2e2a55..d958bae31 100644 --- a/llama_index/indices/managed/vectara/query.py +++ b/llama_index/indices/managed/vectara/query.py @@ -2,11 +2,12 @@ from typing import Any, List, Optional from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever +from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.indices.managed.vectara.retriever import VectaraRetriever from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts.mixin import PromptDictType, PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import NodeWithScore, QueryBundle diff --git a/llama_index/indices/managed/vectara/retriever.py b/llama_index/indices/managed/vectara/retriever.py index 93b0927d3..fc2fc7741 100644 --- a/llama_index/indices/managed/vectara/retriever.py +++ b/llama_index/indices/managed/vectara/retriever.py @@ -8,7 +8,7 @@ from typing import Any, List, Optional, Tuple from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.managed.types import ManagedIndexQueryMode from llama_index.indices.managed.vectara.base import VectaraIndex from llama_index.schema import NodeWithScore, QueryBundle, TextNode diff --git a/llama_index/indices/managed/zilliz/base.py b/llama_index/indices/managed/zilliz/base.py index 484d60699..7b31d7d87 100644 --- a/llama_index/indices/managed/zilliz/base.py +++ b/llama_index/indices/managed/zilliz/base.py @@ -10,7 +10,7 @@ from typing import Any, Dict, Optional, Sequence, Type import requests -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexDict, IndexStructType from llama_index.indices.managed.base import BaseManagedIndex, IndexType from llama_index.schema import BaseNode, Document diff --git a/llama_index/indices/managed/zilliz/retriever.py b/llama_index/indices/managed/zilliz/retriever.py index b38068e1e..15cda246a 100644 --- a/llama_index/indices/managed/zilliz/retriever.py +++ b/llama_index/indices/managed/zilliz/retriever.py @@ -5,7 +5,7 @@ import requests from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.managed.zilliz.base import ZillizCloudPipelineIndex from llama_index.indices.query.schema import QueryBundle from llama_index.schema import NodeWithScore, QueryBundle, TextNode diff --git a/llama_index/indices/multi_modal/base.py b/llama_index/indices/multi_modal/base.py index 95048aaef..d3ac2d19a 100644 --- a/llama_index/indices/multi_modal/base.py +++ b/llama_index/indices/multi_modal/base.py @@ -6,7 +6,8 @@ An index that that is built on top of multiple vector stores for different modal import logging from typing import Any, List, Optional, Sequence, cast -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexDict, MultiModelIndexDict from llama_index.embeddings.multi_modal_base import MultiModalEmbedding from llama_index.embeddings.utils import EmbedType, resolve_embed_model diff --git a/llama_index/indices/multi_modal/retriever.py b/llama_index/indices/multi_modal/retriever.py index ef36e3795..cf3ce560f 100644 --- a/llama_index/indices/multi_modal/retriever.py +++ b/llama_index/indices/multi_modal/retriever.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.core import ( +from llama_index.core.base_multi_modal_retriever import ( MultiModalRetriever, ) from llama_index.data_structs.data_structs import IndexDict diff --git a/llama_index/indices/prompt_helper.py b/llama_index/indices/prompt_helper.py index f2c5ca9ef..5e6a25bcd 100644 --- a/llama_index/indices/prompt_helper.py +++ b/llama_index/indices/prompt_helper.py @@ -15,9 +15,9 @@ from typing import Callable, List, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS +from llama_index.core.llms.types import ChatMessage from llama_index.llm_predictor.base import LLMMetadata from llama_index.llms.llm import LLM -from llama_index.llms.types import ChatMessage from llama_index.node_parser.text.token import TokenTextSplitter from llama_index.node_parser.text.utils import truncate_text from llama_index.prompts import ( diff --git a/llama_index/indices/query/base.py b/llama_index/indices/query/base.py index 87d179f26..f6fc5a3c1 100644 --- a/llama_index/indices/query/base.py +++ b/llama_index/indices/query/base.py @@ -1,5 +1,5 @@ # for backwards compatibility -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine __all__ = [ "BaseQueryEngine", diff --git a/llama_index/indices/query/embedding_utils.py b/llama_index/indices/query/embedding_utils.py index 22e80dcad..40031f9cf 100644 --- a/llama_index/indices/query/embedding_utils.py +++ b/llama_index/indices/query/embedding_utils.py @@ -5,7 +5,7 @@ from typing import Any, Callable, List, Optional, Tuple import numpy as np -from llama_index.embeddings.base import similarity as default_similarity_fn +from llama_index.core.embeddings.base import similarity as default_similarity_fn from llama_index.vector_stores.types import VectorStoreQueryMode diff --git a/llama_index/indices/query/query_transform/base.py b/llama_index/indices/query/query_transform/base.py index bb4ee668f..c69d2cc31 100644 --- a/llama_index/indices/query/query_transform/base.py +++ b/llama_index/indices/query/query_transform/base.py @@ -4,6 +4,7 @@ import dataclasses from abc import abstractmethod from typing import Dict, Optional, cast +from llama_index.core.response.schema import Response from llama_index.indices.query.query_transform.prompts import ( DEFAULT_DECOMPOSE_QUERY_TRANSFORM_PROMPT, DEFAULT_IMAGE_OUTPUT_PROMPT, @@ -17,7 +18,6 @@ from llama_index.llms.utils import resolve_llm from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_HYDE_PROMPT from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType -from llama_index.response.schema import Response from llama_index.schema import QueryBundle, QueryType from llama_index.utils import print_text diff --git a/llama_index/indices/struct_store/json_query.py b/llama_index/indices/struct_store/json_query.py index 353aff77e..bb7e389db 100644 --- a/llama_index/indices/struct_store/json_query.py +++ b/llama_index/indices/struct_store/json_query.py @@ -2,12 +2,12 @@ import json import logging from typing import Any, Callable, Dict, List, Optional, Union -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import Response from llama_index.prompts import BasePromptTemplate, PromptTemplate from llama_index.prompts.default_prompts import DEFAULT_JSON_PATH_PROMPT from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.prompts.prompt_type import PromptType -from llama_index.response.schema import Response from llama_index.schema import QueryBundle from llama_index.service_context import ServiceContext from llama_index.utils import print_text diff --git a/llama_index/indices/struct_store/pandas.py b/llama_index/indices/struct_store/pandas.py index 129b6e927..85109e17d 100644 --- a/llama_index/indices/struct_store/pandas.py +++ b/llama_index/indices/struct_store/pandas.py @@ -5,7 +5,8 @@ from typing import Any, Optional, Sequence import pandas as pd -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.table import PandasStructTable from llama_index.indices.struct_store.base import BaseStructStoreIndex from llama_index.schema import BaseNode diff --git a/llama_index/indices/struct_store/sql.py b/llama_index/indices/struct_store/sql.py index f59127669..592f19649 100644 --- a/llama_index/indices/struct_store/sql.py +++ b/llama_index/indices/struct_store/sql.py @@ -5,7 +5,8 @@ from typing import Any, Optional, Sequence, Union from sqlalchemy import Table -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.table import SQLStructTable from llama_index.indices.common.struct_store.schema import SQLContextContainer from llama_index.indices.common.struct_store.sql import SQLStructDatapointExtractor diff --git a/llama_index/indices/struct_store/sql_query.py b/llama_index/indices/struct_store/sql_query.py index 53cf0d933..b9d9b4f73 100644 --- a/llama_index/indices/struct_store/sql_query.py +++ b/llama_index/indices/struct_store/sql_query.py @@ -5,7 +5,8 @@ from typing import Any, Dict, List, Optional, Tuple, Union, cast from sqlalchemy import Table -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import Response from llama_index.indices.struct_store.container_builder import ( SQLContextContainerBuilder, ) @@ -20,7 +21,6 @@ from llama_index.prompts.default_prompts import ( ) from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.prompts.prompt_type import PromptType -from llama_index.response.schema import Response from llama_index.response_synthesizers import ( get_response_synthesizer, ) diff --git a/llama_index/indices/struct_store/sql_retriever.py b/llama_index/indices/struct_store/sql_retriever.py index a90160364..0971cb6c6 100644 --- a/llama_index/indices/struct_store/sql_retriever.py +++ b/llama_index/indices/struct_store/sql_retriever.py @@ -8,7 +8,7 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, Union, cast from sqlalchemy import Table from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.embeddings.base import BaseEmbedding from llama_index.objects.base import ObjectRetriever from llama_index.objects.table_node_mapping import SQLTableSchema diff --git a/llama_index/indices/tree/all_leaf_retriever.py b/llama_index/indices/tree/all_leaf_retriever.py index db831f073..8de60a921 100644 --- a/llama_index/indices/tree/all_leaf_retriever.py +++ b/llama_index/indices/tree/all_leaf_retriever.py @@ -4,7 +4,7 @@ import logging from typing import Any, List, Optional, cast from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexGraph from llama_index.indices.tree.base import TreeIndex from llama_index.indices.utils import get_sorted_node_list diff --git a/llama_index/indices/tree/base.py b/llama_index/indices/tree/base.py index c1365f09d..4fa5a020a 100644 --- a/llama_index/indices/tree/base.py +++ b/llama_index/indices/tree/base.py @@ -3,7 +3,7 @@ from enum import Enum from typing import Any, Dict, Optional, Sequence, Union -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever # from llama_index.data_structs.data_structs import IndexGraph from llama_index.data_structs.data_structs import IndexGraph diff --git a/llama_index/indices/tree/select_leaf_retriever.py b/llama_index/indices/tree/select_leaf_retriever.py index a61a3e5ae..d4ce25fb9 100644 --- a/llama_index/indices/tree/select_leaf_retriever.py +++ b/llama_index/indices/tree/select_leaf_retriever.py @@ -4,7 +4,8 @@ import logging from typing import Any, Dict, List, Optional, cast from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever +from llama_index.core.response.schema import Response from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.base import TreeIndex from llama_index.indices.tree.utils import get_numbered_text_from_nodes @@ -19,7 +20,6 @@ from llama_index.prompts.default_prompts import ( DEFAULT_QUERY_PROMPT_MULTIPLE, DEFAULT_TEXT_QA_PROMPT, ) -from llama_index.response.schema import Response from llama_index.response_synthesizers import get_response_synthesizer from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle from llama_index.utils import print_text, truncate_text diff --git a/llama_index/indices/tree/tree_root_retriever.py b/llama_index/indices/tree/tree_root_retriever.py index 58b2ef0cb..fe581e3cc 100644 --- a/llama_index/indices/tree/tree_root_retriever.py +++ b/llama_index/indices/tree/tree_root_retriever.py @@ -3,7 +3,7 @@ import logging from typing import Any, List, Optional from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.query.schema import QueryBundle from llama_index.indices.tree.base import TreeIndex from llama_index.indices.utils import get_sorted_node_list diff --git a/llama_index/indices/vector_store/base.py b/llama_index/indices/vector_store/base.py index 98a7e32a6..918cbc41b 100644 --- a/llama_index/indices/vector_store/base.py +++ b/llama_index/indices/vector_store/base.py @@ -7,7 +7,7 @@ import logging from typing import Any, Dict, List, Optional, Sequence from llama_index.async_utils import run_async_tasks -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexDict from llama_index.indices.base import BaseIndex from llama_index.indices.utils import async_embed_nodes, embed_nodes diff --git a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py index 9b8db1169..0405db29f 100644 --- a/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py +++ b/llama_index/indices/vector_store/retrievers/auto_retriever/auto_retriever.py @@ -4,7 +4,7 @@ from typing import Any, List, Optional, Tuple, cast from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.core import BaseAutoRetriever +from llama_index.core.base_auto_retriever import BaseAutoRetriever from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.indices.vector_store.retrievers import VectorIndexRetriever diff --git a/llama_index/indices/vector_store/retrievers/retriever.py b/llama_index/indices/vector_store/retrievers/retriever.py index 2be7db0d5..83b097496 100644 --- a/llama_index/indices/vector_store/retrievers/retriever.py +++ b/llama_index/indices/vector_store/retrievers/retriever.py @@ -5,7 +5,7 @@ from typing import Any, Dict, List, Optional from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.data_structs.data_structs import IndexDict from llama_index.indices.utils import log_vector_store_query_result from llama_index.indices.vector_store.base import VectorStoreIndex diff --git a/llama_index/langchain_helpers/agents/tools.py b/llama_index/langchain_helpers/agents/tools.py index 01801486d..f85853d1d 100644 --- a/llama_index/langchain_helpers/agents/tools.py +++ b/llama_index/langchain_helpers/agents/tools.py @@ -4,8 +4,8 @@ from typing import Any, Dict, List from llama_index.bridge.langchain import BaseTool from llama_index.bridge.pydantic import BaseModel, Field -from llama_index.core import BaseQueryEngine -from llama_index.response.schema import RESPONSE_TYPE +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.schema import TextNode diff --git a/llama_index/llama_dataset/base.py b/llama_index/llama_dataset/base.py index 0d0c0472b..a7fc03b84 100644 --- a/llama_index/llama_dataset/base.py +++ b/llama_index/llama_dataset/base.py @@ -11,7 +11,7 @@ from pandas import DataFrame as PandasDataFrame from llama_index.async_utils import asyncio_module from llama_index.bridge.pydantic import BaseModel, Field, PrivateAttr -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.evaluation import BaseEvaluator PredictorType = Union[BaseQueryEngine, BaseEvaluator] diff --git a/llama_index/llama_dataset/generator.py b/llama_index/llama_dataset/generator.py index 6adb38213..e0085a885 100644 --- a/llama_index/llama_dataset/generator.py +++ b/llama_index/llama_dataset/generator.py @@ -7,6 +7,7 @@ from typing import List from llama_index import Document, ServiceContext, SummaryIndex from llama_index.async_utils import DEFAULT_NUM_WORKERS, run_jobs +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.ingestion import run_transformations from llama_index.llama_dataset import ( CreatedBy, @@ -18,7 +19,6 @@ from llama_index.postprocessor.node import KeywordNodePostprocessor from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT from llama_index.prompts.mixin import PromptDictType, PromptMixin, PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE from llama_index.schema import BaseNode, MetadataMode, NodeWithScore DEFAULT_QUESTION_GENERATION_PROMPT = """\ diff --git a/llama_index/llama_dataset/rag.py b/llama_index/llama_dataset/rag.py index 5e9c897f4..3a4182767 100644 --- a/llama_index/llama_dataset/rag.py +++ b/llama_index/llama_dataset/rag.py @@ -7,7 +7,7 @@ from typing import List, Optional from pandas import DataFrame as PandasDataFrame from llama_index.bridge.pydantic import Field -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine from llama_index.llama_dataset.base import ( BaseLlamaDataExample, BaseLlamaDataset, diff --git a/llama_index/llm_predictor/base.py b/llama_index/llm_predictor/base.py index d807d221f..0b1ab1693 100644 --- a/llama_index/llm_predictor/base.py +++ b/llama_index/llm_predictor/base.py @@ -10,6 +10,11 @@ from typing_extensions import Self from llama_index.bridge.pydantic import BaseModel, PrivateAttr from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.core.llms.types import ( + ChatMessage, + LLMMetadata, + MessageRole, +) from llama_index.llms.llm import ( LLM, astream_chat_response_to_tokens, @@ -17,11 +22,6 @@ from llama_index.llms.llm import ( stream_chat_response_to_tokens, stream_completion_response_to_tokens, ) -from llama_index.llms.types import ( - ChatMessage, - LLMMetadata, - MessageRole, -) from llama_index.llms.utils import LLMType, resolve_llm from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.schema import BaseComponent diff --git a/llama_index/llm_predictor/mock.py b/llama_index/llm_predictor/mock.py index d3a971f18..dc005201d 100644 --- a/llama_index/llm_predictor/mock.py +++ b/llama_index/llm_predictor/mock.py @@ -6,9 +6,9 @@ from deprecated import deprecated from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS +from llama_index.core.llms.types import LLMMetadata from llama_index.llm_predictor.base import BaseLLMPredictor from llama_index.llms.llm import LLM -from llama_index.llms.types import LLMMetadata from llama_index.prompts.base import BasePromptTemplate from llama_index.prompts.prompt_type import PromptType from llama_index.token_counter.utils import ( diff --git a/llama_index/llms/__init__.py b/llama_index/llms/__init__.py index bf4d304e5..6a1aa18ca 100644 --- a/llama_index/llms/__init__.py +++ b/llama_index/llms/__init__.py @@ -1,3 +1,14 @@ +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) from llama_index.llms.ai21 import AI21 from llama_index.llms.anthropic import Anthropic from llama_index.llms.anyscale import Anyscale @@ -30,17 +41,6 @@ from llama_index.llms.perplexity import Perplexity from llama_index.llms.portkey import Portkey from llama_index.llms.predibase import PredibaseLLM from llama_index.llms.replicate import Replicate -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, -) from llama_index.llms.vertex import Vertex from llama_index.llms.vllm import Vllm, VllmServer from llama_index.llms.watsonx import WatsonX diff --git a/llama_index/llms/ai21.py b/llama_index/llms/ai21.py index 0ed8216b6..860e36034 100644 --- a/llama_index/llms/ai21.py +++ b/llama_index/llms/ai21.py @@ -2,14 +2,7 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.ai21_utils import ai21_model_to_context_size -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import ( - completion_to_chat_decorator, - get_from_param_or_env, -) -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, @@ -17,6 +10,13 @@ from llama_index.llms.types import ( CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.ai21_utils import ai21_model_to_context_size +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import ( + completion_to_chat_decorator, + get_from_param_or_env, +) from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/anthropic.py b/llama_index/llms/anthropic.py index 86ceff3a5..5cbf2ca48 100644 --- a/llama_index/llms/anthropic.py +++ b/llama_index/llms/anthropic.py @@ -3,6 +3,17 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_TEMPERATURE +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) from llama_index.llms.anthropic_utils import ( anthropic_modelname_to_contextsize, messages_to_anthropic_prompt, @@ -18,17 +29,6 @@ from llama_index.llms.generic_utils import ( stream_chat_to_completion_decorator, ) from llama_index.llms.llm import LLM -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, -) from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_ANTHROPIC_MODEL = "claude-2" diff --git a/llama_index/llms/anthropic_utils.py b/llama_index/llms/anthropic_utils.py index f0904bd73..eb2eb23fb 100644 --- a/llama_index/llms/anthropic_utils.py +++ b/llama_index/llms/anthropic_utils.py @@ -1,6 +1,6 @@ from typing import Dict, Sequence -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole HUMAN_PREFIX = "\n\nHuman:" ASSISTANT_PREFIX = "\n\nAssistant:" diff --git a/llama_index/llms/anyscale.py b/llama_index/llms/anyscale.py index d9404326d..d17f86ec6 100644 --- a/llama_index/llms/anyscale.py +++ b/llama_index/llms/anyscale.py @@ -2,12 +2,12 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE +from llama_index.core.llms.types import ChatMessage, LLMMetadata from llama_index.llms.anyscale_utils import ( anyscale_modelname_to_contextsize, ) from llama_index.llms.generic_utils import get_from_param_or_env from llama_index.llms.openai import OpenAI -from llama_index.llms.types import ChatMessage, LLMMetadata from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_API_BASE = "https://api.endpoints.anyscale.com/v1" diff --git a/llama_index/llms/anyscale_utils.py b/llama_index/llms/anyscale_utils.py index d86bbf300..b82a1c3bb 100644 --- a/llama_index/llms/anyscale_utils.py +++ b/llama_index/llms/anyscale_utils.py @@ -1,6 +1,6 @@ from typing import Any, Dict, List, Sequence -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole LLAMA_MODELS = { "meta-llama/Llama-2-7b-chat-hf": 4096, diff --git a/llama_index/llms/azure_openai.py b/llama_index/llms/azure_openai.py index 137f78f86..8caa77f56 100644 --- a/llama_index/llms/azure_openai.py +++ b/llama_index/llms/azure_openai.py @@ -6,13 +6,13 @@ from openai import AzureOpenAI as SyncAzureOpenAI from llama_index.bridge.pydantic import Field, PrivateAttr, root_validator from llama_index.callbacks import CallbackManager +from llama_index.core.llms.types import ChatMessage from llama_index.llms.generic_utils import get_from_param_or_env from llama_index.llms.openai import OpenAI from llama_index.llms.openai_utils import ( refresh_openai_azuread_token, resolve_from_aliases, ) -from llama_index.llms.types import ChatMessage from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/base.py b/llama_index/llms/base.py index 734143046..e3a0b1b31 100644 --- a/llama_index/llms/base.py +++ b/llama_index/llms/base.py @@ -12,7 +12,7 @@ from typing import ( from llama_index.bridge.pydantic import Field, validator from llama_index.callbacks import CallbackManager, CBEventType, EventPayload -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, diff --git a/llama_index/llms/bedrock.py b/llama_index/llms/bedrock.py index c6c7d4e6f..b76d19b91 100644 --- a/llama_index/llms/bedrock.py +++ b/llama_index/llms/bedrock.py @@ -3,6 +3,16 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) from llama_index.llms.base import ( llm_chat_callback, llm_completion_callback, @@ -20,16 +30,6 @@ from llama_index.llms.generic_utils import ( stream_completion_response_to_chat_response, ) from llama_index.llms.llm import LLM -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, -) from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/bedrock_utils.py b/llama_index/llms/bedrock_utils.py index a43a4a3c4..cf8e9a05c 100644 --- a/llama_index/llms/bedrock_utils.py +++ b/llama_index/llms/bedrock_utils.py @@ -10,6 +10,7 @@ from tenacity import ( wait_exponential, ) +from llama_index.core.llms.types import ChatMessage from llama_index.llms.anthropic_utils import messages_to_anthropic_prompt from llama_index.llms.generic_utils import ( prompt_to_messages, @@ -20,7 +21,6 @@ from llama_index.llms.llama_utils import ( from llama_index.llms.llama_utils import ( messages_to_prompt as messages_to_llama_prompt, ) -from llama_index.llms.types import ChatMessage HUMAN_PREFIX = "\n\nHuman:" ASSISTANT_PREFIX = "\n\nAssistant:" diff --git a/llama_index/llms/clarifai.py b/llama_index/llms/clarifai.py index 3d821023d..88950cc0f 100644 --- a/llama_index/llms/clarifai.py +++ b/llama_index/llms/clarifai.py @@ -2,12 +2,7 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.base import ( - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.llm import LLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -17,6 +12,11 @@ from llama_index.llms.types import ( CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.llms.llm import LLM from llama_index.types import BaseOutputParser, PydanticProgramMode EXAMPLE_URL = "https://clarifai.com/anthropic/completion/models/claude-v2" diff --git a/llama_index/llms/cohere.py b/llama_index/llms/cohere.py index 2383a2eae..d83d6a39e 100644 --- a/llama_index/llms/cohere.py +++ b/llama_index/llms/cohere.py @@ -3,6 +3,17 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) from llama_index.llms.base import ( llm_chat_callback, llm_completion_callback, @@ -15,17 +26,6 @@ from llama_index.llms.cohere_utils import ( messages_to_cohere_history, ) from llama_index.llms.llm import LLM -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, -) from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/cohere_utils.py b/llama_index/llms/cohere_utils.py index 292102f51..421d9037a 100644 --- a/llama_index/llms/cohere_utils.py +++ b/llama_index/llms/cohere_utils.py @@ -9,7 +9,7 @@ from tenacity import ( wait_exponential, ) -from llama_index.llms.types import ChatMessage +from llama_index.core.llms.types import ChatMessage COMMAND_MODELS = { "command": 4096, diff --git a/llama_index/llms/custom.py b/llama_index/llms/custom.py index 48eee2aee..516a5e08c 100644 --- a/llama_index/llms/custom.py +++ b/llama_index/llms/custom.py @@ -1,5 +1,13 @@ from typing import Any, Sequence +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, +) from llama_index.llms.base import ( llm_chat_callback, llm_completion_callback, @@ -9,14 +17,6 @@ from llama_index.llms.generic_utils import ( stream_completion_to_chat_decorator, ) from llama_index.llms.llm import LLM -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, -) class CustomLLM(LLM): diff --git a/llama_index/llms/everlyai.py b/llama_index/llms/everlyai.py index 708b801db..211ff729c 100644 --- a/llama_index/llms/everlyai.py +++ b/llama_index/llms/everlyai.py @@ -2,10 +2,10 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE +from llama_index.core.llms.types import ChatMessage, LLMMetadata from llama_index.llms.everlyai_utils import everlyai_modelname_to_contextsize from llama_index.llms.generic_utils import get_from_param_or_env from llama_index.llms.openai import OpenAI -from llama_index.llms.types import ChatMessage, LLMMetadata from llama_index.types import BaseOutputParser, PydanticProgramMode EVERLYAI_API_BASE = "https://everlyai.xyz/hosted" diff --git a/llama_index/llms/gemini.py b/llama_index/llms/gemini.py index 54195e441..57eaa8d1d 100644 --- a/llama_index/llms/gemini.py +++ b/llama_index/llms/gemini.py @@ -6,6 +6,14 @@ from typing import Any, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + CompletionResponse, + CompletionResponseGen, + LLMMetadata, +) from llama_index.llms.base import ( llm_chat_callback, llm_completion_callback, @@ -18,14 +26,6 @@ from llama_index.llms.gemini_utils import ( completion_from_gemini_response, merge_neighboring_same_role_messages, ) -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseGen, - CompletionResponse, - CompletionResponseGen, - LLMMetadata, -) if typing.TYPE_CHECKING: import google.generativeai as genai diff --git a/llama_index/llms/gemini_utils.py b/llama_index/llms/gemini_utils.py index f235a0134..b19f9ee19 100644 --- a/llama_index/llms/gemini_utils.py +++ b/llama_index/llms/gemini_utils.py @@ -1,12 +1,12 @@ import typing from typing import Sequence, Union +from llama_index.core.llms.types import MessageRole from llama_index.llms.base import ( ChatMessage, ChatResponse, CompletionResponse, ) -from llama_index.llms.types import MessageRole if typing.TYPE_CHECKING: import google.ai.generativelanguage as glm diff --git a/llama_index/llms/generic_utils.py b/llama_index/llms/generic_utils.py index 3ad12c0c2..3be36a267 100644 --- a/llama_index/llms/generic_utils.py +++ b/llama_index/llms/generic_utils.py @@ -1,7 +1,7 @@ import os from typing import Any, Awaitable, Callable, List, Optional, Sequence -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, diff --git a/llama_index/llms/gradient.py b/llama_index/llms/gradient.py index 6cc754893..992859058 100644 --- a/llama_index/llms/gradient.py +++ b/llama_index/llms/gradient.py @@ -5,14 +5,14 @@ from typing_extensions import override from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, CompletionResponse, CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_completion_callback +from llama_index.llms.custom import CustomLLM from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/huggingface.py b/llama_index/llms/huggingface.py index 436931539..bb338a1f2 100644 --- a/llama_index/llms/huggingface.py +++ b/llama_index/llms/huggingface.py @@ -8,6 +8,17 @@ from llama_index.constants import ( DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS, ) +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) from llama_index.llms.base import ( llm_chat_callback, llm_completion_callback, @@ -20,17 +31,6 @@ from llama_index.llms.generic_utils import ( from llama_index.llms.generic_utils import ( messages_to_prompt as generic_messages_to_prompt, ) -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, -) from llama_index.prompts.base import PromptTemplate from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/konko.py b/llama_index/llms/konko.py index ecb056258..3ab7cd2bd 100644 --- a/llama_index/llms/konko.py +++ b/llama_index/llms/konko.py @@ -3,6 +3,16 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) from llama_index.llms.base import llm_chat_callback, llm_completion_callback from llama_index.llms.generic_utils import ( achat_to_completion_decorator, @@ -24,16 +34,6 @@ from llama_index.llms.konko_utils import ( to_openai_message_dicts, ) from llama_index.llms.llm import LLM -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, -) from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_KONKO_MODEL = "meta-llama/Llama-2-13b-chat-hf" diff --git a/llama_index/llms/konko_utils.py b/llama_index/llms/konko_utils.py index a097aab4c..c285e30e9 100644 --- a/llama_index/llms/konko_utils.py +++ b/llama_index/llms/konko_utils.py @@ -11,8 +11,8 @@ from tenacity import ( ) from llama_index.bridge.pydantic import BaseModel +from llama_index.core.llms.types import ChatMessage from llama_index.llms.generic_utils import get_from_param_or_env -from llama_index.llms.types import ChatMessage DEFAULT_KONKO_API_TYPE = "open_ai" DEFAULT_KONKO_API_BASE = "https://api.konko.ai/v1" diff --git a/llama_index/llms/langchain.py b/llama_index/llms/langchain.py index 56b093759..873ee9ab2 100644 --- a/llama_index/llms/langchain.py +++ b/llama_index/llms/langchain.py @@ -6,13 +6,7 @@ if TYPE_CHECKING: from llama_index.bridge.pydantic import PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.generic_utils import ( - completion_response_to_chat_response, - stream_completion_response_to_chat_response, -) -from llama_index.llms.llm import LLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -22,6 +16,12 @@ from llama_index.llms.types import ( CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) +from llama_index.llms.llm import LLM from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/langchain_utils.py b/llama_index/llms/langchain_utils.py index accc029b0..90fefcd77 100644 --- a/llama_index/llms/langchain_utils.py +++ b/llama_index/llms/langchain_utils.py @@ -15,9 +15,9 @@ from llama_index.bridge.langchain import ( ) from llama_index.bridge.langchain import BaseMessage as LCMessage from llama_index.constants import AI21_J2_CONTEXT_WINDOW, COHERE_CONTEXT_WINDOW +from llama_index.core.llms.types import ChatMessage, LLMMetadata, MessageRole from llama_index.llms.anyscale_utils import anyscale_modelname_to_contextsize from llama_index.llms.openai_utils import openai_modelname_to_contextsize -from llama_index.llms.types import ChatMessage, LLMMetadata, MessageRole def is_chat_model(llm: BaseLanguageModel) -> bool: diff --git a/llama_index/llms/litellm.py b/llama_index/llms/litellm.py index e1524c630..4ddfea738 100644 --- a/llama_index/llms/litellm.py +++ b/llama_index/llms/litellm.py @@ -3,6 +3,16 @@ from typing import Any, Awaitable, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_TEMPERATURE +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) from llama_index.llms.base import llm_chat_callback, llm_completion_callback from llama_index.llms.generic_utils import ( achat_to_completion_decorator, @@ -24,16 +34,6 @@ from llama_index.llms.litellm_utils import ( validate_litellm_api_key, ) from llama_index.llms.llm import LLM -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, -) from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_LITELLM_MODEL = "gpt-3.5-turbo" diff --git a/llama_index/llms/litellm_utils.py b/llama_index/llms/litellm_utils.py index 2af40dab6..ab4cefe49 100644 --- a/llama_index/llms/litellm_utils.py +++ b/llama_index/llms/litellm_utils.py @@ -11,7 +11,7 @@ from tenacity import ( ) from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.types import ChatMessage +from llama_index.core.llms.types import ChatMessage MISSING_API_KEY_ERROR_MESSAGE = """No API key found for LLM. E.g. to use openai Please set the OPENAI_API_KEY environment variable or \ diff --git a/llama_index/llms/llama_api.py b/llama_index/llms/llama_api.py index 9f7e07e13..1364a4a79 100644 --- a/llama_index/llms/llama_api.py +++ b/llama_index/llms/llama_api.py @@ -3,14 +3,7 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import chat_to_completion_decorator -from llama_index.llms.openai_utils import ( - from_openai_message_dict, - to_openai_message_dicts, -) -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, @@ -18,6 +11,13 @@ from llama_index.llms.types import ( CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import chat_to_completion_decorator +from llama_index.llms.openai_utils import ( + from_openai_message_dict, + to_openai_message_dicts, +) from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/llama_cpp.py b/llama_index/llms/llama_cpp.py index 7ab0bd1cd..124554c92 100644 --- a/llama_index/llms/llama_cpp.py +++ b/llama_index/llms/llama_cpp.py @@ -11,13 +11,7 @@ from llama_index.constants import ( DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE, ) -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import ( - completion_response_to_chat_response, - stream_completion_response_to_chat_response, -) -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, @@ -25,6 +19,12 @@ from llama_index.llms.types import ( CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) from llama_index.types import BaseOutputParser, PydanticProgramMode from llama_index.utils import get_cache_dir diff --git a/llama_index/llms/llama_utils.py b/llama_index/llms/llama_utils.py index 2ee0e950b..642bd5b7c 100644 --- a/llama_index/llms/llama_utils.py +++ b/llama_index/llms/llama_utils.py @@ -1,6 +1,6 @@ from typing import List, Optional, Sequence -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole BOS, EOS = "<s>", "</s>" B_INST, E_INST = "[INST]", "[/INST]" diff --git a/llama_index/llms/llm.py b/llama_index/llms/llm.py index afe5442a5..850d3340c 100644 --- a/llama_index/llms/llm.py +++ b/llama_index/llms/llm.py @@ -3,11 +3,7 @@ from typing import Any, List, Optional, Protocol, Sequence, runtime_checkable from llama_index.bridge.pydantic import BaseModel, Field, validator from llama_index.callbacks import CBEventType, EventPayload -from llama_index.llms.base import BaseLLM -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, -) -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponseAsyncGen, ChatResponseGen, @@ -15,6 +11,10 @@ from llama_index.llms.types import ( CompletionResponseGen, MessageRole, ) +from llama_index.llms.base import BaseLLM +from llama_index.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) from llama_index.prompts import BasePromptTemplate, PromptTemplate from llama_index.types import ( BaseOutputParser, diff --git a/llama_index/llms/localai.py b/llama_index/llms/localai.py index 4da2dc03b..15ca2a463 100644 --- a/llama_index/llms/localai.py +++ b/llama_index/llms/localai.py @@ -11,10 +11,10 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field from llama_index.constants import DEFAULT_CONTEXT_WINDOW +from llama_index.core.llms.types import ChatMessage, LLMMetadata from llama_index.llms.openai import OpenAI from llama_index.llms.openai_like import OpenAILike from llama_index.llms.openai_utils import is_function_calling_model -from llama_index.llms.types import ChatMessage, LLMMetadata from llama_index.types import BaseOutputParser, PydanticProgramMode # Use these as kwargs for OpenAILike to connect to LocalAIs diff --git a/llama_index/llms/mistral.py b/llama_index/llms/mistral.py index e72f22fc4..e46152537 100644 --- a/llama_index/llms/mistral.py +++ b/llama_index/llms/mistral.py @@ -3,6 +3,19 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_TEMPERATURE + +# from mistralai.models.chat_completion import ChatMessage +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) from llama_index.llms.base import ( llm_chat_callback, llm_completion_callback, @@ -18,19 +31,6 @@ from llama_index.llms.llm import LLM from llama_index.llms.mistralai_utils import ( mistralai_modelname_to_contextsize, ) - -# from mistralai.models.chat_completion import ChatMessage -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, -) from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_MISTRALAI_MODEL = "mistral-tiny" diff --git a/llama_index/llms/mock.py b/llama_index/llms/mock.py index 9e3cf32e2..0cce089a8 100644 --- a/llama_index/llms/mock.py +++ b/llama_index/llms/mock.py @@ -1,14 +1,14 @@ from typing import Any, Callable, Optional, Sequence from llama_index.callbacks import CallbackManager -from llama_index.llms.base import llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, CompletionResponse, CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_completion_callback +from llama_index.llms.custom import CustomLLM from llama_index.types import PydanticProgramMode diff --git a/llama_index/llms/monsterapi.py b/llama_index/llms/monsterapi.py index 0e21207cb..aaa1090e5 100644 --- a/llama_index/llms/monsterapi.py +++ b/llama_index/llms/monsterapi.py @@ -3,15 +3,15 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, CompletionResponse, CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_MONSTER_TEMP = 0.75 diff --git a/llama_index/llms/ollama.py b/llama_index/llms/ollama.py index a2b3f9773..1801c49b4 100644 --- a/llama_index/llms/ollama.py +++ b/llama_index/llms/ollama.py @@ -6,9 +6,7 @@ from httpx import Timeout from llama_index.bridge.pydantic import Field from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, @@ -17,6 +15,8 @@ from llama_index.llms.types import ( LLMMetadata, MessageRole, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM DEFAULT_REQUEST_TIMEOUT = 30.0 diff --git a/llama_index/llms/openai.py b/llama_index/llms/openai.py index 5a7ff6946..d30af239b 100644 --- a/llama_index/llms/openai.py +++ b/llama_index/llms/openai.py @@ -26,6 +26,17 @@ from llama_index.callbacks import CallbackManager from llama_index.constants import ( DEFAULT_TEMPERATURE, ) +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) from llama_index.llms.base import ( llm_chat_callback, llm_completion_callback, @@ -49,17 +60,6 @@ from llama_index.llms.openai_utils import ( resolve_openai_credentials, to_openai_message_dicts, ) -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, - MessageRole, -) from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_OPENAI_MODEL = "gpt-3.5-turbo" diff --git a/llama_index/llms/openai_like.py b/llama_index/llms/openai_like.py index ced6bda32..09ef4241e 100644 --- a/llama_index/llms/openai_like.py +++ b/llama_index/llms/openai_like.py @@ -2,8 +2,8 @@ from typing import Optional, Union from llama_index.bridge.pydantic import Field from llama_index.constants import DEFAULT_CONTEXT_WINDOW +from llama_index.core.llms.types import LLMMetadata from llama_index.llms.openai import OpenAI, Tokenizer -from llama_index.llms.types import LLMMetadata class OpenAILike(OpenAI): diff --git a/llama_index/llms/openai_utils.py b/llama_index/llms/openai_utils.py index 099e048cb..830425c6a 100644 --- a/llama_index/llms/openai_utils.py +++ b/llama_index/llms/openai_utils.py @@ -20,8 +20,8 @@ from tenacity import ( from tenacity.stop import stop_base from llama_index.bridge.pydantic import BaseModel +from llama_index.core.llms.types import ChatMessage from llama_index.llms.generic_utils import get_from_param_or_env -from llama_index.llms.types import ChatMessage DEFAULT_OPENAI_API_TYPE = "open_ai" DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1" diff --git a/llama_index/llms/openllm.py b/llama_index/llms/openllm.py index fde68121b..dea6d54d0 100644 --- a/llama_index/llms/openllm.py +++ b/llama_index/llms/openllm.py @@ -13,6 +13,16 @@ from typing import ( from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, +) from llama_index.llms.base import ( llm_chat_callback, llm_completion_callback, @@ -24,16 +34,6 @@ from llama_index.llms.generic_utils import ( messages_to_prompt as generic_messages_to_prompt, ) from llama_index.llms.llm import LLM -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseAsyncGen, - ChatResponseGen, - CompletionResponse, - CompletionResponseAsyncGen, - CompletionResponseGen, - LLMMetadata, -) from llama_index.types import PydanticProgramMode logger = logging.getLogger(__name__) diff --git a/llama_index/llms/openrouter.py b/llama_index/llms/openrouter.py index b8ff7024a..77ac299a1 100644 --- a/llama_index/llms/openrouter.py +++ b/llama_index/llms/openrouter.py @@ -6,9 +6,9 @@ from llama_index.constants import ( DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE, ) +from llama_index.core.llms.types import LLMMetadata from llama_index.llms.generic_utils import get_from_param_or_env from llama_index.llms.openai_like import OpenAILike -from llama_index.llms.types import LLMMetadata DEFAULT_API_BASE = "https://openrouter.ai/api/v1" DEFAULT_MODEL = "gryphe/mythomax-l2-13b" diff --git a/llama_index/llms/palm.py b/llama_index/llms/palm.py index 30b49b6e3..1e0200001 100644 --- a/llama_index/llms/palm.py +++ b/llama_index/llms/palm.py @@ -5,14 +5,14 @@ from typing import Any, Callable, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, CompletionResponse, CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_completion_callback +from llama_index.llms.custom import CustomLLM from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_PALM_MODEL = "models/text-bison-001" diff --git a/llama_index/llms/perplexity.py b/llama_index/llms/perplexity.py index dd36e6bb2..005e010ba 100644 --- a/llama_index/llms/perplexity.py +++ b/llama_index/llms/perplexity.py @@ -6,9 +6,7 @@ import requests from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.llm import LLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -18,6 +16,8 @@ from llama_index.llms.types import ( CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.llm import LLM from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/portkey.py b/llama_index/llms/portkey.py index 48c92ca63..1c1f1ba23 100644 --- a/llama_index/llms/portkey.py +++ b/llama_index/llms/portkey.py @@ -4,6 +4,14 @@ Portkey integration with Llama_index for enhanced monitoring. from typing import TYPE_CHECKING, Any, Callable, List, Optional, Sequence, Union, cast from llama_index.bridge.pydantic import Field, PrivateAttr +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseGen, + CompletionResponse, + CompletionResponseGen, + LLMMetadata, +) from llama_index.llms.base import llm_chat_callback, llm_completion_callback from llama_index.llms.custom import CustomLLM from llama_index.llms.generic_utils import ( @@ -18,14 +26,6 @@ from llama_index.llms.portkey_utils import ( get_llm, is_chat_model, ) -from llama_index.llms.types import ( - ChatMessage, - ChatResponse, - ChatResponseGen, - CompletionResponse, - CompletionResponseGen, - LLMMetadata, -) from llama_index.types import BaseOutputParser, PydanticProgramMode if TYPE_CHECKING: diff --git a/llama_index/llms/portkey_utils.py b/llama_index/llms/portkey_utils.py index e23e6b5ee..e2da09c10 100644 --- a/llama_index/llms/portkey_utils.py +++ b/llama_index/llms/portkey_utils.py @@ -6,6 +6,7 @@ the functionality and usability of the Portkey class """ from typing import TYPE_CHECKING, List +from llama_index.core.llms.types import LLMMetadata from llama_index.llms.anthropic import Anthropic from llama_index.llms.anthropic_utils import CLAUDE_MODELS from llama_index.llms.openai import OpenAI @@ -16,7 +17,6 @@ from llama_index.llms.openai_utils import ( GPT4_MODELS, TURBO_MODELS, ) -from llama_index.llms.types import LLMMetadata if TYPE_CHECKING: from portkey import ( diff --git a/llama_index/llms/predibase.py b/llama_index/llms/predibase.py index 38b86216f..cca2997a3 100644 --- a/llama_index/llms/predibase.py +++ b/llama_index/llms/predibase.py @@ -8,14 +8,14 @@ from llama_index.constants import ( DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE, ) -from llama_index.llms.base import llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, CompletionResponse, CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_completion_callback +from llama_index.llms.custom import CustomLLM from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/replicate.py b/llama_index/llms/replicate.py index 16c8adae6..bfbd95eab 100644 --- a/llama_index/llms/replicate.py +++ b/llama_index/llms/replicate.py @@ -2,13 +2,7 @@ from typing import Any, Dict, Sequence from llama_index.bridge.pydantic import Field from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.custom import CustomLLM -from llama_index.llms.generic_utils import ( - completion_response_to_chat_response, - stream_completion_response_to_chat_response, -) -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, @@ -16,6 +10,12 @@ from llama_index.llms.types import ( CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.custom import CustomLLM +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) DEFAULT_REPLICATE_TEMP = 0.75 diff --git a/llama_index/llms/rungpt.py b/llama_index/llms/rungpt.py index 835163532..e0296ac1f 100644 --- a/llama_index/llms/rungpt.py +++ b/llama_index/llms/rungpt.py @@ -4,9 +4,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Union from llama_index.bridge.pydantic import Field from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.llm import LLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -17,6 +15,8 @@ from llama_index.llms.types import ( LLMMetadata, MessageRole, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.llm import LLM from llama_index.types import BaseOutputParser, PydanticProgramMode DEFAULT_RUNGPT_MODEL = "rungpt" diff --git a/llama_index/llms/types.py b/llama_index/llms/types.py index 9db785861..ebc949983 100644 --- a/llama_index/llms/types.py +++ b/llama_index/llms/types.py @@ -1,110 +1,29 @@ -from enum import Enum -from typing import Any, AsyncGenerator, Generator, Optional - -from llama_index.bridge.pydantic import BaseModel, Field -from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS - - -class MessageRole(str, Enum): - """Message role.""" - - SYSTEM = "system" - USER = "user" - ASSISTANT = "assistant" - FUNCTION = "function" - TOOL = "tool" - - -# ===== Generic Model Input - Chat ===== -class ChatMessage(BaseModel): - """Chat message.""" - - role: MessageRole = MessageRole.USER - content: Optional[Any] = "" - additional_kwargs: dict = Field(default_factory=dict) - - def __str__(self) -> str: - return f"{self.role.value}: {self.content}" - - -# ===== Generic Model Output - Chat ===== -class ChatResponse(BaseModel): - """Chat response.""" - - message: ChatMessage - raw: Optional[dict] = None - delta: Optional[str] = None - additional_kwargs: dict = Field(default_factory=dict) - - def __str__(self) -> str: - return str(self.message) - - -ChatResponseGen = Generator[ChatResponse, None, None] -ChatResponseAsyncGen = AsyncGenerator[ChatResponse, None] - - -# ===== Generic Model Output - Completion ===== -class CompletionResponse(BaseModel): - """ - Completion response. - - Fields: - text: Text content of the response if not streaming, or if streaming, - the current extent of streamed text. - additional_kwargs: Additional information on the response(i.e. token - counts, function calling information). - raw: Optional raw JSON that was parsed to populate text, if relevant. - delta: New text that just streamed in (only relevant when streaming). - """ - - text: str - additional_kwargs: dict = Field(default_factory=dict) - raw: Optional[dict] = None - delta: Optional[str] = None - - def __str__(self) -> str: - return self.text - - -CompletionResponseGen = Generator[CompletionResponse, None, None] -CompletionResponseAsyncGen = AsyncGenerator[CompletionResponse, None] - - -class LLMMetadata(BaseModel): - context_window: int = Field( - default=DEFAULT_CONTEXT_WINDOW, - description=( - "Total number of tokens the model can be input and output for one response." - ), - ) - num_output: int = Field( - default=DEFAULT_NUM_OUTPUTS, - description="Number of tokens the model can output when generating a response.", - ) - is_chat_model: bool = Field( - default=False, - description=( - "Set True if the model exposes a chat interface (i.e. can be passed a" - " sequence of messages, rather than text), like OpenAI's" - " /v1/chat/completions endpoint." - ), - ) - is_function_calling_model: bool = Field( - default=False, - # SEE: https://openai.com/blog/function-calling-and-other-api-updates - description=( - "Set True if the model supports function calling messages, similar to" - " OpenAI's function calling API. For example, converting 'Email Anya to" - " see if she wants to get coffee next Friday' to a function call like" - " `send_email(to: string, body: string)`." - ), - ) - model_name: str = Field( - default="unknown", - description=( - "The model's name used for logging, testing, and sanity checking. For some" - " models this can be automatically discerned. For other models, like" - " locally loaded models, this must be manually specified." - ), - ) +"""LLM Types. + +Maintain this file for backwards compat. + +""" + +from llama_index.core.llms.types import ( + ChatMessage, + ChatResponse, + ChatResponseAsyncGen, + ChatResponseGen, + CompletionResponse, + CompletionResponseAsyncGen, + CompletionResponseGen, + LLMMetadata, + MessageRole, +) + +__all__ = [ + "ChatMessage", + "ChatResponse", + "ChatResponseAsyncGen", + "ChatResponseGen", + "CompletionResponse", + "CompletionResponseAsyncGen", + "CompletionResponseGen", + "LLMMetadata", + "MessageRole", +] diff --git a/llama_index/llms/vertex.py b/llama_index/llms/vertex.py index f878381bb..9abae84e7 100644 --- a/llama_index/llms/vertex.py +++ b/llama_index/llms/vertex.py @@ -2,12 +2,7 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.base import ( - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.llm import LLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -18,6 +13,11 @@ from llama_index.llms.types import ( LLMMetadata, MessageRole, ) +from llama_index.llms.base import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.llms.llm import LLM from llama_index.llms.vertex_gemini_utils import is_gemini_model from llama_index.llms.vertex_utils import ( CHAT_MODELS, diff --git a/llama_index/llms/vertex_utils.py b/llama_index/llms/vertex_utils.py index e25d1a7cd..0ec14f1c3 100644 --- a/llama_index/llms/vertex_utils.py +++ b/llama_index/llms/vertex_utils.py @@ -12,7 +12,7 @@ from tenacity import ( wait_exponential, ) -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole CHAT_MODELS = ["chat-bison", "chat-bison-32k", "chat-bison@001"] TEXT_MODELS = ["text-bison", "text-bison-32k", "text-bison@001"] diff --git a/llama_index/llms/vllm.py b/llama_index/llms/vllm.py index e52870e93..25b38c197 100644 --- a/llama_index/llms/vllm.py +++ b/llama_index/llms/vllm.py @@ -3,16 +3,7 @@ from typing import Any, Callable, Dict, List, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.generic_utils import ( - completion_response_to_chat_response, - stream_completion_response_to_chat_response, -) -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, -) -from llama_index.llms.llm import LLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -22,6 +13,15 @@ from llama_index.llms.types import ( CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.generic_utils import ( + completion_response_to_chat_response, + stream_completion_response_to_chat_response, +) +from llama_index.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) +from llama_index.llms.llm import LLM from llama_index.llms.vllm_utils import get_response, post_http_request from llama_index.types import BaseOutputParser, PydanticProgramMode diff --git a/llama_index/llms/watsonx.py b/llama_index/llms/watsonx.py index 15c69392c..765cf0f5a 100644 --- a/llama_index/llms/watsonx.py +++ b/llama_index/llms/watsonx.py @@ -2,13 +2,7 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.base import llm_chat_callback, llm_completion_callback -from llama_index.llms.generic_utils import ( - completion_to_chat_decorator, - stream_completion_to_chat_decorator, -) -from llama_index.llms.llm import LLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -18,6 +12,12 @@ from llama_index.llms.types import ( CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.base import llm_chat_callback, llm_completion_callback +from llama_index.llms.generic_utils import ( + completion_to_chat_decorator, + stream_completion_to_chat_decorator, +) +from llama_index.llms.llm import LLM from llama_index.llms.watsonx_utils import ( WATSONX_MODELS, get_from_param_or_env_without_error, diff --git a/llama_index/llms/xinference.py b/llama_index/llms/xinference.py index 62c02e90f..f4b970bcf 100644 --- a/llama_index/llms/xinference.py +++ b/llama_index/llms/xinference.py @@ -3,12 +3,7 @@ from typing import Any, Callable, Dict, Optional, Sequence, Tuple from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager -from llama_index.llms.base import ( - llm_chat_callback, - llm_completion_callback, -) -from llama_index.llms.custom import CustomLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, @@ -17,6 +12,11 @@ from llama_index.llms.types import ( LLMMetadata, MessageRole, ) +from llama_index.llms.base import ( + llm_chat_callback, + llm_completion_callback, +) +from llama_index.llms.custom import CustomLLM from llama_index.llms.xinference_utils import ( xinference_message_to_history, xinference_modelname_to_contextsize, diff --git a/llama_index/llms/xinference_utils.py b/llama_index/llms/xinference_utils.py index bc1be0515..224df573f 100644 --- a/llama_index/llms/xinference_utils.py +++ b/llama_index/llms/xinference_utils.py @@ -2,7 +2,7 @@ from typing import Optional from typing_extensions import NotRequired, TypedDict -from llama_index.llms.types import ChatMessage +from llama_index.core.llms.types import ChatMessage XINFERENCE_MODEL_SIZES = { "baichuan": 2048, diff --git a/llama_index/memory/chat_memory_buffer.py b/llama_index/memory/chat_memory_buffer.py index a8fcb64cd..5aa96189c 100644 --- a/llama_index/memory/chat_memory_buffer.py +++ b/llama_index/memory/chat_memory_buffer.py @@ -2,6 +2,7 @@ import json from typing import Any, Callable, Dict, List, Optional from llama_index.bridge.pydantic import Field, root_validator +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.llms.llm import LLM from llama_index.llms.types import ChatMessage, MessageRole from llama_index.memory.types import DEFAULT_CHAT_STORE_KEY, BaseMemory diff --git a/llama_index/memory/types.py b/llama_index/memory/types.py index 42ea49cb6..a84a18858 100644 --- a/llama_index/memory/types.py +++ b/llama_index/memory/types.py @@ -1,8 +1,8 @@ from abc import abstractmethod from typing import Any, List, Optional +from llama_index.core.llms.types import ChatMessage from llama_index.llms.llm import LLM -from llama_index.llms.types import ChatMessage from llama_index.schema import BaseComponent DEFAULT_CHAT_STORE_KEY = "chat_history" diff --git a/llama_index/multi_modal_llms/base.py b/llama_index/multi_modal_llms/base.py index fd6f16878..068ae2e07 100644 --- a/llama_index/multi_modal_llms/base.py +++ b/llama_index/multi_modal_llms/base.py @@ -7,7 +7,7 @@ from llama_index.constants import ( DEFAULT_NUM_INPUT_FILES, DEFAULT_NUM_OUTPUTS, ) -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, diff --git a/llama_index/multi_modal_llms/gemini.py b/llama_index/multi_modal_llms/gemini.py index 13935cad2..aa6920a3b 100644 --- a/llama_index/multi_modal_llms/gemini.py +++ b/llama_index/multi_modal_llms/gemini.py @@ -6,13 +6,7 @@ from typing import Any, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE -from llama_index.llms.gemini_utils import ( - ROLES_FROM_GEMINI, - chat_from_gemini_response, - chat_message_to_gemini, - completion_from_gemini_response, -) -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -21,6 +15,12 @@ from llama_index.llms.types import ( CompletionResponseAsyncGen, CompletionResponseGen, ) +from llama_index.llms.gemini_utils import ( + ROLES_FROM_GEMINI, + chat_from_gemini_response, + chat_message_to_gemini, + completion_from_gemini_response, +) from llama_index.multi_modal_llms import ( MultiModalLLM, MultiModalLLMMetadata, diff --git a/llama_index/multi_modal_llms/openai.py b/llama_index/multi_modal_llms/openai.py index e92c8a625..2d37fc119 100644 --- a/llama_index/multi_modal_llms/openai.py +++ b/llama_index/multi_modal_llms/openai.py @@ -16,15 +16,7 @@ from llama_index.constants import ( DEFAULT_NUM_OUTPUTS, DEFAULT_TEMPERATURE, ) -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, -) -from llama_index.llms.openai_utils import ( - from_openai_message, - resolve_openai_credentials, - to_openai_message_dicts, -) -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -34,6 +26,14 @@ from llama_index.llms.types import ( CompletionResponseGen, MessageRole, ) +from llama_index.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) +from llama_index.llms.openai_utils import ( + from_openai_message, + resolve_openai_credentials, + to_openai_message_dicts, +) from llama_index.multi_modal_llms import ( MultiModalLLM, MultiModalLLMMetadata, diff --git a/llama_index/multi_modal_llms/replicate_multi_modal.py b/llama_index/multi_modal_llms/replicate_multi_modal.py index b0ae63ca6..3cf2a3316 100644 --- a/llama_index/multi_modal_llms/replicate_multi_modal.py +++ b/llama_index/multi_modal_llms/replicate_multi_modal.py @@ -4,10 +4,7 @@ from typing import Any, Callable, Dict, Optional, Sequence from llama_index.bridge.pydantic import Field, PrivateAttr from llama_index.callbacks import CallbackManager from llama_index.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS -from llama_index.llms.generic_utils import ( - messages_to_prompt as generic_messages_to_prompt, -) -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseAsyncGen, @@ -16,6 +13,9 @@ from llama_index.llms.types import ( CompletionResponseAsyncGen, CompletionResponseGen, ) +from llama_index.llms.generic_utils import ( + messages_to_prompt as generic_messages_to_prompt, +) from llama_index.multi_modal_llms import ( MultiModalLLM, MultiModalLLMMetadata, diff --git a/llama_index/node_parser/relational/base_element.py b/llama_index/node_parser/relational/base_element.py index 85ac34c67..9758f8eda 100644 --- a/llama_index/node_parser/relational/base_element.py +++ b/llama_index/node_parser/relational/base_element.py @@ -6,10 +6,10 @@ from tqdm import tqdm from llama_index.bridge.pydantic import BaseModel, Field, ValidationError from llama_index.callbacks.base import CallbackManager +from llama_index.core.response.schema import PydanticResponse from llama_index.llms.llm import LLM from llama_index.llms.openai import OpenAI from llama_index.node_parser.interface import NodeParser -from llama_index.response.schema import PydanticResponse from llama_index.schema import BaseNode, Document, IndexNode, TextNode from llama_index.utils import get_tqdm_iterable diff --git a/llama_index/objects/base.py b/llama_index/objects/base.py index 210859b76..e6aeb7678 100644 --- a/llama_index/objects/base.py +++ b/llama_index/objects/base.py @@ -4,7 +4,7 @@ import pickle import warnings from typing import Any, Generic, List, Optional, Sequence, Type, TypeVar -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.base import BaseIndex from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.objects.base_node_mapping import ( diff --git a/llama_index/prompts/__init__.py b/llama_index/prompts/__init__.py index 40955a806..9f9ec2b3b 100644 --- a/llama_index/prompts/__init__.py +++ b/llama_index/prompts/__init__.py @@ -1,6 +1,6 @@ """Prompt class.""" -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.prompts.base import ( BasePromptTemplate, ChatPromptTemplate, diff --git a/llama_index/prompts/base.py b/llama_index/prompts/base.py index 18f0532b2..4b3a717f5 100644 --- a/llama_index/prompts/base.py +++ b/llama_index/prompts/base.py @@ -13,6 +13,7 @@ if TYPE_CHECKING: ConditionalPromptSelector as LangchainSelector, ) from llama_index.bridge.pydantic import BaseModel +from llama_index.core.llms.types import ChatMessage from llama_index.llms.base import BaseLLM from llama_index.llms.generic_utils import ( messages_to_prompt as default_messages_to_prompt, @@ -20,7 +21,6 @@ from llama_index.llms.generic_utils import ( from llama_index.llms.generic_utils import ( prompt_to_messages, ) -from llama_index.llms.types import ChatMessage from llama_index.prompts.prompt_type import PromptType from llama_index.prompts.utils import get_template_vars from llama_index.types import BaseOutputParser diff --git a/llama_index/prompts/chat_prompts.py b/llama_index/prompts/chat_prompts.py index 3fb855103..f83ac5584 100644 --- a/llama_index/prompts/chat_prompts.py +++ b/llama_index/prompts/chat_prompts.py @@ -1,6 +1,6 @@ """Prompts for ChatGPT.""" -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.prompts.base import ChatPromptTemplate # text qa prompt diff --git a/llama_index/query_engine/__init__.py b/llama_index/query_engine/__init__.py index 8aa6632eb..f3bdfd462 100644 --- a/llama_index/query_engine/__init__.py +++ b/llama_index/query_engine/__init__.py @@ -1,4 +1,4 @@ -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine # SQL from llama_index.indices.struct_store.sql_query import ( diff --git a/llama_index/query_engine/citation_query_engine.py b/llama_index/query_engine/citation_query_engine.py index 2268a866e..6c4aa6af9 100644 --- a/llama_index/query_engine/citation_query_engine.py +++ b/llama_index/query_engine/citation_query_engine.py @@ -2,14 +2,15 @@ from typing import Any, List, Optional, Sequence from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.indices.base import BaseGPTIndex from llama_index.node_parser import SentenceSplitter, TextSplitter from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import PromptTemplate from llama_index.prompts.base import BasePromptTemplate from llama_index.prompts.mixin import PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import ( BaseSynthesizer, ResponseMode, diff --git a/llama_index/query_engine/cogniswitch_query_engine.py b/llama_index/query_engine/cogniswitch_query_engine.py index 072c0512f..c6886f275 100644 --- a/llama_index/query_engine/cogniswitch_query_engine.py +++ b/llama_index/query_engine/cogniswitch_query_engine.py @@ -2,8 +2,8 @@ from typing import Any, Dict import requests -from llama_index.core import BaseQueryEngine -from llama_index.response.schema import Response +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import Response from llama_index.schema import QueryBundle diff --git a/llama_index/query_engine/custom.py b/llama_index/query_engine/custom.py index bbee3e9b7..7b534edbb 100644 --- a/llama_index/query_engine/custom.py +++ b/llama_index/query_engine/custom.py @@ -5,9 +5,9 @@ from typing import Union from llama_index.bridge.pydantic import BaseModel, Field from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.prompts.mixin import PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import QueryBundle, QueryType STR_OR_RESPONSE_TYPE = Union[RESPONSE_TYPE, str] diff --git a/llama_index/query_engine/flare/base.py b/llama_index/query_engine/flare/base.py index c83473f8e..89c19e820 100644 --- a/llama_index/query_engine/flare/base.py +++ b/llama_index/query_engine/flare/base.py @@ -7,7 +7,8 @@ Active Retrieval Augmented Generation. from typing import Any, Dict, Optional from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType, PromptMixinType from llama_index.query_engine.flare.answer_inserter import ( @@ -18,7 +19,6 @@ from llama_index.query_engine.flare.output_parser import ( IsDoneOutputParser, QueryTaskOutputParser, ) -from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import QueryBundle from llama_index.service_context import ServiceContext from llama_index.utils import print_text diff --git a/llama_index/query_engine/graph_query_engine.py b/llama_index/query_engine/graph_query_engine.py index 98b594724..b97ea0add 100644 --- a/llama_index/query_engine/graph_query_engine.py +++ b/llama_index/query_engine/graph_query_engine.py @@ -1,9 +1,9 @@ from typing import Any, Dict, List, Optional, Tuple from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.indices.composability.graph import ComposableGraph -from llama_index.response.schema import RESPONSE_TYPE from llama_index.schema import IndexNode, NodeWithScore, QueryBundle, TextNode diff --git a/llama_index/query_engine/knowledge_graph_query_engine.py b/llama_index/query_engine/knowledge_graph_query_engine.py index 0c156d973..cf0d002af 100644 --- a/llama_index/query_engine/knowledge_graph_query_engine.py +++ b/llama_index/query_engine/knowledge_graph_query_engine.py @@ -4,14 +4,14 @@ import logging from typing import Any, Dict, List, Optional, Sequence from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.graph_stores.registry import ( GRAPH_STORE_CLASS_TO_GRAPH_STORE_TYPE, GraphStoreType, ) from llama_index.prompts.base import BasePromptTemplate, PromptTemplate, PromptType from llama_index.prompts.mixin import PromptDictType, PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.schema import NodeWithScore, QueryBundle, TextNode from llama_index.service_context import ServiceContext diff --git a/llama_index/query_engine/multi_modal.py b/llama_index/query_engine/multi_modal.py index 886b8acb1..a207f3e19 100644 --- a/llama_index/query_engine/multi_modal.py +++ b/llama_index/query_engine/multi_modal.py @@ -2,6 +2,7 @@ from typing import Any, Dict, List, Optional, Sequence, Tuple from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.indices.multi_modal import MultiModalVectorIndexRetriever from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle, QueryType @@ -11,7 +12,6 @@ from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_TEXT_QA_PROMPT from llama_index.prompts.mixin import PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import ImageNode, NodeWithScore diff --git a/llama_index/query_engine/multistep_query_engine.py b/llama_index/query_engine/multistep_query_engine.py index fc875ce65..68e42820a 100644 --- a/llama_index/query_engine/multistep_query_engine.py +++ b/llama_index/query_engine/multistep_query_engine.py @@ -1,10 +1,10 @@ from typing import Any, Callable, Dict, List, Optional, Tuple, cast from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.indices.query.query_transform.base import StepDecomposeQueryTransform from llama_index.prompts.mixin import PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.schema import NodeWithScore, QueryBundle, TextNode diff --git a/llama_index/query_engine/pandas_query_engine.py b/llama_index/query_engine/pandas_query_engine.py index a6ebb9530..24e4ab40d 100644 --- a/llama_index/query_engine/pandas_query_engine.py +++ b/llama_index/query_engine/pandas_query_engine.py @@ -13,13 +13,13 @@ from typing import Any, Callable, Optional import numpy as np import pandas as pd -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import Response from llama_index.exec_utils import safe_eval, safe_exec from llama_index.indices.struct_store.pandas import PandasIndex from llama_index.prompts import BasePromptTemplate from llama_index.prompts.default_prompts import DEFAULT_PANDAS_PROMPT from llama_index.prompts.mixin import PromptMixinType -from llama_index.response.schema import Response from llama_index.schema import QueryBundle from llama_index.service_context import ServiceContext from llama_index.utils import print_text diff --git a/llama_index/query_engine/retriever_query_engine.py b/llama_index/query_engine/retriever_query_engine.py index 1fa0355e7..f3696b391 100644 --- a/llama_index/query_engine/retriever_query_engine.py +++ b/llama_index/query_engine/retriever_query_engine.py @@ -3,11 +3,12 @@ from typing import Any, List, Optional, Sequence from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.postprocessor.types import BaseNodePostprocessor from llama_index.prompts import BasePromptTemplate from llama_index.prompts.mixin import PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import ( BaseSynthesizer, ResponseMode, diff --git a/llama_index/query_engine/retry_query_engine.py b/llama_index/query_engine/retry_query_engine.py index 7a7b20fdb..7cdad01f6 100644 --- a/llama_index/query_engine/retry_query_engine.py +++ b/llama_index/query_engine/retry_query_engine.py @@ -2,14 +2,14 @@ import logging from typing import Optional from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.evaluation.base import BaseEvaluator from llama_index.evaluation.guideline import GuidelineEvaluator from llama_index.indices.query.query_transform.feedback_transform import ( FeedbackQueryTransformation, ) from llama_index.prompts.mixin import PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import QueryBundle logger = logging.getLogger(__name__) diff --git a/llama_index/query_engine/retry_source_query_engine.py b/llama_index/query_engine/retry_source_query_engine.py index 13be39f13..7ff9eba4a 100644 --- a/llama_index/query_engine/retry_source_query_engine.py +++ b/llama_index/query_engine/retry_source_query_engine.py @@ -2,12 +2,12 @@ import logging from typing import Optional from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.evaluation import BaseEvaluator from llama_index.indices.list.base import SummaryIndex from llama_index.prompts.mixin import PromptMixinType from llama_index.query_engine.retriever_query_engine import RetrieverQueryEngine -from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import Document, QueryBundle from llama_index.service_context import ServiceContext diff --git a/llama_index/query_engine/router_query_engine.py b/llama_index/query_engine/router_query_engine.py index 0e31b2a3f..a17daa797 100644 --- a/llama_index/query_engine/router_query_engine.py +++ b/llama_index/query_engine/router_query_engine.py @@ -5,18 +5,19 @@ from llama_index.async_utils import run_async_tasks from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseQueryEngine, BaseRetriever -from llama_index.objects.base import ObjectRetriever -from llama_index.prompts.default_prompt_selectors import ( - DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, -) -from llama_index.prompts.mixin import PromptMixinType -from llama_index.response.schema import ( +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever +from llama_index.core.response.schema import ( RESPONSE_TYPE, PydanticResponse, Response, StreamingResponse, ) +from llama_index.objects.base import ObjectRetriever +from llama_index.prompts.default_prompt_selectors import ( + DEFAULT_TREE_SUMMARIZE_PROMPT_SEL, +) +from llama_index.prompts.mixin import PromptMixinType from llama_index.response_synthesizers import TreeSummarize from llama_index.schema import BaseNode, QueryBundle from llama_index.selectors.types import BaseSelector diff --git a/llama_index/query_engine/sql_join_query_engine.py b/llama_index/query_engine/sql_join_query_engine.py index faf5821fa..98bb210c0 100644 --- a/llama_index/query_engine/sql_join_query_engine.py +++ b/llama_index/query_engine/sql_join_query_engine.py @@ -4,7 +4,8 @@ import logging from typing import Callable, Dict, Optional, Union from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE, Response from llama_index.indices.query.query_transform.base import BaseQueryTransform from llama_index.indices.struct_store.sql_query import ( BaseSQLTableQueryEngine, @@ -14,7 +15,6 @@ from llama_index.llm_predictor.base import LLMPredictorType from llama_index.llms.utils import resolve_llm from llama_index.prompts.base import BasePromptTemplate, PromptTemplate from llama_index.prompts.mixin import PromptDictType, PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE, Response from llama_index.schema import QueryBundle from llama_index.selectors.llm_selectors import LLMSingleSelector from llama_index.selectors.pydantic_selectors import PydanticSingleSelector diff --git a/llama_index/query_engine/sub_question_query_engine.py b/llama_index/query_engine/sub_question_query_engine.py index 6bf4efd8d..8272e10d3 100644 --- a/llama_index/query_engine/sub_question_query_engine.py +++ b/llama_index/query_engine/sub_question_query_engine.py @@ -6,12 +6,12 @@ from llama_index.async_utils import run_async_tasks from llama_index.bridge.pydantic import BaseModel, Field from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.prompts.mixin import PromptMixinType from llama_index.question_gen.llm_generators import LLMQuestionGenerator from llama_index.question_gen.openai_generator import OpenAIQuestionGenerator from llama_index.question_gen.types import BaseQuestionGenerator, SubQuestion -from llama_index.response.schema import RESPONSE_TYPE from llama_index.response_synthesizers import BaseSynthesizer, get_response_synthesizer from llama_index.schema import NodeWithScore, QueryBundle, TextNode from llama_index.service_context import ServiceContext diff --git a/llama_index/query_engine/transform_query_engine.py b/llama_index/query_engine/transform_query_engine.py index 219d8ecf7..64f757419 100644 --- a/llama_index/query_engine/transform_query_engine.py +++ b/llama_index/query_engine/transform_query_engine.py @@ -1,10 +1,10 @@ from typing import List, Optional, Sequence from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.response.schema import RESPONSE_TYPE from llama_index.indices.query.query_transform.base import BaseQueryTransform from llama_index.prompts.mixin import PromptMixinType -from llama_index.response.schema import RESPONSE_TYPE from llama_index.schema import NodeWithScore, QueryBundle diff --git a/llama_index/readers/make_com/wrapper.py b/llama_index/readers/make_com/wrapper.py index 2cb79e6f3..4c9c9c18b 100644 --- a/llama_index/readers/make_com/wrapper.py +++ b/llama_index/readers/make_com/wrapper.py @@ -8,8 +8,8 @@ from typing import Any, List, Optional import requests +from llama_index.core.response.schema import Response from llama_index.readers.base import BaseReader -from llama_index.response.schema import Response from llama_index.schema import Document, NodeWithScore, TextNode diff --git a/llama_index/response/__init__.py b/llama_index/response/__init__.py index 294a5bc8e..b99207e3d 100644 --- a/llama_index/response/__init__.py +++ b/llama_index/response/__init__.py @@ -1,5 +1,5 @@ """Init params.""" -from llama_index.response.schema import Response +from llama_index.core.response.schema import Response __all__ = ["Response"] diff --git a/llama_index/response/notebook_utils.py b/llama_index/response/notebook_utils.py index b037c4a49..fc8b97640 100644 --- a/llama_index/response/notebook_utils.py +++ b/llama_index/response/notebook_utils.py @@ -8,8 +8,8 @@ import requests from IPython.display import Markdown, display from PIL import Image +from llama_index.core.response.schema import Response from llama_index.img_utils import b64_2_img -from llama_index.response.schema import Response from llama_index.schema import ImageNode, MetadataMode, NodeWithScore from llama_index.utils import truncate_text diff --git a/llama_index/response/pprint_utils.py b/llama_index/response/pprint_utils.py index 26a86c9f3..1b047ad85 100644 --- a/llama_index/response/pprint_utils.py +++ b/llama_index/response/pprint_utils.py @@ -3,7 +3,7 @@ import textwrap from pprint import pprint from typing import Any, Dict -from llama_index.response.schema import Response +from llama_index.core.response.schema import Response from llama_index.schema import NodeWithScore from llama_index.utils import truncate_text diff --git a/llama_index/response/schema.py b/llama_index/response/schema.py index 1834b6ccf..b9a6459b3 100644 --- a/llama_index/response/schema.py +++ b/llama_index/response/schema.py @@ -1,142 +1,14 @@ -"""Response schema.""" +"""Response schema. -from dataclasses import dataclass, field -from typing import Any, Dict, List, Optional, Union +Maintain this file for backwards compat. -from llama_index.bridge.pydantic import BaseModel -from llama_index.schema import NodeWithScore -from llama_index.types import TokenGen -from llama_index.utils import truncate_text +""" +from llama_index.core.response.schema import ( + RESPONSE_TYPE, + PydanticResponse, + Response, + StreamingResponse, +) -@dataclass -class Response: - """Response object. - - Returned if streaming=False. - - Attributes: - response: The response text. - - """ - - response: Optional[str] - source_nodes: List[NodeWithScore] = field(default_factory=list) - metadata: Optional[Dict[str, Any]] = None - - def __str__(self) -> str: - """Convert to string representation.""" - return self.response or "None" - - def get_formatted_sources(self, length: int = 100) -> str: - """Get formatted sources text.""" - texts = [] - for source_node in self.source_nodes: - fmt_text_chunk = truncate_text(source_node.node.get_content(), length) - doc_id = source_node.node.node_id or "None" - source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}" - texts.append(source_text) - return "\n\n".join(texts) - - -@dataclass -class PydanticResponse: - """PydanticResponse object. - - Returned if streaming=False. - - Attributes: - response: The response text. - - """ - - response: Optional[BaseModel] - source_nodes: List[NodeWithScore] = field(default_factory=list) - metadata: Optional[Dict[str, Any]] = None - - def __str__(self) -> str: - """Convert to string representation.""" - return self.response.json() if self.response else "None" - - def __getattr__(self, name: str) -> Any: - """Get attribute, but prioritize the pydantic response object.""" - if self.response is not None and name in self.response.dict(): - return getattr(self.response, name) - else: - return None - - def get_formatted_sources(self, length: int = 100) -> str: - """Get formatted sources text.""" - texts = [] - for source_node in self.source_nodes: - fmt_text_chunk = truncate_text(source_node.node.get_content(), length) - doc_id = source_node.node.node_id or "None" - source_text = f"> Source (Doc id: {doc_id}): {fmt_text_chunk}" - texts.append(source_text) - return "\n\n".join(texts) - - def get_response(self) -> Response: - """Get a standard response object.""" - response_txt = self.response.json() if self.response else "None" - return Response(response_txt, self.source_nodes, self.metadata) - - -@dataclass -class StreamingResponse: - """StreamingResponse object. - - Returned if streaming=True. - - Attributes: - response_gen: The response generator. - - """ - - response_gen: TokenGen - source_nodes: List[NodeWithScore] = field(default_factory=list) - metadata: Optional[Dict[str, Any]] = None - response_txt: Optional[str] = None - - def __str__(self) -> str: - """Convert to string representation.""" - if self.response_txt is None and self.response_gen is not None: - response_txt = "" - for text in self.response_gen: - response_txt += text - self.response_txt = response_txt - return self.response_txt or "None" - - def get_response(self) -> Response: - """Get a standard response object.""" - if self.response_txt is None and self.response_gen is not None: - response_txt = "" - for text in self.response_gen: - response_txt += text - self.response_txt = response_txt - return Response(self.response_txt, self.source_nodes, self.metadata) - - def print_response_stream(self) -> None: - """Print the response stream.""" - if self.response_txt is None and self.response_gen is not None: - response_txt = "" - for text in self.response_gen: - print(text, end="", flush=True) - response_txt += text - self.response_txt = response_txt - else: - print(self.response_txt) - - def get_formatted_sources(self, length: int = 100, trim_text: int = True) -> str: - """Get formatted sources text.""" - texts = [] - for source_node in self.source_nodes: - fmt_text_chunk = source_node.node.get_content() - if trim_text: - fmt_text_chunk = truncate_text(fmt_text_chunk, length) - node_id = source_node.node.node_id or "None" - source_text = f"> Source (Node id: {node_id}): {fmt_text_chunk}" - texts.append(source_text) - return "\n\n".join(texts) - - -RESPONSE_TYPE = Union[Response, StreamingResponse, PydanticResponse] +__all__ = ["Response", "PydanticResponse", "StreamingResponse", "RESPONSE_TYPE"] diff --git a/llama_index/response_synthesizers/base.py b/llama_index/response_synthesizers/base.py index 9f77d4a58..5790b3313 100644 --- a/llama_index/response_synthesizers/base.py +++ b/llama_index/response_synthesizers/base.py @@ -13,13 +13,13 @@ from typing import Any, Dict, Generator, List, Optional, Sequence, Union from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.prompts.mixin import PromptMixin -from llama_index.response.schema import ( +from llama_index.core.response.schema import ( RESPONSE_TYPE, PydanticResponse, Response, StreamingResponse, ) +from llama_index.prompts.mixin import PromptMixin from llama_index.schema import BaseNode, MetadataMode, NodeWithScore, QueryBundle from llama_index.service_context import ServiceContext from llama_index.types import RESPONSE_TEXT_TYPE diff --git a/llama_index/response_synthesizers/google/generativeai/base.py b/llama_index/response_synthesizers/google/generativeai/base.py index e9daa9cbf..cbc1246cd 100644 --- a/llama_index/response_synthesizers/google/generativeai/base.py +++ b/llama_index/response_synthesizers/google/generativeai/base.py @@ -11,9 +11,9 @@ from typing import TYPE_CHECKING, Any, List, Optional, Sequence, cast from llama_index.bridge.pydantic import BaseModel # type: ignore from llama_index.callbacks.schema import CBEventType, EventPayload +from llama_index.core.response.schema import Response from llama_index.indices.query.schema import QueryBundle from llama_index.prompts.mixin import PromptDictType -from llama_index.response.schema import Response from llama_index.response_synthesizers.base import BaseSynthesizer, QueryTextType from llama_index.schema import MetadataMode, NodeWithScore, TextNode from llama_index.types import RESPONSE_TEXT_TYPE diff --git a/llama_index/retrievers/__init__.py b/llama_index/retrievers/__init__.py index 1e4c66d2a..171679943 100644 --- a/llama_index/retrievers/__init__.py +++ b/llama_index/retrievers/__init__.py @@ -1,4 +1,5 @@ -from llama_index.core import BaseImageRetriever, BaseRetriever +from llama_index.core.base_retriever import BaseRetriever +from llama_index.core.image_retriever import BaseImageRetriever from llama_index.indices.empty.retrievers import EmptyIndexRetriever from llama_index.indices.keyword_table.retrievers import KeywordTableSimpleRetriever from llama_index.indices.knowledge_graph.retrievers import ( diff --git a/llama_index/retrievers/auto_merging_retriever.py b/llama_index/retrievers/auto_merging_retriever.py index f27d4284c..4a1f0a60c 100644 --- a/llama_index/retrievers/auto_merging_retriever.py +++ b/llama_index/retrievers/auto_merging_retriever.py @@ -5,7 +5,7 @@ from collections import defaultdict from typing import Dict, List, Optional, Tuple, cast from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.query.schema import QueryBundle from llama_index.indices.utils import truncate_text from llama_index.indices.vector_store.retrievers.retriever import VectorIndexRetriever diff --git a/llama_index/retrievers/bm25_retriever.py b/llama_index/retrievers/bm25_retriever.py index 61cc5be38..3604c93fc 100644 --- a/llama_index/retrievers/bm25_retriever.py +++ b/llama_index/retrievers/bm25_retriever.py @@ -5,7 +5,7 @@ from nltk.stem import PorterStemmer from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.keyword_table.utils import simple_extract_keywords from llama_index.indices.vector_store.base import VectorStoreIndex from llama_index.schema import BaseNode, NodeWithScore, QueryBundle diff --git a/llama_index/retrievers/pathway_retriever.py b/llama_index/retrievers/pathway_retriever.py index fd4040f60..e7b6e311a 100644 --- a/llama_index/retrievers/pathway_retriever.py +++ b/llama_index/retrievers/pathway_retriever.py @@ -5,7 +5,7 @@ from typing import Any, Callable, List, Optional, Tuple, Union from llama_index.callbacks.base import CallbackManager from llama_index.constants import DEFAULT_SIMILARITY_TOP_K -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.embeddings import BaseEmbedding from llama_index.indices.query.schema import QueryBundle from llama_index.ingestion.pipeline import run_transformations diff --git a/llama_index/retrievers/recursive_retriever.py b/llama_index/retrievers/recursive_retriever.py index bc5817b1f..4ad3cd060 100644 --- a/llama_index/retrievers/recursive_retriever.py +++ b/llama_index/retrievers/recursive_retriever.py @@ -2,7 +2,8 @@ from typing import Dict, List, Optional, Tuple, Union from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseQueryEngine, BaseRetriever +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.base_retriever import BaseRetriever from llama_index.schema import BaseNode, IndexNode, NodeWithScore, QueryBundle, TextNode from llama_index.utils import print_text diff --git a/llama_index/retrievers/router_retriever.py b/llama_index/retrievers/router_retriever.py index b1f964846..72740a88c 100644 --- a/llama_index/retrievers/router_retriever.py +++ b/llama_index/retrievers/router_retriever.py @@ -5,7 +5,7 @@ import logging from typing import List, Optional, Sequence from llama_index.callbacks.schema import CBEventType, EventPayload -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.prompts.mixin import PromptMixinType from llama_index.schema import NodeWithScore, QueryBundle from llama_index.selectors.types import BaseSelector diff --git a/llama_index/retrievers/transform_retriever.py b/llama_index/retrievers/transform_retriever.py index f200f7510..df8228aca 100644 --- a/llama_index/retrievers/transform_retriever.py +++ b/llama_index/retrievers/transform_retriever.py @@ -1,7 +1,7 @@ from typing import List, Optional from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.query.query_transform.base import BaseQueryTransform from llama_index.prompts.mixin import PromptMixinType from llama_index.schema import NodeWithScore, QueryBundle diff --git a/llama_index/retrievers/you_retriever.py b/llama_index/retrievers/you_retriever.py index df042b6ce..f29f2ab63 100644 --- a/llama_index/retrievers/you_retriever.py +++ b/llama_index/retrievers/you_retriever.py @@ -7,7 +7,7 @@ from typing import List, Optional import requests from llama_index.callbacks.base import CallbackManager -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.query.schema import QueryBundle from llama_index.schema import NodeWithScore, QueryBundle, TextNode diff --git a/llama_index/schema.py b/llama_index/schema.py index 0ce2e62a1..cd6acd19f 100644 --- a/llama_index/schema.py +++ b/llama_index/schema.py @@ -760,5 +760,9 @@ class QueryBundle(DataClassJsonMixin): return [] return [self.image_path] + def __str__(self) -> str: + """Convert to string representation.""" + return self.query_str + QueryType = Union[str, QueryBundle] diff --git a/llama_index/service_context.py b/llama_index/service_context.py index c4378f248..13070115b 100644 --- a/llama_index/service_context.py +++ b/llama_index/service_context.py @@ -1,12 +1,11 @@ import logging from dataclasses import dataclass -from typing import List, Optional +from typing import Any, List, Optional, cast import llama_index from llama_index.bridge.pydantic import BaseModel from llama_index.callbacks.base import CallbackManager -from llama_index.embeddings.base import BaseEmbedding -from llama_index.embeddings.utils import EmbedType, resolve_embed_model +from llama_index.core.embeddings.base import BaseEmbedding from llama_index.indices.prompt_helper import PromptHelper from llama_index.llm_predictor import LLMPredictor from llama_index.llm_predictor.base import BaseLLMPredictor, LLMMetadata @@ -88,7 +87,7 @@ class ServiceContext: llm_predictor: Optional[BaseLLMPredictor] = None, llm: Optional[LLMType] = "default", prompt_helper: Optional[PromptHelper] = None, - embed_model: Optional[EmbedType] = "default", + embed_model: Optional[Any] = "default", node_parser: Optional[NodeParser] = None, text_splitter: Optional[TextSplitter] = None, transformations: Optional[List[TransformComponent]] = None, @@ -132,6 +131,10 @@ class ServiceContext: chunk_size_limit (Optional[int]): renamed to chunk_size """ + from llama_index.embeddings.utils import EmbedType, resolve_embed_model + + embed_model = cast(EmbedType, embed_model) + if chunk_size_limit is not None and chunk_size is None: logger.warning( "chunk_size_limit is deprecated, please specify chunk_size instead" @@ -227,7 +230,7 @@ class ServiceContext: llm_predictor: Optional[BaseLLMPredictor] = None, llm: Optional[LLMType] = "default", prompt_helper: Optional[PromptHelper] = None, - embed_model: Optional[EmbedType] = "default", + embed_model: Optional[Any] = "default", node_parser: Optional[NodeParser] = None, text_splitter: Optional[TextSplitter] = None, transformations: Optional[List[TransformComponent]] = None, @@ -245,6 +248,10 @@ class ServiceContext: chunk_size_limit: Optional[int] = None, ) -> "ServiceContext": """Instantiate a new service context using a previous as the defaults.""" + from llama_index.embeddings.utils import EmbedType, resolve_embed_model + + embed_model = cast(EmbedType, embed_model) + if chunk_size_limit is not None and chunk_size is None: logger.warning( "chunk_size_limit is deprecated, please specify chunk_size", diff --git a/llama_index/tools/query_engine.py b/llama_index/tools/query_engine.py index b0b9de791..8e151ac8a 100644 --- a/llama_index/tools/query_engine.py +++ b/llama_index/tools/query_engine.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, Any, Optional -from llama_index.core import BaseQueryEngine +from llama_index.core.base_query_engine import BaseQueryEngine if TYPE_CHECKING: from llama_index.langchain_helpers.agents.tools import ( diff --git a/llama_index/tools/retriever_tool.py b/llama_index/tools/retriever_tool.py index 029d320c4..9d2bbb712 100644 --- a/llama_index/tools/retriever_tool.py +++ b/llama_index/tools/retriever_tool.py @@ -2,7 +2,7 @@ from typing import TYPE_CHECKING, Any, Optional -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever if TYPE_CHECKING: from llama_index.langchain_helpers.agents.tools import LlamaIndexTool diff --git a/llama_index/types.py b/llama_index/types.py index e454b18e8..9197d04c2 100644 --- a/llama_index/types.py +++ b/llama_index/types.py @@ -14,7 +14,7 @@ from typing import ( ) from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole Model = TypeVar("Model", bound=BaseModel) diff --git a/tests/agent/openai/test_openai_agent.py b/tests/agent/openai/test_openai_agent.py index 81b60ae24..6e8798266 100644 --- a/tests/agent/openai/test_openai_agent.py +++ b/tests/agent/openai/test_openai_agent.py @@ -5,10 +5,10 @@ import pytest from llama_index.agent.openai.base import OpenAIAgent from llama_index.agent.openai.step import call_tool_with_error_handling from llama_index.chat_engine.types import AgentChatResponse, StreamingAgentChatResponse +from llama_index.core.llms.types import ChatMessage, ChatResponse from llama_index.llms.base import ChatMessage, ChatResponse from llama_index.llms.mock import MockLLM from llama_index.llms.openai import OpenAI -from llama_index.llms.types import ChatMessage, ChatResponse from llama_index.tools.function_tool import FunctionTool from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, ChoiceDelta diff --git a/tests/agent/react/test_react_agent.py b/tests/agent/react/test_react_agent.py index c728a5562..af4563499 100644 --- a/tests/agent/react/test_react_agent.py +++ b/tests/agent/react/test_react_agent.py @@ -7,13 +7,13 @@ from llama_index.agent.react.types import ObservationReasoningStep from llama_index.agent.types import Task from llama_index.bridge.pydantic import PrivateAttr from llama_index.chat_engine.types import AgentChatResponse, StreamingAgentChatResponse -from llama_index.llms.mock import MockLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, ChatResponseGen, MessageRole, ) +from llama_index.llms.mock import MockLLM from llama_index.tools.function_tool import FunctionTool from llama_index.tools.types import BaseTool diff --git a/tests/chat_engine/test_condense_question.py b/tests/chat_engine/test_condense_question.py index fb249a5a3..5a7a8d4c8 100644 --- a/tests/chat_engine/test_condense_question.py +++ b/tests/chat_engine/test_condense_question.py @@ -1,9 +1,9 @@ from unittest.mock import Mock from llama_index.chat_engine.condense_question import CondenseQuestionChatEngine -from llama_index.core import BaseQueryEngine -from llama_index.llms.types import ChatMessage, MessageRole -from llama_index.response.schema import Response +from llama_index.core.base_query_engine import BaseQueryEngine +from llama_index.core.llms.types import ChatMessage, MessageRole +from llama_index.core.response.schema import Response from llama_index.service_context import ServiceContext diff --git a/tests/chat_engine/test_simple.py b/tests/chat_engine/test_simple.py index fa6e191b2..f0d38d432 100644 --- a/tests/chat_engine/test_simple.py +++ b/tests/chat_engine/test_simple.py @@ -1,5 +1,5 @@ from llama_index.chat_engine.simple import SimpleChatEngine -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.service_context import ServiceContext diff --git a/tests/conftest.py b/tests/conftest.py index 5d6b5e2d1..cbbf5065c 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -5,9 +5,9 @@ from typing import Any, List, Optional import openai import pytest +from llama_index.core.llms.types import LLMMetadata from llama_index.llm_predictor.base import LLMPredictor from llama_index.llms.mock import MockLLM -from llama_index.llms.types import LLMMetadata from llama_index.node_parser.text import SentenceSplitter, TokenTextSplitter from llama_index.service_context import ServiceContext diff --git a/tests/embeddings/test_base.py b/tests/embeddings/test_base.py index 3eb3d1bec..ce4df5f90 100644 --- a/tests/embeddings/test_base.py +++ b/tests/embeddings/test_base.py @@ -3,7 +3,7 @@ import os from typing import Any, List from unittest.mock import patch -from llama_index.embeddings.base import SimilarityMode, mean_agg +from llama_index.core.embeddings.base import SimilarityMode, mean_agg from llama_index.embeddings.openai import OpenAIEmbedding from tests.conftest import CachedOpenAIApiKeys diff --git a/tests/evaluation/test_base.py b/tests/evaluation/test_base.py index d4ce1d3f8..93c9f1f69 100644 --- a/tests/evaluation/test_base.py +++ b/tests/evaluation/test_base.py @@ -1,9 +1,9 @@ from typing import Any, Optional, Sequence +from llama_index.core.response.schema import NodeWithScore, Response from llama_index.evaluation import BaseEvaluator from llama_index.evaluation.base import EvaluationResult from llama_index.prompts.mixin import PromptDictType -from llama_index.response.schema import NodeWithScore, Response from llama_index.schema import TextNode diff --git a/tests/indices/list/test_index.py b/tests/indices/list/test_index.py index 3ff749990..358e4b934 100644 --- a/tests/indices/list/test_index.py +++ b/tests/indices/list/test_index.py @@ -2,7 +2,7 @@ from typing import Dict, List, Tuple -from llama_index.core import BaseRetriever +from llama_index.core.base_retriever import BaseRetriever from llama_index.indices.list.base import ListRetrieverMode, SummaryIndex from llama_index.schema import BaseNode, Document from llama_index.service_context import ServiceContext diff --git a/tests/indices/managed/test_google.py b/tests/indices/managed/test_google.py index 644ec6043..225ddc77d 100644 --- a/tests/indices/managed/test_google.py +++ b/tests/indices/managed/test_google.py @@ -1,7 +1,7 @@ from unittest.mock import MagicMock, patch import pytest -from llama_index.response.schema import Response +from llama_index.core.response.schema import Response from llama_index.schema import Document try: diff --git a/tests/indices/struct_store/test_json_query.py b/tests/indices/struct_store/test_json_query.py index a84b13ad5..fdca94bf6 100644 --- a/tests/indices/struct_store/test_json_query.py +++ b/tests/indices/struct_store/test_json_query.py @@ -6,11 +6,11 @@ from typing import Any, Dict, cast from unittest.mock import patch import pytest +from llama_index.core.response.schema import Response from llama_index.indices.struct_store.json_query import JSONQueryEngine, JSONType from llama_index.llm_predictor import LLMPredictor from llama_index.llms.mock import MockLLM from llama_index.prompts.base import BasePromptTemplate -from llama_index.response.schema import Response from llama_index.schema import QueryBundle from llama_index.service_context import ServiceContext diff --git a/tests/llms/test_anthropic.py b/tests/llms/test_anthropic.py index c7386ffbd..187a67718 100644 --- a/tests/llms/test_anthropic.py +++ b/tests/llms/test_anthropic.py @@ -1,6 +1,6 @@ import pytest +from llama_index.core.llms.types import ChatMessage from llama_index.llms.anthropic import Anthropic -from llama_index.llms.types import ChatMessage try: import anthropic diff --git a/tests/llms/test_anthropic_utils.py b/tests/llms/test_anthropic_utils.py index 76b4dce62..c0f7c179b 100644 --- a/tests/llms/test_anthropic_utils.py +++ b/tests/llms/test_anthropic_utils.py @@ -1,9 +1,9 @@ import pytest +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.llms.anthropic_utils import ( anthropic_modelname_to_contextsize, messages_to_anthropic_prompt, ) -from llama_index.llms.types import ChatMessage, MessageRole def test_messages_to_anthropic_prompt() -> None: diff --git a/tests/llms/test_bedrock.py b/tests/llms/test_bedrock.py index 26db8bddb..d26661d40 100644 --- a/tests/llms/test_bedrock.py +++ b/tests/llms/test_bedrock.py @@ -5,8 +5,8 @@ from typing import Any, Generator import pytest from botocore.response import StreamingBody from botocore.stub import Stubber +from llama_index.core.llms.types import ChatMessage from llama_index.llms import Bedrock -from llama_index.llms.types import ChatMessage from pytest import MonkeyPatch diff --git a/tests/llms/test_cohere.py b/tests/llms/test_cohere.py index 1d65c83a3..cd9eff0af 100644 --- a/tests/llms/test_cohere.py +++ b/tests/llms/test_cohere.py @@ -1,7 +1,7 @@ from typing import Any import pytest -from llama_index.llms.types import ChatMessage +from llama_index.core.llms.types import ChatMessage from pytest import MonkeyPatch try: diff --git a/tests/llms/test_custom.py b/tests/llms/test_custom.py index 3cd79eca3..90e874e3d 100644 --- a/tests/llms/test_custom.py +++ b/tests/llms/test_custom.py @@ -1,12 +1,12 @@ from typing import Any -from llama_index.llms.custom import CustomLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, CompletionResponse, CompletionResponseGen, LLMMetadata, ) +from llama_index.llms.custom import CustomLLM class TestLLM(CustomLLM): diff --git a/tests/llms/test_gradient.py b/tests/llms/test_gradient.py index 87db5148b..1b9817e7b 100644 --- a/tests/llms/test_gradient.py +++ b/tests/llms/test_gradient.py @@ -5,8 +5,8 @@ from typing import Any from unittest.mock import MagicMock, patch import pytest +from llama_index.core.llms.types import CompletionResponse from llama_index.llms.gradient import GradientBaseModelLLM, GradientModelAdapterLLM -from llama_index.llms.types import CompletionResponse class GradientModel(MagicMock): diff --git a/tests/llms/test_konko.py b/tests/llms/test_konko.py index 8b62dd0b6..848ac54cc 100644 --- a/tests/llms/test_konko.py +++ b/tests/llms/test_konko.py @@ -1,8 +1,8 @@ from typing import Any, Generator import pytest +from llama_index.core.llms.types import ChatMessage from llama_index.llms.konko import Konko -from llama_index.llms.types import ChatMessage from pytest import MonkeyPatch try: diff --git a/tests/llms/test_langchain.py b/tests/llms/test_langchain.py index 15b1c03f5..dbbbdc40b 100644 --- a/tests/llms/test_langchain.py +++ b/tests/llms/test_langchain.py @@ -1,7 +1,7 @@ from typing import List import pytest -from llama_index.llms.types import ChatMessage, MessageRole +from llama_index.core.llms.types import ChatMessage, MessageRole try: import cohere diff --git a/tests/llms/test_litellm.py b/tests/llms/test_litellm.py index 8786f7b50..2dc7a5d24 100644 --- a/tests/llms/test_litellm.py +++ b/tests/llms/test_litellm.py @@ -6,8 +6,8 @@ except ImportError: litellm = None # type: ignore import pytest +from llama_index.core.llms.types import ChatMessage from llama_index.llms.litellm import LiteLLM -from llama_index.llms.types import ChatMessage from pytest import MonkeyPatch from tests.conftest import CachedOpenAIApiKeys diff --git a/tests/llms/test_llama_utils.py b/tests/llms/test_llama_utils.py index b8587d7a5..23c8e6ee2 100644 --- a/tests/llms/test_llama_utils.py +++ b/tests/llms/test_llama_utils.py @@ -1,6 +1,7 @@ from typing import Sequence import pytest +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.llms.llama_utils import ( B_INST, B_SYS, @@ -12,7 +13,6 @@ from llama_index.llms.llama_utils import ( completion_to_prompt, messages_to_prompt, ) -from llama_index.llms.types import ChatMessage, MessageRole @pytest.fixture() diff --git a/tests/llms/test_localai.py b/tests/llms/test_localai.py index eda548c0a..d1035678a 100644 --- a/tests/llms/test_localai.py +++ b/tests/llms/test_localai.py @@ -1,8 +1,8 @@ from unittest.mock import MagicMock, patch import pytest +from llama_index.core.llms.types import ChatMessage from llama_index.llms import LocalAI -from llama_index.llms.types import ChatMessage from openai.types import Completion, CompletionChoice from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage diff --git a/tests/llms/test_openai.py b/tests/llms/test_openai.py index ebc42b209..305857581 100644 --- a/tests/llms/test_openai.py +++ b/tests/llms/test_openai.py @@ -3,8 +3,8 @@ from typing import Any, AsyncGenerator, Generator from unittest.mock import AsyncMock, MagicMock, patch import pytest +from llama_index.core.llms.types import ChatMessage from llama_index.llms.openai import OpenAI -from llama_index.llms.types import ChatMessage from openai.types.chat.chat_completion import ( ChatCompletion, ChatCompletionMessage, diff --git a/tests/llms/test_openai_like.py b/tests/llms/test_openai_like.py index 99a96f6f4..f6bbaa83b 100644 --- a/tests/llms/test_openai_like.py +++ b/tests/llms/test_openai_like.py @@ -1,9 +1,9 @@ from typing import List from unittest.mock import MagicMock, call, patch +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.llms import LOCALAI_DEFAULTS, OpenAILike from llama_index.llms.openai import Tokenizer -from llama_index.llms.types import ChatMessage, MessageRole from openai.types import Completion, CompletionChoice from openai.types.chat.chat_completion import ChatCompletion, Choice from openai.types.chat.chat_completion_message import ChatCompletionMessage diff --git a/tests/llms/test_openai_utils.py b/tests/llms/test_openai_utils.py index 712b1857b..1acf0b94e 100644 --- a/tests/llms/test_openai_utils.py +++ b/tests/llms/test_openai_utils.py @@ -2,13 +2,13 @@ from typing import List import pytest from llama_index.bridge.pydantic import BaseModel +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.llms.openai_utils import ( from_openai_message_dicts, from_openai_messages, to_openai_message_dicts, to_openai_tool, ) -from llama_index.llms.types import ChatMessage, MessageRole from openai.types.chat.chat_completion_assistant_message_param import ( FunctionCall as FunctionCallParam, ) diff --git a/tests/llms/test_palm.py b/tests/llms/test_palm.py index c36f2b7ee..bc221db00 100644 --- a/tests/llms/test_palm.py +++ b/tests/llms/test_palm.py @@ -31,8 +31,8 @@ class MockPalmPackage(MagicMock): return self._mock_models() +from llama_index.core.llms.types import CompletionResponse from llama_index.llms.palm import PaLM -from llama_index.llms.types import CompletionResponse @pytest.mark.skipif( diff --git a/tests/llms/test_rungpt.py b/tests/llms/test_rungpt.py index 475e719c8..163246f7a 100644 --- a/tests/llms/test_rungpt.py +++ b/tests/llms/test_rungpt.py @@ -2,11 +2,11 @@ from typing import Any, Dict, Generator, List from unittest.mock import MagicMock, patch import pytest -from llama_index.llms.rungpt import RunGptLLM -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, MessageRole, ) +from llama_index.llms.rungpt import RunGptLLM try: import sseclient diff --git a/tests/llms/test_vertex.py b/tests/llms/test_vertex.py index 3037ba36c..970398397 100644 --- a/tests/llms/test_vertex.py +++ b/tests/llms/test_vertex.py @@ -1,7 +1,7 @@ from typing import Sequence import pytest -from llama_index.llms.types import ChatMessage, CompletionResponse +from llama_index.core.llms.types import ChatMessage, CompletionResponse from llama_index.llms.vertex import Vertex from llama_index.llms.vertex_utils import init_vertexai diff --git a/tests/llms/test_watsonx.py b/tests/llms/test_watsonx.py index 990028bc9..006eace81 100644 --- a/tests/llms/test_watsonx.py +++ b/tests/llms/test_watsonx.py @@ -3,7 +3,7 @@ from typing import Any, Dict, Generator, Optional from unittest.mock import MagicMock import pytest -from llama_index.llms.types import ChatMessage +from llama_index.core.llms.types import ChatMessage try: import ibm_watson_machine_learning diff --git a/tests/llms/test_xinference.py b/tests/llms/test_xinference.py index 3c2000746..8299e8a10 100644 --- a/tests/llms/test_xinference.py +++ b/tests/llms/test_xinference.py @@ -1,7 +1,7 @@ from typing import Any, Dict, Generator, Iterator, List, Mapping, Sequence, Tuple, Union import pytest -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, CompletionResponse, diff --git a/tests/program/test_llm_program.py b/tests/program/test_llm_program.py index ae8d4dcab..fca35966a 100644 --- a/tests/program/test_llm_program.py +++ b/tests/program/test_llm_program.py @@ -4,7 +4,7 @@ import json from unittest.mock import MagicMock from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( ChatMessage, ChatResponse, CompletionResponse, diff --git a/tests/program/test_lmformatenforcer.py b/tests/program/test_lmformatenforcer.py index 9b9468c3d..1e3af9124 100644 --- a/tests/program/test_lmformatenforcer.py +++ b/tests/program/test_lmformatenforcer.py @@ -3,8 +3,8 @@ from unittest.mock import MagicMock import pytest from llama_index.bridge.pydantic import BaseModel +from llama_index.core.llms.types import CompletionResponse from llama_index.llms.huggingface import HuggingFaceLLM -from llama_index.llms.types import CompletionResponse from llama_index.program.lmformatenforcer_program import LMFormatEnforcerPydanticProgram has_lmformatenforcer = find_spec("lmformatenforcer") is not None diff --git a/tests/program/test_multi_modal_llm_program.py b/tests/program/test_multi_modal_llm_program.py index 7d1fe9b84..96022532d 100644 --- a/tests/program/test_multi_modal_llm_program.py +++ b/tests/program/test_multi_modal_llm_program.py @@ -5,7 +5,7 @@ from typing import Sequence from unittest.mock import MagicMock from llama_index.bridge.pydantic import BaseModel -from llama_index.llms.types import ( +from llama_index.core.llms.types import ( CompletionResponse, ) from llama_index.multi_modal_llms import MultiModalLLMMetadata diff --git a/tests/prompts/test_base.py b/tests/prompts/test_base.py index 00993b6f6..25d011d5d 100644 --- a/tests/prompts/test_base.py +++ b/tests/prompts/test_base.py @@ -4,8 +4,8 @@ from typing import Any import pytest +from llama_index.core.llms.types import ChatMessage, MessageRole from llama_index.llms import MockLLM -from llama_index.llms.types import ChatMessage, MessageRole from llama_index.prompts import ( ChatPromptTemplate, LangchainPromptTemplate, diff --git a/tests/query_engine/test_cogniswitch_query_engine.py b/tests/query_engine/test_cogniswitch_query_engine.py index 842f4d8fd..c5bedeaf0 100644 --- a/tests/query_engine/test_cogniswitch_query_engine.py +++ b/tests/query_engine/test_cogniswitch_query_engine.py @@ -2,8 +2,8 @@ from typing import Any from unittest.mock import patch import pytest +from llama_index.core.response.schema import Response from llama_index.query_engine.cogniswitch_query_engine import CogniswitchQueryEngine -from llama_index.response.schema import Response @pytest.fixture() diff --git a/tests/query_engine/test_pandas.py b/tests/query_engine/test_pandas.py index 0c7acedd5..b95e14961 100644 --- a/tests/query_engine/test_pandas.py +++ b/tests/query_engine/test_pandas.py @@ -7,13 +7,13 @@ from typing import Any, Dict, cast import pandas as pd import pytest +from llama_index.core.response.schema import Response from llama_index.indices.query.schema import QueryBundle from llama_index.indices.service_context import ServiceContext from llama_index.query_engine.pandas_query_engine import ( PandasQueryEngine, default_output_processor, ) -from llama_index.response.schema import Response def test_pandas_query_engine(mock_service_context: ServiceContext) -> None: -- GitLab