diff --git a/coverage.xml b/coverage.xml index 3b0a9de9589ec9d0a2cfd42d99e90a5f81bc993a..9af9ebee27365dd1289c5962a87b8451a3feef7c 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,5 +1,5 @@ <?xml version="1.0" ?> -<coverage version="7.3.3" timestamp="1702893702032" lines-valid="345" lines-covered="345" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> +<coverage version="7.3.3" timestamp="1702894511196" lines-valid="345" lines-covered="345" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.3 --> <!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd --> <sources> @@ -349,12 +349,12 @@ <line number="2" hits="1"/> <line number="4" hits="1"/> <line number="5" hits="1"/> - <line number="7" hits="1"/> + <line number="6" hits="1"/> <line number="8" hits="1"/> - <line number="11" hits="1"/> + <line number="9" hits="1"/> <line number="12" hits="1"/> - <line number="14" hits="1"/> - <line number="19" hits="1"/> + <line number="13" hits="1"/> + <line number="15" hits="1"/> <line number="20" hits="1"/> <line number="21" hits="1"/> <line number="22" hits="1"/> @@ -362,12 +362,12 @@ <line number="24" hits="1"/> <line number="25" hits="1"/> <line number="26" hits="1"/> - <line number="28" hits="1"/> + <line number="27" hits="1"/> <line number="29" hits="1"/> <line number="30" hits="1"/> <line number="31" hits="1"/> <line number="32" hits="1"/> - <line number="35" hits="1"/> + <line number="33" hits="1"/> <line number="36" hits="1"/> <line number="37" hits="1"/> <line number="38" hits="1"/> @@ -380,11 +380,11 @@ <line number="45" hits="1"/> <line number="46" hits="1"/> <line number="47" hits="1"/> - <line number="49" hits="1"/> + <line number="48" hits="1"/> <line number="50" hits="1"/> - <line number="51" hits="1"/> - <line number="53" hits="1"/> - <line number="54" hits="1"/> + <line number="55" hits="1"/> + <line number="57" hits="1"/> + <line number="58" hits="1"/> </lines> </class> </classes> diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index a6546eec9ae9450d3e9dc1148e3d792e25c05363..5d3be2fb9072d3e25d50a6c7456500f7d72083bc 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,18 +9,9 @@ }, { "cell_type": "code", - "execution_count": 1, + "execution_count": 12, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "/Users/jakit/customers/aurelio/semantic-router/.venv/lib/python3.11/site-packages/tqdm/auto.py:21: TqdmWarning: IProgress not found. Please update jupyter and ipywidgets. See https://ipywidgets.readthedocs.io/en/stable/user_install.html\n", - " from .autonotebook import tqdm as notebook_tqdm\n" - ] - } - ], + "outputs": [], "source": [ "# OpenAI\n", "import openai\n", @@ -48,7 +39,7 @@ }, { "cell_type": "code", - "execution_count": 2, + "execution_count": 13, "metadata": {}, "outputs": [], "source": [ @@ -100,7 +91,7 @@ }, { "cell_type": "code", - "execution_count": 3, + "execution_count": 14, "metadata": {}, "outputs": [], "source": [ @@ -122,7 +113,7 @@ }, { "cell_type": "code", - "execution_count": 4, + "execution_count": 15, "metadata": {}, "outputs": [], "source": [ @@ -139,7 +130,7 @@ }, { "cell_type": "code", - "execution_count": 5, + "execution_count": 16, "metadata": {}, "outputs": [], "source": [ @@ -221,7 +212,7 @@ }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 17, "metadata": {}, "outputs": [], "source": [ @@ -237,7 +228,7 @@ }, { "cell_type": "code", - "execution_count": 7, + "execution_count": 18, "metadata": {}, "outputs": [], "source": [ @@ -317,19 +308,19 @@ }, { "cell_type": "code", - "execution_count": 8, + "execution_count": 19, "metadata": {}, "outputs": [], "source": [ "from semantic_router.schema import Route\n", - "from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n", + "from semantic_router.encoders import CohereEncoder\n", "from semantic_router.layer import RouteLayer\n", "from semantic_router.utils.logger import logger\n", "\n", "\n", "def create_router(routes: list[dict]) -> RouteLayer:\n", " logger.info(\"Creating route layer...\")\n", - " encoder = OpenAIEncoder()\n", + " encoder = CohereEncoder()\n", "\n", " route_list: list[Route] = []\n", " for route in routes:\n", @@ -351,7 +342,7 @@ }, { "cell_type": "code", - "execution_count": 9, + "execution_count": 24, "metadata": {}, "outputs": [], "source": [ @@ -366,8 +357,14 @@ " logger.error(f\"Error calling function: {e}\")\n", "\n", "\n", - "def call_llm(query: str):\n", - " return llm_mistral(query)\n", + "def call_llm(query: str) -> str:\n", + " try:\n", + " ai_message = llm_mistral(query)\n", + " except Exception as e:\n", + " logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n", + " ai_message = llm_openai(query)\n", + "\n", + " return ai_message\n", "\n", "\n", "def call(query: str, functions: list[Callable], router: RouteLayer):\n", @@ -392,16 +389,93 @@ }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 21, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[31m2023-12-18 12:18:00 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:00 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger AI message: {\n", + " \"name\": \"get_time\",\n", + " \"utterances\": [\n", + " \"what is the time in new york\",\n", + " \"can you tell me the time in london\",\n", + " \"get me the current time in tokyo\",\n", + " \"i need to know the time in sydney\",\n", + " \"please tell me the current time in paris\"\n", + " ]\n", + "}\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[31m2023-12-18 12:18:07 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:07 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger AI message: {\n", + " \"name\": \"get_news\",\n", + " \"utterances\": [\n", + " \"Can I get the latest news in Canada?\",\n", + " \"Show me the recent news in the US\",\n", + " \"I would like to know about the sports news in England\",\n", + " \"Let's check the technology news in Japan\",\n", + " \"Show me the health related news in Germany\"\n", + " ]\n", + "}\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\u001b[0m\n", + "\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Route: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\n", + "Route: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\n" + ] + } + ], + "source": [ + "def get_time(location: str) -> str:\n", + " \"\"\"Useful to get the time in a specific location\"\"\"\n", + " print(f\"Calling `get_time` function with location: {location}\")\n", + " return \"get_time\"\n", + "\n", + "\n", + "def get_news(category: str, country: str) -> str:\n", + " \"\"\"Useful to get the news in a specific country\"\"\"\n", + " print(\n", + " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", + " )\n", + " return \"get_news\"\n", + "\n", + "\n", + "# Registering functions to the router\n", + "route_get_time = generate_route(get_time)\n", + "route_get_news = generate_route(get_news)\n", + "\n", + "routes = [route_get_time, route_get_news]\n", + "router = create_router(routes)\n", + "\n", + "# Tools\n", + "tools = [get_time, get_news]" + ] + }, + { + "cell_type": "code", + "execution_count": 27, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 11:46:34 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:34 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger AI message: \n", " Example output:\n", " {\n", " \"name\": \"get_time\",\n", @@ -413,10 +487,10 @@ " \"Can you tell me the time in Berlin?\"\n", " ]\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:38 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger AI message: \n", " Example output:\n", " {\n", " \"name\": \"get_news\",\n", @@ -428,9 +502,8 @@ " \"What's the latest news from Germany?\"\n", " ]\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Encoding 10 documents...\u001b[0m\n" + "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n" ] }, { @@ -470,22 +543,20 @@ }, { "cell_type": "code", - "execution_count": 11, + "execution_count": 26, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Encoding 1 documents...\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:42 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger AI message: \n", " {\n", " \"location\": \"Stockholm\"\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Encoding 1 documents...\u001b[0m\n" + "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" ] }, { @@ -500,15 +571,14 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:44 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:45 INFO semantic_router.utils.logger AI message: \n", + "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger AI message: \n", " {\n", " \"category\": \"tech\",\n", " \"country\": \"Lithuania\"\n", " }\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:45 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:45 INFO semantic_router.utils.logger Encoding 1 documents...\u001b[0m\n" + "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" ] }, { @@ -523,9 +593,9 @@ "name": "stderr", "output_type": "stream", "text": [ - "\u001b[33m2023-12-18 11:46:46 WARNING semantic_router.utils.logger No function found\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:46 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-18 11:46:46 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" + "\u001b[33m2023-12-18 12:20:05 WARNING semantic_router.utils.logger No function found\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-18 12:20:06 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" ] }, { @@ -534,7 +604,7 @@ "' How can I help you today?'" ] }, - "execution_count": 11, + "execution_count": 26, "metadata": {}, "output_type": "execute_result" } @@ -569,7 +639,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 65acbf7d3e088240cd40942bfd853c536f7e7615..c6d4cc962b7b9ac38400f527ac20baa6543490d9 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -3,6 +3,7 @@ from time import sleep import openai from openai import OpenAIError +from openai.types import CreateEmbeddingResponse from semantic_router.encoders import BaseEncoder from semantic_router.utils.logger import logger @@ -36,7 +37,7 @@ class OpenAIEncoder(BaseEncoder): try: logger.info(f"Encoding {len(docs)} documents...") embeds = self.client.embeddings.create(input=docs, model=self.name) - if embeds.data is not None: + if embeds.data: break except OpenAIError as e: sleep(2**j) @@ -46,8 +47,11 @@ class OpenAIEncoder(BaseEncoder): logger.error(f"OpenAI API call failed. Error: {error_message}") raise ValueError(f"OpenAI API call failed. Error: {e}") - if embeds is None or embeds.data is None: - logger.error(f"No embeddings returned. Error: {error_message}") + if ( + not embeds + or not isinstance(embeds, CreateEmbeddingResponse) + or not embeds.data + ): raise ValueError(f"No embeddings returned. Error: {error_message}") embeddings = [embeds_obj.embedding for embeds_obj in embeds.data] diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py index 98ef2c5ea7fffd31aa4389b55b311fc092848b21..cc79d27207f7847439b74b0c29f4fb75d42d5381 100644 --- a/tests/unit/encoders/test_openai.py +++ b/tests/unit/encoders/test_openai.py @@ -1,6 +1,7 @@ import pytest from openai import OpenAIError -from openai.types.embedding import Embedding +from openai.types import CreateEmbeddingResponse, Embedding +from openai.types.create_embedding_response import Usage from semantic_router.encoders import OpenAIEncoder @@ -47,10 +48,20 @@ class TestOpenAIEncoder: ] mocker.patch("os.getenv", return_value="fake-api-key") + mocker.patch("time.sleep", return_value=None) # To speed up the test + + mock_embedding = Embedding(index=0, object="embedding", embedding=[0.1, 0.2]) + # Mock the CreateEmbeddingResponse object + mock_response = CreateEmbeddingResponse( + model="text-embedding-ada-002", + object="list", + usage=Usage(prompt_tokens=0, total_tokens=20), + data=[mock_embedding], + ) + + responses = [OpenAIError("OpenAI error"), mock_response] mocker.patch.object( - openai_encoder.client.embeddings, - "create", - return_value=mock_embeddings, + openai_encoder.client.embeddings, "create", side_effect=responses ) embeddings = openai_encoder(["test document"]) assert embeddings == [[0.1, 0.2]] @@ -88,7 +99,17 @@ class TestOpenAIEncoder: mocker.patch("os.getenv", return_value="fake-api-key") mocker.patch("time.sleep", return_value=None) # To speed up the test - responses = [OpenAIError("Test error"), mock_embeddings] + + mock_embedding = Embedding(index=0, object="embedding", embedding=[0.1, 0.2]) + # Mock the CreateEmbeddingResponse object + mock_response = CreateEmbeddingResponse( + model="text-embedding-ada-002", + object="list", + usage=Usage(prompt_tokens=0, total_tokens=20), + data=[mock_embedding], + ) + + responses = [OpenAIError("OpenAI error"), mock_response] mocker.patch.object( openai_encoder.client.embeddings, "create", side_effect=responses )