From 68c70823de137051e433acdbae65e4d5ee026e2d Mon Sep 17 00:00:00 2001
From: Simonas <20096648+simjak@users.noreply.github.com>
Date: Thu, 14 Dec 2023 11:17:20 +0200
Subject: [PATCH] WIP

---
 docs/examples/function_calling.ipynb | 286 +++++++++++++++++++++++++++
 pyproject.toml                       |   2 +-
 2 files changed, 287 insertions(+), 1 deletion(-)
 create mode 100644 docs/examples/function_calling.ipynb

diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb
new file mode 100644
index 00000000..8e65e71e
--- /dev/null
+++ b/docs/examples/function_calling.ipynb
@@ -0,0 +1,286 @@
+{
+ "cells": [
+  {
+   "cell_type": "code",
+   "execution_count": 3,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "# https://platform.openai.com/docs/guides/function-calling\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 6,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "get_weather\n",
+      "get_time\n",
+      "get_news\n"
+     ]
+    }
+   ],
+   "source": [
+    "from semantic_router.schema import Route\n",
+    "\n",
+    "from semantic_router.encoders import CohereEncoder\n",
+    "from semantic_router.layer import RouteLayer\n",
+    "\n",
+    "encoder = CohereEncoder()\n",
+    "\n",
+    "config = [\n",
+    "    {\n",
+    "        \"name\": \"get_weather\",\n",
+    "        \"utterances\": [\n",
+    "            \"What is the weather like in SF?\",\n",
+    "            \"What is the weather in Cyprus?\",\n",
+    "            \"weather in London?\",\n",
+    "        ],\n",
+    "    },\n",
+    "    {\n",
+    "        \"name\": \"get_time\",\n",
+    "        \"utterances\": [\n",
+    "            \"What time is it in New York?\",\n",
+    "            \"What time is it in London?\",\n",
+    "            \"What is the time in Paris?\",\n",
+    "        ],\n",
+    "    },\n",
+    "    {\n",
+    "        \"name\": \"get_news\",\n",
+    "        \"utterances\": [\n",
+    "            \"What is happening in the world?\",\n",
+    "            \"What is the latest news?\",\n",
+    "            \"What is the latest news in the US?\",\n",
+    "        ],\n",
+    "    },\n",
+    "]\n",
+    "\n",
+    "routes = [Route(name=route[\"name\"], utterances=route[\"utterances\"]) for route in config]\n",
+    "\n",
+    "route_layer = RouteLayer(encoder=encoder, routes=routes)\n",
+    "\n",
+    "queries = [\n",
+    "    \"What is the weather like in Barcelona?\",\n",
+    "    \"What time is it in Taiwan?\",\n",
+    "    \"What is happening in the world?\",\n",
+    "]\n",
+    "\n",
+    "for query in queries:\n",
+    "    function_name = route_layer(query)\n",
+    "    print(function_name)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": [
+    "def get_weather(location: str):\n",
+    "    print(f\"getting weather for {location}\")\n",
+    "\n",
+    "\n",
+    "def extract_function_parameters(query: str, function: Callable):\n",
+    "    # llm(\n",
+    "    #     query=query,\n",
+    "    #     function=function,\n",
+    "    #     prompt=\"What are the parameters for this function?\",\n",
+    "    # )\n",
+    "    print(\"Extracting function parameters..\")\n",
+    "\n",
+    "\n",
+    "if category == \"get_weather\":\n",
+    "    print(f\"Category is `{category}`\")\n",
+    "    params = extract_function_parameters(query, get_weather)\n",
+    "    print(\"Getting weather..\")\n",
+    "    # get_weather(**params)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 10,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "None\n"
+     ]
+    }
+   ],
+   "source": [
+    "print(generated_config)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 21,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "Route config: {'name': 'get_time', 'utterances': ['What is the current time in San Francisco?', 'What time is it in New York?', 'Current time in London?']}\n"
+     ]
+    }
+   ],
+   "source": [
+    "import json\n",
+    "\n",
+    "example_specification = (\n",
+    "    {\n",
+    "        \"type\": \"function\",\n",
+    "        \"function\": {\n",
+    "            \"name\": \"get_current_weather\",\n",
+    "            \"description\": \"Get the current weather\",\n",
+    "            \"parameters\": {\n",
+    "                \"type\": \"object\",\n",
+    "                \"properties\": {\n",
+    "                    \"location\": {\n",
+    "                        \"type\": \"string\",\n",
+    "                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
+    "                    },\n",
+    "                    \"format\": {\n",
+    "                        \"type\": \"string\",\n",
+    "                        \"enum\": [\"celsius\", \"fahrenheit\"],\n",
+    "                        \"description\": \"The temperature unit to use. Infer this from the users location.\",\n",
+    "                    },\n",
+    "                },\n",
+    "                \"required\": [\"location\", \"format\"],\n",
+    "            },\n",
+    "        },\n",
+    "    },\n",
+    ")\n",
+    "\n",
+    "example_config = {\n",
+    "    \"name\": \"get_weather\",\n",
+    "    \"utterances\": [\n",
+    "        \"What is the weather like in SF?\",\n",
+    "        \"What is the weather in Cyprus?\",\n",
+    "        \"weather in London?\",\n",
+    "    ],\n",
+    "}\n",
+    "\n",
+    "specification = (\n",
+    "    {\n",
+    "        \"type\": \"function\",\n",
+    "        \"function\": {\n",
+    "            \"name\": \"get_time\",\n",
+    "            \"description\": \"Get the current time\",\n",
+    "            \"parameters\": {\n",
+    "                \"type\": \"object\",\n",
+    "                \"properties\": {\n",
+    "                    \"location\": {\n",
+    "                        \"type\": \"string\",\n",
+    "                        \"description\": \"The city and state, e.g. San Francisco, CA\",\n",
+    "                    },\n",
+    "                },\n",
+    "                \"required\": [\"location\"],\n",
+    "            },\n",
+    "        },\n",
+    "    },\n",
+    ")\n",
+    "\n",
+    "prompt = f\"\"\"\n",
+    "    Given the following specification, generate a config in JSON format\n",
+    "    Example:\n",
+    "    SPECIFICATION:\n",
+    "    {example_specification}\n",
+    "\n",
+    "    CONFIG:\n",
+    "    {example_config}\n",
+    "\n",
+    "    GIVEN SPECIFICATION:\n",
+    "    {specification}\n",
+    "\n",
+    "    GENERATED CONFIG:\n",
+    "\"\"\"\n",
+    "\n",
+    "\n",
+    "response = openai.chat.completions.create(\n",
+    "    model=\"gpt-4\",\n",
+    "    messages=[\n",
+    "        {\"role\": \"system\", \"content\": f\"{prompt}\"},\n",
+    "    ],\n",
+    ")\n",
+    "\n",
+    "ai_message = response.choices[0].message.content\n",
+    "if ai_message:\n",
+    "    route_config = json.loads(ai_message)\n",
+    "    print(f\"Route config: {route_config}\")\n"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": 23,
+   "metadata": {},
+   "outputs": [
+    {
+     "name": "stdout",
+     "output_type": "stream",
+     "text": [
+      "None\n",
+      "get_time\n",
+      "get_time\n",
+      "get_time\n",
+      "None\n",
+      "None\n"
+     ]
+    }
+   ],
+   "source": [
+    "routes = [Route(name=route[\"name\"], utterances=route[\"utterances\"]) for route in [route_config]]\n",
+    "\n",
+    "route_layer = RouteLayer(encoder=encoder, routes=routes)\n",
+    "\n",
+    "queries = [\n",
+    "    \"What is the weather like in Barcelona?\",\n",
+    "    \"What time is it in Taiwan?\",\n",
+    "    \"What is happening in the world?\",\n",
+    "    \"what is the time in Kaunas?\"\n",
+    "    \"Im bored\",\n",
+    "    \"I want to play a game\",\n",
+    "    \"Banana\"\n",
+    "]\n",
+    "\n",
+    "for query in queries:\n",
+    "    function_name = route_layer(query)\n",
+    "    print(function_name)"
+   ]
+  },
+  {
+   "cell_type": "code",
+   "execution_count": null,
+   "metadata": {},
+   "outputs": [],
+   "source": []
+  }
+ ],
+ "metadata": {
+  "kernelspec": {
+   "display_name": ".venv",
+   "language": "python",
+   "name": "python3"
+  },
+  "language_info": {
+   "codemirror_mode": {
+    "name": "ipython",
+    "version": 3
+   },
+   "file_extension": ".py",
+   "mimetype": "text/x-python",
+   "name": "python",
+   "nbconvert_exporter": "python",
+   "pygments_lexer": "ipython3",
+   "version": "3.11.3"
+  }
+ },
+ "nbformat": 4,
+ "nbformat_minor": 2
+}
diff --git a/pyproject.toml b/pyproject.toml
index 03a35cf6..b2024988 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -36,7 +36,7 @@ requires = ["poetry-core"]
 build-backend = "poetry.core.masonry.api"
 
 [tool.ruff.per-file-ignores]
-"*.ipynb" = ["E402"]
+"*.ipynb" = ["ALL"]
 
 [tool.mypy]
 ignore_missing_imports = true
-- 
GitLab