diff --git a/.gitignore b/.gitignore index 807674fa1ab9fee059c2942a22139cdd77f60eca..219091bdd82b6db50d7e9b2d03630d28c53ac1d1 100644 --- a/.gitignore +++ b/.gitignore @@ -17,3 +17,4 @@ mac.env .coverage .coverage.* .pytest_cache +test.py \ No newline at end of file diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 650b5b9426fd9af7733cae3584540cb52a6d64b3..a418d05f951ff0ac6f5abbff6dbce1907291737e 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,11 +9,22 @@ }, { "cell_type": "code", - "execution_count": 213, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "%reload_ext dotenv\n", + "%dotenv" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# OpenAI\n", + "import os\n", "import openai\n", "from semantic_router.utils.logger import logger\n", "\n", @@ -39,7 +50,7 @@ }, { "cell_type": "code", - "execution_count": 214, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -48,7 +59,7 @@ "import requests\n", "\n", "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n", - "HF_API_TOKEN = os.environ[\"HF_API_TOKEN\"]\n", + "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n", "\n", "\n", "def llm_mistral(prompt: str) -> str:\n", @@ -180,7 +191,7 @@ }, { "cell_type": "code", - "execution_count": 217, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -242,6 +253,23 @@ "Set up the routing layer" ] }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.schema import Route\n", + "from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n", + "from semantic_router.layer import RouteLayer\n", + "from semantic_router.utils.logger import logger\n", + "\n", + "\n", + "def create_router(routes: list[dict]) -> RouteLayer:\n", + " logger.info(\"Creating route layer...\")\n", + " encoder = OpenAIEncoder" + ] + }, { "cell_type": "code", "execution_count": null, @@ -256,7 +284,7 @@ "\n", "def create_router(routes: list[dict]) -> RouteLayer:\n", " logger.info(\"Creating route layer...\")\n", - " encoder = CohereEncoder()\n", + " encoder = OpenAIEncoder()\n", "\n", " route_list: list[Route] = []\n", " for route in routes:\n", @@ -278,7 +306,7 @@ }, { "cell_type": "code", - "execution_count": 219, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -349,72 +377,38 @@ }, { "cell_type": "code", - "execution_count": 220, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_time(location: str) -> str:\n", + " \"\"\"Useful to get the time in a specific location\"\"\"\n", + " print(f\"Calling `get_time` function with location: {location}\")\n", + " return \"get_time\"\n", + "\n", + "\n", + "def get_news(category: str, country: str) -> str:\n", + " \"\"\"Useful to get the news in a specific country\"\"\"\n", + " print(\n", + " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", + " )\n", + " return \"get_news\"\n", + "\n", + "\n", + "# Registering functions to the router\n", + "route_get_time = generate_route(get_time)\n", + "route_get_news = generate_route(get_news)\n", + "\n", + "routes = [route_get_time, route_get_news]\n", + "router = create_router(routes)\n", + "\n", + "# Tools\n", + "tools = [get_time, get_news]" + ] + }, + { + "cell_type": "markdown", "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " 'location': 'Stockholm'\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "parameters: {'location': 'Stockholm'}\n", - "Calling `get_time` function with location: Stockholm\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " 'category': 'tech',\n", - " 'country': 'Lithuania'\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "parameters: {'category': 'tech', 'country': 'Lithuania'}\n", - "Calling `get_news` function with category: tech and country: Lithuania\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[33m2023-12-15 11:41:57 WARNING semantic_router.utils.logger No function found\u001b[0m\n", - "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", - "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" - ] - }, - { - "data": { - "text/plain": [ - "' How can I help you today?'" - ] - }, - "execution_count": 220, - "metadata": {}, - "output_type": "execute_result" - } - ], "source": [ "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", @@ -438,7 +432,7 @@ "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", - "version": "3.11.3" + "version": "3.11.5" } }, "nbformat": 4, diff --git a/pyproject.toml b/pyproject.toml index b041516c70c5fac4d3adfd7b857e495bc64c372f..030a8f72d6f05151ddd88df4fb4bb4427d28950d 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "semantic-router" -version = "0.0.9" +version = "0.0.10" description = "Super fast semantic router for AI decision making" authors = [ "James Briggs <james@aurelio.ai>", diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 1375c57ac3abbd40d6a92fa4350ee108483138c6..d02787a38fa8b73d055b8dea16ea540242826ecc 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -36,7 +36,7 @@ class OpenAIEncoder(BaseEncoder): try: logger.info(f"Encoding {len(docs)} documents...") embeds = self.client.embeddings.create(input=docs, model=self.name) - if isinstance(embeds, dict) and "data" in embeds: + if "data" in embeds: break except OpenAIError as e: sleep(2**j) diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py index f67bce1e3137d8e9fa0cf7db0042bfa404cbc9d3..501a9b04f5d12c59e014edb6bc5798d2aa9d31b5 100644 --- a/tests/unit/encoders/test_openai.py +++ b/tests/unit/encoders/test_openai.py @@ -71,6 +71,7 @@ class TestOpenAIEncoder: ) with pytest.raises(ValueError) as e: openai_encoder(["test document"]) + assert "OpenAI API call failed. Error: Non-OpenAIError" in str(e.value) def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):