diff --git a/poetry.lock b/poetry.lock index 7efeda7e28a550d04e2e108fa1901478a7483589..63248ed235772f93d252b60125797bf96601e1a3 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.5.1 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.7.1 and should not be changed by hand. [[package]] name = "aiohttp" @@ -1716,6 +1716,7 @@ files = [ {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:69b023b2b4daa7548bcfbd4aa3da05b3a74b772db9e23b982788168117739938"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:81e0b275a9ecc9c0c0c07b4b90ba548307583c125f54d5b6946cfee6360c733d"}, {file = "PyYAML-6.0.1-cp310-cp310-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:ba336e390cd8e4d1739f42dfe9bb83a3cc2e80f567d8805e11b46f4a943f5515"}, + {file = "PyYAML-6.0.1-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:326c013efe8048858a6d312ddd31d56e468118ad4cdeda36c719bf5bb6192290"}, {file = "PyYAML-6.0.1-cp310-cp310-win32.whl", hash = "sha256:bd4af7373a854424dabd882decdc5579653d7868b8fb26dc7d0e99f823aa5924"}, {file = "PyYAML-6.0.1-cp310-cp310-win_amd64.whl", hash = "sha256:fd1592b3fdf65fff2ad0004b5e363300ef59ced41c2e6b3a99d4089fa8c5435d"}, {file = "PyYAML-6.0.1-cp311-cp311-macosx_10_9_x86_64.whl", hash = "sha256:6965a7bc3cf88e5a1c3bd2e0b5c22f8d677dc88a455344035f03399034eb3007"}, @@ -1723,8 +1724,15 @@ files = [ {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:42f8152b8dbc4fe7d96729ec2b99c7097d656dc1213a3229ca5383f973a5ed6d"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:062582fca9fabdd2c8b54a3ef1c978d786e0f6b3a1510e0ac93ef59e0ddae2bc"}, {file = "PyYAML-6.0.1-cp311-cp311-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:d2b04aac4d386b172d5b9692e2d2da8de7bfb6c387fa4f801fbf6fb2e6ba4673"}, + {file = "PyYAML-6.0.1-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:e7d73685e87afe9f3b36c799222440d6cf362062f78be1013661b00c5c6f678b"}, {file = "PyYAML-6.0.1-cp311-cp311-win32.whl", hash = "sha256:1635fd110e8d85d55237ab316b5b011de701ea0f29d07611174a1b42f1444741"}, {file = "PyYAML-6.0.1-cp311-cp311-win_amd64.whl", hash = "sha256:bf07ee2fef7014951eeb99f56f39c9bb4af143d8aa3c21b1677805985307da34"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:855fb52b0dc35af121542a76b9a84f8d1cd886ea97c84703eaa6d88e37a2ad28"}, + {file = "PyYAML-6.0.1-cp312-cp312-macosx_11_0_arm64.whl", hash = "sha256:40df9b996c2b73138957fe23a16a4f0ba614f4c0efce1e9406a184b6d07fa3a9"}, + {file = "PyYAML-6.0.1-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:6c22bec3fbe2524cde73d7ada88f6566758a8f7227bfbf93a408a9d86bcc12a0"}, + {file = "PyYAML-6.0.1-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:8d4e9c88387b0f5c7d5f281e55304de64cf7f9c0021a3525bd3b1c542da3b0e4"}, + {file = "PyYAML-6.0.1-cp312-cp312-win32.whl", hash = "sha256:d483d2cdf104e7c9fa60c544d92981f12ad66a457afae824d146093b8c294c54"}, + {file = "PyYAML-6.0.1-cp312-cp312-win_amd64.whl", hash = "sha256:0d3304d8c0adc42be59c5f8a4d9e3d7379e6955ad754aa9d6ab7a398b59dd1df"}, {file = "PyYAML-6.0.1-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:50550eb667afee136e9a77d6dc71ae76a44df8b3e51e41b77f6de2932bfe0f47"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:1fe35611261b29bd1de0070f0b2f47cb6ff71fa6595c077e42bd0c419fa27b98"}, {file = "PyYAML-6.0.1-cp36-cp36m-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:704219a11b772aea0d8ecd7058d0082713c3562b4e271b849ad7dc4a5c90c13c"}, @@ -1741,6 +1749,7 @@ files = [ {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:a0cd17c15d3bb3fa06978b4e8958dcdc6e0174ccea823003a106c7d4d7899ac5"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:28c119d996beec18c05208a8bd78cbe4007878c6dd15091efb73a30e90539696"}, {file = "PyYAML-6.0.1-cp38-cp38-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:7e07cbde391ba96ab58e532ff4803f79c4129397514e1413a7dc761ccd755735"}, + {file = "PyYAML-6.0.1-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:49a183be227561de579b4a36efbb21b3eab9651dd81b1858589f796549873dd6"}, {file = "PyYAML-6.0.1-cp38-cp38-win32.whl", hash = "sha256:184c5108a2aca3c5b3d3bf9395d50893a7ab82a38004c8f61c258d4428e80206"}, {file = "PyYAML-6.0.1-cp38-cp38-win_amd64.whl", hash = "sha256:1e2722cc9fbb45d9b87631ac70924c11d3a401b2d7f410cc0e3bbf249f2dca62"}, {file = "PyYAML-6.0.1-cp39-cp39-macosx_10_9_x86_64.whl", hash = "sha256:9eb6caa9a297fc2c2fb8862bc5370d0303ddba53ba97e71f08023b6cd73d16a8"}, @@ -1748,6 +1757,7 @@ files = [ {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:5773183b6446b2c99bb77e77595dd486303b4faab2b086e7b17bc6bef28865f6"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_s390x.manylinux2014_s390x.whl", hash = "sha256:b786eecbdf8499b9ca1d697215862083bd6d2a99965554781d0d8d1ad31e13a0"}, {file = "PyYAML-6.0.1-cp39-cp39-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:bc1bf2925a1ecd43da378f4db9e4f799775d6367bdb94671027b73b393a7c42c"}, + {file = "PyYAML-6.0.1-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:04ac92ad1925b2cff1db0cfebffb6ffc43457495c9b3c39d3fcae417d7125dc5"}, {file = "PyYAML-6.0.1-cp39-cp39-win32.whl", hash = "sha256:faca3bdcf85b2fc05d06ff3fbc1f83e1391b3e724afa3feba7d13eeab355484c"}, {file = "PyYAML-6.0.1-cp39-cp39-win_amd64.whl", hash = "sha256:510c9deebc5c0225e8c96813043e62b680ba2f9c50a08d3724c7f28a747d1486"}, {file = "PyYAML-6.0.1.tar.gz", hash = "sha256:bfdf460b1736c775f2ba9f6a92bca30bc2095067b8a9d77876d1fad6cc3b4a43"}, diff --git a/pyproject.toml b/pyproject.toml index 0741dac08d7b3961d12a6e361286d1ef9b131b1a..71ef163be7594118fe0b5b6b6e5d8ad272bcbd04 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "semantic-router" -version = "0.0.12" +version = "0.0.13" description = "Super fast semantic router for AI decision making" authors = [ "James Briggs <james@aurelio.ai>", diff --git a/semantic_router/encoders/cohere.py b/semantic_router/encoders/cohere.py index ae0db7a24cb1d427a1bbe002d40d31fb1515e86b..f7aef0e6227938ef174d867f12e05ac19f58524d 100644 --- a/semantic_router/encoders/cohere.py +++ b/semantic_router/encoders/cohere.py @@ -11,9 +11,11 @@ class CohereEncoder(BaseEncoder): def __init__( self, - name: str = os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0"), + name: str | None = None, cohere_api_key: str | None = None, ): + if name is None: + name = os.getenv("COHERE_MODEL_NAME", "embed-english-v3.0") super().__init__(name=name) cohere_api_key = cohere_api_key or os.getenv("COHERE_API_KEY") if cohere_api_key is None: diff --git a/semantic_router/encoders/openai.py b/semantic_router/encoders/openai.py index 9744401fdfbb7426bcc640b2d7c45bb29565f8e7..173fe94a493fe750156ef8f26bac8d9c9346d30d 100644 --- a/semantic_router/encoders/openai.py +++ b/semantic_router/encoders/openai.py @@ -15,9 +15,11 @@ class OpenAIEncoder(BaseEncoder): def __init__( self, - name: str = os.getenv("OPENAI_MODEL_NAME", "text-embedding-ada-002"), + name: str | None = None, openai_api_key: str | None = None, ): + if name is None: + name = os.getenv("OPENAI_MODEL_NAME", "text-embedding-ada-002") super().__init__(name=name) api_key = openai_api_key or os.getenv("OPENAI_API_KEY") if api_key is None: diff --git a/semantic_router/route.py b/semantic_router/route.py index 06ebf8f39a6ad6b7f5fb71f298ad691621925a50..30c20887ebcd7073fe2a32630d0d566ada3c945c 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -64,12 +64,12 @@ class Route(BaseModel): return cls(**data) @classmethod - async def from_dynamic_route(cls, entity: Union[BaseModel, Callable]): + def from_dynamic_route(cls, entity: Union[BaseModel, Callable]): """ Generate a dynamic Route object from a function or Pydantic model using LLM """ schema = function_call.get_schema(item=entity) - dynamic_route = await cls._generate_dynamic_route(function_schema=schema) + dynamic_route = cls._generate_dynamic_route(function_schema=schema) return dynamic_route @classmethod @@ -85,7 +85,7 @@ class Route(BaseModel): raise ValueError("No <config></config> tags found in the output.") @classmethod - async def _generate_dynamic_route(cls, function_schema: dict[str, Any]): + def _generate_dynamic_route(cls, function_schema: dict[str, Any]): logger.info("Generating dynamic route...") prompt = f""" diff --git a/tests/unit/encoders/test_openai.py b/tests/unit/encoders/test_openai.py index cc79d27207f7847439b74b0c29f4fb75d42d5381..fb0e604f1b085d5e039f7dcc9263d4cd5d22e4ec 100644 --- a/tests/unit/encoders/test_openai.py +++ b/tests/unit/encoders/test_openai.py @@ -22,7 +22,6 @@ class TestOpenAIEncoder: mocker.patch("os.getenv", return_value=None) with pytest.raises(ValueError) as e: OpenAIEncoder() - assert "OpenAI API key cannot be 'None'." in str(e.value) def test_openai_encoder_call_uninitialized_client(self, openai_encoder): # Set the client to None to simulate an uninitialized client diff --git a/tests/unit/test_route.py b/tests/unit/test_route.py index 2843ae4065a2ac75498a4e4f29f056b5be21fb00..0a9a6eba6a9cdd7e456cbd891ab872521f17dc60 100644 --- a/tests/unit/test_route.py +++ b/tests/unit/test_route.py @@ -1,4 +1,4 @@ -from unittest.mock import AsyncMock, patch +from unittest.mock import Mock, AsyncMock, patch import pytest @@ -43,9 +43,8 @@ def test_is_valid_with_invalid_json(): class TestRoute: - @pytest.mark.asyncio - @patch("semantic_router.route.llm", new_callable=AsyncMock) - async def test_generate_dynamic_route(self, mock_llm): + @patch("semantic_router.route.llm", new_callable=Mock) + def test_generate_dynamic_route(self, mock_llm): print(f"mock_llm: {mock_llm}") mock_llm.return_value = """ <config> @@ -61,7 +60,7 @@ class TestRoute: </config> """ function_schema = {"name": "test_function", "type": "function"} - route = await Route._generate_dynamic_route(function_schema) + route = Route._generate_dynamic_route(function_schema) assert route.name == "test_function" assert route.utterances == [ "example_utterance_1", @@ -71,6 +70,35 @@ class TestRoute: "example_utterance_5", ] + # TODO add async version + # @pytest.mark.asyncio + # @patch("semantic_router.route.allm", new_callable=Mock) + # async def test_generate_dynamic_route_async(self, mock_llm): + # print(f"mock_llm: {mock_llm}") + # mock_llm.return_value = """ + # <config> + # { + # "name": "test_function", + # "utterances": [ + # "example_utterance_1", + # "example_utterance_2", + # "example_utterance_3", + # "example_utterance_4", + # "example_utterance_5"] + # } + # </config> + # """ + # function_schema = {"name": "test_function", "type": "function"} + # route = await Route._generate_dynamic_route(function_schema) + # assert route.name == "test_function" + # assert route.utterances == [ + # "example_utterance_1", + # "example_utterance_2", + # "example_utterance_3", + # "example_utterance_4", + # "example_utterance_5", + # ] + def test_to_dict(self): route = Route(name="test", utterances=["utterance"]) expected_dict = { @@ -87,9 +115,8 @@ class TestRoute: assert route.name == "test" assert route.utterances == ["utterance"] - @pytest.mark.asyncio - @patch("semantic_router.route.llm", new_callable=AsyncMock) - async def test_from_dynamic_route(self, mock_llm): + @patch("semantic_router.route.llm", new_callable=Mock) + def test_from_dynamic_route(self, mock_llm): # Mock the llm function mock_llm.return_value = """ <config> @@ -109,7 +136,7 @@ class TestRoute: """Test function docstring""" pass - dynamic_route = await Route.from_dynamic_route(test_function) + dynamic_route = Route.from_dynamic_route(test_function) assert dynamic_route.name == "test_function" assert dynamic_route.utterances == [ @@ -120,6 +147,40 @@ class TestRoute: "example_utterance_5", ] + # TODO add async functions + # @pytest.mark.asyncio + # @patch("semantic_router.route.allm", new_callable=AsyncMock) + # async def test_from_dynamic_route_async(self, mock_llm): + # # Mock the llm function + # mock_llm.return_value = """ + # <config> + # { + # "name": "test_function", + # "utterances": [ + # "example_utterance_1", + # "example_utterance_2", + # "example_utterance_3", + # "example_utterance_4", + # "example_utterance_5"] + # } + # </config> + # """ + + # def test_function(input: str): + # """Test function docstring""" + # pass + + # dynamic_route = await Route.from_dynamic_route(test_function) + + # assert dynamic_route.name == "test_function" + # assert dynamic_route.utterances == [ + # "example_utterance_1", + # "example_utterance_2", + # "example_utterance_3", + # "example_utterance_4", + # "example_utterance_5", + # ] + def test_parse_route_config(self): config = """ <config>