diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 9d992f204a01ab15af97fbe12fb8ff35176c9a73..2fa202a1e941757462a2a7e9a09027be61a4d522 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -1,14 +1,5 @@ { "cells": [ - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [ - "# https://platform.openai.com/docs/guides/function-calling" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -18,7 +9,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 73, "metadata": {}, "outputs": [], "source": [ @@ -26,6 +17,7 @@ "import openai\n", "from semantic_router.utils.logger import logger\n", "\n", + "# Docs # https://platform.openai.com/docs/guides/function-calling\n", "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", " try:\n", " response = openai.chat.completions.create(\n", @@ -45,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 74, "metadata": {}, "outputs": [], "source": [ @@ -77,25 +69,30 @@ " if response.status_code != 200:\n", " raise Exception(\"Failed to call HuggingFace API\", response.text)\n", "\n", - " return response.json()[0]['generated_text']" + " ai_message = response.json()[0]['generated_text']\n", + " if not ai_message:\n", + " raise Exception(\"AI message is empty\", ai_message)\n", + " logger.info(f\"AI message: {ai_message}\")\n", + " return ai_message" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ - "### Now we need to generate config from function specification with `GPT-4`" + "### Now we need to generate config from function specification using LLM" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 75, "metadata": {}, "outputs": [], "source": [ "import json\n", "from semantic_router.utils.logger import logger\n", "\n", + "\n", "def generate_config(specification: dict) -> dict:\n", " logger.info(\"Generating config...\")\n", " example_specification = (\n", @@ -136,8 +133,7 @@ " }\n", "\n", " prompt = f\"\"\"\n", - " Given the following specification, generate a config in a valid JSON format\n", - " enclosed in double quotes,\n", + " Given the following specification, generate a config ONLY in a valid JSON format.\n", " Example:\n", " SPECIFICATION:\n", " {example_specification}\n", @@ -151,7 +147,11 @@ " GENERATED CONFIG:\n", " \"\"\"\n", "\n", - " ai_message = llm_openai(prompt)\n", + " # ai_message = llm_openai(prompt)\n", + " ai_message = llm_mistral(prompt)\n", + "\n", + " # Mistral parsing\n", + " ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip()\n", "\n", " try:\n", " route_config = json.loads(ai_message)\n", @@ -160,7 +160,8 @@ " logger.info(f\"Generated config: {route_config}\")\n", " return route_config\n", " except json.JSONDecodeError as json_error:\n", - " raise Exception(\"JSON parsing error\", json_error)" + " logger.error(f\"JSON parsing error {json_error}\")\n", + " return {}" ] }, { @@ -172,7 +173,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 76, "metadata": {}, "outputs": [], "source": [ @@ -221,19 +222,21 @@ "\n", " # ai_message = llm_openai(prompt)\n", " ai_message = llm_mistral(prompt)\n", - " print(ai_message)\n", + "\n", + " ai_message = ai_message.replace(\"'\", '\"').strip()\n", "\n", " try:\n", " parameters = json.loads(ai_message)\n", " logger.info(f\"Extracted parameters: {parameters}\")\n", " return parameters\n", " except json.JSONDecodeError as json_error:\n", - " raise Exception(\"JSON parsing error\", json_error)" + " logger.error(f\"JSON parsing error {json_error}\")\n", + " return {}" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 77, "metadata": {}, "outputs": [], "source": [ @@ -256,7 +259,7 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 63, "metadata": {}, "outputs": [], "source": [ @@ -283,9 +286,79 @@ }, { "cell_type": "code", - "execution_count": null, + "execution_count": 72, "metadata": {}, - "outputs": [], + "outputs": [ + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-14 16:22:59 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:02 INFO semantic_router.utils.logger AI message: \n", + " {\n", + " \"name\": \"get_time\",\n", + " \"utterances\": [\n", + " \"What is the current time in New York?\",\n", + " \"Tell me the time in Los Angeles\",\n", + " \"What is the current time in Chicago?\",\n", + " \"The time in Houston?\",\n", + " \"What is the current time in Philadelphia?\"\n", + " ]\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:02 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['What is the current time in New York?', 'Tell me the time in Los Angeles', 'What is the current time in Chicago?', 'The time in Houston?', 'What is the current time in Philadelphia?', 'Get the current time']}\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:02 INFO semantic_router.utils.logger Getting route layer...\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:03 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:04 INFO semantic_router.utils.logger AI message: \n", + " {\"location\": \"Taiwan\"}\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Taiwan'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calling get_time function with location: Taiwan\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-14 16:23:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:05 INFO semantic_router.utils.logger AI message: \n", + " {\n", + " \"location\": \"London\"\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:05 INFO semantic_router.utils.logger Extracted parameters: {'location': 'London'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calling get_time function with location: London\n" + ] + }, + { + "name": "stderr", + "output_type": "stream", + "text": [ + "\u001b[32m2023-12-14 16:23:06 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:07 INFO semantic_router.utils.logger AI message: \n", + " {\n", + " \"location\": \"Kaunas\"\n", + " }\u001b[0m\n", + "\u001b[32m2023-12-14 16:23:07 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Kaunas'}\u001b[0m\n" + ] + }, + { + "name": "stdout", + "output_type": "stream", + "text": [ + "Calling get_time function with location: Kaunas\n" + ] + } + ], "source": [ "def get_time(location: str) -> str:\n", " print(f\"Calling get_time function with location: {location}\")\n", @@ -333,7 +406,7 @@ "\n", " # Call the function\n", " get_time(**function_parameters)\n", - " \n", + "\n", " except ValueError as e:\n", " logger.error(f\"Error: {e}\")" ]