diff --git a/coverage.xml b/coverage.xml index 52f00b34c6122ea9b73d3f3cf386ee7b09a1c64a..9bdbfa3735a680c28c2ffcced9bfe29ca71f5977 100644 --- a/coverage.xml +++ b/coverage.xml @@ -1,5 +1,5 @@ <?xml version="1.0" ?> -<coverage version="7.3.2" timestamp="1702538160019" lines-valid="344" lines-covered="344" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> +<coverage version="7.3.2" timestamp="1702633916069" lines-valid="344" lines-covered="344" line-rate="1" branches-covered="0" branches-valid="0" branch-rate="0" complexity="0"> <!-- Generated by coverage.py: https://coverage.readthedocs.io/en/7.3.2 --> <!-- Based on https://raw.githubusercontent.com/cobertura/web/master/htdocs/xml/coverage-04.dtd --> <sources> diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb new file mode 100644 index 0000000000000000000000000000000000000000..650b5b9426fd9af7733cae3584540cb52a6d64b3 --- /dev/null +++ b/docs/examples/function_calling.ipynb @@ -0,0 +1,446 @@ +{ + "cells": [ + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define LLMs" + ] + }, + { + "cell_type": "code", + "execution_count": 213, + "metadata": {}, + "outputs": [], + "source": [ + "# OpenAI\n", + "import openai\n", + "from semantic_router.utils.logger import logger\n", + "\n", + "\n", + "# Docs # https://platform.openai.com/docs/guides/function-calling\n", + "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", + " try:\n", + " logger.info(f\"Calling {model} model\")\n", + " response = openai.chat.completions.create(\n", + " model=model,\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": f\"{prompt}\"},\n", + " ],\n", + " )\n", + " ai_message = response.choices[0].message.content\n", + " if not ai_message:\n", + " raise Exception(\"AI message is empty\", ai_message)\n", + " logger.info(f\"AI message: {ai_message}\")\n", + " return ai_message\n", + " except Exception as e:\n", + " raise Exception(\"Failed to call OpenAI API\", e)" + ] + }, + { + "cell_type": "code", + "execution_count": 214, + "metadata": {}, + "outputs": [], + "source": [ + "# Mistral\n", + "import os\n", + "import requests\n", + "\n", + "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n", + "HF_API_TOKEN = os.environ[\"HF_API_TOKEN\"]\n", + "\n", + "\n", + "def llm_mistral(prompt: str) -> str:\n", + " api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n", + " headers = {\n", + " \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n", + " \"Content-Type\": \"application/json\",\n", + " }\n", + "\n", + " logger.info(\"Calling Mistral model\")\n", + " response = requests.post(\n", + " api_url,\n", + " headers=headers,\n", + " json={\n", + " \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n", + " \"parameters\": {\n", + " \"max_new_tokens\": 200,\n", + " \"temperature\": 0.1,\n", + " },\n", + " },\n", + " )\n", + " if response.status_code != 200:\n", + " raise Exception(\"Failed to call HuggingFace API\", response.text)\n", + "\n", + " ai_message = response.json()[0][\"generated_text\"]\n", + " if not ai_message:\n", + " raise Exception(\"AI message is empty\", ai_message)\n", + " logger.info(f\"AI message: {ai_message}\")\n", + " return ai_message" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Now we need to generate config from function schema using LLM" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import inspect\n", + "from typing import Any\n", + "\n", + "\n", + "def get_function_schema(function) -> dict[str, Any]:\n", + " schema = {\n", + " \"name\": function.__name__,\n", + " \"description\": str(inspect.getdoc(function)),\n", + " \"signature\": str(inspect.signature(function)),\n", + " \"output\": str(\n", + " inspect.signature(function).return_annotation,\n", + " ),\n", + " }\n", + " return schema" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import json\n", + "\n", + "from semantic_router.utils.logger import logger\n", + "\n", + "\n", + "def generate_route(function) -> dict:\n", + " logger.info(\"Generating config...\")\n", + " example_schema = {\n", + " \"name\": \"get_weather\",\n", + " \"description\": \"Useful to get the weather in a specific location\",\n", + " \"signature\": \"(location: str) -> str\",\n", + " \"output\": \"<class 'str'>\",\n", + " }\n", + "\n", + " example_config = {\n", + " \"name\": \"get_weather\",\n", + " \"utterances\": [\n", + " \"What is the weather like in SF?\",\n", + " \"What is the weather in Cyprus?\",\n", + " \"weather in London?\",\n", + " \"Tell me the weather in New York\",\n", + " \"what is the current weather in Paris?\",\n", + " ],\n", + " }\n", + "\n", + " function_schema = get_function_schema(function)\n", + "\n", + " prompt = f\"\"\"\n", + " You are a helpful assistant designed to output JSON.\n", + " Given the following function schema\n", + " {function_schema}\n", + " generate a routing config with the format:\n", + " {example_config}\n", + "\n", + " For example:\n", + " Input: {example_schema}\n", + " Output: {example_config}\n", + "\n", + " Input: {function_schema}\n", + " Output:\n", + " \"\"\"\n", + "\n", + " ai_message = llm_openai(prompt)\n", + "\n", + " ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", + "\n", + " try:\n", + " route_config = json.loads(ai_message)\n", + " logger.info(f\"Generated config: {route_config}\")\n", + " return route_config\n", + " except json.JSONDecodeError as json_error:\n", + " logger.error(f\"JSON parsing error {json_error}\")\n", + " print(f\"AI message: {ai_message}\")\n", + " return {\"error\": \"Failed to generate config\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Extract function parameters using `Mistral` open-source model" + ] + }, + { + "cell_type": "code", + "execution_count": 217, + "metadata": {}, + "outputs": [], + "source": [ + "def extract_parameters(query: str, function) -> dict:\n", + " logger.info(\"Extracting parameters...\")\n", + " example_query = \"How is the weather in Hawaii right now in International units?\"\n", + "\n", + " example_schema = {\n", + " \"name\": \"get_weather\",\n", + " \"description\": \"Useful to get the weather in a specific location\",\n", + " \"signature\": \"(location: str, degree: str) -> str\",\n", + " \"output\": \"<class 'str'>\",\n", + " }\n", + "\n", + " example_parameters = {\n", + " \"location\": \"London\",\n", + " \"degree\": \"Celsius\",\n", + " }\n", + "\n", + " prompt = f\"\"\"\n", + " You are a helpful assistant designed to output JSON.\n", + " Given the following function schema\n", + " {get_function_schema(function)}\n", + " and query\n", + " {query}\n", + " extract the parameters values from the query, in a valid JSON format.\n", + " Example:\n", + " Input:\n", + " query: {example_query}\n", + " schema: {example_schema}\n", + "\n", + " Output:\n", + " parameters: {example_parameters}\n", + "\n", + " Input:\n", + " query: {query}\n", + " schema: {get_function_schema(function)}\n", + " Output:\n", + " parameters:\n", + " \"\"\"\n", + "\n", + " ai_message = llm_mistral(prompt)\n", + "\n", + " ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", + "\n", + " try:\n", + " parameters = json.loads(ai_message)\n", + " logger.info(f\"Extracted parameters: {parameters}\")\n", + " return parameters\n", + " except json.JSONDecodeError as json_error:\n", + " logger.error(f\"JSON parsing error {json_error}\")\n", + " return {\"error\": \"Failed to extract parameters\"}" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up the routing layer" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "from semantic_router.schema import Route\n", + "from semantic_router.encoders import CohereEncoder\n", + "from semantic_router.layer import RouteLayer\n", + "from semantic_router.utils.logger import logger\n", + "\n", + "\n", + "def create_router(routes: list[dict]) -> RouteLayer:\n", + " logger.info(\"Creating route layer...\")\n", + " encoder = CohereEncoder()\n", + "\n", + " route_list: list[Route] = []\n", + " for route in routes:\n", + " if \"name\" in route and \"utterances\" in route:\n", + " print(f\"Route: {route}\")\n", + " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", + " else:\n", + " logger.warning(f\"Misconfigured route: {route}\")\n", + "\n", + " return RouteLayer(encoder=encoder, routes=route_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up calling functions" + ] + }, + { + "cell_type": "code", + "execution_count": 219, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable\n", + "\n", + "\n", + "def call_function(function: Callable, parameters: dict[str, str]):\n", + " try:\n", + " return function(**parameters)\n", + " except TypeError as e:\n", + " logger.error(f\"Error calling function: {e}\")\n", + "\n", + "\n", + "def call_llm(query: str):\n", + " return llm_mistral(query)\n", + "\n", + "\n", + "def call(query: str, functions: list[Callable], router: RouteLayer):\n", + " function_name = router(query)\n", + " if not function_name:\n", + " logger.warning(\"No function found\")\n", + " return call_llm(query)\n", + "\n", + " for function in functions:\n", + " if function.__name__ == function_name:\n", + " parameters = extract_parameters(query, function)\n", + " print(f\"parameters: {parameters}\")\n", + " return call_function(function, parameters)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Workflow" + ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_time(location: str) -> str:\n", + " \"\"\"Useful to get the time in a specific location\"\"\"\n", + " print(f\"Calling `get_time` function with location: {location}\")\n", + " return \"get_time\"\n", + "\n", + "\n", + "def get_news(category: str, country: str) -> str:\n", + " \"\"\"Useful to get the news in a specific country\"\"\"\n", + " print(\n", + " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", + " )\n", + " return \"get_news\"\n", + "\n", + "\n", + "# Registering functions to the router\n", + "route_get_time = generate_route(get_time)\n", + "route_get_news = generate_route(get_news)\n", + "\n", + "routes = [route_get_time, route_get_news]\n", + "router = create_router(routes)\n", + "\n", + "# Tools\n", + "tools = [get_time, get_news]" + ] + }, + { + "cell_type": "code", + "execution_count": 220, + "metadata": {}, + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger AI message: \n", + " {\n", + " 'location': 'Stockholm'\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: {'location': 'Stockholm'}\n", + "Calling `get_time` function with location: Stockholm\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger AI message: \n", + " {\n", + " 'category': 'tech',\n", + " 'country': 'Lithuania'\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "parameters: {'category': 'tech', 'country': 'Lithuania'}\n", + "Calling `get_news` function with category: tech and country: Lithuania\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m2023-12-15 11:41:57 WARNING semantic_router.utils.logger No function found\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' How can I help you today?'" + ] + }, + "execution_count": 220, + "metadata": {}, + "output_type": "execute_result" + } + ], + "source": [ + "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", + "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", + "call(query=\"Hi!\", functions=tools, router=router)" + ] + } + ], + "metadata": { + "kernelspec": { + "display_name": ".venv", + "language": "python", + "name": "python3" + }, + "language_info": { + "codemirror_mode": { + "name": "ipython", + "version": 3 + }, + "file_extension": ".py", + "mimetype": "text/x-python", + "name": "python", + "nbconvert_exporter": "python", + "pygments_lexer": "ipython3", + "version": "3.11.3" + } + }, + "nbformat": 4, + "nbformat_minor": 2 +} diff --git a/pyproject.toml b/pyproject.toml index 5152d8937c55949b6540178ca262757db8ef6f3e..8a92ec383d86701f6aca5e9908e5306d016e1bbf 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -36,7 +36,7 @@ requires = ["poetry-core"] build-backend = "poetry.core.masonry.api" [tool.ruff.per-file-ignores] -"*.ipynb" = ["E402"] +"*.ipynb" = ["ALL"] [tool.mypy] ignore_missing_imports = true