diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 502c0ae459653936b392d0b858971719f3c87f0d..e05686b383417c09237c13a1deb9867ba6ea2f73 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -9,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": 73, + "execution_count": 213, "metadata": {}, "outputs": [], "source": [ @@ -20,6 +20,7 @@ "# Docs # https://platform.openai.com/docs/guides/function-calling\n", "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", " try:\n", + " logger.info(f\"Calling {model} model\")\n", " response = openai.chat.completions.create(\n", " model=model,\n", " messages=[\n", @@ -37,7 +38,7 @@ }, { "cell_type": "code", - "execution_count": 102, + "execution_count": 214, "metadata": {}, "outputs": [], "source": [ @@ -55,14 +56,15 @@ " \"Content-Type\": \"application/json\",\n", " }\n", "\n", + " logger.info(\"Calling Mistral model\")\n", " response = requests.post(\n", " api_url,\n", " headers=headers,\n", " json={\n", - " \"inputs\": prompt,\n", + " \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n", " \"parameters\": {\n", - " \"max_new_tokens\": 1000,\n", - " \"temperature\": 0.2,\n", + " \"max_new_tokens\": 200,\n", + " \"temperature\": 0.1,\n", " },\n", " },\n", " )\n", @@ -80,30 +82,48 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "### Now we need to generate config from function specification using LLM" + "### Now we need to generate config from function schema using LLM" ] }, { "cell_type": "code", - "execution_count": 134, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "import inspect\n", + "from typing import Any\n", + "\n", + "def get_function_schema(function) -> dict[str, Any]:\n", + " schema = {\n", + " \"name\": function.__name__,\n", + " \"description\": str(inspect.getdoc(function)),\n", + " \"signature\": str(inspect.signature(function)),\n", + " \"output\": str(\n", + " inspect.signature(function).return_annotation,\n", + " ),\n", + " }\n", + " return schema" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", - "from pydantic import BaseModel\n", "from semantic_router.utils.logger import logger\n", "\n", - "def generate_config(schema: dict) -> dict:\n", + "def generate_route(function) -> dict:\n", " logger.info(\"Generating config...\")\n", - "\n", - " class GetWeatherSchema(BaseModel):\n", - " location: str\n", - "\n", - " class Config:\n", - " name = \"get_weather\"\n", - "\n", - " example_schema = GetWeatherSchema.schema()\n", + " example_schema = {\n", + " \"name\": \"get_weather\",\n", + " \"description\": \"Useful to get the weather in a specific location\",\n", + " \"signature\": \"(location: str) -> str\",\n", + " \"output\": \"<class 'str'>\",\n", + " }\n", "\n", " example_config = {\n", " \"name\": \"get_weather\",\n", @@ -116,23 +136,26 @@ " ],\n", " }\n", "\n", + " function_schema = get_function_schema(function)\n", + "\n", " prompt = f\"\"\"\n", - " Given the following Pydantic function schema,\n", - " generate a config ONLY in a valid JSON format.\n", - " For example:\n", - " SCHEMA: {example_schema}\n", - " CONFIG: {example_config}\n", + " You are a helpful assistant designed to output JSON.\n", + " Given the following function schema\n", + " {function_schema}\n", + " generate a routing config with the format:\n", + " {example_config}\n", "\n", + " For example:\n", + " Input: {example_schema}\n", + " Output: {example_config}\n", "\n", - " GIVEN SCHEMA: {schema}\n", - " GENERATED CONFIG: <generated_response_in_json>\n", + " Input: {function_schema}\n", + " Output:\n", " \"\"\"\n", "\n", " ai_message = llm_openai(prompt)\n", - " print(f\"AI message: {ai_message}\")\n", "\n", - " # Parsing for Mistral model\n", - " ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip()\n", + " ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", "\n", " try:\n", " route_config = json.loads(ai_message)\n", @@ -141,7 +164,7 @@ " except json.JSONDecodeError as json_error:\n", " logger.error(f\"JSON parsing error {json_error}\")\n", " print(f\"AI message: {ai_message}\")\n", - " return {}" + " return {\"error\": \"Failed to generate config\"}" ] }, { @@ -153,46 +176,51 @@ }, { "cell_type": "code", - "execution_count": 138, + "execution_count": 217, "metadata": {}, "outputs": [], "source": [ - "def extract_parameters(query: str, schema: dict) -> dict:\n", + "def extract_parameters(query: str, function) -> dict:\n", " logger.info(\"Extracting parameters...\")\n", - " example_query = \"what is the weather in London?\"\n", - "\n", - " class GetWeatherSchema(BaseModel):\n", - " location: str\n", - "\n", - " class Config:\n", - " name = \"get_weather\"\n", + " example_query = \"How is the weather in Hawaii right now in International units?\"\n", "\n", - " example_schema = GetWeatherSchema.schema()\n", + " example_schema = {\n", + " \"name\": \"get_weather\",\n", + " \"description\": \"Useful to get the weather in a specific location\",\n", + " \"signature\": \"(location: str, degree: str) -> str\",\n", + " \"output\": \"<class 'str'>\",\n", + " }\n", "\n", " example_parameters = {\n", " \"location\": \"London\",\n", + " \"degree\": \"Celsius\",\n", " }\n", "\n", " prompt = f\"\"\"\n", - " Given the following function schema and query, extract the parameters from the\n", - " query, in a valid JSON format.\n", - " Example:\n", - " SCHEMA:\n", - " {example_schema}\n", - " QUERY:\n", - " {example_query}\n", - " PARAMETERS:\n", - " {example_parameters}\n", - " GIVEN SCHEMA:\n", - " {schema}\n", - " GIVEN QUERY:\n", + " You are a helpful assistant designed to output JSON.\n", + " Given the following function schema\n", + " {get_function_schema(function)}\n", + " and query\n", " {query}\n", - " EXTRACTED PARAMETERS:\n", + " extract the parameters values from the query, in a valid JSON format.\n", + " Example:\n", + " Input:\n", + " query: {example_query}\n", + " schema: {example_schema}\n", + "\n", + " Output:\n", + " parameters: {example_parameters}\n", + "\n", + " Input:\n", + " query: {query}\n", + " schema: {get_function_schema(function)}\n", + " Output:\n", + " parameters:\n", " \"\"\"\n", "\n", - " ai_message = llm_openai(prompt)\n", + " ai_message = llm_mistral(prompt)\n", "\n", - " ai_message = ai_message.replace(\"'\", '\"').strip()\n", + " ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", "\n", " try:\n", " parameters = json.loads(ai_message)\n", @@ -200,7 +228,7 @@ " return parameters\n", " except json.JSONDecodeError as json_error:\n", " logger.error(f\"JSON parsing error {json_error}\")\n", - " return {}" + " return {\"error\": \"Failed to extract parameters\"}" ] }, { @@ -212,7 +240,7 @@ }, { "cell_type": "code", - "execution_count": 139, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -221,21 +249,58 @@ "from semantic_router.layer import RouteLayer\n", "from semantic_router.utils.logger import logger\n", "\n", - "\n", - "def get_route_layer(config: list[dict]) -> RouteLayer:\n", - " logger.info(\"Getting route layer...\")\n", + "def create_router(routes: list[dict]) -> RouteLayer:\n", + " logger.info(\"Creating route layer...\")\n", " encoder = CohereEncoder()\n", "\n", - " routes = []\n", - " print(f\"Config: {config}\")\n", - " for route in config:\n", + " route_list: list[Route] = []\n", + " for route in routes:\n", " if \"name\" in route and \"utterances\" in route:\n", " print(f\"Route: {route}\")\n", - " routes.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", + " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", " else:\n", " logger.warning(f\"Misconfigured route: {route}\")\n", "\n", - " return RouteLayer(encoder=encoder, routes=routes)" + " return RouteLayer(encoder=encoder, routes=route_list)" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up calling functions" + ] + }, + { + "cell_type": "code", + "execution_count": 219, + "metadata": {}, + "outputs": [], + "source": [ + "from typing import Callable\n", + "\n", + "def call_function(function: Callable, parameters: dict[str, str]):\n", + " try:\n", + " return function(**parameters)\n", + " except TypeError as e:\n", + " logger.error(f\"Error calling function: {e}\")\n", + "\n", + "\n", + "def call_llm(query: str):\n", + " return llm_mistral(query)\n", + "\n", + "\n", + "def call(query: str, functions: list[Callable], router: RouteLayer):\n", + " function_name = router(query)\n", + " if not function_name:\n", + " logger.warning(\"No function found\")\n", + " return call_llm(query)\n", + "\n", + " for function in functions:\n", + " if function.__name__ == function_name:\n", + " parameters = extract_parameters(query, function)\n", + " print(f\"parameters: {parameters}\")\n", + " return call_function(function, parameters)" ] }, { @@ -247,96 +312,116 @@ }, { "cell_type": "code", - "execution_count": 140, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def get_time(location: str) -> str:\n", + " \"\"\"Useful to get the time in a specific location\"\"\"\n", + " print(f\"Calling `get_time` function with location: {location}\")\n", + " return \"get_time\"\n", + "\n", + "\n", + "def get_news(category: str, country: str) -> str:\n", + " \"\"\"Useful to get the news in a specific country\"\"\"\n", + " print(\n", + " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", + " )\n", + " return \"get_news\"\n", + "\n", + "\n", + "# Registering functions to the router\n", + "route_get_time = generate_route(get_time)\n", + "route_get_news = generate_route(get_news)\n", + "\n", + "routes = [route_get_time, route_get_news]\n", + "router = create_router(routes)\n", + "\n", + "# Tools\n", + "tools = [get_time, get_news]" + ] + }, + { + "cell_type": "code", + "execution_count": 220, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-14 17:28:22 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-14 17:28:28 INFO semantic_router.utils.logger AI message: {\"name\": \"get_time\", \"utterances\": [\"What is the time in SF?\", \"What is the current time in London?\", \"Time in Tokyo?\", \"Tell me the time in New York\", \"What is the time now in Paris?\"]}\u001b[0m\n", - "\u001b[32m2023-12-14 17:28:28 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['What is the time in SF?', 'What is the current time in London?', 'Time in Tokyo?', 'Tell me the time in New York', 'What is the time now in Paris?']}\u001b[0m\n", - "\u001b[32m2023-12-14 17:28:28 INFO semantic_router.utils.logger Getting route layer...\u001b[0m\n" + "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:54 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger AI message: \n", + " {\n", + " 'location': 'Stockholm'\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "AI message: {\"name\": \"get_time\", \"utterances\": [\"What is the time in SF?\", \"What is the current time in London?\", \"Time in Tokyo?\", \"Tell me the time in New York\", \"What is the time now in Paris?\"]}\n", - "Config: [{'name': 'get_time', 'utterances': ['What is the time in SF?', 'What is the current time in London?', 'Time in Tokyo?', 'Tell me the time in New York', 'What is the time now in Paris?']}]\n", - "Route: {'name': 'get_time', 'utterances': ['What is the time in SF?', 'What is the current time in London?', 'Time in Tokyo?', 'Tell me the time in New York', 'What is the time now in Paris?']}\n", - "None What is the weather like in Barcelona?\n" + "parameters: {'location': 'Stockholm'}\n", + "Calling `get_time` function with location: Stockholm\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-14 17:28:29 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n" + "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:55 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger AI message: \n", + " {\n", + " 'category': 'tech',\n", + " 'country': 'Lithuania'\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:56 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "get_time What time is it in Taiwan?\n" + "parameters: {'category': 'tech', 'country': 'Lithuania'}\n", + "Calling `get_news` function with category: tech and country: Lithuania\n" ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[33m2023-12-15 11:41:57 WARNING semantic_router.utils.logger No function found\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n", + "\u001b[32m2023-12-15 11:41:57 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n" + ] + }, + { + "data": { + "text/plain": [ + "' How can I help you today?'" + ] + }, + "execution_count": 220, + "metadata": {}, + "output_type": "execute_result" } ], "source": [ - "from pydantic import BaseModel\n", - "\n", - "class GetTimeSchema(BaseModel):\n", - " location: str\n", - "\n", - " class Config:\n", - " name = \"get_time\"\n", - "\n", - "get_time_schema = GetTimeSchema.schema()\n", - "\n", - "def get_time(location: str) -> str:\n", - " # Validate parameters\n", - " GetTimeSchema(location=location)\n", - "\n", - " print(f\"Calling get_time function with location: {location}\")\n", - " return \"get_time\"\n", - "\n", - "\n", - "route_config = generate_config(get_time_schema)\n", - "route_layer = get_route_layer([route_config])\n", - "\n", - "queries = [\n", - " \"What is the weather like in Barcelona?\",\n", - " \"What time is it in Taiwan?\",\n", - " \"What is happening in the world?\",\n", - " \"what is the time in Kaunas?\",\n", - " \"Im bored\",\n", - " \"I want to play a game\",\n", - " \"Banana\",\n", - "]\n", - "\n", - "# Calling functions\n", - "for query in queries:\n", - " function_name = route_layer(query)\n", - " print(function_name, query)\n", - "\n", - " if function_name == \"get_time\":\n", - " function_parameters = extract_parameters(query, get_time_schema)\n", - " try:\n", - " # Call the function\n", - " get_time(**function_parameters)\n", - " except ValueError as e:\n", - " logger.error(f\"Error: {e}\")" + "call(query=\"What is the time in Stockholm?\",\n", + " functions=tools,\n", + " router=router)\n", + "call(\n", + " query=\"What is the tech news in the Lithuania?\",\n", + " functions=tools,\n", + " router=router)\n", + "call(\n", + " query=\"Hi!\",\n", + " functions=tools,\n", + " router=router)\n" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": {