Skip to content
Snippets Groups Projects
Commit 8d0a38b4 authored by James Briggs's avatar James Briggs
Browse files

chore: lint

parent 15497bc5
Branches james/litellm
No related tags found
No related merge requests found
......@@ -5,29 +5,20 @@ import time
from datetime import datetime
from functools import wraps
from platform import python_version
from typing import Any, List, Optional
from unittest.mock import mock_open, patch
from typing import Optional
import numpy as np
import pytest
from semantic_router.encoders import (
CohereEncoder,
DenseEncoder,
OpenAIEncoder,
)
from semantic_router.encoders.base import (
AsymmetricDenseMixin,
AsymmetricSparseMixin,
SparseEncoder,
)
from semantic_router.index.local import LocalIndex
from semantic_router.index.pinecone import PineconeIndex
from semantic_router.index.qdrant import QdrantIndex
from semantic_router.llms import BaseLLM, OpenAILLM
from semantic_router.route import Route
from semantic_router.routers import HybridRouter, RouterConfig, SemanticRouter
from semantic_router.schema import RouteChoice, SparseEmbedding
from semantic_router.routers import HybridRouter, SemanticRouter
from semantic_router.schema import RouteChoice
from semantic_router.utils.logger import logger
PINECONE_SLEEP = 8
......@@ -150,6 +141,7 @@ def routes():
Route(name="Route 2", utterances=["Goodbye", "Bye", "Au revoir"]),
]
@pytest.fixture
def routes_2():
return [
......@@ -157,6 +149,7 @@ def routes_2():
Route(name="Route 2", utterances=["Hi"]),
]
@pytest.fixture
def routes_3():
return [
......@@ -164,6 +157,7 @@ def routes_3():
Route(name="Route 2", utterances=["Asparagus"]),
]
@pytest.fixture
def routes_4():
return [
......@@ -171,12 +165,14 @@ def routes_4():
Route(name="Route 2", utterances=["Asparagus"]),
]
@pytest.fixture
def route_single_utterance():
return [
Route(name="Route 3", utterances=["Hello"]),
]
@pytest.fixture
def dynamic_routes():
return [
......@@ -192,6 +188,7 @@ def dynamic_routes():
),
]
@pytest.fixture
def test_data():
return [
......@@ -202,6 +199,7 @@ def test_data():
("tell me an interesting fact", None),
]
def get_test_indexes():
indexes = [LocalIndex]
if importlib.util.find_spec("qdrant_client") is not None:
......@@ -210,16 +208,19 @@ def get_test_indexes():
indexes.append(PineconeIndex)
return indexes
def get_test_encoders():
encoders = [OpenAIEncoder]
if importlib.util.find_spec("cohere") is not None:
encoders.append(CohereEncoder)
return encoders
def get_test_routers():
routers = [SemanticRouter, HybridRouter]
return routers
@pytest.mark.parametrize(
"index_cls,encoder_cls,router_cls",
[
......@@ -246,9 +247,11 @@ class TestIndexEncoders:
else:
assert score_threshold == encoder.score_threshold
assert route_layer.top_k == 10
@retry(max_retries=RETRY_COUNT, delay=PINECONE_SLEEP)
def check_index_populated():
assert len(route_layer.index) == 5
check_index_populated()
assert (
len(set(route_layer._get_route_names()))
......@@ -276,6 +279,7 @@ class TestIndexEncoders:
else:
assert score_threshold == 0.3
@pytest.mark.parametrize(
"index_cls,encoder_cls,router_cls",
[
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment