Skip to content
Snippets Groups Projects
Commit 74acb782 authored by Luca Mannini's avatar Luca Mannini
Browse files

format

parent 3ab633d0
Branches
Tags
No related merge requests found
%% Cell type:markdown id: tags:
## Define LLMs
%% Cell type:code id: tags:
``` python
%reload_ext dotenv
%dotenv
```
%% Cell type:code id: tags:
``` python
# OpenAI
import os
import openai
from semantic_router.utils.logger import logger
# Docs # https://platform.openai.com/docs/guides/function-calling
def llm_openai(prompt: str, model: str = "gpt-4") -> str:
try:
logger.info(f"Calling {model} model")
response = openai.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": f"{prompt}"},
],
)
ai_message = response.choices[0].message.content
if not ai_message:
raise Exception("AI message is empty", ai_message)
logger.info(f"AI message: {ai_message}")
return ai_message
except Exception as e:
raise Exception("Failed to call OpenAI API", e)
```
%% Cell type:code id: tags:
``` python
# Mistral
import os
import requests
# Docs https://huggingface.co/docs/transformers/main_classes/text_generation
HF_API_TOKEN = os.getenv("HF_API_TOKEN")
def llm_mistral(prompt: str) -> str:
api_url = "https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/"
headers = {
"Authorization": f"Bearer {HF_API_TOKEN}",
"Content-Type": "application/json",
}
logger.info("Calling Mistral model")
response = requests.post(
api_url,
headers=headers,
json={
"inputs": f"You are a helpful assistant, user query: {prompt}",
"parameters": {
"max_new_tokens": 200,
"temperature": 0.1,
},
},
)
if response.status_code != 200:
raise Exception("Failed to call HuggingFace API", response.text)
ai_message = response.json()[0]["generated_text"]
if not ai_message:
raise Exception("AI message is empty", ai_message)
logger.info(f"AI message: {ai_message}")
return ai_message
```
%% Cell type:markdown id: tags:
### Now we need to generate config from function schema using LLM
%% Cell type:code id: tags:
``` python
import inspect
from typing import Any
def get_function_schema(function) -> dict[str, Any]:
schema = {
"name": function.__name__,
"description": str(inspect.getdoc(function)),
"signature": str(inspect.signature(function)),
"output": str(
inspect.signature(function).return_annotation,
),
}
return schema
```
%% Cell type:code id: tags:
``` python
import json
from semantic_router.utils.logger import logger
def generate_route(function) -> dict:
logger.info("Generating config...")
example_schema = {
"name": "get_weather",
"description": "Useful to get the weather in a specific location",
"signature": "(location: str) -> str",
"output": "<class 'str'>",
}
example_config = {
"name": "get_weather",
"utterances": [
"What is the weather like in SF?",
"What is the weather in Cyprus?",
"weather in London?",
"Tell me the weather in New York",
"what is the current weather in Paris?",
],
}
function_schema = get_function_schema(function)
prompt = f"""
You are a helpful assistant designed to output JSON.
Given the following function schema
{function_schema}
generate a routing config with the format:
{example_config}
For example:
Input: {example_schema}
Output: {example_config}
Input: {function_schema}
Output:
"""
ai_message = llm_openai(prompt)
ai_message = ai_message.replace("CONFIG:", "").replace("'", '"').strip().rstrip(",")
try:
route_config = json.loads(ai_message)
logger.info(f"Generated config: {route_config}")
return route_config
except json.JSONDecodeError as json_error:
logger.error(f"JSON parsing error {json_error}")
print(f"AI message: {ai_message}")
return {"error": "Failed to generate config"}
```
%% Cell type:markdown id: tags:
Extract function parameters using `Mistral` open-source model
%% Cell type:code id: tags:
``` python
def extract_parameters(query: str, function) -> dict:
logger.info("Extracting parameters...")
example_query = "How is the weather in Hawaii right now in International units?"
example_schema = {
"name": "get_weather",
"description": "Useful to get the weather in a specific location",
"signature": "(location: str, degree: str) -> str",
"output": "<class 'str'>",
}
example_parameters = {
"location": "London",
"degree": "Celsius",
}
prompt = f"""
You are a helpful assistant designed to output JSON.
Given the following function schema
{get_function_schema(function)}
and query
{query}
extract the parameters values from the query, in a valid JSON format.
Example:
Input:
query: {example_query}
schema: {example_schema}
Output:
parameters: {example_parameters}
Input:
query: {query}
schema: {get_function_schema(function)}
Output:
parameters:
"""
ai_message = llm_mistral(prompt)
ai_message = ai_message.replace("CONFIG:", "").replace("'", '"').strip().rstrip(",")
try:
parameters = json.loads(ai_message)
logger.info(f"Extracted parameters: {parameters}")
return parameters
except json.JSONDecodeError as json_error:
logger.error(f"JSON parsing error {json_error}")
return {"error": "Failed to extract parameters"}
```
%% Cell type:markdown id: tags:
Set up the routing layer
%% Cell type:code id: tags:
``` python
from semantic_router.schema import Route
from semantic_router.encoders import CohereEncoder, OpenAIEncoder
from semantic_router.layer import RouteLayer
from semantic_router.utils.logger import logger
def create_router(routes: list[dict]) -> RouteLayer:
logger.info("Creating route layer...")
encoder = OpenAIEncoder
```
%% Cell type:code id: tags:
``` python
from semantic_router.schema import Route
from semantic_router.encoders import CohereEncoder
from semantic_router.layer import RouteLayer
from semantic_router.utils.logger import logger
def create_router(routes: list[dict]) -> RouteLayer:
logger.info("Creating route layer...")
encoder = OpenAIEncoder()
route_list: list[Route] = []
for route in routes:
if "name" in route and "utterances" in route:
print(f"Route: {route}")
route_list.append(Route(name=route["name"], utterances=route["utterances"]))
else:
logger.warning(f"Misconfigured route: {route}")
return RouteLayer(encoder=encoder, routes=route_list)
```
%% Cell type:markdown id: tags:
Set up calling functions
%% Cell type:code id: tags:
``` python
from typing import Callable
def call_function(function: Callable, parameters: dict[str, str]):
try:
return function(**parameters)
except TypeError as e:
logger.error(f"Error calling function: {e}")
def call_llm(query: str):
return llm_mistral(query)
def call(query: str, functions: list[Callable], router: RouteLayer):
function_name = router(query)
if not function_name:
logger.warning("No function found")
return call_llm(query)
for function in functions:
if function.__name__ == function_name:
parameters = extract_parameters(query, function)
print(f"parameters: {parameters}")
return call_function(function, parameters)
```
%% Cell type:markdown id: tags:
### Workflow
%% Cell type:code id: tags:
``` python
def get_time(location: str) -> str:
"""Useful to get the time in a specific location"""
print(f"Calling `get_time` function with location: {location}")
return "get_time"
def get_news(category: str, country: str) -> str:
"""Useful to get the news in a specific country"""
print(
f"Calling `get_news` function with category: {category} and country: {country}"
)
return "get_news"
# Registering functions to the router
route_get_time = generate_route(get_time)
route_get_news = generate_route(get_news)
routes = [route_get_time, route_get_news]
router = create_router(routes)
# Tools
tools = [get_time, get_news]
```
%% Cell type:code id: tags:
``` python
def get_time(location: str) -> str:
"""Useful to get the time in a specific location"""
print(f"Calling `get_time` function with location: {location}")
return "get_time"
def get_news(category: str, country: str) -> str:
"""Useful to get the news in a specific country"""
print(
f"Calling `get_news` function with category: {category} and country: {country}"
)
return "get_news"
# Registering functions to the router
route_get_time = generate_route(get_time)
route_get_news = generate_route(get_news)
routes = [route_get_time, route_get_news]
router = create_router(routes)
# Tools
tools = [get_time, get_news]
```
%% Cell type:markdown id: tags:
call(query="What is the time in Stockholm?", functions=tools, router=router)
call(query="What is the tech news in the Lithuania?", functions=tools, router=router)
call(query="Hi!", functions=tools, router=router)
......
......@@ -4,6 +4,7 @@ from time import sleep
import openai
from openai import OpenAIError
from openai.types import CreateEmbeddingResponse
from semantic_router.encoders import BaseEncoder
from semantic_router.utils.logger import logger
......
......@@ -71,7 +71,7 @@ class TestOpenAIEncoder:
)
with pytest.raises(ValueError) as e:
openai_encoder(["test document"])
assert "OpenAI API call failed. Error: Non-OpenAIError" in str(e.value)
def test_openai_encoder_call_successful_retry(self, openai_encoder, mocker):
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment