From f8dc93b715d8ed1da077063633b1624efe518068 Mon Sep 17 00:00:00 2001 From: Simon Suo <simonsdsuo@gmail.com> Date: Fri, 23 Jun 2023 23:52:14 -0700 Subject: [PATCH] Track langchain dependency via bridge module. (#6573) * wip * wip * wip * wip --- benchmarks/struct_indices/spider/evaluate.py | 4 +- .../struct_indices/spider/generate_sql.py | 5 +- .../struct_indices/spider/spider_utils.py | 4 +- experimental/cli/configuration.py | 4 +- llama_index/agent/openai_agent.py | 4 +- llama_index/agent/retriever_openai_agent.py | 3 +- llama_index/bridge/__init__.py | 0 llama_index/bridge/langchain.py | 106 ++++++++++++++++++ llama_index/chat_engine/react.py | 3 +- llama_index/chat_engine/simple.py | 3 +- llama_index/chat_engine/utils.py | 3 +- llama_index/embeddings/langchain.py | 2 +- llama_index/evaluation/dataset_generation.py | 2 +- llama_index/evaluation/guideline_eval.py | 2 +- .../indices/query/query_transform/base.py | 2 +- llama_index/indices/service_context.py | 2 +- .../indices/struct_store/json_query.py | 2 +- .../indices/tree/select_leaf_retriever.py | 2 +- .../langchain_helpers/agents/agents.py | 11 +- .../langchain_helpers/agents/toolkits.py | 3 +- llama_index/langchain_helpers/agents/tools.py | 2 +- .../langchain_helpers/memory_wrapper.py | 11 +- llama_index/langchain_helpers/sql_wrapper.py | 2 +- llama_index/langchain_helpers/streaming.py | 3 +- .../langchain_helpers/text_splitter.py | 2 +- llama_index/llm_predictor/base.py | 8 +- llama_index/llm_predictor/chatgpt.py | 13 ++- llama_index/output_parsers/guardrails.py | 2 +- llama_index/output_parsers/langchain.py | 2 +- llama_index/playground/base.py | 2 +- llama_index/program/openai_program.py | 3 +- llama_index/prompts/base.py | 7 +- llama_index/prompts/chat_prompts.py | 2 +- .../prompts/default_prompt_selectors.py | 2 +- llama_index/query_engine/flare/base.py | 2 +- .../query_engine/pandas_query_engine.py | 2 +- .../query_engine/sql_join_query_engine.py | 2 +- .../query_engine/sub_question_query_engine.py | 2 +- llama_index/readers/base.py | 2 +- llama_index/readers/obsidian.py | 2 +- llama_index/readers/schema/base.py | 2 +- llama_index/selectors/pydantic_selectors.py | 2 +- .../token_counter/mock_chain_wrapper.py | 2 +- llama_index/tools/function_tool.py | 2 +- llama_index/tools/query_plan.py | 2 +- llama_index/tools/types.py | 2 +- setup.py | 1 + tests/indices/test_prompt_helper.py | 2 +- tests/llm_predictor/test_base.py | 4 +- tests/output_parsers/test_base.py | 6 +- tests/prompts/test_base.py | 10 +- tests/token_predictor/test_base.py | 2 +- 52 files changed, 191 insertions(+), 86 deletions(-) create mode 100644 llama_index/bridge/__init__.py create mode 100644 llama_index/bridge/langchain.py diff --git a/benchmarks/struct_indices/spider/evaluate.py b/benchmarks/struct_indices/spider/evaluate.py index 3a6630a662..ebed59cb17 100644 --- a/benchmarks/struct_indices/spider/evaluate.py +++ b/benchmarks/struct_indices/spider/evaluate.py @@ -6,8 +6,8 @@ import logging import os from typing import Dict, List, Optional -from langchain.chat_models import ChatOpenAI -from langchain.schema import HumanMessage +from llama_index.bridge.langchain import ChatOpenAI +from llama_index.bridge.langchain import HumanMessage from llama_index.response.schema import Response from spider_utils import create_indexes, load_examples from tqdm import tqdm diff --git a/benchmarks/struct_indices/spider/generate_sql.py b/benchmarks/struct_indices/spider/generate_sql.py index 29a68af366..cf42f2054c 100644 --- a/benchmarks/struct_indices/spider/generate_sql.py +++ b/benchmarks/struct_indices/spider/generate_sql.py @@ -5,9 +5,8 @@ import logging import os import re -from langchain.chat_models import ChatOpenAI -from langchain.llms import OpenAI -from langchain.base_language import BaseLanguageModel +from llama_index.bridge.langchain import ChatOpenAI, OpenAI +from llama_index.bridge.langchain import BaseLanguageModel from sqlalchemy import create_engine, text from tqdm import tqdm diff --git a/benchmarks/struct_indices/spider/spider_utils.py b/benchmarks/struct_indices/spider/spider_utils.py index aa37e10e35..e29fc02e68 100644 --- a/benchmarks/struct_indices/spider/spider_utils.py +++ b/benchmarks/struct_indices/spider/spider_utils.py @@ -4,8 +4,8 @@ import json import os from typing import Dict, Tuple, Union -from langchain import OpenAI -from langchain.chat_models import ChatOpenAI +from llama_index.bridge.langchain import OpenAI +from llama_index.bridge.langchain import ChatOpenAI from sqlalchemy import create_engine, text from llama_index import SQLStructStoreIndex, LLMPredictor, SQLDatabase diff --git a/experimental/cli/configuration.py b/experimental/cli/configuration.py index b9850d5931..a13c98e6cd 100644 --- a/experimental/cli/configuration.py +++ b/experimental/cli/configuration.py @@ -2,8 +2,8 @@ import os from configparser import ConfigParser, SectionProxy from typing import Any, Type from llama_index.embeddings.openai import OpenAIEmbedding -from langchain import OpenAI -from langchain.base_language import BaseLanguageModel +from llama_index.bridge.langchain import OpenAI +from llama_index.bridge.langchain import BaseLanguageModel from llama_index.indices.base import BaseIndex from llama_index.embeddings.base import BaseEmbedding from llama_index import ( diff --git a/llama_index/agent/openai_agent.py b/llama_index/agent/openai_agent.py index 9d847a123a..555354da36 100644 --- a/llama_index/agent/openai_agent.py +++ b/llama_index/agent/openai_agent.py @@ -2,9 +2,7 @@ import json from abc import abstractmethod from typing import Callable, List, Optional -from langchain.chat_models import ChatOpenAI -from langchain.memory import ChatMessageHistory -from langchain.schema import FunctionMessage +from llama_index.bridge.langchain import FunctionMessage, ChatMessageHistory, ChatOpenAI from llama_index.callbacks.base import CallbackManager from llama_index.chat_engine.types import BaseChatEngine diff --git a/llama_index/agent/retriever_openai_agent.py b/llama_index/agent/retriever_openai_agent.py index 8cd27f6efe..e03f0e6b80 100644 --- a/llama_index/agent/retriever_openai_agent.py +++ b/llama_index/agent/retriever_openai_agent.py @@ -4,9 +4,8 @@ from llama_index.agent.openai_agent import BaseOpenAIAgent from llama_index.objects.base import ObjectRetriever from llama_index.tools.types import BaseTool from typing import Optional, List -from langchain.chat_models import ChatOpenAI +from llama_index.bridge.langchain import ChatOpenAI, ChatMessageHistory from llama_index.callbacks.base import CallbackManager -from langchain.memory import ChatMessageHistory from llama_index.agent.openai_agent import ( SUPPORTED_MODEL_NAMES, DEFAULT_MAX_FUNCTION_CALLS, diff --git a/llama_index/bridge/__init__.py b/llama_index/bridge/__init__.py new file mode 100644 index 0000000000..e69de29bb2 diff --git a/llama_index/bridge/langchain.py b/llama_index/bridge/langchain.py new file mode 100644 index 0000000000..dd1ceed09e --- /dev/null +++ b/llama_index/bridge/langchain.py @@ -0,0 +1,106 @@ +import langchain + +# LLMs +from langchain.llms import BaseLLM, FakeListLLM, OpenAI, AI21, Cohere +from langchain.chat_models.base import BaseChatModel +from langchain.chat_models import ChatOpenAI +from langchain.base_language import BaseLanguageModel + +# embeddings +from langchain.embeddings.base import Embeddings + +# prompts +from langchain import PromptTemplate, BasePromptTemplate +from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model +from langchain.prompts.chat import ( + AIMessagePromptTemplate, + ChatPromptTemplate, + HumanMessagePromptTemplate, + BaseMessagePromptTemplate, +) + +# chain +from langchain import LLMChain + +# chat and memory +from langchain.memory.chat_memory import BaseChatMemory +from langchain.memory import ConversationBufferMemory, ChatMessageHistory + +# agents and tools +from langchain.agents.agent_toolkits.base import BaseToolkit +from langchain.agents import AgentType +from langchain.agents import AgentExecutor, initialize_agent +from langchain.tools import StructuredTool, Tool, BaseTool + +# input & output +from langchain.text_splitter import TextSplitter +from langchain.output_parsers import ResponseSchema +from langchain.output_parsers import PydanticOutputParser +from langchain.input import print_text, get_color_mapping + +# callback +from langchain.callbacks.base import BaseCallbackHandler, BaseCallbackManager + +# schema +from langchain.schema import AIMessage, FunctionMessage, BaseMessage, HumanMessage +from langchain.schema import BaseMemory +from langchain.schema import BaseOutputParser, LLMResult +from langchain.schema import ChatGeneration + +# misc +from langchain.sql_database import SQLDatabase +from langchain.cache import GPTCache, BaseCache +from langchain.docstore.document import Document + +__all__ = [ + "langchain", + "BaseLLM", + "FakeListLLM", + "OpenAI", + "AI21", + "Cohere", + "BaseChatModel", + "ChatOpenAI", + "BaseLanguageModel", + "Embeddings", + "PromptTemplate", + "BasePromptTemplate", + "ConditionalPromptSelector", + "is_chat_model", + "AIMessagePromptTemplate", + "ChatPromptTemplate", + "HumanMessagePromptTemplate", + "BaseMessagePromptTemplate", + "LLMChain", + "BaseChatMemory", + "ConversationBufferMemory", + "ChatMessageHistory", + "BaseToolkit", + "AgentType", + "AgentExecutor", + "initialize_agent", + "StructuredTool", + "Tool", + "BaseTool", + "TextSplitter", + "ResponseSchema", + "PydanticOutputParser", + "print_text", + "get_color_mapping", + "BaseCallbackHandler", + "BaseCallbackManager", + "AIMessage", + "FunctionMessage", + "BaseMessage", + "HumanMessage", + "BaseMemory", + "BaseOutputParser", + "HumanMessage", + "BaseMessage", + "LLMResult", + "ChatGeneration", + "SQLDatabase", + "GPTCache", + "BaseCache", + "Document", +] diff --git a/llama_index/chat_engine/react.py b/llama_index/chat_engine/react.py index e69965a8a1..2f3b839acd 100644 --- a/llama_index/chat_engine/react.py +++ b/llama_index/chat_engine/react.py @@ -1,7 +1,6 @@ from typing import Any, Optional, Sequence -from langchain.memory import ConversationBufferMemory -from langchain.memory.chat_memory import BaseChatMemory +from llama_index.bridge.langchain import ConversationBufferMemory, BaseChatMemory from llama_index.chat_engine.types import BaseChatEngine, ChatHistoryType from llama_index.chat_engine.utils import is_chat_model, to_langchain_chat_history diff --git a/llama_index/chat_engine/simple.py b/llama_index/chat_engine/simple.py index 531bac8c29..983bc4b3d3 100644 --- a/llama_index/chat_engine/simple.py +++ b/llama_index/chat_engine/simple.py @@ -1,7 +1,6 @@ from typing import Any, Optional -from langchain.chat_models.base import BaseChatModel -from langchain.schema import ChatGeneration +from llama_index.bridge.langchain import BaseChatModel, ChatGeneration from llama_index.chat_engine.types import BaseChatEngine, ChatHistoryType from llama_index.chat_engine.utils import ( diff --git a/llama_index/chat_engine/utils.py b/llama_index/chat_engine/utils.py index 25bb2899d8..ba137abdfc 100644 --- a/llama_index/chat_engine/utils.py +++ b/llama_index/chat_engine/utils.py @@ -1,7 +1,6 @@ from typing import Optional -from langchain.chat_models.base import BaseChatModel -from langchain.memory import ChatMessageHistory +from llama_index.bridge.langchain import BaseChatModel, ChatMessageHistory from llama_index.chat_engine.types import ChatHistoryType from llama_index.indices.service_context import ServiceContext diff --git a/llama_index/embeddings/langchain.py b/llama_index/embeddings/langchain.py index 0b89aa7dbc..e5b743fc51 100644 --- a/llama_index/embeddings/langchain.py +++ b/llama_index/embeddings/langchain.py @@ -3,7 +3,7 @@ from typing import Any, List -from langchain.embeddings.base import Embeddings as LCEmbeddings +from llama_index.bridge.langchain import Embeddings as LCEmbeddings from llama_index.embeddings.base import BaseEmbedding diff --git a/llama_index/evaluation/dataset_generation.py b/llama_index/evaluation/dataset_generation.py index a837cc7a81..aa71d2f316 100644 --- a/llama_index/evaluation/dataset_generation.py +++ b/llama_index/evaluation/dataset_generation.py @@ -4,7 +4,7 @@ from __future__ import annotations import re from typing import List, Optional -from langchain.chat_models import ChatOpenAI +from llama_index.bridge.langchain import ChatOpenAI from llama_index import ( Document, diff --git a/llama_index/evaluation/guideline_eval.py b/llama_index/evaluation/guideline_eval.py index 26a6de704f..5eab879df3 100644 --- a/llama_index/evaluation/guideline_eval.py +++ b/llama_index/evaluation/guideline_eval.py @@ -1,7 +1,7 @@ import logging from typing import Optional -from langchain.output_parsers import PydanticOutputParser +from llama_index.bridge.langchain import PydanticOutputParser from pydantic import BaseModel, Field from llama_index.evaluation.base import BaseEvaluator, Evaluation diff --git a/llama_index/indices/query/query_transform/base.py b/llama_index/indices/query/query_transform/base.py index 8551c9b57a..9fb4a68836 100644 --- a/llama_index/indices/query/query_transform/base.py +++ b/llama_index/indices/query/query_transform/base.py @@ -4,7 +4,7 @@ import dataclasses from abc import abstractmethod from typing import Dict, Optional, cast -from langchain.input import print_text +from llama_index.bridge.langchain import print_text from llama_index.indices.query.query_transform.prompts import ( DEFAULT_DECOMPOSE_QUERY_TRANSFORM_PROMPT, diff --git a/llama_index/indices/service_context.py b/llama_index/indices/service_context.py index 6a591e725c..ddbb282903 100644 --- a/llama_index/indices/service_context.py +++ b/llama_index/indices/service_context.py @@ -3,7 +3,7 @@ import logging from dataclasses import dataclass from typing import Optional -from langchain.base_language import BaseLanguageModel +from llama_index.bridge.langchain import BaseLanguageModel import llama_index from llama_index.callbacks.base import CallbackManager diff --git a/llama_index/indices/struct_store/json_query.py b/llama_index/indices/struct_store/json_query.py index cf2ebc533e..af9d588264 100644 --- a/llama_index/indices/struct_store/json_query.py +++ b/llama_index/indices/struct_store/json_query.py @@ -2,7 +2,7 @@ import json import logging from typing import Any, Callable, Dict, List, Optional, Union -from langchain.input import print_text +from llama_index.bridge.langchain import print_text from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle diff --git a/llama_index/indices/tree/select_leaf_retriever.py b/llama_index/indices/tree/select_leaf_retriever.py index 2230e9e134..84af7e548f 100644 --- a/llama_index/indices/tree/select_leaf_retriever.py +++ b/llama_index/indices/tree/select_leaf_retriever.py @@ -3,7 +3,7 @@ import logging from typing import Any, Dict, List, Optional, cast -from langchain.input import print_text +from llama_index.bridge.langchain import print_text from llama_index.data_structs.node import Node, NodeWithScore from llama_index.indices.base_retriever import BaseRetriever diff --git a/llama_index/langchain_helpers/agents/agents.py b/llama_index/langchain_helpers/agents/agents.py index 2bc4b3b5c5..dc077c46b0 100644 --- a/llama_index/langchain_helpers/agents/agents.py +++ b/llama_index/langchain_helpers/agents/agents.py @@ -2,10 +2,13 @@ from typing import Any, Optional -from langchain.agents import AgentExecutor, initialize_agent -from langchain.callbacks.base import BaseCallbackManager -from langchain.llms.base import BaseLLM -from langchain.agents.agent_types import AgentType +from llama_index.bridge.langchain import ( + BaseLLM, + AgentType, + AgentExecutor, + initialize_agent, + BaseCallbackManager, +) from llama_index.langchain_helpers.agents.toolkits import LlamaToolkit diff --git a/llama_index/langchain_helpers/agents/toolkits.py b/llama_index/langchain_helpers/agents/toolkits.py index d821f9f53a..968556b18f 100644 --- a/llama_index/langchain_helpers/agents/toolkits.py +++ b/llama_index/langchain_helpers/agents/toolkits.py @@ -2,8 +2,7 @@ from typing import List -from langchain.agents.agent_toolkits.base import BaseToolkit -from langchain.tools import BaseTool +from llama_index.bridge.langchain import BaseTool, BaseToolkit from pydantic import Field from llama_index.langchain_helpers.agents.tools import ( diff --git a/llama_index/langchain_helpers/agents/tools.py b/llama_index/langchain_helpers/agents/tools.py index c4c4f7f57a..4e49407864 100644 --- a/llama_index/langchain_helpers/agents/tools.py +++ b/llama_index/langchain_helpers/agents/tools.py @@ -2,7 +2,7 @@ from typing import Dict -from langchain.tools import BaseTool +from llama_index.bridge.langchain import BaseTool from pydantic import BaseModel, Field from llama_index.indices.query.base import BaseQueryEngine diff --git a/llama_index/langchain_helpers/memory_wrapper.py b/llama_index/langchain_helpers/memory_wrapper.py index 9fcfde0209..04a401caad 100644 --- a/llama_index/langchain_helpers/memory_wrapper.py +++ b/llama_index/langchain_helpers/memory_wrapper.py @@ -2,10 +2,13 @@ from typing import Any, Dict, List, Optional -from langchain.memory.chat_memory import BaseChatMemory -from langchain.schema import AIMessage -from langchain.schema import BaseMemory as Memory -from langchain.schema import BaseMessage, HumanMessage +from llama_index.bridge.langchain import ( + BaseChatMemory, + AIMessage, + BaseMemory as Memory, + BaseMessage, + HumanMessage, +) from pydantic import Field from llama_index.indices.base import BaseIndex diff --git a/llama_index/langchain_helpers/sql_wrapper.py b/llama_index/langchain_helpers/sql_wrapper.py index 179753aceb..d9d5d1e6a9 100644 --- a/llama_index/langchain_helpers/sql_wrapper.py +++ b/llama_index/langchain_helpers/sql_wrapper.py @@ -1,7 +1,7 @@ """SQL wrapper around SQLDatabase in langchain.""" from typing import Any, Dict, List, Tuple, Optional -from langchain.sql_database import SQLDatabase as LangchainSQLDatabase +from llama_index.bridge.langchain import SQLDatabase as LangchainSQLDatabase from sqlalchemy import MetaData, create_engine, insert, text from sqlalchemy.engine import Engine diff --git a/llama_index/langchain_helpers/streaming.py b/llama_index/langchain_helpers/streaming.py index 3e59a2fbd8..ba5dda56cd 100644 --- a/llama_index/langchain_helpers/streaming.py +++ b/llama_index/langchain_helpers/streaming.py @@ -2,8 +2,7 @@ from queue import Queue from threading import Event from typing import Any, Generator, Union -from langchain.callbacks.base import BaseCallbackHandler -from langchain.schema import LLMResult +from llama_index.bridge.langchain import BaseCallbackHandler, LLMResult class StreamingGeneratorCallbackHandler(BaseCallbackHandler): diff --git a/llama_index/langchain_helpers/text_splitter.py b/llama_index/langchain_helpers/text_splitter.py index aa4131ab4d..f524de4dfc 100644 --- a/llama_index/langchain_helpers/text_splitter.py +++ b/llama_index/langchain_helpers/text_splitter.py @@ -2,8 +2,8 @@ from dataclasses import dataclass from typing import Callable, List, Optional -from langchain.text_splitter import TextSplitter from llama_index.constants import DEFAULT_CHUNK_OVERLAP, DEFAULT_CHUNK_SIZE +from llama_index.bridge.langchain import TextSplitter from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload diff --git a/llama_index/llm_predictor/base.py b/llama_index/llm_predictor/base.py index c6b0107ab9..0a87e6d097 100644 --- a/llama_index/llm_predictor/base.py +++ b/llama_index/llm_predictor/base.py @@ -6,12 +6,10 @@ from dataclasses import dataclass from threading import Thread from typing import Any, Generator, Optional, Protocol, Tuple, runtime_checkable -import langchain import openai -from langchain import BaseCache, Cohere, LLMChain, OpenAI -from langchain.base_language import BaseLanguageModel -from langchain.chat_models import ChatOpenAI -from langchain.llms import AI21 +from llama_index.bridge.langchain import langchain +from llama_index.bridge.langchain import BaseCache, Cohere, LLMChain, OpenAI +from llama_index.bridge.langchain import ChatOpenAI, AI21, BaseLanguageModel from llama_index.callbacks.base import CallbackManager from llama_index.callbacks.schema import CBEventType, EventPayload diff --git a/llama_index/llm_predictor/chatgpt.py b/llama_index/llm_predictor/chatgpt.py index 8d91f6d74f..2d3f934798 100644 --- a/llama_index/llm_predictor/chatgpt.py +++ b/llama_index/llm_predictor/chatgpt.py @@ -4,16 +4,17 @@ import logging from typing import Any, List, Optional, Union import openai -from langchain import LLMChain -from langchain.chat_models import ChatOpenAI -from langchain.prompts.base import BasePromptTemplate -from langchain.prompts.chat import ( +from llama_index.bridge.langchain import ( + LLMChain, + ChatOpenAI, BaseMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, + BaseLanguageModel, + BaseMessage, + PromptTemplate, + BasePromptTemplate, ) -from langchain.prompts.prompt import PromptTemplate -from langchain.base_language import BaseLanguageModel, BaseMessage from llama_index.llm_predictor.base import LLMPredictor from llama_index.prompts.base import Prompt diff --git a/llama_index/output_parsers/guardrails.py b/llama_index/output_parsers/guardrails.py index 94d0f9d291..131a43c632 100644 --- a/llama_index/output_parsers/guardrails.py +++ b/llama_index/output_parsers/guardrails.py @@ -12,7 +12,7 @@ except ImportError: from copy import deepcopy from typing import Any, Callable, Optional -from langchain.llms.base import BaseLLM +from llama_index.bridge.langchain import BaseLLM from llama_index.output_parsers.base import BaseOutputParser diff --git a/llama_index/output_parsers/langchain.py b/llama_index/output_parsers/langchain.py index c63414bde9..e10762633e 100644 --- a/llama_index/output_parsers/langchain.py +++ b/llama_index/output_parsers/langchain.py @@ -3,7 +3,7 @@ from string import Formatter from typing import Any, Optional -from langchain.schema import BaseOutputParser as LCOutputParser +from llama_index.bridge.langchain import BaseOutputParser as LCOutputParser from llama_index.output_parsers.base import BaseOutputParser diff --git a/llama_index/playground/base.py b/llama_index/playground/base.py index 8ead94d5c4..acd317230d 100644 --- a/llama_index/playground/base.py +++ b/llama_index/playground/base.py @@ -5,7 +5,7 @@ import time from typing import Any, Dict, List, Optional, Type, Union import pandas as pd -from langchain.input import get_color_mapping, print_text +from llama_index.bridge.langchain import get_color_mapping, print_text from llama_index.indices.base import BaseIndex from llama_index.indices.list.base import ListIndex, ListRetrieverMode diff --git a/llama_index/program/openai_program.py b/llama_index/program/openai_program.py index da804e419e..6ca5a02f80 100644 --- a/llama_index/program/openai_program.py +++ b/llama_index/program/openai_program.py @@ -1,7 +1,6 @@ from typing import Any, Dict, Generic, Optional, Type, Union -from langchain.chat_models import ChatOpenAI -from langchain.schema import HumanMessage +from llama_index.bridge.langchain import ChatOpenAI, HumanMessage from llama_index.program.base_program import BasePydanticProgram, Model from llama_index.prompts.base import Prompt diff --git a/llama_index/prompts/base.py b/llama_index/prompts/base.py index 8dd649e4bc..77383a730d 100644 --- a/llama_index/prompts/base.py +++ b/llama_index/prompts/base.py @@ -2,10 +2,9 @@ from copy import deepcopy from typing import Any, Dict, Optional -from langchain import BasePromptTemplate as BaseLangchainPrompt -from langchain import PromptTemplate as LangchainPrompt -from langchain.base_language import BaseLanguageModel -from langchain.chains.prompt_selector import ConditionalPromptSelector +from llama_index.bridge.langchain import BasePromptTemplate as BaseLangchainPrompt +from llama_index.bridge.langchain import PromptTemplate as LangchainPrompt +from llama_index.bridge.langchain import BaseLanguageModel, ConditionalPromptSelector from llama_index.output_parsers.base import BaseOutputParser from llama_index.prompts.prompt_type import PromptType diff --git a/llama_index/prompts/chat_prompts.py b/llama_index/prompts/chat_prompts.py index 8986850679..2d02c83fea 100644 --- a/llama_index/prompts/chat_prompts.py +++ b/llama_index/prompts/chat_prompts.py @@ -1,6 +1,6 @@ """Prompts for ChatGPT.""" -from langchain.prompts.chat import ( +from llama_index.bridge.langchain import ( AIMessagePromptTemplate, ChatPromptTemplate, HumanMessagePromptTemplate, diff --git a/llama_index/prompts/default_prompt_selectors.py b/llama_index/prompts/default_prompt_selectors.py index 9f8d3d1a90..8aef6c153c 100644 --- a/llama_index/prompts/default_prompt_selectors.py +++ b/llama_index/prompts/default_prompt_selectors.py @@ -1,5 +1,5 @@ """Prompt selectors.""" -from langchain.chains.prompt_selector import ConditionalPromptSelector, is_chat_model +from llama_index.bridge.langchain import ConditionalPromptSelector, is_chat_model from llama_index.prompts.chat_prompts import ( CHAT_REFINE_PROMPT, diff --git a/llama_index/query_engine/flare/base.py b/llama_index/query_engine/flare/base.py index 27f8054f41..6b3858d142 100644 --- a/llama_index/query_engine/flare/base.py +++ b/llama_index/query_engine/flare/base.py @@ -4,7 +4,7 @@ Active Retrieval Augmented Generation. """ -from langchain.input import print_text +from llama_index.bridge.langchain import print_text from typing import Optional from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.service_context import ServiceContext diff --git a/llama_index/query_engine/pandas_query_engine.py b/llama_index/query_engine/pandas_query_engine.py index cc6466d16b..580f1b7538 100644 --- a/llama_index/query_engine/pandas_query_engine.py +++ b/llama_index/query_engine/pandas_query_engine.py @@ -4,7 +4,7 @@ import logging from typing import Any, Callable, Optional import pandas as pd -from langchain.input import print_text +from llama_index.bridge.langchain import print_text from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.query.schema import QueryBundle diff --git a/llama_index/query_engine/sql_join_query_engine.py b/llama_index/query_engine/sql_join_query_engine.py index 57e513a766..da823c6917 100644 --- a/llama_index/query_engine/sql_join_query_engine.py +++ b/llama_index/query_engine/sql_join_query_engine.py @@ -1,6 +1,6 @@ """SQL Join query engine.""" -from langchain.input import print_text +from llama_index.bridge.langchain import print_text from typing import Optional, cast, Dict, Callable from llama_index.indices.query.base import BaseQueryEngine from llama_index.indices.struct_store.sql_query import NLStructStoreQueryEngine diff --git a/llama_index/query_engine/sub_question_query_engine.py b/llama_index/query_engine/sub_question_query_engine.py index 9e94c2c583..37adc3a21b 100644 --- a/llama_index/query_engine/sub_question_query_engine.py +++ b/llama_index/query_engine/sub_question_query_engine.py @@ -2,7 +2,7 @@ import asyncio import logging from typing import List, Optional, Sequence, cast -from langchain.input import get_color_mapping, print_text +from llama_index.bridge.langchain import get_color_mapping, print_text from llama_index.async_utils import run_async_tasks from llama_index.callbacks.base import CallbackManager diff --git a/llama_index/readers/base.py b/llama_index/readers/base.py index a7056e69b3..6822822285 100644 --- a/llama_index/readers/base.py +++ b/llama_index/readers/base.py @@ -2,7 +2,7 @@ from abc import abstractmethod from typing import Any, List -from langchain.docstore.document import Document as LCDocument +from llama_index.bridge.langchain import Document as LCDocument from llama_index.readers.schema.base import Document diff --git a/llama_index/readers/obsidian.py b/llama_index/readers/obsidian.py index 3f4cc3c4f9..7503335f35 100644 --- a/llama_index/readers/obsidian.py +++ b/llama_index/readers/obsidian.py @@ -9,7 +9,7 @@ import os from pathlib import Path from typing import Any, List -from langchain.docstore.document import Document as LCDocument +from llama_index.bridge.langchain import Document as LCDocument from llama_index.readers.base import BaseReader from llama_index.readers.file.markdown_reader import MarkdownReader diff --git a/llama_index/readers/schema/base.py b/llama_index/readers/schema/base.py index 6444570181..0c8d12c0a4 100644 --- a/llama_index/readers/schema/base.py +++ b/llama_index/readers/schema/base.py @@ -2,7 +2,7 @@ from dataclasses import dataclass from typing import Optional -from langchain.docstore.document import Document as LCDocument +from llama_index.bridge.langchain import Document as LCDocument from llama_index.schema import BaseDocument diff --git a/llama_index/selectors/pydantic_selectors.py b/llama_index/selectors/pydantic_selectors.py index 34de757b20..6172f9db8f 100644 --- a/llama_index/selectors/pydantic_selectors.py +++ b/llama_index/selectors/pydantic_selectors.py @@ -1,4 +1,4 @@ -from langchain.chat_models import ChatOpenAI +from llama_index.bridge.langchain import ChatOpenAI from typing import Any, Optional, Sequence from llama_index.indices.query.schema import QueryBundle diff --git a/llama_index/token_counter/mock_chain_wrapper.py b/llama_index/token_counter/mock_chain_wrapper.py index 82b6ed958d..b7642702a1 100644 --- a/llama_index/token_counter/mock_chain_wrapper.py +++ b/llama_index/token_counter/mock_chain_wrapper.py @@ -2,7 +2,7 @@ from typing import Any, Dict, Optional -from langchain.llms.base import BaseLLM +from llama_index.bridge.langchain import BaseLLM from llama_index.constants import DEFAULT_NUM_OUTPUTS from llama_index.langchain_helpers.chain_wrapper import LLMPredictor diff --git a/llama_index/tools/function_tool.py b/llama_index/tools/function_tool.py index 0e51f1d802..698f53a781 100644 --- a/llama_index/tools/function_tool.py +++ b/llama_index/tools/function_tool.py @@ -2,7 +2,7 @@ from typing import Any, Optional, Callable, Type from pydantic import BaseModel from llama_index.tools.types import BaseTool, ToolMetadata -from langchain.tools import Tool, StructuredTool +from llama_index.bridge.langchain import Tool, StructuredTool from inspect import signature from llama_index.tools.utils import create_schema_from_function diff --git a/llama_index/tools/query_plan.py b/llama_index/tools/query_plan.py index ddc9c60bd5..92d1420507 100644 --- a/llama_index/tools/query_plan.py +++ b/llama_index/tools/query_plan.py @@ -8,7 +8,7 @@ from llama_index.data_structs.node import NodeWithScore, Node from typing import Dict, List, Any, Optional from pydantic import BaseModel, Field from llama_index.indices.query.schema import QueryBundle -from langchain.input import print_text +from llama_index.bridge.langchain import print_text DEFAULT_NAME = "query_plan_tool" diff --git a/llama_index/tools/types.py b/llama_index/tools/types.py index 0064f1cd01..db20598aed 100644 --- a/llama_index/tools/types.py +++ b/llama_index/tools/types.py @@ -2,7 +2,7 @@ from abc import abstractmethod from dataclasses import dataclass from typing import Any, Dict, Optional, Type -from langchain.tools import StructuredTool, Tool +from llama_index.bridge.langchain import StructuredTool, Tool from pydantic import BaseModel diff --git a/setup.py b/setup.py index 9b9bd1588c..4cc64f5516 100644 --- a/setup.py +++ b/setup.py @@ -25,6 +25,7 @@ install_requires = [ "fsspec>=2023.5.0", "typing-inspect==0.8.0", "typing_extensions==4.5.0", + "bs4", # hotfix for langchain 0.0.212 bug ] # NOTE: if python version >= 3.9, install tiktoken diff --git a/tests/indices/test_prompt_helper.py b/tests/indices/test_prompt_helper.py index 92da0662af..4e6c08bff2 100644 --- a/tests/indices/test_prompt_helper.py +++ b/tests/indices/test_prompt_helper.py @@ -1,7 +1,7 @@ """Test PromptHelper.""" from typing import cast -from langchain import PromptTemplate as LangchainPrompt +from llama_index.bridge.langchain import PromptTemplate as LangchainPrompt from llama_index.data_structs.node import Node from llama_index.indices.prompt_helper import PromptHelper diff --git a/tests/llm_predictor/test_base.py b/tests/llm_predictor/test_base.py index 0f5166364f..fa0935d76e 100644 --- a/tests/llm_predictor/test_base.py +++ b/tests/llm_predictor/test_base.py @@ -4,7 +4,7 @@ from typing import Any, Tuple from unittest.mock import patch import pytest -from langchain.llms.fake import FakeListLLM +from llama_index.bridge.langchain import FakeListLLM from llama_index.llm_predictor.structured import LLMPredictor, StructuredLLMPredictor from llama_index.output_parsers.base import BaseOutputParser @@ -61,7 +61,7 @@ def test_struct_llm_predictor_with_cache() -> None: """Test LLM predictor.""" from gptcache.processor.pre import get_prompt from gptcache.manager.factory import get_data_manager - from langchain.cache import GPTCache + from llama_index.bridge.langchain import GPTCache def init_gptcache_map(cache_obj: Cache) -> None: cache_path = "test" diff --git a/tests/output_parsers/test_base.py b/tests/output_parsers/test_base.py index b07f5aecc2..18e53dd741 100644 --- a/tests/output_parsers/test_base.py +++ b/tests/output_parsers/test_base.py @@ -1,8 +1,10 @@ """Test Output parsers.""" -from langchain.output_parsers import ResponseSchema -from langchain.schema import BaseOutputParser as LCOutputParser +from llama_index.bridge.langchain import ( + ResponseSchema, + BaseOutputParser as LCOutputParser, +) from llama_index.output_parsers.langchain import LangchainOutputParser diff --git a/tests/prompts/test_base.py b/tests/prompts/test_base.py index 1adcbef31a..ec722e6287 100644 --- a/tests/prompts/test_base.py +++ b/tests/prompts/test_base.py @@ -3,10 +3,12 @@ from unittest.mock import MagicMock import pytest -from langchain import PromptTemplate -from langchain.chains.prompt_selector import ConditionalPromptSelector -from langchain.chat_models.base import BaseChatModel -from langchain.chat_models.openai import ChatOpenAI +from llama_index.bridge.langchain import ( + PromptTemplate, + ConditionalPromptSelector, + BaseChatModel, + ChatOpenAI, +) from llama_index.prompts.base import Prompt diff --git a/tests/token_predictor/test_base.py b/tests/token_predictor/test_base.py index 252a8c5f88..d3893975c3 100644 --- a/tests/token_predictor/test_base.py +++ b/tests/token_predictor/test_base.py @@ -3,7 +3,7 @@ from typing import Any from unittest.mock import MagicMock, patch -from langchain.llms.base import BaseLLM +from llama_index.bridge.langchain import BaseLLM from llama_index.indices.keyword_table.base import KeywordTableIndex from llama_index.indices.list.base import ListIndex -- GitLab