Skip to content
Snippets Groups Projects
Commit cd4d798b authored by Simonas's avatar Simonas
Browse files

function calling

parent 87f57abf
No related branches found
No related tags found
No related merge requests found
%% Cell type:markdown id: tags:
## Define LLMs
%% Cell type:code id: tags:
``` python
# OpenAI
import openai
from semantic_router.utils.logger import logger
# Docs # https://platform.openai.com/docs/guides/function-calling
def llm_openai(prompt: str, model: str = "gpt-4") -> str:
try:
logger.info(f"Calling {model} model")
response = openai.chat.completions.create(
model=model,
messages=[
{"role": "system", "content": f"{prompt}"},
],
)
ai_message = response.choices[0].message.content
if not ai_message:
raise Exception("AI message is empty", ai_message)
logger.info(f"AI message: {ai_message}")
return ai_message
except Exception as e:
raise Exception("Failed to call OpenAI API", e)
```
%% Cell type:code id: tags:
``` python
# Mistral
import os
import requests
# Docs https://huggingface.co/docs/transformers/main_classes/text_generation
HF_API_TOKEN = os.environ["HF_API_TOKEN"]
def llm_mistral(prompt: str) -> str:
api_url = "https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/"
headers = {
"Authorization": f"Bearer {HF_API_TOKEN}",
"Content-Type": "application/json",
}
logger.info("Calling Mistral model")
response = requests.post(
api_url,
headers=headers,
json={
"inputs": prompt,
"inputs": f"You are a helpful assistant, user query: {prompt}",
"parameters": {
"max_new_tokens": 1000,
"temperature": 0.2,
"max_new_tokens": 200,
"temperature": 0.1,
},
},
)
if response.status_code != 200:
raise Exception("Failed to call HuggingFace API", response.text)
ai_message = response.json()[0]['generated_text']
if not ai_message:
raise Exception("AI message is empty", ai_message)
logger.info(f"AI message: {ai_message}")
return ai_message
```
%% Cell type:markdown id: tags:
### Now we need to generate config from function specification using LLM
### Now we need to generate config from function schema using LLM
%% Cell type:code id: tags:
``` python
import json
import inspect
from typing import Any
from pydantic import BaseModel
from semantic_router.utils.logger import logger
def get_function_schema(function) -> dict[str, Any]:
schema = {
"name": function.__name__,
"description": str(inspect.getdoc(function)),
"signature": str(inspect.signature(function)),
"output": str(
inspect.signature(function).return_annotation,
),
}
return schema
```
def generate_config(schema: dict) -> dict:
logger.info("Generating config...")
%% Cell type:code id: tags:
class GetWeatherSchema(BaseModel):
location: str
``` python
import json
class Config:
name = "get_weather"
from semantic_router.utils.logger import logger
example_schema = GetWeatherSchema.schema()
def generate_route(function) -> dict:
logger.info("Generating config...")
example_schema = {
"name": "get_weather",
"description": "Useful to get the weather in a specific location",
"signature": "(location: str) -> str",
"output": "<class 'str'>",
}
example_config = {
"name": "get_weather",
"utterances": [
"What is the weather like in SF?",
"What is the weather in Cyprus?",
"weather in London?",
"Tell me the weather in New York",
"what is the current weather in Paris?",
],
}
function_schema = get_function_schema(function)
prompt = f"""
Given the following Pydantic function schema,
generate a config ONLY in a valid JSON format.
For example:
SCHEMA: {example_schema}
CONFIG: {example_config}
You are a helpful assistant designed to output JSON.
Given the following function schema
{function_schema}
generate a routing config with the format:
{example_config}
For example:
Input: {example_schema}
Output: {example_config}
GIVEN SCHEMA: {schema}
GENERATED CONFIG: <generated_response_in_json>
Input: {function_schema}
Output:
"""
ai_message = llm_openai(prompt)
print(f"AI message: {ai_message}")
# Parsing for Mistral model
ai_message = ai_message.replace("CONFIG:", "").replace("'", '"').strip()
ai_message = ai_message.replace("CONFIG:", "").replace("'", '"').strip().rstrip(",")
try:
route_config = json.loads(ai_message)
logger.info(f"Generated config: {route_config}")
return route_config
except json.JSONDecodeError as json_error:
logger.error(f"JSON parsing error {json_error}")
print(f"AI message: {ai_message}")
return {}
return {"error": "Failed to generate config"}
```
%% Cell type:markdown id: tags:
Extract function parameters using `Mistral` open-source model
%% Cell type:code id: tags:
``` python
def extract_parameters(query: str, schema: dict) -> dict:
def extract_parameters(query: str, function) -> dict:
logger.info("Extracting parameters...")
example_query = "what is the weather in London?"
class GetWeatherSchema(BaseModel):
location: str
example_query = "How is the weather in Hawaii right now in International units?"
class Config:
name = "get_weather"
example_schema = GetWeatherSchema.schema()
example_schema = {
"name": "get_weather",
"description": "Useful to get the weather in a specific location",
"signature": "(location: str, degree: str) -> str",
"output": "<class 'str'>",
}
example_parameters = {
"location": "London",
"degree": "Celsius",
}
prompt = f"""
Given the following function schema and query, extract the parameters from the
query, in a valid JSON format.
Example:
SCHEMA:
{example_schema}
QUERY:
{example_query}
PARAMETERS:
{example_parameters}
GIVEN SCHEMA:
{schema}
GIVEN QUERY:
You are a helpful assistant designed to output JSON.
Given the following function schema
{get_function_schema(function)}
and query
{query}
EXTRACTED PARAMETERS:
extract the parameters values from the query, in a valid JSON format.
Example:
Input:
query: {example_query}
schema: {example_schema}
Output:
parameters: {example_parameters}
Input:
query: {query}
schema: {get_function_schema(function)}
Output:
parameters:
"""
ai_message = llm_openai(prompt)
ai_message = llm_mistral(prompt)
ai_message = ai_message.replace("'", '"').strip()
ai_message = ai_message.replace("CONFIG:", "").replace("'", '"').strip().rstrip(",")
try:
parameters = json.loads(ai_message)
logger.info(f"Extracted parameters: {parameters}")
return parameters
except json.JSONDecodeError as json_error:
logger.error(f"JSON parsing error {json_error}")
return {}
return {"error": "Failed to extract parameters"}
```
%% Cell type:markdown id: tags:
Set up the routing layer
%% Cell type:code id: tags:
``` python
from semantic_router.schema import Route
from semantic_router.encoders import CohereEncoder
from semantic_router.layer import RouteLayer
from semantic_router.utils.logger import logger
def get_route_layer(config: list[dict]) -> RouteLayer:
logger.info("Getting route layer...")
def create_router(routes: list[dict]) -> RouteLayer:
logger.info("Creating route layer...")
encoder = CohereEncoder()
routes = []
print(f"Config: {config}")
for route in config:
route_list: list[Route] = []
for route in routes:
if "name" in route and "utterances" in route:
print(f"Route: {route}")
routes.append(Route(name=route["name"], utterances=route["utterances"]))
route_list.append(Route(name=route["name"], utterances=route["utterances"]))
else:
logger.warning(f"Misconfigured route: {route}")
return RouteLayer(encoder=encoder, routes=routes)
return RouteLayer(encoder=encoder, routes=route_list)
```
%% Cell type:markdown id: tags:
### Workflow
Set up calling functions
%% Cell type:code id: tags:
``` python
from pydantic import BaseModel
from typing import Callable
class GetTimeSchema(BaseModel):
location: str
def call_function(function: Callable, parameters: dict[str, str]):
try:
return function(**parameters)
except TypeError as e:
logger.error(f"Error calling function: {e}")
def call_llm(query: str):
return llm_mistral(query)
def call(query: str, functions: list[Callable], router: RouteLayer):
function_name = router(query)
if not function_name:
logger.warning("No function found")
return call_llm(query)
for function in functions:
if function.__name__ == function_name:
parameters = extract_parameters(query, function)
print(f"parameters: {parameters}")
return call_function(function, parameters)
```
class Config:
name = "get_time"
%% Cell type:markdown id: tags:
get_time_schema = GetTimeSchema.schema()
### Workflow
def get_time(location: str) -> str:
# Validate parameters
GetTimeSchema(location=location)
%% Cell type:code id: tags:
print(f"Calling get_time function with location: {location}")
``` python
def get_time(location: str) -> str:
"""Useful to get the time in a specific location"""
print(f"Calling `get_time` function with location: {location}")
return "get_time"
route_config = generate_config(get_time_schema)
route_layer = get_route_layer([route_config])
queries = [
"What is the weather like in Barcelona?",
"What time is it in Taiwan?",
"What is happening in the world?",
"what is the time in Kaunas?",
"Im bored",
"I want to play a game",
"Banana",
]
# Calling functions
for query in queries:
function_name = route_layer(query)
print(function_name, query)
if function_name == "get_time":
function_parameters = extract_parameters(query, get_time_schema)
try:
# Call the function
get_time(**function_parameters)
except ValueError as e:
logger.error(f"Error: {e}")
```
def get_news(category: str, country: str) -> str:
"""Useful to get the news in a specific country"""
print(
f"Calling `get_news` function with category: {category} and country: {country}"
)
return "get_news"
%% Output
2023-12-14 17:28:22 INFO semantic_router.utils.logger Generating config...
2023-12-14 17:28:28 INFO semantic_router.utils.logger AI message: {"name": "get_time", "utterances": ["What is the time in SF?", "What is the current time in London?", "Time in Tokyo?", "Tell me the time in New York", "What is the time now in Paris?"]}
2023-12-14 17:28:28 INFO semantic_router.utils.logger Generated config: {'name': 'get_time', 'utterances': ['What is the time in SF?', 'What is the current time in London?', 'Time in Tokyo?', 'Tell me the time in New York', 'What is the time now in Paris?']}
2023-12-14 17:28:28 INFO semantic_router.utils.logger Getting route layer...
AI message: {"name": "get_time", "utterances": ["What is the time in SF?", "What is the current time in London?", "Time in Tokyo?", "Tell me the time in New York", "What is the time now in Paris?"]}
Config: [{'name': 'get_time', 'utterances': ['What is the time in SF?', 'What is the current time in London?', 'Time in Tokyo?', 'Tell me the time in New York', 'What is the time now in Paris?']}]
Route: {'name': 'get_time', 'utterances': ['What is the time in SF?', 'What is the current time in London?', 'Time in Tokyo?', 'Tell me the time in New York', 'What is the time now in Paris?']}
None What is the weather like in Barcelona?
# Registering functions to the router
route_get_time = generate_route(get_time)
route_get_news = generate_route(get_news)
2023-12-14 17:28:29 INFO semantic_router.utils.logger Extracting parameters...
routes = [route_get_time, route_get_news]
router = create_router(routes)
get_time What time is it in Taiwan?
# Tools
tools = [get_time, get_news]
```
%% Cell type:code id: tags:
``` python
call(query="What is the time in Stockholm?",
functions=tools,
router=router)
call(
query="What is the tech news in the Lithuania?",
functions=tools,
router=router)
call(
query="Hi!",
functions=tools,
router=router)
```
%% Output
2023-12-15 11:41:54 INFO semantic_router.utils.logger Extracting parameters...
2023-12-15 11:41:54 INFO semantic_router.utils.logger Calling Mistral model
2023-12-15 11:41:55 INFO semantic_router.utils.logger AI message:
{
'location': 'Stockholm'
}
2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracted parameters: {'location': 'Stockholm'}
parameters: {'location': 'Stockholm'}
Calling `get_time` function with location: Stockholm
2023-12-15 11:41:55 INFO semantic_router.utils.logger Extracting parameters...
2023-12-15 11:41:55 INFO semantic_router.utils.logger Calling Mistral model
2023-12-15 11:41:56 INFO semantic_router.utils.logger AI message:
{
'category': 'tech',
'country': 'Lithuania'
}
2023-12-15 11:41:56 INFO semantic_router.utils.logger Extracted parameters: {'category': 'tech', 'country': 'Lithuania'}
parameters: {'category': 'tech', 'country': 'Lithuania'}
Calling `get_news` function with category: tech and country: Lithuania
2023-12-15 11:41:57 WARNING semantic_router.utils.logger No function found
2023-12-15 11:41:57 INFO semantic_router.utils.logger Calling Mistral model
2023-12-15 11:41:57 INFO semantic_router.utils.logger AI message: How can I help you today?
' How can I help you today?'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment