From 418a0c8bb1ac5103c8c26190082a7d5890eaf7ac Mon Sep 17 00:00:00 2001
From: Simonas <20096648+simjak@users.noreply.github.com>
Date: Thu, 14 Dec 2023 11:55:36 +0200
Subject: [PATCH] WIP: config generation

---
 docs/examples/function_calling.ipynb | 314 ++++++++++-----------------
 1 file changed, 117 insertions(+), 197 deletions(-)

diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb
index 8e65e71e..ef61ca64 100644
--- a/docs/examples/function_calling.ipynb
+++ b/docs/examples/function_calling.ipynb
@@ -6,260 +6,180 @@
    "metadata": {},
    "outputs": [],
    "source": [
-    "# https://platform.openai.com/docs/guides/function-calling\n"
+    "# https://platform.openai.com/docs/guides/function-calling"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 6,
+   "execution_count": 21,
    "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "get_weather\n",
-      "get_time\n",
-      "get_news\n"
-     ]
-    }
-   ],
+   "outputs": [],
    "source": [
-    "from semantic_router.schema import Route\n",
-    "\n",
-    "from semantic_router.encoders import CohereEncoder\n",
-    "from semantic_router.layer import RouteLayer\n",
-    "\n",
-    "encoder = CohereEncoder()\n",
+    "import json\n",
+    "import openai\n",
+    "\n",
+    "\n",
+    "def generate_config(specification: dict) -> dict:\n",
+    "    print(\"Generating config...\")\n",
+    "    example_specification = (\n",
+    "        {\n",
+    "            \"type\": \"function\",\n",
+    "            \"function\": {\n",
+    "                \"name\": \"get_current_weather\",\n",
+    "                \"description\": \"Get the current weather\",\n",
+    "                \"parameters\": {\n",
+    "                    \"type\": \"object\",\n",
+    "                    \"properties\": {\n",
+    "                        \"location\": {\n",
+    "                            \"type\": \"string\",\n",
+    "                            \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
+    "                        },\n",
+    "                        \"format\": {\n",
+    "                            \"type\": \"string\",\n",
+    "                            \"enum\": [\"celsius\", \"fahrenheit\"],\n",
+    "                            \"description\": \"The temperature unit to use. Infer this \"\n",
+    "                            \" from the users location.\",\n",
+    "                        },\n",
+    "                    },\n",
+    "                    \"required\": [\"location\", \"format\"],\n",
+    "                },\n",
+    "            },\n",
+    "        },\n",
+    "    )\n",
     "\n",
-    "config = [\n",
-    "    {\n",
+    "    example_config = {\n",
     "        \"name\": \"get_weather\",\n",
     "        \"utterances\": [\n",
     "            \"What is the weather like in SF?\",\n",
     "            \"What is the weather in Cyprus?\",\n",
     "            \"weather in London?\",\n",
+    "            \"Tell me the weather in New York\",\n",
+    "            \"what is the current weather in Paris?\",\n",
     "        ],\n",
-    "    },\n",
-    "    {\n",
-    "        \"name\": \"get_time\",\n",
-    "        \"utterances\": [\n",
-    "            \"What time is it in New York?\",\n",
-    "            \"What time is it in London?\",\n",
-    "            \"What is the time in Paris?\",\n",
-    "        ],\n",
-    "    },\n",
-    "    {\n",
-    "        \"name\": \"get_news\",\n",
-    "        \"utterances\": [\n",
-    "            \"What is happening in the world?\",\n",
-    "            \"What is the latest news?\",\n",
-    "            \"What is the latest news in the US?\",\n",
-    "        ],\n",
-    "    },\n",
-    "]\n",
+    "    }\n",
     "\n",
-    "routes = [Route(name=route[\"name\"], utterances=route[\"utterances\"]) for route in config]\n",
+    "    prompt = f\"\"\"\n",
+    "    Given the following specification, generate a config in a valid JSON format,\n",
+    "    Example:\n",
+    "    SPECIFICATION:\n",
+    "    {example_specification}\n",
     "\n",
-    "route_layer = RouteLayer(encoder=encoder, routes=routes)\n",
+    "    CONFIG:\n",
+    "    {example_config}\n",
     "\n",
-    "queries = [\n",
-    "    \"What is the weather like in Barcelona?\",\n",
-    "    \"What time is it in Taiwan?\",\n",
-    "    \"What is happening in the world?\",\n",
-    "]\n",
+    "    GIVEN SPECIFICATION:\n",
+    "    {specification}\n",
     "\n",
-    "for query in queries:\n",
-    "    function_name = route_layer(query)\n",
-    "    print(function_name)"
+    "    GENERATED CONFIG:\n",
+    "    \"\"\"\n",
+    "\n",
+    "    try:\n",
+    "        response = openai.chat.completions.create(\n",
+    "            model=\"gpt-4\",\n",
+    "            messages=[\n",
+    "                {\"role\": \"system\", \"content\": f\"{prompt}\"},\n",
+    "            ],\n",
+    "        )\n",
+    "        ai_message = response.choices[0].message.content\n",
+    "        print(\"AI message:\", ai_message)\n",
+    "        route_config = json.loads(ai_message)\n",
+    "        return route_config\n",
+    "\n",
+    "    except json.JSONDecodeError as json_error:\n",
+    "        raise Exception(\"JSON parsing error\", json_error)\n",
+    "    except Exception as e:\n",
+    "        raise Exception(\"Error generating config from Openai\", e)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": null,
+   "execution_count": 22,
    "metadata": {},
    "outputs": [],
    "source": [
-    "def get_weather(location: str):\n",
-    "    print(f\"getting weather for {location}\")\n",
-    "\n",
-    "\n",
-    "def extract_function_parameters(query: str, function: Callable):\n",
-    "    # llm(\n",
-    "    #     query=query,\n",
-    "    #     function=function,\n",
-    "    #     prompt=\"What are the parameters for this function?\",\n",
-    "    # )\n",
-    "    print(\"Extracting function parameters..\")\n",
-    "\n",
+    "from semantic_router.schema import Route\n",
+    "from semantic_router.encoders import CohereEncoder\n",
+    "from semantic_router.layer import RouteLayer\n",
     "\n",
-    "if category == \"get_weather\":\n",
-    "    print(f\"Category is `{category}`\")\n",
-    "    params = extract_function_parameters(query, get_weather)\n",
-    "    print(\"Getting weather..\")\n",
-    "    # get_weather(**params)"
+    "def get_route_layer(config: list[dict]) -> RouteLayer:\n",
+    "    print(\"Getting route layer...\")\n",
+    "    encoder = CohereEncoder()\n",
+    "    routes = [\n",
+    "        Route(name=route[\"name\"], utterances=route[\"utterances\"]) for route in config\n",
+    "    ]\n",
+    "    return RouteLayer(encoder=encoder, routes=routes)"
    ]
   },
   {
    "cell_type": "code",
-   "execution_count": 10,
+   "execution_count": 25,
    "metadata": {},
    "outputs": [
     {
      "name": "stdout",
      "output_type": "stream",
      "text": [
-      "None\n"
+      "Generating config...\n",
+      "AI message: {\n",
+      "    \"name\": \"get_time\",\n",
+      "    \"utterances\": [\n",
+      "        \"What is the current time in SF?\",\n",
+      "        \"Tell me the time in London\",\n",
+      "        \"Could you tell me the time in New York?\",\n",
+      "        \"May I know the current time in Paris?\",\n",
+      "        \"Can you tell me what time is it in Singapore?\"\n",
+      "    ]\n",
+      "}\n",
+      "Getting route layer...\n",
+      "Getting function name for queries:\n",
+      "\n",
+      "(None, 'What is the weather like in Barcelona?')\n",
+      "('get_time', 'What time is it in Taiwan?')\n",
+      "(None, 'What is happening in the world?')\n",
+      "('get_time', 'what is the time in Kaunas?')\n",
+      "(None, 'Im bored')\n",
+      "(None, 'I want to play a game')\n",
+      "(None, 'Banana')\n"
      ]
     }
    ],
    "source": [
-    "print(generated_config)"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 21,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "Route config: {'name': 'get_time', 'utterances': ['What is the current time in San Francisco?', 'What time is it in New York?', 'Current time in London?']}\n"
-     ]
-    }
-   ],
-   "source": [
-    "import json\n",
-    "\n",
-    "example_specification = (\n",
-    "    {\n",
-    "        \"type\": \"function\",\n",
-    "        \"function\": {\n",
-    "            \"name\": \"get_current_weather\",\n",
-    "            \"description\": \"Get the current weather\",\n",
-    "            \"parameters\": {\n",
-    "                \"type\": \"object\",\n",
-    "                \"properties\": {\n",
-    "                    \"location\": {\n",
-    "                        \"type\": \"string\",\n",
-    "                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
-    "                    },\n",
-    "                    \"format\": {\n",
-    "                        \"type\": \"string\",\n",
-    "                        \"enum\": [\"celsius\", \"fahrenheit\"],\n",
-    "                        \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
-    "                    },\n",
+    "specification = {\n",
+    "    \"type\": \"function\",\n",
+    "    \"function\": {\n",
+    "        \"name\": \"get_time\",\n",
+    "        \"description\": \"Get the current time\",\n",
+    "        \"parameters\": {\n",
+    "            \"type\": \"object\",\n",
+    "            \"properties\": {\n",
+    "                \"location\": {\n",
+    "                    \"type\": \"string\",\n",
+    "                    \"description\": \"The city and state\",\n",
     "                },\n",
-    "                \"required\": [\"location\", \"format\"],\n",
     "            },\n",
+    "            \"required\": [\"location\"],\n",
     "        },\n",
     "    },\n",
-    ")\n",
-    "\n",
-    "example_config = {\n",
-    "    \"name\": \"get_weather\",\n",
-    "    \"utterances\": [\n",
-    "        \"What is the weather like in SF?\",\n",
-    "        \"What is the weather in Cyprus?\",\n",
-    "        \"weather in London?\",\n",
-    "    ],\n",
     "}\n",
     "\n",
-    "specification = (\n",
-    "    {\n",
-    "        \"type\": \"function\",\n",
-    "        \"function\": {\n",
-    "            \"name\": \"get_time\",\n",
-    "            \"description\": \"Get the current time\",\n",
-    "            \"parameters\": {\n",
-    "                \"type\": \"object\",\n",
-    "                \"properties\": {\n",
-    "                    \"location\": {\n",
-    "                        \"type\": \"string\",\n",
-    "                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
-    "                    },\n",
-    "                },\n",
-    "                \"required\": [\"location\"],\n",
-    "            },\n",
-    "        },\n",
-    "    },\n",
-    ")\n",
-    "\n",
-    "prompt = f\"\"\"\n",
-    "    Given the following specification, generate a config in JSON format\n",
-    "    Example:\n",
-    "    SPECIFICATION:\n",
-    "    {example_specification}\n",
-    "\n",
-    "    CONFIG:\n",
-    "    {example_config}\n",
-    "\n",
-    "    GIVEN SPECIFICATION:\n",
-    "    {specification}\n",
-    "\n",
-    "    GENERATED CONFIG:\n",
-    "\"\"\"\n",
-    "\n",
-    "\n",
-    "response = openai.chat.completions.create(\n",
-    "    model=\"gpt-4\",\n",
-    "    messages=[\n",
-    "        {\"role\": \"system\", \"content\": f\"{prompt}\"},\n",
-    "    ],\n",
-    ")\n",
-    "\n",
-    "ai_message = response.choices[0].message.content\n",
-    "if ai_message:\n",
-    "    route_config = json.loads(ai_message)\n",
-    "    print(f\"Route config: {route_config}\")\n"
-   ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": 23,
-   "metadata": {},
-   "outputs": [
-    {
-     "name": "stdout",
-     "output_type": "stream",
-     "text": [
-      "None\n",
-      "get_time\n",
-      "get_time\n",
-      "get_time\n",
-      "None\n",
-      "None\n"
-     ]
-    }
-   ],
-   "source": [
-    "routes = [Route(name=route[\"name\"], utterances=route[\"utterances\"]) for route in [route_config]]\n",
-    "\n",
-    "route_layer = RouteLayer(encoder=encoder, routes=routes)\n",
+    "route_config = generate_config(specification)\n",
+    "route_layer = get_route_layer([route_config])\n",
     "\n",
     "queries = [\n",
     "    \"What is the weather like in Barcelona?\",\n",
     "    \"What time is it in Taiwan?\",\n",
     "    \"What is happening in the world?\",\n",
-    "    \"what is the time in Kaunas?\"\n",
+    "    \"what is the time in Kaunas?\",\n",
     "    \"Im bored\",\n",
     "    \"I want to play a game\",\n",
-    "    \"Banana\"\n",
+    "    \"Banana\",\n",
     "]\n",
     "\n",
+    "print(\"Getting function name for queries:\\n\")\n",
     "for query in queries:\n",
     "    function_name = route_layer(query)\n",
-    "    print(function_name)"
+    "    print((function_name, query))"
    ]
-  },
-  {
-   "cell_type": "code",
-   "execution_count": null,
-   "metadata": {},
-   "outputs": [],
-   "source": []
   }
  ],
  "metadata": {
-- 
GitLab