From 87f57abf8045c2f44e9fddf4a44e0ed2f83ca063 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Thu, 14 Dec 2023 17:29:45 +0200 Subject: [PATCH] wip: function call with pydantic schema --- docs/examples/function_calling.ipynb | 253 ++++++++++----------------- 1 file changed, 90 insertions(+), 163 deletions(-) diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 2fa202a1..502c0ae4 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -37,7 +37,7 @@ }, { "cell_type": "code", - "execution_count": 74, + "execution_count": 102, "metadata": {}, "outputs": [], "source": [ @@ -61,7 +61,7 @@ " json={\n", " \"inputs\": prompt,\n", " \"parameters\": {\n", - " \"max_new_tokens\": 200,\n", + " \"max_new_tokens\": 1000,\n", " \"temperature\": 0.2,\n", " },\n", " },\n", @@ -85,41 +85,25 @@ }, { "cell_type": "code", - "execution_count": 75, + "execution_count": 134, "metadata": {}, "outputs": [], "source": [ "import json\n", - "from semantic_router.utils.logger import logger\n", "\n", + "from pydantic import BaseModel\n", + "from semantic_router.utils.logger import logger\n", "\n", - "def generate_config(specification: dict) -> dict:\n", + "def generate_config(schema: dict) -> dict:\n", " logger.info(\"Generating config...\")\n", - " example_specification = (\n", - " {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"get_current_weather\",\n", - " \"description\": \"Get the current weather\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"location\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", - " },\n", - " \"format\": {\n", - " \"type\": \"string\",\n", - " \"enum\": [\"celsius\", \"fahrenheit\"],\n", - " \"description\": \"The temperature unit to use. Infer this \"\n", - " \" from the users location.\",\n", - " },\n", - " },\n", - " \"required\": [\"location\", \"format\"],\n", - " },\n", - " },\n", - " },\n", - " )\n", + "\n", + " class GetWeatherSchema(BaseModel):\n", + " location: str\n", + "\n", + " class Config:\n", + " name = \"get_weather\"\n", + "\n", + " example_schema = GetWeatherSchema.schema()\n", "\n", " example_config = {\n", " \"name\": \"get_weather\",\n", @@ -133,34 +117,30 @@ " }\n", "\n", " prompt = f\"\"\"\n", - " Given the following specification, generate a config ONLY in a valid JSON format.\n", - " Example:\n", - " SPECIFICATION:\n", - " {example_specification}\n", - "\n", - " CONFIG:\n", - " {example_config}\n", + " Given the following Pydantic function schema,\n", + " generate a config ONLY in a valid JSON format.\n", + " For example:\n", + " SCHEMA: {example_schema}\n", + " CONFIG: {example_config}\n", "\n", - " GIVEN SPECIFICATION:\n", - " {specification}\n", "\n", - " GENERATED CONFIG:\n", + " GIVEN SCHEMA: {schema}\n", + " GENERATED CONFIG: <generated_response_in_json>\n", " \"\"\"\n", "\n", - " # ai_message = llm_openai(prompt)\n", - " ai_message = llm_mistral(prompt)\n", + " ai_message = llm_openai(prompt)\n", + " print(f\"AI message: {ai_message}\")\n", "\n", - " # Mistral parsing\n", + " # Parsing for Mistral model\n", " ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip()\n", "\n", " try:\n", " route_config = json.loads(ai_message)\n", - " function_description = specification[\"function\"][\"description\"]\n", - " route_config[\"utterances\"].append(function_description)\n", " logger.info(f\"Generated config: {route_config}\")\n", " return route_config\n", " except json.JSONDecodeError as json_error:\n", " logger.error(f\"JSON parsing error {json_error}\")\n", + " print(f\"AI message: {ai_message}\")\n", " return {}" ] }, @@ -168,60 +148,49 @@ "cell_type": "markdown", "metadata": {}, "source": [ - "Extract function parameters using `Mistal` open-source model" + "Extract function parameters using `Mistral` open-source model" ] }, { "cell_type": "code", - "execution_count": 76, + "execution_count": 138, "metadata": {}, "outputs": [], "source": [ - "def extract_parameters(query: str, specification: dict) -> dict:\n", + "def extract_parameters(query: str, schema: dict) -> dict:\n", " logger.info(\"Extracting parameters...\")\n", " example_query = \"what is the weather in London?\"\n", "\n", - " example_specification = {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"get_time\",\n", - " \"description\": \"Get the current time\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"location\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"Example of city and state\",\n", - " },\n", - " },\n", - " \"required\": [\"location\"],\n", - " },\n", - " },\n", - " }\n", + " class GetWeatherSchema(BaseModel):\n", + " location: str\n", + "\n", + " class Config:\n", + " name = \"get_weather\"\n", + "\n", + " example_schema = GetWeatherSchema.schema()\n", "\n", " example_parameters = {\n", " \"location\": \"London\",\n", " }\n", "\n", " prompt = f\"\"\"\n", - " Given the following specification and query, extract the parameters from the query,\n", - " in a valid JSON format enclosed in double quotes.\n", + " Given the following function schema and query, extract the parameters from the\n", + " query, in a valid JSON format.\n", " Example:\n", - " SPECIFICATION:\n", - " {example_specification}\n", + " SCHEMA:\n", + " {example_schema}\n", " QUERY:\n", " {example_query}\n", " PARAMETERS:\n", " {example_parameters}\n", - " GIVEN SPECIFICATION:\n", - " {specification}\n", + " GIVEN SCHEMA:\n", + " {schema}\n", " GIVEN QUERY:\n", " {query}\n", " EXTRACTED PARAMETERS:\n", " \"\"\"\n", "\n", - " # ai_message = llm_openai(prompt)\n", - " ai_message = llm_mistral(prompt)\n", + " ai_message = llm_openai(prompt)\n", "\n", " ai_message = ai_message.replace(\"'\", '\"').strip()\n", "\n", @@ -234,22 +203,6 @@ " return {}" ] }, - { - "cell_type": "code", - "execution_count": 77, - "metadata": {}, - "outputs": [], - "source": [ - "def validate_parameters(function_parameters, specification):\n", - " required_params = specification[\"function\"][\"parameters\"][\"required\"]\n", - " missing_params = [\n", - " param for param in required_params if param not in function_parameters\n", - " ]\n", - " if missing_params:\n", - " raise ValueError(f\"Missing required parameters: {missing_params}\")\n", - " return True" - ] - }, { "cell_type": "markdown", "metadata": {}, @@ -259,7 +212,7 @@ }, { "cell_type": "code", - "execution_count": 63, + "execution_count": 139, "metadata": {}, "outputs": [], "source": [ @@ -268,12 +221,20 @@ "from semantic_router.layer import RouteLayer\n", "from semantic_router.utils.logger import logger\n", "\n", + "\n", "def get_route_layer(config: list[dict]) -> RouteLayer:\n", " logger.info(\"Getting route layer...\")\n", " encoder = CohereEncoder()\n", - " routes = [\n", - " Route(name=route[\"name\"], utterances=route[\"utterances\"]) for route in config\n", - " ]\n", + "\n", + " routes = []\n", + " print(f\"Config: {config}\")\n", + " for route in config:\n", + " if \"name\" in route and \"utterances\" in route:\n", + " print(f\"Route: {route}\")\n", + " routes.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", + " else:\n", + " logger.warning(f\"Misconfigured route: {route}\")\n", + "\n", " return RouteLayer(encoder=encoder, routes=routes)" ] }, @@ -286,103 +247,64 @@ }, { "cell_type": "code", - "execution_count": 72, + "execution_count": 140, "metadata": {}, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-14 16:22:59 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:02 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"What is the current time in New York?\",\n", - " \"Tell me the time in Los Angeles\",\n", - " \"What is the current time in Chicago?\",\n", - " \"The time in Houston?\",\n", - " \"What is the current time in Philadelphia?\"\n", - " ]\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:02 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['What is the current time in New York?', 'Tell me the time in Los Angeles', 'What is the current time in Chicago?', 'The time in Houston?', 'What is the current time in Philadelphia?', 'Get the current time']}\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:02 INFO semantic_router.utils.logger Getting route layer...\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:03 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:04 INFO semantic_router.utils.logger AI message: \n", - " {\"location\": \"Taiwan\"}\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Taiwan'}\u001b[0m\n" + "\u001b[32m2023-12-14 17:28:22 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", + "\u001b[32m2023-12-14 17:28:28 INFO semantic_router.utils.logger AI message: {\"name\": \"get_time\", \"utterances\": [\"What is the time in SF?\", \"What is the current time in London?\", \"Time in Tokyo?\", \"Tell me the time in New York\", \"What is the time now in Paris?\"]}\u001b[0m\n", + "\u001b[32m2023-12-14 17:28:28 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['What is the time in SF?', 'What is the current time in London?', 'Time in Tokyo?', 'Tell me the time in New York', 'What is the time now in Paris?']}\u001b[0m\n", + "\u001b[32m2023-12-14 17:28:28 INFO semantic_router.utils.logger Getting route layer...\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Calling get_time function with location: Taiwan\n" + "AI message: {\"name\": \"get_time\", \"utterances\": [\"What is the time in SF?\", \"What is the current time in London?\", \"Time in Tokyo?\", \"Tell me the time in New York\", \"What is the time now in Paris?\"]}\n", + "Config: [{'name': 'get_time', 'utterances': ['What is the time in SF?', 'What is the current time in London?', 'Time in Tokyo?', 'Tell me the time in New York', 'What is the time now in Paris?']}]\n", + "Route: {'name': 'get_time', 'utterances': ['What is the time in SF?', 'What is the current time in London?', 'Time in Tokyo?', 'Tell me the time in New York', 'What is the time now in Paris?']}\n", + "None What is the weather like in Barcelona?\n" ] }, { "name": "stderr", "output_type": "stream", "text": [ - "\u001b[32m2023-12-14 16:23:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:05 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"location\": \"London\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:05 INFO semantic_router.utils.logger Extracted parameters: {'location': 'London'}\u001b[0m\n" + "\u001b[32m2023-12-14 17:28:29 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n" ] }, { "name": "stdout", "output_type": "stream", "text": [ - "Calling get_time function with location: London\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-14 16:23:06 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:07 INFO semantic_router.utils.logger AI message: \n", - " {\n", - " \"location\": \"Kaunas\"\n", - " }\u001b[0m\n", - "\u001b[32m2023-12-14 16:23:07 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Kaunas'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Calling get_time function with location: Kaunas\n" + "get_time What time is it in Taiwan?\n" ] } ], "source": [ + "from pydantic import BaseModel\n", + "\n", + "class GetTimeSchema(BaseModel):\n", + " location: str\n", + "\n", + " class Config:\n", + " name = \"get_time\"\n", + "\n", + "get_time_schema = GetTimeSchema.schema()\n", + "\n", "def get_time(location: str) -> str:\n", + " # Validate parameters\n", + " GetTimeSchema(location=location)\n", + "\n", " print(f\"Calling get_time function with location: {location}\")\n", " return \"get_time\"\n", "\n", - "get_time_spec = {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"get_time\",\n", - " \"description\": \"Get the current time\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"location\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"The city and state\",\n", - " },\n", - " },\n", - " \"required\": [\"location\"],\n", - " },\n", - " },\n", - "}\n", "\n", - "route_config = generate_config(get_time_spec)\n", + "route_config = generate_config(get_time_schema)\n", "route_layer = get_route_layer([route_config])\n", "\n", "queries = [\n", @@ -398,18 +320,23 @@ "# Calling functions\n", "for query in queries:\n", " function_name = route_layer(query)\n", + " print(function_name, query)\n", "\n", " if function_name == \"get_time\":\n", - " function_parameters = extract_parameters(query, get_time_spec)\n", + " function_parameters = extract_parameters(query, get_time_schema)\n", " try:\n", - " if validate_parameters(function_parameters, get_time_spec):\n", - "\n", - " # Call the function\n", - " get_time(**function_parameters)\n", - "\n", + " # Call the function\n", + " get_time(**function_parameters)\n", " except ValueError as e:\n", " logger.error(f\"Error: {e}\")" ] + }, + { + "cell_type": "code", + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [] } ], "metadata": { -- GitLab