Skip to content
Snippets Groups Projects
Commit 36e8b721 authored by Simonas's avatar Simonas
Browse files

comment

parent a1370a5f
No related branches found
No related tags found
No related merge requests found
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# https://platform.openai.com/docs/guides/function-calling # https://platform.openai.com/docs/guides/function-calling
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
## Define LLMs ## Define LLMs
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# OpenAI # OpenAI
import openai import openai
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
def llm_openai(prompt: str, model: str = "gpt-4") -> str: def llm_openai(prompt: str, model: str = "gpt-4") -> str:
try: try:
response = openai.chat.completions.create( response = openai.chat.completions.create(
model=model, model=model,
messages=[ messages=[
{"role": "system", "content": f"{prompt}"}, {"role": "system", "content": f"{prompt}"},
], ],
) )
ai_message = response.choices[0].message.content ai_message = response.choices[0].message.content
if not ai_message: if not ai_message:
raise Exception("AI message is empty", ai_message) raise Exception("AI message is empty", ai_message)
logger.info(f"AI message: {ai_message}") logger.info(f"AI message: {ai_message}")
return ai_message return ai_message
except Exception as e: except Exception as e:
raise Exception("Failed to call OpenAI API", e) raise Exception("Failed to call OpenAI API", e)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
# Mistral # Mistral
import os import os
import requests import requests
# Docs https://huggingface.co/docs/transformers/main_classes/text_generation # Docs https://huggingface.co/docs/transformers/main_classes/text_generation
HF_API_TOKEN = os.environ["HF_API_TOKEN"] HF_API_TOKEN = os.environ["HF_API_TOKEN"]
def llm_mistral(prompt: str) -> str: def llm_mistral(prompt: str) -> str:
api_url = "https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/" api_url = "https://z5t4cuhg21uxfmc3.us-east-1.aws.endpoints.huggingface.cloud/"
headers = { headers = {
"Authorization": f"Bearer {HF_API_TOKEN}", "Authorization": f"Bearer {HF_API_TOKEN}",
"Content-Type": "application/json", "Content-Type": "application/json",
} }
response = requests.post( response = requests.post(
api_url, api_url,
headers=headers, headers=headers,
json={ json={
"inputs": prompt, "inputs": prompt,
"parameters": { "parameters": {
"max_new_tokens": 200, "max_new_tokens": 200,
"temperature": 0.2, "temperature": 0.2,
}, },
}, },
) )
if response.status_code != 200: if response.status_code != 200:
raise Exception("Failed to call HuggingFace API", response.text) raise Exception("Failed to call HuggingFace API", response.text)
return response.json()[0]['generated_text'] return response.json()[0]['generated_text']
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Now we need to generate config from function specification with `GPT-4` ### Now we need to generate config from function specification with `GPT-4`
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
import json import json
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
def generate_config(specification: dict) -> dict: def generate_config(specification: dict) -> dict:
logger.info("Generating config...") logger.info("Generating config...")
example_specification = ( example_specification = (
{ {
"type": "function", "type": "function",
"function": { "function": {
"name": "get_current_weather", "name": "get_current_weather",
"description": "Get the current weather", "description": "Get the current weather",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state, e.g. San Francisco, CA", "description": "The city and state, e.g. San Francisco, CA",
}, },
"format": { "format": {
"type": "string", "type": "string",
"enum": ["celsius", "fahrenheit"], "enum": ["celsius", "fahrenheit"],
"description": "The temperature unit to use. Infer this " "description": "The temperature unit to use. Infer this "
" from the users location.", " from the users location.",
}, },
}, },
"required": ["location", "format"], "required": ["location", "format"],
}, },
}, },
}, },
) )
example_config = { example_config = {
"name": "get_weather", "name": "get_weather",
"utterances": [ "utterances": [
"What is the weather like in SF?", "What is the weather like in SF?",
"What is the weather in Cyprus?", "What is the weather in Cyprus?",
"weather in London?", "weather in London?",
"Tell me the weather in New York", "Tell me the weather in New York",
"what is the current weather in Paris?", "what is the current weather in Paris?",
], ],
} }
prompt = f""" prompt = f"""
Given the following specification, generate a config in a valid JSON format Given the following specification, generate a config in a valid JSON format
enclosed in double quotes, enclosed in double quotes,
Example: Example:
SPECIFICATION: SPECIFICATION:
{example_specification} {example_specification}
CONFIG: CONFIG:
{example_config} {example_config}
GIVEN SPECIFICATION: GIVEN SPECIFICATION:
{specification} {specification}
GENERATED CONFIG: GENERATED CONFIG:
""" """
ai_message = llm_openai(prompt) ai_message = llm_openai(prompt)
try: try:
route_config = json.loads(ai_message) route_config = json.loads(ai_message)
function_description = specification["function"]["description"] function_description = specification["function"]["description"]
route_config["utterances"].append(function_description) route_config["utterances"].append(function_description)
logger.info(f"Generated config: {route_config}") logger.info(f"Generated config: {route_config}")
return route_config return route_config
except json.JSONDecodeError as json_error: except json.JSONDecodeError as json_error:
raise Exception("JSON parsing error", json_error) raise Exception("JSON parsing error", json_error)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Extract function parameters using `Mistal` open-source model Extract function parameters using `Mistal` open-source model
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def extract_parameters(query: str, specification: dict) -> dict: def extract_parameters(query: str, specification: dict) -> dict:
logger.info("Extracting parameters...") logger.info("Extracting parameters...")
example_query = "what is the weather in London?" example_query = "what is the weather in London?"
example_specification = { example_specification = {
"type": "function", "type": "function",
"function": { "function": {
"name": "get_time", "name": "get_time",
"description": "Get the current time", "description": "Get the current time",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "Example of city and state", "description": "Example of city and state",
}, },
}, },
"required": ["location"], "required": ["location"],
}, },
}, },
} }
example_parameters = { example_parameters = {
"location": "London", "location": "London",
} }
prompt = f""" prompt = f"""
Given the following specification and query, extract the parameters from the query, Given the following specification and query, extract the parameters from the query,
in a valid JSON format enclosed in double quotes. in a valid JSON format enclosed in double quotes.
Example: Example:
SPECIFICATION: SPECIFICATION:
{example_specification} {example_specification}
QUERY: QUERY:
{example_query} {example_query}
PARAMETERS: PARAMETERS:
{example_parameters} {example_parameters}
GIVEN SPECIFICATION: GIVEN SPECIFICATION:
{specification} {specification}
GIVEN QUERY: GIVEN QUERY:
{query} {query}
EXTRACTED PARAMETERS: EXTRACTED PARAMETERS:
""" """
# ai_message = llm_openai(prompt) # ai_message = llm_openai(prompt)
ai_message = llm_mistral(prompt) ai_message = llm_mistral(prompt)
print(ai_message) print(ai_message)
try: try:
parameters = json.loads(ai_message) parameters = json.loads(ai_message)
logger.info(f"Extracted parameters: {parameters}") logger.info(f"Extracted parameters: {parameters}")
return parameters return parameters
except json.JSONDecodeError as json_error: except json.JSONDecodeError as json_error:
raise Exception("JSON parsing error", json_error) raise Exception("JSON parsing error", json_error)
``` ```
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def validate_parameters(function_parameters, specification): def validate_parameters(function_parameters, specification):
required_params = specification["function"]["parameters"]["required"] required_params = specification["function"]["parameters"]["required"]
missing_params = [ missing_params = [
param for param in required_params if param not in function_parameters param for param in required_params if param not in function_parameters
] ]
if missing_params: if missing_params:
raise ValueError(f"Missing required parameters: {missing_params}") raise ValueError(f"Missing required parameters: {missing_params}")
return True return True
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
Set up the routing layer Set up the routing layer
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
from semantic_router.schema import Route from semantic_router.schema import Route
from semantic_router.encoders import CohereEncoder from semantic_router.encoders import CohereEncoder
from semantic_router.layer import RouteLayer from semantic_router.layer import RouteLayer
from semantic_router.utils.logger import logger from semantic_router.utils.logger import logger
def get_route_layer(config: list[dict]) -> RouteLayer: def get_route_layer(config: list[dict]) -> RouteLayer:
logger.info("Getting route layer...") logger.info("Getting route layer...")
encoder = CohereEncoder() encoder = CohereEncoder()
routes = [ routes = [
Route(name=route["name"], utterances=route["utterances"]) for route in config Route(name=route["name"], utterances=route["utterances"]) for route in config
] ]
return RouteLayer(encoder=encoder, routes=routes) return RouteLayer(encoder=encoder, routes=routes)
``` ```
%% Cell type:markdown id: tags: %% Cell type:markdown id: tags:
### Workflow ### Workflow
%% Cell type:code id: tags: %% Cell type:code id: tags:
``` python ``` python
def get_time(location: str) -> str: def get_time(location: str) -> str:
print(f"Calling get_time function with location: {location}") print(f"Calling get_time function with location: {location}")
return "get_time" return "get_time"
get_time_spec = { get_time_spec = {
"type": "function", "type": "function",
"function": { "function": {
"name": "get_time", "name": "get_time",
"description": "Get the current time", "description": "Get the current time",
"parameters": { "parameters": {
"type": "object", "type": "object",
"properties": { "properties": {
"location": { "location": {
"type": "string", "type": "string",
"description": "The city and state", "description": "The city and state",
}, },
}, },
"required": ["location"], "required": ["location"],
}, },
}, },
} }
route_config = generate_config(get_time_spec) route_config = generate_config(get_time_spec)
route_layer = get_route_layer([route_config]) route_layer = get_route_layer([route_config])
queries = [ queries = [
"What is the weather like in Barcelona?", "What is the weather like in Barcelona?",
"What time is it in Taiwan?", "What time is it in Taiwan?",
"What is happening in the world?", "What is happening in the world?",
"what is the time in Kaunas?", "what is the time in Kaunas?",
"Im bored", "Im bored",
"I want to play a game", "I want to play a game",
"Banana", "Banana",
] ]
# Calling functions # Calling functions
for query in queries: for query in queries:
function_name = route_layer(query) function_name = route_layer(query)
if function_name == "get_time": if function_name == "get_time":
function_parameters = extract_parameters(query, get_time_spec) function_parameters = extract_parameters(query, get_time_spec)
try: try:
if validate_parameters(function_parameters, get_time_spec): if validate_parameters(function_parameters, get_time_spec):
# Call the function
get_time(**function_parameters) get_time(**function_parameters)
except ValueError as e: except ValueError as e:
logger.error(f"Error: {e}") logger.error(f"Error: {e}")
``` ```
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment