From d96ec2ea3ac8cbe88b3376cad3ecea0e1fe3e47c Mon Sep 17 00:00:00 2001 From: James Briggs <35938317+jamescalam@users.noreply.github.com> Date: Thu, 28 Dec 2023 09:46:10 +0100 Subject: [PATCH] added tests to cover route layer save/load --- semantic_router/utils/llm.py | 56 ++++++++++++++++++------------------ test_output.json | 1 - test_output.yaml | 4 --- tests/unit/test_layer.py | 26 +++++++++++++++++ 4 files changed, 54 insertions(+), 33 deletions(-) delete mode 100644 test_output.json delete mode 100644 test_output.yaml diff --git a/semantic_router/utils/llm.py b/semantic_router/utils/llm.py index 0d22b9a6..6ce28ff8 100644 --- a/semantic_router/utils/llm.py +++ b/semantic_router/utils/llm.py @@ -33,31 +33,31 @@ def llm(prompt: str) -> str | None: logger.error(f"LLM error: {e}") raise Exception(f"LLM error: {e}") - -async def allm(prompt: str) -> str | None: - try: - client = openai.AsyncOpenAI( - base_url="https://openrouter.ai/api/v1", - api_key=os.getenv("OPENROUTER_API_KEY"), - ) - - completion = await client.chat.completions.create( - model="mistralai/mistral-7b-instruct", - messages=[ - { - "role": "user", - "content": prompt, - }, - ], - temperature=0.01, - max_tokens=200, - ) - - output = completion.choices[0].message.content - - if not output: - raise Exception("No output generated") - return output - except Exception as e: - logger.error(f"LLM error: {e}") - raise Exception(f"LLM error: {e}") +# TODO integrate async LLM function +# async def allm(prompt: str) -> str | None: +# try: +# client = openai.AsyncOpenAI( +# base_url="https://openrouter.ai/api/v1", +# api_key=os.getenv("OPENROUTER_API_KEY"), +# ) + +# completion = await client.chat.completions.create( +# model="mistralai/mistral-7b-instruct", +# messages=[ +# { +# "role": "user", +# "content": prompt, +# }, +# ], +# temperature=0.01, +# max_tokens=200, +# ) + +# output = completion.choices[0].message.content + +# if not output: +# raise Exception("No output generated") +# return output +# except Exception as e: +# logger.error(f"LLM error: {e}") +# raise Exception(f"LLM error: {e}") diff --git a/test_output.json b/test_output.json deleted file mode 100644 index 1f930085..00000000 --- a/test_output.json +++ /dev/null @@ -1 +0,0 @@ -[{"name": "test", "utterances": ["utterance"], "description": null}] diff --git a/test_output.yaml b/test_output.yaml deleted file mode 100644 index b7167647..00000000 --- a/test_output.yaml +++ /dev/null @@ -1,4 +0,0 @@ -- description: null - name: test - utterances: - - utterance diff --git a/tests/unit/test_layer.py b/tests/unit/test_layer.py index 386edf6d..c6898235 100644 --- a/tests/unit/test_layer.py +++ b/tests/unit/test_layer.py @@ -173,6 +173,32 @@ class TestRouteLayer: route_layer = RouteLayer(encoder=base_encoder) assert route_layer.score_threshold == 0.82 + def test_json(self, openai_encoder, routes): + route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + route_layer.to_json("test_output.json") + assert os.path.exists("test_output.json") + route_layer_from_file = RouteLayer.from_json("test_output.json") + assert route_layer_from_file.index is not None and route_layer_from_file.categories is not None + os.remove("test_output.json") + + def test_yaml(self, openai_encoder, routes): + route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + route_layer.to_yaml("test_output.yaml") + assert os.path.exists("test_output.yaml") + route_layer_from_file = RouteLayer.from_yaml("test_output.yaml") + assert route_layer_from_file.index is not None and route_layer_from_file.categories is not None + os.remove("test_output.yaml") + + def test_config(self, openai_encoder, routes): + route_layer = RouteLayer(encoder=openai_encoder, routes=routes) + # confirm route creation functions as expected + layer_config = route_layer.to_config() + assert layer_config.routes == routes + # now load from config and confirm it's the same + route_layer_from_config = RouteLayer.from_config(layer_config) + assert (route_layer_from_config.index == route_layer.index).all() + assert (route_layer_from_config.categories == route_layer.categories).all() + assert route_layer_from_config.score_threshold == route_layer.score_threshold # Add more tests for edge cases and error handling as needed. -- GitLab