From 418a0c8bb1ac5103c8c26190082a7d5890eaf7ac Mon Sep 17 00:00:00 2001 From: Simonas <20096648+simjak@users.noreply.github.com> Date: Thu, 14 Dec 2023 11:55:36 +0200 Subject: [PATCH] WIP: config generation --- docs/examples/function_calling.ipynb | 314 ++++++++++----------------- 1 file changed, 117 insertions(+), 197 deletions(-) diff --git a/docs/examples/function_calling.ipynb b/docs/examples/function_calling.ipynb index 8e65e71e..ef61ca64 100644 --- a/docs/examples/function_calling.ipynb +++ b/docs/examples/function_calling.ipynb @@ -6,260 +6,180 @@ "metadata": {}, "outputs": [], "source": [ - "# https://platform.openai.com/docs/guides/function-calling\n" + "# https://platform.openai.com/docs/guides/function-calling" ] }, { "cell_type": "code", - "execution_count": 6, + "execution_count": 21, "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "get_weather\n", - "get_time\n", - "get_news\n" - ] - } - ], + "outputs": [], "source": [ - "from semantic_router.schema import Route\n", - "\n", - "from semantic_router.encoders import CohereEncoder\n", - "from semantic_router.layer import RouteLayer\n", - "\n", - "encoder = CohereEncoder()\n", + "import json\n", + "import openai\n", + "\n", + "\n", + "def generate_config(specification: dict) -> dict:\n", + " print(\"Generating config...\")\n", + " example_specification = (\n", + " {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_current_weather\",\n", + " \"description\": \"Get the current weather\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", + " },\n", + " \"format\": {\n", + " \"type\": \"string\",\n", + " \"enum\": [\"celsius\", \"fahrenheit\"],\n", + " \"description\": \"The temperature unit to use. Infer this \"\n", + " \" from the users location.\",\n", + " },\n", + " },\n", + " \"required\": [\"location\", \"format\"],\n", + " },\n", + " },\n", + " },\n", + " )\n", "\n", - "config = [\n", - " {\n", + " example_config = {\n", " \"name\": \"get_weather\",\n", " \"utterances\": [\n", " \"What is the weather like in SF?\",\n", " \"What is the weather in Cyprus?\",\n", " \"weather in London?\",\n", + " \"Tell me the weather in New York\",\n", + " \"what is the current weather in Paris?\",\n", " ],\n", - " },\n", - " {\n", - " \"name\": \"get_time\",\n", - " \"utterances\": [\n", - " \"What time is it in New York?\",\n", - " \"What time is it in London?\",\n", - " \"What is the time in Paris?\",\n", - " ],\n", - " },\n", - " {\n", - " \"name\": \"get_news\",\n", - " \"utterances\": [\n", - " \"What is happening in the world?\",\n", - " \"What is the latest news?\",\n", - " \"What is the latest news in the US?\",\n", - " ],\n", - " },\n", - "]\n", + " }\n", "\n", - "routes = [Route(name=route[\"name\"], utterances=route[\"utterances\"]) for route in config]\n", + " prompt = f\"\"\"\n", + " Given the following specification, generate a config in a valid JSON format,\n", + " Example:\n", + " SPECIFICATION:\n", + " {example_specification}\n", "\n", - "route_layer = RouteLayer(encoder=encoder, routes=routes)\n", + " CONFIG:\n", + " {example_config}\n", "\n", - "queries = [\n", - " \"What is the weather like in Barcelona?\",\n", - " \"What time is it in Taiwan?\",\n", - " \"What is happening in the world?\",\n", - "]\n", + " GIVEN SPECIFICATION:\n", + " {specification}\n", "\n", - "for query in queries:\n", - " function_name = route_layer(query)\n", - " print(function_name)" + " GENERATED CONFIG:\n", + " \"\"\"\n", + "\n", + " try:\n", + " response = openai.chat.completions.create(\n", + " model=\"gpt-4\",\n", + " messages=[\n", + " {\"role\": \"system\", \"content\": f\"{prompt}\"},\n", + " ],\n", + " )\n", + " ai_message = response.choices[0].message.content\n", + " print(\"AI message:\", ai_message)\n", + " route_config = json.loads(ai_message)\n", + " return route_config\n", + "\n", + " except json.JSONDecodeError as json_error:\n", + " raise Exception(\"JSON parsing error\", json_error)\n", + " except Exception as e:\n", + " raise Exception(\"Error generating config from Openai\", e)" ] }, { "cell_type": "code", - "execution_count": null, + "execution_count": 22, "metadata": {}, "outputs": [], "source": [ - "def get_weather(location: str):\n", - " print(f\"getting weather for {location}\")\n", - "\n", - "\n", - "def extract_function_parameters(query: str, function: Callable):\n", - " # llm(\n", - " # query=query,\n", - " # function=function,\n", - " # prompt=\"What are the parameters for this function?\",\n", - " # )\n", - " print(\"Extracting function parameters..\")\n", - "\n", + "from semantic_router.schema import Route\n", + "from semantic_router.encoders import CohereEncoder\n", + "from semantic_router.layer import RouteLayer\n", "\n", - "if category == \"get_weather\":\n", - " print(f\"Category is `{category}`\")\n", - " params = extract_function_parameters(query, get_weather)\n", - " print(\"Getting weather..\")\n", - " # get_weather(**params)" + "def get_route_layer(config: list[dict]) -> RouteLayer:\n", + " print(\"Getting route layer...\")\n", + " encoder = CohereEncoder()\n", + " routes = [\n", + " Route(name=route[\"name\"], utterances=route[\"utterances\"]) for route in config\n", + " ]\n", + " return RouteLayer(encoder=encoder, routes=routes)" ] }, { "cell_type": "code", - "execution_count": 10, + "execution_count": 25, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ - "None\n" + "Generating config...\n", + "AI message: {\n", + " \"name\": \"get_time\",\n", + " \"utterances\": [\n", + " \"What is the current time in SF?\",\n", + " \"Tell me the time in London\",\n", + " \"Could you tell me the time in New York?\",\n", + " \"May I know the current time in Paris?\",\n", + " \"Can you tell me what time is it in Singapore?\"\n", + " ]\n", + "}\n", + "Getting route layer...\n", + "Getting function name for queries:\n", + "\n", + "(None, 'What is the weather like in Barcelona?')\n", + "('get_time', 'What time is it in Taiwan?')\n", + "(None, 'What is happening in the world?')\n", + "('get_time', 'what is the time in Kaunas?')\n", + "(None, 'Im bored')\n", + "(None, 'I want to play a game')\n", + "(None, 'Banana')\n" ] } ], "source": [ - "print(generated_config)" - ] - }, - { - "cell_type": "code", - "execution_count": 21, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "Route config: {'name': 'get_time', 'utterances': ['What is the current time in San Francisco?', 'What time is it in New York?', 'Current time in London?']}\n" - ] - } - ], - "source": [ - "import json\n", - "\n", - "example_specification = (\n", - " {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"get_current_weather\",\n", - " \"description\": \"Get the current weather\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"location\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", - " },\n", - " \"format\": {\n", - " \"type\": \"string\",\n", - " \"enum\": [\"celsius\", \"fahrenheit\"],\n", - " \"description\": \"The temperature unit to use. Infer this from the users location.\",\n", - " },\n", + "specification = {\n", + " \"type\": \"function\",\n", + " \"function\": {\n", + " \"name\": \"get_time\",\n", + " \"description\": \"Get the current time\",\n", + " \"parameters\": {\n", + " \"type\": \"object\",\n", + " \"properties\": {\n", + " \"location\": {\n", + " \"type\": \"string\",\n", + " \"description\": \"The city and state\",\n", " },\n", - " \"required\": [\"location\", \"format\"],\n", " },\n", + " \"required\": [\"location\"],\n", " },\n", " },\n", - ")\n", - "\n", - "example_config = {\n", - " \"name\": \"get_weather\",\n", - " \"utterances\": [\n", - " \"What is the weather like in SF?\",\n", - " \"What is the weather in Cyprus?\",\n", - " \"weather in London?\",\n", - " ],\n", "}\n", "\n", - "specification = (\n", - " {\n", - " \"type\": \"function\",\n", - " \"function\": {\n", - " \"name\": \"get_time\",\n", - " \"description\": \"Get the current time\",\n", - " \"parameters\": {\n", - " \"type\": \"object\",\n", - " \"properties\": {\n", - " \"location\": {\n", - " \"type\": \"string\",\n", - " \"description\": \"The city and state, e.g. San Francisco, CA\",\n", - " },\n", - " },\n", - " \"required\": [\"location\"],\n", - " },\n", - " },\n", - " },\n", - ")\n", - "\n", - "prompt = f\"\"\"\n", - " Given the following specification, generate a config in JSON format\n", - " Example:\n", - " SPECIFICATION:\n", - " {example_specification}\n", - "\n", - " CONFIG:\n", - " {example_config}\n", - "\n", - " GIVEN SPECIFICATION:\n", - " {specification}\n", - "\n", - " GENERATED CONFIG:\n", - "\"\"\"\n", - "\n", - "\n", - "response = openai.chat.completions.create(\n", - " model=\"gpt-4\",\n", - " messages=[\n", - " {\"role\": \"system\", \"content\": f\"{prompt}\"},\n", - " ],\n", - ")\n", - "\n", - "ai_message = response.choices[0].message.content\n", - "if ai_message:\n", - " route_config = json.loads(ai_message)\n", - " print(f\"Route config: {route_config}\")\n" - ] - }, - { - "cell_type": "code", - "execution_count": 23, - "metadata": {}, - "outputs": [ - { - "name": "stdout", - "output_type": "stream", - "text": [ - "None\n", - "get_time\n", - "get_time\n", - "get_time\n", - "None\n", - "None\n" - ] - } - ], - "source": [ - "routes = [Route(name=route[\"name\"], utterances=route[\"utterances\"]) for route in [route_config]]\n", - "\n", - "route_layer = RouteLayer(encoder=encoder, routes=routes)\n", + "route_config = generate_config(specification)\n", + "route_layer = get_route_layer([route_config])\n", "\n", "queries = [\n", " \"What is the weather like in Barcelona?\",\n", " \"What time is it in Taiwan?\",\n", " \"What is happening in the world?\",\n", - " \"what is the time in Kaunas?\"\n", + " \"what is the time in Kaunas?\",\n", " \"Im bored\",\n", " \"I want to play a game\",\n", - " \"Banana\"\n", + " \"Banana\",\n", "]\n", "\n", + "print(\"Getting function name for queries:\\n\")\n", "for query in queries:\n", " function_name = route_layer(query)\n", - " print(function_name)" + " print((function_name, query))" ] - }, - { - "cell_type": "code", - "execution_count": null, - "metadata": {}, - "outputs": [], - "source": [] } ], "metadata": { -- GitLab