Skip to content
Snippets Groups Projects
schema.py 1015 B
Newer Older
  • Learn to ignore specific revisions
  • James Briggs's avatar
    James Briggs committed
    from enum import Enum
    
    Simonas's avatar
    Simonas committed
    
    
    from pydantic import BaseModel
    
    from pydantic.dataclasses import dataclass
    
    Simonas's avatar
    Simonas committed
    
    
    James Briggs's avatar
    James Briggs committed
    from semantic_router.encoders import (
        BaseEncoder,
        CohereEncoder,
        OpenAIEncoder,
    
    James Briggs's avatar
    James Briggs committed
    class EncoderType(Enum):
    
    James Briggs's avatar
    James Briggs committed
        HUGGINGFACE = "huggingface"
        OPENAI = "openai"
        COHERE = "cohere"
    
    
    Simonas's avatar
    Simonas committed
    
    
    class RouteChoice(BaseModel):
        name: str | None = None
        function_call: dict | None = None
    
    
    
    James Briggs's avatar
    James Briggs committed
    @dataclass
    
    James Briggs's avatar
    James Briggs committed
    class Encoder:
    
    James Briggs's avatar
    James Briggs committed
        type: EncoderType
    
    James Briggs's avatar
    James Briggs committed
        name: str | None
    
    James Briggs's avatar
    James Briggs committed
        model: BaseEncoder
    
    James Briggs's avatar
    James Briggs committed
    
    
    James Briggs's avatar
    James Briggs committed
        def __init__(self, type: str, name: str | None):
    
    James Briggs's avatar
    James Briggs committed
            self.type = EncoderType(type)
    
    James Briggs's avatar
    James Briggs committed
            self.name = name
    
    James Briggs's avatar
    James Briggs committed
            if self.type == EncoderType.HUGGINGFACE:
    
                raise NotImplementedError
    
    James Briggs's avatar
    James Briggs committed
            elif self.type == EncoderType.OPENAI:
                self.model = OpenAIEncoder(name)
            elif self.type == EncoderType.COHERE:
                self.model = CohereEncoder(name)
    
    James Briggs's avatar
    James Briggs committed
            else:
    
    James Briggs's avatar
    James Briggs committed
                raise ValueError
    
    James Briggs's avatar
    James Briggs committed
    
    
    Simonas's avatar
    Simonas committed
        def __call__(self, texts: list[str]) -> list[list[float]]:
    
    James Briggs's avatar
    James Briggs committed
            return self.model(texts)