From a1370a5f63d04bf7020f04bcdad4305cf3734590 Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Thu, 14 Dec 2023 15:50:59 +0200 Subject: [PATCH] working function calling with Mistral --- docs/examples/function_calling.ipynb | 211 +++++++++++++-------------- 1 file changed, 99 insertions(+), 112 deletions(-) diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 88759cb3..1bdb8bf8 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -2,23 +2,30 @@ "cells": [ { "cell_type": "code", - "execution_count": 3, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# https://platform.openai.com/docs/guides/function-calling" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "## Define LLMs" + ] + }, { "cell_type": "code", - "execution_count": 58, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ + "# OpenAI\n", "import openai\n", "from semantic_router.utils.logger import logger\n", "\n", - "\n", "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", " try:\n", " response = openai.chat.completions.create(\n", @@ -38,14 +45,57 @@ }, { "cell_type": "code", - "execution_count": 64, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "# Mistral\n", + "import os\n", + "import requests\n", + "\n", + "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n", + "HF_API_TOKEN = os.environ[\"HF_API_TOKEN\"]\n", + "\n", + "def llm_mistral(prompt: str) -> str:\n", + " api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n", + " headers = {\n", + " \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n", + " \"Content-Type\": \"application/json\",\n", + " }\n", + "\n", + " response = requests.post(\n", + " api_url,\n", + " headers=headers,\n", + " json={\n", + " \"inputs\": prompt,\n", + " \"parameters\": {\n", + " \"max_new_tokens\": 200,\n", + " \"temperature\": 0.2,\n", + " },\n", + " },\n", + " )\n", + " if response.status_code != 200:\n", + " raise Exception(\"Failed to call HuggingFace API\", response.text)\n", + "\n", + " return response.json()[0]['generated_text']" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "### Now we need to generate config from function specification with `GPT-4`" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "from semantic_router.utils.logger import logger\n", "\n", - "\n", "def generate_config(specification: dict) -> dict:\n", " logger.info(\"Generating config...\")\n", " example_specification = (\n", @@ -113,9 +163,16 @@ " raise Exception(\"JSON parsing error\", json_error)" ] }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Extract function parameters using `Mistal` open-source model" + ] + }, { "cell_type": "code", - "execution_count": 65, + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -162,7 +219,9 @@ " EXTRACTED PARAMETERS:\n", " \"\"\"\n", "\n", - " ai_message = llm_openai(prompt)\n", + " # ai_message = llm_openai(prompt)\n", + " ai_message = llm_mistral(prompt)\n", + " print(ai_message)\n", "\n", " try:\n", " parameters = json.loads(ai_message)\n", @@ -174,7 +233,30 @@ }, { "cell_type": "code", - "execution_count": 66, + "execution_count": null, + "metadata": {}, + "outputs": [], + "source": [ + "def validate_parameters(function_parameters, specification):\n", + " required_params = specification[\"function\"][\"parameters\"][\"required\"]\n", + " missing_params = [\n", + " param for param in required_params if param not in function_parameters\n", + " ]\n", + " if missing_params:\n", + " raise ValueError(f\"Missing required parameters: {missing_params}\")\n", + " return True" + ] + }, + { + "cell_type": "markdown", + "metadata": {}, + "source": [ + "Set up the routing layer" + ] + }, + { + "cell_type": "code", + "execution_count": null, "metadata": {}, "outputs": [], "source": [ @@ -193,107 +275,23 @@ ] }, { - "cell_type": "code", - "execution_count": 73, + "cell_type": "markdown", "metadata": {}, - "outputs": [], "source": [ - "def validate_parameters(function_parameters, specification):\n", - " required_params = specification[\"function\"][\"parameters\"][\"required\"]\n", - " missing_params = [\n", - " param for param in required_params if param not in function_parameters\n", - " ]\n", - " if missing_params:\n", - " raise ValueError(f\"Missing required parameters: {missing_params}\")\n", - " return True" + "### Workflow" ] }, { "cell_type": "code", - "execution_count": 74, + "execution_count": null, "metadata": {}, - "outputs": [ - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-14 13:16:49 INFO semantic_router.utils.logger Generating config...\u001b[0m\n", - "\u001b[32m2023-12-14 13:16:54 INFO semantic_router.utils.logger AI message: {\"name\": \"get_time\", \"utterances\": [\"What is the current time in London?\", \"Tell me the time in New York\", \"What's happening now in Paris?\", \"time in San Francisco?\", \"Tell me the time in Sydney\"]}\u001b[0m\n", - "\u001b[32m2023-12-14 13:16:54 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['What is the current time in London?', 'Tell me the time in New York', \"What's happening now in Paris?\", 'time in San Francisco?', 'Tell me the time in Sydney', 'Get the current time']}\u001b[0m\n", - "\u001b[32m2023-12-14 13:16:54 INFO semantic_router.utils.logger Getting route layer...\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Getting function name for queries:\n", - "\n", - "What is the weather like in Barcelona? None {}\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-14 13:16:55 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-14 13:16:56 INFO semantic_router.utils.logger AI message: {\"location\": \"Taiwan\"}\u001b[0m\n", - "\u001b[32m2023-12-14 13:16:56 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Taiwan'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "What time is it in Taiwan? get_time {'location': 'Taiwan'}\n", - "Calling get_time function with location: Taiwan\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-14 13:16:56 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-14 13:16:58 INFO semantic_router.utils.logger AI message: {\"location\": \"the world\"}\u001b[0m\n", - "\u001b[32m2023-12-14 13:16:58 INFO semantic_router.utils.logger Extracted parameters: {'location': 'the world'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "What is happening in the world? get_time {'location': 'the world'}\n", - "Calling get_time function with location: the world\n" - ] - }, - { - "name": "stderr", - "output_type": "stream", - "text": [ - "\u001b[32m2023-12-14 13:16:58 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n", - "\u001b[32m2023-12-14 13:17:00 INFO semantic_router.utils.logger AI message: {\"location\": \"Kaunas\"}\u001b[0m\n", - "\u001b[32m2023-12-14 13:17:00 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Kaunas'}\u001b[0m\n" - ] - }, - { - "name": "stdout", - "output_type": "stream", - "text": [ - "what is the time in Kaunas? get_time {'location': 'Kaunas'}\n", - "Calling get_time function with location: Kaunas\n", - "Im bored None {}\n", - "I want to play a game None {}\n", - "Banana None {}\n" - ] - } - ], + "outputs": [], "source": [ "def get_time(location: str) -> str:\n", " print(f\"Calling get_time function with location: {location}\")\n", " return \"get_time\"\n", "\n", - "\n", - "specification = {\n", + "get_time_spec = {\n", " \"type\": \"function\",\n", " \"function\": {\n", " \"name\": \"get_time\",\n", @@ -311,7 +309,7 @@ " },\n", "}\n", "\n", - "route_config = generate_config(specification)\n", + "route_config = generate_config(get_time_spec)\n", "route_layer = get_route_layer([route_config])\n", "\n", "queries = [\n", @@ -324,29 +322,18 @@ " \"Banana\",\n", "]\n", "\n", - "print(\"Getting function name for queries:\\n\")\n", + "# Calling functions\n", "for query in queries:\n", " function_name = route_layer(query)\n", "\n", - " function_parameters = {}\n", - " if function_name:\n", - " function_parameters = extract_parameters(query, specification)\n", - " print(query, function_name, function_parameters)\n", - "\n", " if function_name == \"get_time\":\n", + " function_parameters = extract_parameters(query, get_time_spec)\n", " try:\n", - " if validate_parameters(function_parameters, specification):\n", + " if validate_parameters(function_parameters, get_time_spec):\n", " get_time(**function_parameters)\n", " except ValueError as e:\n", " logger.error(f\"Error: {e}\")" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { -- GitLab