Newer
Older
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Define LLMs"
]
},
"from semantic_router.utils.logger import logger\n",
"\n",
"# Docs # https://platform.openai.com/docs/guides/function-calling\n",
"def llm_openai(prompt: str, model: str = \"gpt-4\") -> str:\n",
" try:\n",
" response = openai.chat.completions.create(\n",
" model=model,\n",
" messages=[\n",
" {\"role\": \"system\", \"content\": f\"{prompt}\"},\n",
" ],\n",
" )\n",
" ai_message = response.choices[0].message.content\n",
" if not ai_message:\n",
" raise Exception(\"AI message is empty\", ai_message)\n",
" logger.info(f\"AI message: {ai_message}\")\n",
" return ai_message\n",
" except Exception as e:\n",
" raise Exception(\"Failed to call OpenAI API\", e)"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"# Mistral\n",
"import os\n",
"import requests\n",
"\n",
"# Docs https://huggingface.co/docs/transformers/main_classes/text_generation\n",
"def llm_mistral(prompt: str) -> str:\n",
" api_url = \"https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/\"\n",
" headers = {\n",
" \"Authorization\": f\"Bearer {HF_API_TOKEN}\",\n",
" \"Content-Type\": \"application/json\",\n",
" }\n",
"\n",
" response = requests.post(\n",
" api_url,\n",
" headers=headers,\n",
" json={\n",
" \"inputs\": f\"You are a helpful assistant, user query: {prompt}\",\n",
" \"temperature\": 0.01,\n",
" \"num_beams\": 5,\n",
" \"num_return_sequences\": 1,\n",
" },\n",
" },\n",
" )\n",
" if response.status_code != 200:\n",
" raise Exception(\"Failed to call HuggingFace API\", response.text)\n",
"\n",
" logger.info(f\"AI message: {ai_message}\")\n",
" return ai_message"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"### Now we need to generate config from function schema using LLM"
"metadata": {},
"outputs": [],
"source": [
"import inspect\n",
"from typing import Any\n",
"\n",
"def get_function_schema(function) -> dict[str, Any]:\n",
" schema = {\n",
" \"name\": function.__name__,\n",
" \"description\": str(inspect.getdoc(function)),\n",
" \"signature\": str(inspect.signature(function)),\n",
" \"output\": str(\n",
" inspect.signature(function).return_annotation,\n",
" ),\n",
" }\n",
" return schema"
]
},
{
"cell_type": "code",
"execution_count": 15,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"\n",
"\n",
"def is_valid_config(route_config_str: str) -> bool:\n",
" try:\n",
" output_json = json.loads(route_config_str)\n",
" return all(key in output_json for key in [\"name\", \"utterances\"])\n",
" except json.JSONDecodeError:\n",
" return False"
]
},
{
"cell_type": "code",
"execution_count": 16,
"metadata": {},
"outputs": [],
"source": [
"import json\n",
"from semantic_router.utils.logger import logger\n",
" function_schema = get_function_schema(function)\n",
"\n",
" You are tasked to generate a JSON configuration based on the provided\n",
" function schema. Please follow the template below:\n",
" {{\n",
" \"name\": \"<function_name>\",\n",
" \"utterances\": [\n",
" \"<example_utterance_1>\",\n",
" \"<example_utterance_2>\",\n",
" \"<example_utterance_3>\",\n",
" \"<example_utterance_4>\",\n",
" \"<example_utterance_5>\"]\n",
" }}\n",
"\n",
" Only include the \"name\" and \"utterances\" keys in your answer.\n",
" The \"name\" should match the function name and the \"utterances\"\n",
" should comprise a list of 5 example phrases that could be used to invoke\n",
" the function.\n",
"\n",
" Input schema:\n",
" {function_schema}\n",
" try:\n",
" ai_message = llm_mistral(prompt)\n",
"\n",
" # Parse the response\n",
" ai_message = ai_message[ai_message.find(\"{\") :]\n",
" ai_message = (\n",
" ai_message.replace(\"'\", '\"')\n",
" .replace('\"s', \"'s\")\n",
" .strip()\n",
" .rstrip(\",\")\n",
" .replace(\"}\", \"}\")\n",
" )\n",
"\n",
" valid_config = is_valid_config(ai_message)\n",
" if not valid_config:\n",
" logger.warning(f\"Mistral failed with error, falling back to OpenAI\")\n",
" ai_message = llm_openai(prompt)\n",
" if not is_valid_config(ai_message):\n",
" raise Exception(\"Invalid config generated\")\n",
" except Exception as e:\n",
" logger.error(f\"Fall back to OpenAI failed with error {e}\")\n",
" ai_message = llm_openai(prompt)\n",
" if not is_valid_config(ai_message):\n",
" raise Exception(\"Failed to generate config\")\n",
" try:\n",
" route_config = json.loads(ai_message)\n",
" logger.info(f\"Generated config: {route_config}\")\n",
" except json.JSONDecodeError as json_error:\n",
" logger.error(f\"JSON parsing error {json_error}\")\n",
" print(f\"AI message: {ai_message}\")\n",
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Extract function parameters using `Mistral` open-source model"
"execution_count": 17,
"metadata": {},
"outputs": [],
"source": [
"def validate_parameters(function, parameters):\n",
" sig = inspect.signature(function)\n",
" for name, param in sig.parameters.items():\n",
" if name not in parameters:\n",
" return False, f\"Parameter {name} missing from query\"\n",
" if not isinstance(parameters[name], param.annotation):\n",
" return False, f\"Parameter {name} is not of type {param.annotation}\"\n",
" return True, \"Parameters are valid\""
]
},
{
"cell_type": "code",
"execution_count": 18,
"metadata": {},
"outputs": [],
"source": [
" logger.info(\"Extracting parameters...\")\n",
" example_query = \"How is the weather in Hawaii right now in International units?\"\n",
" example_schema = {\n",
" \"name\": \"get_weather\",\n",
" \"description\": \"Useful to get the weather in a specific location\",\n",
" \"signature\": \"(location: str, degree: str) -> str\",\n",
" \"output\": \"<class 'str'>\",\n",
" }\n",
"\n",
" example_parameters = {\n",
" \"location\": \"London\",\n",
" You are a helpful assistant designed to output JSON.\n",
" Given the following function schema\n",
" extract the parameters values from the query, in a valid JSON format.\n",
" Example:\n",
" Input:\n",
" query: {example_query}\n",
" schema: {example_schema}\n",
"\n",
"\n",
" Input:\n",
" query: {query}\n",
" schema: {get_function_schema(function)}\n",
" try:\n",
" ai_message = llm_mistral(prompt)\n",
" ai_message = (\n",
" ai_message.replace(\"Output:\", \"\").replace(\"'\", '\"').strip().rstrip(\",\")\n",
" )\n",
" except Exception as e:\n",
" logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n",
" ai_message = llm_openai(prompt)\n",
" try:\n",
" parameters = json.loads(ai_message)\n",
" valid, message = validate_parameters(function, parameters)\n",
"\n",
" if not valid:\n",
" logger.warning(\n",
" f\"Invalid parameters from Mistral, falling back to OpenAI: {message}\"\n",
" )\n",
" # Fall back to OpenAI\n",
" ai_message = llm_openai(prompt)\n",
" parameters = json.loads(ai_message)\n",
" valid, message = validate_parameters(function, parameters)\n",
" if not valid:\n",
" raise ValueError(message)\n",
"\n",
" logger.info(f\"Extracted parameters: {parameters}\")\n",
" return parameters\n",
" except ValueError as e:\n",
" logger.error(f\"Parameter validation error: {str(e)}\")\n",
" return {\"error\": \"Failed to validate parameters\"}"
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up the routing layer"
]
},
"from semantic_router.schema import Route\n",
"from semantic_router.encoders import CohereEncoder\n",
"from semantic_router.layer import RouteLayer\n",
"from semantic_router.utils.logger import logger\n",
"def create_router(routes: list[dict]) -> RouteLayer:\n",
" logger.info(\"Creating route layer...\")\n",
" route_list: list[Route] = []\n",
" for route in routes:\n",
" if \"name\" in route and \"utterances\" in route:\n",
" print(f\"Route: {route}\")\n",
" route_list.append(Route(name=route[\"name\"], utterances=route[\"utterances\"]))\n",
" else:\n",
" logger.warning(f\"Misconfigured route: {route}\")\n",
"\n",
" return RouteLayer(encoder=encoder, routes=route_list)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Set up calling functions"
]
},
{
"cell_type": "code",
"metadata": {},
"outputs": [],
"source": [
"from typing import Callable\n",
"def call_function(function: Callable, parameters: dict[str, str]):\n",
" try:\n",
" return function(**parameters)\n",
" except TypeError as e:\n",
" logger.error(f\"Error calling function: {e}\")\n",
"\n",
"\n",
"def call_llm(query: str) -> str:\n",
" try:\n",
" ai_message = llm_mistral(query)\n",
" except Exception as e:\n",
" logger.error(f\"Mistral failed with error {e}, falling back to OpenAI\")\n",
" ai_message = llm_openai(query)\n",
"\n",
" return ai_message\n",
"\n",
"\n",
"def call(query: str, functions: list[Callable], router: RouteLayer):\n",
" function_name = router(query)\n",
" if not function_name:\n",
" logger.warning(\"No function found\")\n",
" return call_llm(query)\n",
"\n",
" for function in functions:\n",
" if function.__name__ == function_name:\n",
" parameters = extract_parameters(query, function)\n",
" print(f\"parameters: {parameters}\")\n",
" return call_function(function, parameters)"
394
395
396
397
398
399
400
401
402
403
404
405
406
407
408
409
410
411
412
413
414
415
416
417
418
419
420
421
422
423
424
425
426
427
428
429
430
431
432
433
434
435
436
437
438
439
440
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
"\u001b[32m2023-12-18 12:17:58 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[31m2023-12-18 12:18:00 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n",
"\u001b[32m2023-12-18 12:18:00 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n",
"\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger AI message: {\n",
" \"name\": \"get_time\",\n",
" \"utterances\": [\n",
" \"what is the time in new york\",\n",
" \"can you tell me the time in london\",\n",
" \"get me the current time in tokyo\",\n",
" \"i need to know the time in sydney\",\n",
" \"please tell me the current time in paris\"\n",
" ]\n",
"}\u001b[0m\n",
"\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\u001b[0m\n",
"\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
"\u001b[32m2023-12-18 12:18:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[31m2023-12-18 12:18:07 ERROR semantic_router.utils.logger Fall back to OpenAI failed with error ('Failed to call HuggingFace API', '{\"error\":\"Bad Gateway\"}')\u001b[0m\n",
"\u001b[32m2023-12-18 12:18:07 INFO semantic_router.utils.logger Calling gpt-4 model\u001b[0m\n",
"\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger AI message: {\n",
" \"name\": \"get_news\",\n",
" \"utterances\": [\n",
" \"Can I get the latest news in Canada?\",\n",
" \"Show me the recent news in the US\",\n",
" \"I would like to know about the sports news in England\",\n",
" \"Let's check the technology news in Japan\",\n",
" \"Show me the health related news in Germany\"\n",
" ]\n",
"}\u001b[0m\n",
"\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\u001b[0m\n",
"\u001b[32m2023-12-18 12:18:12 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Route: {'name': 'get_time', 'utterances': ['what is the time in new york', 'can you tell me the time in london', 'get me the current time in tokyo', 'i need to know the time in sydney', 'please tell me the current time in paris']}\n",
"Route: {'name': 'get_news', 'utterances': ['Can I get the latest news in Canada?', 'Show me the recent news in the US', 'I would like to know about the sports news in England', \"Let's check the technology news in Japan\", 'Show me the health related news in Germany']}\n"
]
}
],
441
442
443
444
445
446
447
448
449
450
451
452
453
454
455
456
457
458
459
460
461
462
463
464
465
466
467
468
"source": [
"def get_time(location: str) -> str:\n",
" \"\"\"Useful to get the time in a specific location\"\"\"\n",
" print(f\"Calling `get_time` function with location: {location}\")\n",
" return \"get_time\"\n",
"\n",
"\n",
"def get_news(category: str, country: str) -> str:\n",
" \"\"\"Useful to get the news in a specific country\"\"\"\n",
" print(\n",
" f\"Calling `get_news` function with category: {category} and country: {country}\"\n",
" )\n",
" return \"get_news\"\n",
"\n",
"\n",
"# Registering functions to the router\n",
"route_get_time = generate_route(get_time)\n",
"route_get_news = generate_route(get_news)\n",
"\n",
"routes = [route_get_time, route_get_news]\n",
"router = create_router(routes)\n",
"\n",
"# Tools\n",
"tools = [get_time, get_news]"
]
},
{
"cell_type": "code",
471
472
473
474
475
476
477
478
479
480
481
482
483
484
485
486
487
488
489
490
491
492
493
494
495
496
497
498
499
500
501
502
503
504
505
506
507
508
509
510
511
512
513
514
515
516
517
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:12 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger AI message: \n",
" Example output:\n",
" {\n",
" \"name\": \"get_time\",\n",
" \"utterances\": [\n",
" \"What's the time in New York?\",\n",
" \"Tell me the time in Tokyo.\",\n",
" \"Can you give me the time in London?\",\n",
" \"What's the current time in Sydney?\",\n",
" \"Can you tell me the time in Berlin?\"\n",
" ]\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Generating config...\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:16 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger AI message: \n",
" Example output:\n",
" {\n",
" \"name\": \"get_news\",\n",
" \"utterances\": [\n",
" \"Tell me the latest news from the US\",\n",
" \"What's happening in India today?\",\n",
" \"Get me the top stories from Japan\",\n",
" \"Can you give me the breaking news from Brazil?\",\n",
" \"What's the latest news from Germany?\"\n",
" ]\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Generated config: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:20 INFO semantic_router.utils.logger Creating route layer...\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"Route: {'name': 'get_time', 'utterances': [\"What's the time in New York?\", 'Tell me the time in Tokyo.', 'Can you give me the time in London?', \"What's the current time in Sydney?\", 'Can you tell me the time in Berlin?']}\n",
"Route: {'name': 'get_news', 'utterances': ['Tell me the latest news from the US', \"What's happening in India today?\", 'Get me the top stories from Japan', 'Can you give me the breaking news from Brazil?', \"What's the latest news from Germany?\"]}\n"
]
}
],
518
519
520
521
522
523
524
525
526
527
528
529
530
531
532
533
534
535
536
537
538
539
540
541
542
543
544
"source": [
"def get_time(location: str) -> str:\n",
" \"\"\"Useful to get the time in a specific location\"\"\"\n",
" print(f\"Calling `get_time` function with location: {location}\")\n",
" return \"get_time\"\n",
"\n",
"\n",
"def get_news(category: str, country: str) -> str:\n",
" \"\"\"Useful to get the news in a specific country\"\"\"\n",
" print(\n",
" f\"Calling `get_news` function with category: {category} and country: {country}\"\n",
" )\n",
" return \"get_news\"\n",
"\n",
"\n",
"# Registering functions to the router\n",
"route_get_time = generate_route(get_time)\n",
"route_get_news = generate_route(get_news)\n",
"\n",
"routes = [route_get_time, route_get_news]\n",
"router = create_router(routes)\n",
"\n",
"# Tools\n",
"tools = [get_time, get_news]"
]
},
{
548
549
550
551
552
553
554
555
556
557
558
559
560
561
562
563
564
565
566
567
568
569
570
571
572
573
574
575
576
577
578
579
580
581
582
583
584
585
586
587
588
589
590
591
592
593
594
595
596
597
598
599
600
601
602
603
604
605
606
607
608
609
610
611
"outputs": [
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:02 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger AI message: \n",
" {\n",
" \"location\": \"Stockholm\"\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"parameters: {'location': 'Stockholm'}\n",
"Calling `get_time` function with location: Stockholm\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Extracting parameters...\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:04 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger AI message: \n",
" {\n",
" \"category\": \"tech\",\n",
" \"country\": \"Lithuania\"\n",
" }\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}\u001b[0m\n"
]
},
{
"name": "stdout",
"output_type": "stream",
"text": [
"parameters: {'category': 'tech', 'country': 'Lithuania'}\n",
"Calling `get_news` function with category: tech and country: Lithuania\n"
]
},
{
"name": "stderr",
"output_type": "stream",
"text": [
"\u001b[33m2023-12-18 12:20:05 WARNING semantic_router.utils.logger No function found\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:05 INFO semantic_router.utils.logger Calling Mistral model\u001b[0m\n",
"\u001b[32m2023-12-18 12:20:06 INFO semantic_router.utils.logger AI message: How can I help you today?\u001b[0m\n"
]
},
{
"data": {
"text/plain": [
"' How can I help you today?'"
]
},
"execution_count": 26,
"metadata": {},
"output_type": "execute_result"
}
],
"call(query=\"What is the time in Stockholm?\", functions=tools, router=router)\n",
"call(query=\"What is the tech news in the Lithuania?\", functions=tools, router=router)\n",
"call(query=\"Hi!\", functions=tools, router=router)"
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": ".venv",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",