Skip to content
Snippets Groups Projects
Unverified Commit e65ec815 authored by Siraj R Aizlewood's avatar Siraj R Aizlewood
Browse files

Linting.

parent f963e6d7
No related branches found
No related tags found
No related merge requests found
...@@ -2,7 +2,10 @@ import pytest ...@@ -2,7 +2,10 @@ import pytest
from semantic_router.llms import OpenAILLM from semantic_router.llms import OpenAILLM
from semantic_router.schema import Message from semantic_router.schema import Message
from semantic_router.utils.function_call import get_schema_openai, convert_param_type_to_json_type from semantic_router.utils.function_call import (
get_schema_openai,
convert_param_type_to_json_type,
)
@pytest.fixture @pytest.fixture
...@@ -168,7 +171,6 @@ class TestOpenAILLM: ...@@ -168,7 +171,6 @@ class TestOpenAILLM:
expected_error_message in actual_error_message expected_error_message in actual_error_message
), f"Expected error message: '{expected_error_message}', but got: '{actual_error_message}'" ), f"Expected error message: '{expected_error_message}', but got: '{actual_error_message}'"
def test_convert_param_type_to_json_type(self): def test_convert_param_type_to_json_type(self):
# Test conversion of basic types # Test conversion of basic types
assert convert_param_type_to_json_type("int") == "number" assert convert_param_type_to_json_type("int") == "number"
...@@ -185,15 +187,22 @@ class TestOpenAILLM: ...@@ -185,15 +187,22 @@ class TestOpenAILLM:
function_schema = {"function": "get_user_data", "args": ["user_id"]} function_schema = {"function": "get_user_data", "args": ["user_id"]}
# Mock the __call__ method to return a JSON string as expected # Mock the __call__ method to return a JSON string as expected
mocker.patch.object(OpenAILLM, '__call__', return_value='{"user_id": "123"}') mocker.patch.object(OpenAILLM, "__call__", return_value='{"user_id": "123"}')
result = openai_llm.extract_function_inputs(query, function_schema) result = openai_llm.extract_function_inputs(query, function_schema)
# Ensure the __call__ method is called with the correct parameters # Ensure the __call__ method is called with the correct parameters
expected_messages = [ expected_messages = [
Message(role="system", content="You are an intelligent AI. Given a command or request from the user, call the function to complete the request."), Message(
Message(role="user", content=query) role="system",
content="You are an intelligent AI. Given a command or request from the user, call the function to complete the request.",
),
Message(role="user", content=query),
] ]
openai_llm.__call__.assert_called_once_with(messages=expected_messages, function_schema=function_schema) openai_llm.__call__.assert_called_once_with(
messages=expected_messages, function_schema=function_schema
)
# Check if the result is as expected # Check if the result is as expected
assert result == {"user_id": "123"}, "The function inputs should match the expected dictionary." assert result == {
\ No newline at end of file "user_id": "123"
}, "The function inputs should match the expected dictionary."
...@@ -5,6 +5,7 @@ import pytest ...@@ -5,6 +5,7 @@ import pytest
from semantic_router.llms import BaseLLM from semantic_router.llms import BaseLLM
from semantic_router.route import Route, is_valid from semantic_router.route import Route, is_valid
# Is valid test: # Is valid test:
def test_is_valid_with_valid_json(): def test_is_valid_with_valid_json():
valid_json = '{"name": "test_route", "utterances": ["hello", "hi"]}' valid_json = '{"name": "test_route", "utterances": ["hello", "hi"]}'
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Finish editing this message first!
Please register or to comment