{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Define LLMs"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "%reload_ext dotenv\n",
    "%dotenv"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# OpenAI\n",
    "import os\n",
    "import openai\n",
    "from semantic_router.utils.logger import logger\n",
    "\n",
    "\n",
    "# Docs # https://platform.openai.com/docs/guides/function-calling\n",
    "def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n",
    "    try:\n",
    "        logger.info(f\"Calling {model} model\")\n",
    "        response = openai.chat.completions.create(\n",
    "            model=model,\n",
    "            messages=[\n",
    "                {\"role\": \"system\", \"content\": f\"{prompt}\"},\n",
    "            ],\n",
    "        )\n",
    "        ai_message = response.choices[0].message.content\n",
    "        if not ai_message:\n",
    "            raise Exception(\"AI message is empty\", ai_message)\n",
    "        logger.info(f\"AI message: {ai_message}\")\n",
    "        return ai_message\n",
    "    except Exception as e:\n",
    "        raise Exception(\"Failed to call OpenAI API\", e)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Mistral\n",
    "import os\n",
    "import requests\n",
    "\n",
    "# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n",
    "HF_API_TOKEN = os.getenv(\"HF_API_TOKEN\")\n",
    "\n",
    "\n",
    "def llm_mistral(prompt: str) -> str:\n",
    "    api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n",
    "    headers = {\n",
    "        \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n",
    "        \"Content-Type\": \"application/json\",\n",
    "    }\n",
    "\n",
    "    logger.info(\"Calling Mistral model\")\n",
    "    response = requests.post(\n",
    "        api_url,\n",
    "        headers=headers,\n",
    "        json={\n",
    "            \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n",
    "            \"parameters\": {\n",
    "                \"max_new_tokens\": 200,\n",
    "                \"temperature\": 0.1,\n",
    "            },\n",
    "        },\n",
    "    )\n",
    "    if response.status_code != 200:\n",
    "        raise Exception(\"Failed to call HuggingFace API\", response.text)\n",
    "\n",
    "    ai_message = response.json()[0][\"generated_text\"]\n",
    "    if not ai_message:\n",
    "        raise Exception(\"AI message is empty\", ai_message)\n",
    "    logger.info(f\"AI message: {ai_message}\")\n",
    "    return ai_message"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Now we need to generate config from function schema using LLM"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import inspect\n",
    "from typing import Any\n",
    "\n",
    "\n",
    "def get_function_schema(function) -> dict[str, Any]:\n",
    "    schema = {\n",
    "        \"name\": function.__name__,\n",
    "        \"description\": str(inspect.getdoc(function)),\n",
    "        \"signature\": str(inspect.signature(function)),\n",
    "        \"output\": str(\n",
    "            inspect.signature(function).return_annotation,\n",
    "        ),\n",
    "    }\n",
    "    return schema"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import json\n",
    "\n",
    "from semantic_router.utils.logger import logger\n",
    "\n",
    "\n",
    "def generate_route(function) -> dict:\n",
    "    logger.info(\"Generating config...\")\n",
    "    example_schema = {\n",
    "        \"name\": \"get_weather\",\n",
    "        \"description\": \"Useful to get the weather in a specific location\",\n",
    "        \"signature\": \"(location: str) -> str\",\n",
    "        \"output\": \"<class 'str'>\",\n",
    "    }\n",
    "\n",
    "    example_config = {\n",
    "        \"name\": \"get_weather\",\n",
    "        \"utterances\": [\n",
    "            \"What is the weather like in SF?\",\n",
    "            \"What is the weather in Cyprus?\",\n",
    "            \"weather in London?\",\n",
    "            \"Tell me the weather in New York\",\n",
    "            \"what is the current weather in Paris?\",\n",
    "        ],\n",
    "    }\n",
    "\n",
    "    function_schema = get_function_schema(function)\n",
    "\n",
    "    prompt = f\"\"\"\n",
    "    You are a helpful assistant designed to output JSON.\n",
    "    Given the following function schema\n",
    "    {function_schema}\n",
    "    generate a routing config with the format:\n",
    "    {example_config}\n",
    "\n",
    "    For example:\n",
    "    Input: {example_schema}\n",
    "    Output: {example_config}\n",
    "\n",
    "    Input: {function_schema}\n",
    "    Output:\n",
    "    \"\"\"\n",
    "\n",
    "    ai_message = llm_openai(prompt)\n",
    "\n",
    "    ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n",
    "\n",
    "    try:\n",
    "        route_config = json.loads(ai_message)\n",
    "        logger.info(f\"Generated config: {route_config}\")\n",
    "        return route_config\n",
    "    except json.JSONDecodeError as json_error:\n",
    "        logger.error(f\"JSON parsing error {json_error}\")\n",
    "        print(f\"AI message: {ai_message}\")\n",
    "        return {\"error\": \"Failed to generate config\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Extract function parameters using `Mistral` open-source model"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def extract_parameters(query: str, function) -> dict:\n",
    "    logger.info(\"Extracting parameters...\")\n",
    "    example_query = \"How is the weather in Hawaii right now in International units?\"\n",
    "\n",
    "    example_schema = {\n",
    "        \"name\": \"get_weather\",\n",
    "        \"description\": \"Useful to get the weather in a specific location\",\n",
    "        \"signature\": \"(location: str, degree: str) -> str\",\n",
    "        \"output\": \"<class 'str'>\",\n",
    "    }\n",
    "\n",
    "    example_parameters = {\n",
    "        \"location\": \"London\",\n",
    "        \"degree\": \"Celsius\",\n",
    "    }\n",
    "\n",
    "    prompt = f\"\"\"\n",
    "    You are a helpful assistant designed to output JSON.\n",
    "    Given the following function schema\n",
    "    {get_function_schema(function)}\n",
    "    and query\n",
    "    {query}\n",
    "    extract the parameters values from the query, in a valid JSON format.\n",
    "    Example:\n",
    "    Input:\n",
    "    query: {example_query}\n",
    "    schema: {example_schema}\n",
    "\n",
    "    Output:\n",
    "    parameters: {example_parameters}\n",
    "\n",
    "    Input:\n",
    "    query: {query}\n",
    "    schema: {get_function_schema(function)}\n",
    "    Output:\n",
    "    parameters:\n",
    "    \"\"\"\n",
    "\n",
    "    ai_message = llm_mistral(prompt)\n",
    "\n",
    "    ai_message = ai_message.replace(\"CONFIG:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n",
    "\n",
    "    try:\n",
    "        parameters = json.loads(ai_message)\n",
    "        logger.info(f\"Extracted parameters: {parameters}\")\n",
    "        return parameters\n",
    "    except json.JSONDecodeError as json_error:\n",
    "        logger.error(f\"JSON parsing error {json_error}\")\n",
    "        return {\"error\": \"Failed to extract parameters\"}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up the routing layer"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semantic_router.schema import Route\n",
    "from semantic_router.encoders import CohereEncoder, OpenAIEncoder\n",
    "from semantic_router.layer import RouteLayer\n",
    "from semantic_router.utils.logger import logger\n",
    "\n",
    "\n",
    "def create_router(routes: list[dict]) -> RouteLayer:\n",
    "    logger.info(\"Creating route layer...\")\n",
    "    encoder = OpenAIEncoder"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from semantic_router.schema import Route\n",
    "from semantic_router.encoders import CohereEncoder\n",
    "from semantic_router.layer import RouteLayer\n",
    "from semantic_router.utils.logger import logger\n",
    "\n",
    "\n",
    "def create_router(routes: list[dict]) -> RouteLayer:\n",
    "    logger.info(\"Creating route layer...\")\n",
    "    encoder = OpenAIEncoder()\n",
    "\n",
    "    route_list: list[Route] = []\n",
    "    for route in routes:\n",
    "        if \"name\" in route and \"utterances\" in route:\n",
    "            print(f\"Route: {route}\")\n",
    "            route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n",
    "        else:\n",
    "            logger.warning(f\"Misconfigured route: {route}\")\n",
    "\n",
    "    return RouteLayer(encoder=encoder, routes=route_list)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "Set up calling functions"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "from typing import Callable\n",
    "\n",
    "\n",
    "def call_function(function: Callable, parameters: dict[str, str]):\n",
    "    try:\n",
    "        return function(**parameters)\n",
    "    except TypeError as e:\n",
    "        logger.error(f\"Error calling function: {e}\")\n",
    "\n",
    "\n",
    "def call_llm(query: str):\n",
    "    return llm_mistral(query)\n",
    "\n",
    "\n",
    "def call(query: str, functions: list[Callable], router: RouteLayer):\n",
    "    function_name = router(query)\n",
    "    if not function_name:\n",
    "        logger.warning(\"No function found\")\n",
    "        return call_llm(query)\n",
    "\n",
    "    for function in functions:\n",
    "        if function.__name__ == function_name:\n",
    "            parameters = extract_parameters(query, function)\n",
    "            print(f\"parameters: {parameters}\")\n",
    "            return call_function(function, parameters)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Workflow"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_time(location: str) -> str:\n",
    "    \"\"\"Useful to get the time in a specific location\"\"\"\n",
    "    print(f\"Calling `get_time` function with location: {location}\")\n",
    "    return \"get_time\"\n",
    "\n",
    "\n",
    "def get_news(category: str, country: str) -> str:\n",
    "    \"\"\"Useful to get the news in a specific country\"\"\"\n",
    "    print(\n",
    "        f\"Calling `get_news` function with category: {category} and country: {country}\"\n",
    "    )\n",
    "    return \"get_news\"\n",
    "\n",
    "\n",
    "# Registering functions to the router\n",
    "route_get_time = generate_route(get_time)\n",
    "route_get_news = generate_route(get_news)\n",
    "\n",
    "routes = [route_get_time, route_get_news]\n",
    "router = create_router(routes)\n",
    "\n",
    "# Tools\n",
    "tools = [get_time, get_news]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def get_time(location: str) -> str:\n",
    "    \"\"\"Useful to get the time in a specific location\"\"\"\n",
    "    print(f\"Calling `get_time` function with location: {location}\")\n",
    "    return \"get_time\"\n",
    "\n",
    "\n",
    "def get_news(category: str, country: str) -> str:\n",
    "    \"\"\"Useful to get the news in a specific country\"\"\"\n",
    "    print(\n",
    "        f\"Calling `get_news` function with category: {category} and country: {country}\"\n",
    "    )\n",
    "    return \"get_news\"\n",
    "\n",
    "\n",
    "# Registering functions to the router\n",
    "route_get_time = generate_route(get_time)\n",
    "route_get_news = generate_route(get_news)\n",
    "\n",
    "routes = [route_get_time, route_get_news]\n",
    "router = create_router(routes)\n",
    "\n",
    "# Tools\n",
    "tools = [get_time, get_news]"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n",
    "call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n",
    "call(query=\"Hi!\", functions=tools, router=router)"
   ]
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": ".venv",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.11.5"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 2
}