{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "## Define LLMs" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "%reload_ext dotenv\n", "%dotenv" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# OpenAI\n", "import os\n", "import openai\n", "from semantic_router.utils.logger import logger\n", "\n", "\n", "# Docs # https://platform.openai.com/docs/guides/function-calling\n", "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n", " try:\n", " logger.info(f\"Calling {model} model\")\n", " response = openai.chat.completions.create(\n", " model=model,\n", " messages=[\n", " {\"role\": \"system\", \"content\": f\"{prompt}\"},\n", " ],\n", " )\n", " ai_message = response.choices[0].message.content\n", " if not ai_message:\n", " raise Exception(\"AI message is empty\", ai_message)\n", " logger.info(f\"AI message: {ai_message}\")\n", " return ai_message\n", " except Exception as e:\n", " raise Exception(\"Failed to call OpenAI API\", e)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "# Mistral\n", "import os\n", "import requests\n", "\n", "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n", "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n", "\n", "\n", "def llm_mistral(prompt: str) -> str:\n", " api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n", " headers = {\n", " \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n", " \"Content-Type\": \"application/json\",\n", " }\n", "\n", " logger.info(\"Calling Mistral model\")\n", " response = requests.post(\n", " api_url,\n", " headers=headers,\n", " json={\n", " \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n", " \"parameters\": {\n", " \"max_new_tokens\": 200,\n", " \"temperature\": 0.1,\n", " },\n", " },\n", " )\n", " if response.status_code != 200:\n", " raise Exception(\"Failed to call HuggingFace API\", response.text)\n", "\n", " ai_message = response.json()[0][\"generated_text\"]\n", " if not ai_message:\n", " raise Exception(\"AI message is empty\", ai_message)\n", " logger.info(f\"AI message: {ai_message}\")\n", " return ai_message" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Now we need to generate config from function schema using LLM" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import inspect\n", "from typing import Any\n", "\n", "\n", "def get_function_schema(function) -> dict[str, Any]:\n", " schema = {\n", " \"name\": function.__name__,\n", " \"description\": str(inspect.getdoc(function)),\n", " \"signature\": str(inspect.signature(function)),\n", " \"output\": str(\n", " inspect.signature(function).return_annotation,\n", " ),\n", " }\n", " return schema" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "import json\n", "\n", "from semantic_router.utils.logger import logger\n", "\n", "\n", "def generate_route(function) -> dict:\n", " logger.info(\"Generating config...\")\n", " example_schema = {\n", " \"name\": \"get_weather\",\n", " \"description\": \"Useful to get the weather in a specific location\",\n", " \"signature\": \"(location: str) -> str\",\n", " \"output\": \"<class 'str'>\",\n", " }\n", "\n", " example_config = {\n", " \"name\": \"get_weather\",\n", " \"utterances\": [\n", " \"What is the weather like in SF?\",\n", " \"What is the weather in Cyprus?\",\n", " \"weather in London?\",\n", " \"Tell me the weather in New York\",\n", " \"what is the current weather in Paris?\",\n", " ],\n", " }\n", "\n", " function_schema = get_function_schema(function)\n", "\n", " prompt = f\"\"\"\n", " You are a helpful assistant designed to output JSON.\n", " Given the following function schema\n", " {function_schema}\n", " generate a routing config with the format:\n", " {example_config}\n", "\n", " For example:\n", " Input: {example_schema}\n", " Output: {example_config}\n", "\n", " Input: {function_schema}\n", " Output:\n", " \"\"\"\n", "\n", " ai_message = llm_openai(prompt)\n", "\n", " ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", "\n", " try:\n", " route_config = json.loads(ai_message)\n", " logger.info(f\"Generated config: {route_config}\")\n", " return route_config\n", " except json.JSONDecodeError as json_error:\n", " logger.error(f\"JSON parsing error {json_error}\")\n", " print(f\"AI message: {ai_message}\")\n", " return {\"error\": \"Failed to generate config\"}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Extract function parameters using `Mistral` open-source model" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def extract_parameters(query: str, function) -> dict:\n", " logger.info(\"Extracting parameters...\")\n", " example_query = \"How is the weather in Hawaii right now in International units?\"\n", "\n", " example_schema = {\n", " \"name\": \"get_weather\",\n", " \"description\": \"Useful to get the weather in a specific location\",\n", " \"signature\": \"(location: str, degree: str) -> str\",\n", " \"output\": \"<class 'str'>\",\n", " }\n", "\n", " example_parameters = {\n", " \"location\": \"London\",\n", " \"degree\": \"Celsius\",\n", " }\n", "\n", " prompt = f\"\"\"\n", " You are a helpful assistant designed to output JSON.\n", " Given the following function schema\n", " {get_function_schema(function)}\n", " and query\n", " {query}\n", " extract the parameters values from the query, in a valid JSON format.\n", " Example:\n", " Input:\n", " query: {example_query}\n", " schema: {example_schema}\n", "\n", " Output:\n", " parameters: {example_parameters}\n", "\n", " Input:\n", " query: {query}\n", " schema: {get_function_schema(function)}\n", " Output:\n", " parameters:\n", " \"\"\"\n", "\n", " ai_message = llm_mistral(prompt)\n", "\n", " ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n", "\n", " try:\n", " parameters = json.loads(ai_message)\n", " logger.info(f\"Extracted parameters: {parameters}\")\n", " return parameters\n", " except json.JSONDecodeError as json_error:\n", " logger.error(f\"JSON parsing error {json_error}\")\n", " return {\"error\": \"Failed to extract parameters\"}" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set up the routing layer" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from semantic_router.schema import Route\n", "from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n", "from semantic_router.layer import RouteLayer\n", "from semantic_router.utils.logger import logger\n", "\n", "\n", "def create_router(routes: list[dict]) -> RouteLayer:\n", " logger.info(\"Creating route layer...\")\n", " encoder = OpenAIEncoder" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from semantic_router.schema import Route\n", "from semantic_router.encoders import CohereEncoder\n", "from semantic_router.layer import RouteLayer\n", "from semantic_router.utils.logger import logger\n", "\n", "\n", "def create_router(routes: list[dict]) -> RouteLayer:\n", " logger.info(\"Creating route layer...\")\n", " encoder = OpenAIEncoder()\n", "\n", " route_list: list[Route] = []\n", " for route in routes:\n", " if \"name\" in route and \"utterances\" in route:\n", " print(f\"Route: {route}\")\n", " route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n", " else:\n", " logger.warning(f\"Misconfigured route: {route}\")\n", "\n", " return RouteLayer(encoder=encoder, routes=route_list)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "Set up calling functions" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "from typing import Callable\n", "\n", "\n", "def call_function(function: Callable, parameters: dict[str, str]):\n", " try:\n", " return function(**parameters)\n", " except TypeError as e:\n", " logger.error(f\"Error calling function: {e}\")\n", "\n", "\n", "def call_llm(query: str):\n", " return llm_mistral(query)\n", "\n", "\n", "def call(query: str, functions: list[Callable], router: RouteLayer):\n", " function_name = router(query)\n", " if not function_name:\n", " logger.warning(\"No function found\")\n", " return call_llm(query)\n", "\n", " for function in functions:\n", " if function.__name__ == function_name:\n", " parameters = extract_parameters(query, function)\n", " print(f\"parameters: {parameters}\")\n", " return call_function(function, parameters)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "### Workflow" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_time(location: str) -> str:\n", " \"\"\"Useful to get the time in a specific location\"\"\"\n", " print(f\"Calling `get_time` function with location: {location}\")\n", " return \"get_time\"\n", "\n", "\n", "def get_news(category: str, country: str) -> str:\n", " \"\"\"Useful to get the news in a specific country\"\"\"\n", " print(\n", " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", " )\n", " return \"get_news\"\n", "\n", "\n", "# Registering functions to the router\n", "route_get_time = generate_route(get_time)\n", "route_get_news = generate_route(get_news)\n", "\n", "routes = [route_get_time, route_get_news]\n", "router = create_router(routes)\n", "\n", "# Tools\n", "tools = [get_time, get_news]" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [ "def get_time(location: str) -> str:\n", " \"\"\"Useful to get the time in a specific location\"\"\"\n", " print(f\"Calling `get_time` function with location: {location}\")\n", " return \"get_time\"\n", "\n", "\n", "def get_news(category: str, country: str) -> str:\n", " \"\"\"Useful to get the news in a specific country\"\"\"\n", " print(\n", " f\"Calling `get_news` function with category: {category} and country: {country}\"\n", " )\n", " return \"get_news\"\n", "\n", "\n", "# Registering functions to the router\n", "route_get_time = generate_route(get_time)\n", "route_get_news = generate_route(get_news)\n", "\n", "routes = [route_get_time, route_get_news]\n", "router = create_router(routes)\n", "\n", "# Tools\n", "tools = [get_time, get_news]" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n", "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n", "call(query=\"Hi!\", functions=tools, router=router)" ] } ], "metadata": { "kernelspec": { "display_name": ".venv", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.11.5" } }, "nbformat": 4, "nbformat_minor": 2 }