Skip to content
GitLab
Explore
Sign in
Primary navigation
Search or go to…
Project
Semantic Router
Manage
Activity
Members
Labels
Plan
Issues
Issue boards
Milestones
Iterations
Wiki
Code
Merge requests
Repository
Branches
Commits
Tags
Repository graph
Compare revisions
Snippets
Locked files
Build
Pipelines
Jobs
Pipeline schedules
Artifacts
Deploy
Releases
Package registry
Container Registry
Model registry
Operate
Environments
Terraform modules
Monitor
Incidents
Service Desk
Analyze
Value stream analytics
Contributor analytics
CI/CD analytics
Repository analytics
Code review analytics
Issue analytics
Model experiments
Help
Help
Support
GitLab documentation
Compare GitLab plans
Community forum
Contribute to GitLab
Provide feedback
Terms and privacy
Keyboard shortcuts
?
Snippets
Groups
Projects
Show more breadcrumbs
mirrored_repos
MachineLearning
aurelio-labs
Semantic Router
Commits
bcaa22ec
Commit
bcaa22ec
authored
1 month ago
by
James Briggs
Browse files
Options
Downloads
Patches
Plain Diff
feat: further docstrings and cleanup
parent
f16f620f
No related branches found
Branches containing commit
No related tags found
Tags containing commit
No related merge requests found
Changes
23
Hide whitespace changes
Inline
Side-by-side
Showing
3 changed files
semantic_router/utils/function_call.py
+75
-0
75 additions, 0 deletions
semantic_router/utils/function_call.py
semantic_router/utils/llm.py
+0
-65
0 additions, 65 deletions
semantic_router/utils/llm.py
semantic_router/utils/logger.py
+7
-0
7 additions, 0 deletions
semantic_router/utils/logger.py
with
82 additions
and
65 deletions
semantic_router/utils/function_call.py
+
75
−
0
View file @
bcaa22ec
...
...
@@ -9,6 +9,20 @@ from semantic_router.utils.logger import logger
class
Parameter
(
BaseModel
):
"""
Parameter for a function.
:param name: The name of the parameter.
:type name: str
:param description: The description of the parameter.
:type description: Optional[str]
:param type: The type of the parameter.
:type type: str
:param default: The default value of the parameter.
:type default: Any
:param required: Whether the parameter is required.
:type required: bool
"""
class
Config
:
arbitrary_types_allowed
=
True
...
...
@@ -21,6 +35,11 @@ class Parameter(BaseModel):
required
:
bool
=
Field
(
description
=
"
Whether the parameter is required
"
)
def
to_ollama
(
self
):
"""
Convert the parameter to a dictionary for an Ollama-compatible function schema.
:return: The parameter in dictionary format.
:rtype: Dict[str, Any]
"""
return
{
self
.
name
:
{
"
description
"
:
self
.
description
,
...
...
@@ -41,6 +60,11 @@ class FunctionSchema:
parameters
:
List
[
Parameter
]
=
Field
(
description
=
"
The parameters of the function
"
)
def
__init__
(
self
,
function
:
Union
[
Callable
,
BaseModel
]):
"""
Initialize the FunctionSchema.
:param function: The function to consume.
:type function: Union[Callable, BaseModel]
"""
self
.
function
=
function
if
callable
(
function
):
self
.
_process_function
(
function
)
...
...
@@ -50,6 +74,11 @@ class FunctionSchema:
raise
TypeError
(
"
Function must be a Callable or BaseModel
"
)
def
_process_function
(
self
,
function
:
Callable
):
"""
Process the function to get the name, description, signature, and output.
:param function: The function to process.
:type function: Callable
"""
self
.
name
=
function
.
__name__
self
.
description
=
str
(
inspect
.
getdoc
(
function
))
self
.
signature
=
str
(
inspect
.
signature
(
function
))
...
...
@@ -67,6 +96,11 @@ class FunctionSchema:
self
.
parameters
=
parameters
def
to_ollama
(
self
):
"""
Convert the FunctionSchema to an Ollama-compatible function schema dictionary.
:return: The function schema in dictionary format.
:rtype: Dict[str, Any]
"""
schema_dict
=
{
"
type
"
:
"
function
"
,
"
function
"
:
{
...
...
@@ -94,6 +128,13 @@ class FunctionSchema:
return
schema_dict
def
_ollama_type_mapping
(
self
,
param_type
:
str
)
->
str
:
"""
Map the parameter type to an Ollama-compatible type.
:param param_type: The type of the parameter.
:type param_type: str
:return: The Ollama-compatible type.
:rtype: str
"""
if
param_type
==
"
int
"
:
return
"
number
"
elif
param_type
==
"
float
"
:
...
...
@@ -107,6 +148,13 @@ class FunctionSchema:
def
get_schema_list
(
items
:
List
[
Union
[
BaseModel
,
Callable
]])
->
List
[
Dict
[
str
,
Any
]]:
"""
Get a list of function schemas from a list of functions or Pydantic BaseModels.
:param items: The functions or BaseModels to get the schemas for.
:type items: List[Union[BaseModel, Callable]]
:return: A list of function schemas.
:rtype: List[Dict[str, Any]]
"""
schemas
=
[]
for
item
in
items
:
schema
=
get_schema
(
item
)
...
...
@@ -115,6 +163,13 @@ def get_schema_list(items: List[Union[BaseModel, Callable]]) -> List[Dict[str, A
def
get_schema
(
item
:
Union
[
BaseModel
,
Callable
])
->
Dict
[
str
,
Any
]:
"""
Get a function schema from a function or Pydantic BaseModel.
:param item: The function or BaseModel to get the schema for.
:type item: Union[BaseModel, Callable]
:return: The function schema.
:rtype: Dict[str, Any]
"""
if
isinstance
(
item
,
BaseModel
):
signature_parts
=
[]
for
field_name
,
field_model
in
item
.
__annotations__
.
items
():
...
...
@@ -147,6 +202,13 @@ def get_schema(item: Union[BaseModel, Callable]) -> Dict[str, Any]:
def
convert_python_type_to_json_type
(
param_type
:
str
)
->
str
:
"""
Convert a Python type to a JSON type.
:param param_type: The type of the parameter.
:type param_type: str
:return: The JSON type.
:rtype: str
"""
if
param_type
==
"
int
"
:
return
"
number
"
if
param_type
==
"
float
"
:
...
...
@@ -167,6 +229,19 @@ def convert_python_type_to_json_type(param_type: str) -> str:
async
def
route_and_execute
(
query
:
str
,
llm
:
BaseLLM
,
functions
:
List
[
Callable
],
layer
)
->
Any
:
"""
Route and execute a function.
:param query: The query to route and execute.
:type query: str
:param llm: The LLM to use.
:type llm: BaseLLM
:param functions: The functions to execute.
:type functions: List[Callable]
:param layer: The layer to use.
:type layer: Layer
:return: The result of the function.
:rtype: Any
"""
route_choice
:
RouteChoice
=
layer
(
query
)
for
function
in
functions
:
...
...
This diff is collapsed.
Click to expand it.
semantic_router/utils/llm.py
deleted
100644 → 0
+
0
−
65
View file @
f16f620f
import
os
from
typing
import
Optional
import
openai
from
semantic_router.utils.logger
import
logger
def
llm
(
prompt
:
str
)
->
Optional
[
str
]:
try
:
client
=
openai
.
OpenAI
(
base_url
=
"
https://openrouter.ai/api/v1
"
,
api_key
=
os
.
getenv
(
"
OPENROUTER_API_KEY
"
),
)
completion
=
client
.
chat
.
completions
.
create
(
model
=
"
mistralai/mistral-7b-instruct
"
,
messages
=
[
{
"
role
"
:
"
user
"
,
"
content
"
:
prompt
,
},
],
temperature
=
0.01
,
max_tokens
=
200
,
)
output
=
completion
.
choices
[
0
].
message
.
content
if
not
output
:
raise
Exception
(
"
No output generated
"
)
return
output
except
Exception
as
e
:
logger
.
error
(
f
"
LLM error:
{
e
}
"
)
raise
Exception
(
f
"
LLM error:
{
e
}
"
)
from
e
# TODO integrate async LLM function
# async def allm(prompt: str) -> Optional[str]:
# try:
# client = openai.AsyncOpenAI(
# base_url="https://openrouter.ai/api/v1",
# api_key=os.getenv("OPENROUTER_API_KEY"),
# )
# completion = await client.chat.completions.create(
# model="mistralai/mistral-7b-instruct",
# messages=[
# {
# "role": "user",
# "content": prompt,
# },
# ],
# temperature=0.01,
# max_tokens=200,
# )
# output = completion.choices[0].message.content
# if not output:
# raise Exception("No output generated")
# return output
# except Exception as e:
# logger.error(f"LLM error: {e}")
# raise Exception(f"LLM error: {e}") from e
This diff is collapsed.
Click to expand it.
semantic_router/utils/logger.py
+
7
−
0
View file @
bcaa22ec
...
...
@@ -4,6 +4,9 @@ import colorlog
class
CustomFormatter
(
colorlog
.
ColoredFormatter
):
"""
Custom formatter for the logger.
"""
def
__init__
(
self
):
super
().
__init__
(
"
%(log_color)s%(asctime)s %(levelname)s %(name)s %(message)s
"
,
...
...
@@ -21,6 +24,8 @@ class CustomFormatter(colorlog.ColoredFormatter):
def
add_coloured_handler
(
logger
):
"""
Add a coloured handler to the logger.
"""
formatter
=
CustomFormatter
()
console_handler
=
logging
.
StreamHandler
()
console_handler
.
setFormatter
(
formatter
)
...
...
@@ -29,6 +34,8 @@ def add_coloured_handler(logger):
def
setup_custom_logger
(
name
):
"""
Setup a custom logger.
"""
logger
=
logging
.
getLogger
(
name
)
if
not
logger
.
hasHandlers
():
...
...
This diff is collapsed.
Click to expand it.
Prev
1
2
Next
Preview
0%
Loading
Try again
or
attach a new file
.
Cancel
You are about to add
0
people
to the discussion. Proceed with caution.
Finish editing this message first!
Save comment
Cancel
Please
register
or
sign in
to comment