Skip to content
Snippets Groups Projects
route.py 9.17 KiB
Newer Older
  • Learn to ignore specific revisions
  • import json
    import re
    
    James Briggs's avatar
    James Briggs committed
    from typing import Any, Callable, Dict, List, Optional, Union
    
    from pydantic import BaseModel
    
    Ismail Ashraq's avatar
    Ismail Ashraq committed
    from semantic_router.llms import BaseLLM
    
    from semantic_router.schema import Message, RouteChoice
    
    Simonas's avatar
    Simonas committed
    from semantic_router.utils import function_call
    
    from semantic_router.utils.logger import logger
    
    
    
    def is_valid(route_config: str) -> bool:
    
        """Check if the route config is valid.
    
        :param route_config: The route config to check.
        :type route_config: str
        :return: Whether the route config is valid.
        :rtype: bool
        """
    
        try:
            output_json = json.loads(route_config)
            required_keys = ["name", "utterances"]
    
            if isinstance(output_json, list):
                for item in output_json:
                    missing_keys = [key for key in required_keys if key not in item]
                    if missing_keys:
    
    Bogdan Buduroiu's avatar
    Bogdan Buduroiu committed
                        logger.warning(
                            f"Missing keys in route config: {', '.join(missing_keys)}"
                        )
    
                        return False
                return True
            else:
                missing_keys = [key for key in required_keys if key not in output_json]
                if missing_keys:
    
    Bogdan Buduroiu's avatar
    Bogdan Buduroiu committed
                    logger.warning(
                        f"Missing keys in route config: {', '.join(missing_keys)}"
                    )
    
                    return False
                else:
                    return True
        except json.JSONDecodeError as e:
            logger.error(e)
            return False
    
    
    James Briggs's avatar
    James Briggs committed
    
    
    class Route(BaseModel):
    
        """A route for the semantic router.
    
        :param name: The name of the route.
        :type name: str
        :param utterances: The utterances of the route.
        :type utterances: Union[List[str], List[Any]]
        :param description: The description of the route.
        :type description: Optional[str]
        :param function_schemas: The function schemas of the route.
        :type function_schemas: Optional[List[Dict[str, Any]]]
        :param llm: The LLM to use.
        :type llm: Optional[BaseLLM]
        :param score_threshold: The score threshold of the route.
        :type score_threshold: Optional[float]
        :param metadata: The metadata of the route.
        :type metadata: Optional[Dict[str, Any]]
        """
    
    James Briggs's avatar
    James Briggs committed
        name: str
        utterances: Union[List[str], List[Any]]
        description: Optional[str] = None
        function_schemas: Optional[List[Dict[str, Any]]] = None
        llm: Optional[BaseLLM] = None
        score_threshold: Optional[float] = None
        metadata: Optional[Dict[str, Any]] = {}
    
    
        class Config:
            arbitrary_types_allowed = True
    
    
        def __call__(self, query: Optional[str] = None) -> RouteChoice:
    
            """Call the route. If dynamic routes have been provided the query must have been
            provided and the llm attribute must be set.
    
            :param query: The query to pass to the route.
            :type query: Optional[str]
            :return: The route choice.
            :rtype: RouteChoice
            """
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
            if self.function_schemas:
    
    Ismail Ashraq's avatar
    Ismail Ashraq committed
                if not self.llm:
    
    Bogdan Buduroiu's avatar
    Bogdan Buduroiu committed
                    raise ValueError(
                        "LLM is required for dynamic routes. Please ensure the `llm` "
                        "attribute is set."
    
    James Briggs's avatar
    James Briggs committed
                    )
    
    James Briggs's avatar
    James Briggs committed
                    raise ValueError(
                        "Query is required for dynamic routes. Please ensure the `query` "
                        "argument is passed."
    
    Bogdan Buduroiu's avatar
    Bogdan Buduroiu committed
                    )
    
                # if a function schema is provided we generate the inputs
                extracted_inputs = self.llm.extract_function_inputs(
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
                    query=query, function_schemas=self.function_schemas
    
            else:
                # otherwise we just pass None for the call
    
                func_call = None
    
    James Briggs's avatar
    James Briggs committed
            return RouteChoice(name=self.name, function_call=func_call)
    
        async def acall(self, query: Optional[str] = None) -> RouteChoice:
    
            """Asynchronous call the route. If dynamic routes have been provided the query
            must have been provided and the llm attribute must be set.
    
            :param query: The query to pass to the route.
            :type query: Optional[str]
            :return: The route choice.
            :rtype: RouteChoice
            """
    
            if self.function_schemas:
                if not self.llm:
                    raise ValueError(
                        "LLM is required for dynamic routes. Please ensure the `llm` "
                        "attribute is set."
                    )
                elif query is None:
                    raise ValueError(
                        "Query is required for dynamic routes. Please ensure the `query` "
                        "argument is passed."
                    )
                # if a function schema is provided we generate the inputs
    
    tolgadevAI's avatar
    tolgadevAI committed
                extracted_inputs = await self.llm.async_extract_function_inputs(  # type: ignore # openai-llm
    
                    query=query, function_schemas=self.function_schemas
                )
                func_call = extracted_inputs
            else:
                # otherwise we just pass None for the call
                func_call = None
            return RouteChoice(name=self.name, function_call=func_call)
    
    
        def to_dict(self) -> Dict[str, Any]:
    
            """Convert the route to a dictionary.
    
            :return: The dictionary representation of the route.
            :rtype: Dict[str, Any]
            """
    
            data = self.dict()
            if self.llm is not None:
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
                data["llm"] = {
                    "module": self.llm.__module__,
                    "class": self.llm.__class__.__name__,
                    "model": self.llm.name,
    
        def from_dict(cls, data: Dict[str, Any]):
    
            """Create a Route object from a dictionary.
    
            :param data: The dictionary to create the route from.
            :type data: Dict[str, Any]
            :return: The created route.
            :rtype: Route
            """
    
            return cls(**data)
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
    
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
        def from_dynamic_route(
            cls, llm: BaseLLM, entities: List[Union[BaseModel, Callable]], route_name: str
        ):
    
            """Generate a dynamic Route object from a list of functions or Pydantic models
            using an LLM.
    
            :param llm: The LLM to use.
            :type llm: BaseLLM
            :param entities: The entities to use.
            :type entities: List[Union[BaseModel, Callable]]
            :param route_name: The name of the route.
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
            schemas = function_call.get_schema_list(items=entities)
    
            dynamic_route = cls._generate_dynamic_route(
                llm=llm, function_schemas=schemas, route_name=route_name
            )
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
            dynamic_route.function_schemas = schemas
    
            return dynamic_route
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
    
    
        @classmethod
        def _parse_route_config(cls, config: str) -> str:
    
            """Parse the route config from the LLM output using regex. Expects the output
            content to be wrapped in <config></config> tags.
    
            :param config: The LLM output.
            :type config: str
            :return: The parsed route config.
            :rtype: str
            """
    
            # Regular expression to match content inside <config></config>
            config_pattern = r"<config>(.*?)</config>"
            match = re.search(config_pattern, config, re.DOTALL)
    
            if match:
                config_content = match.group(1).strip()  # Get the matched content
                return config_content
            else:
                raise ValueError("No <config></config> tags found in the output.")
    
        @classmethod
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
        def _generate_dynamic_route(
            cls, llm: BaseLLM, function_schemas: List[Dict[str, Any]], route_name: str
        ):
    
            """Generate a dynamic Route object from a list of function schemas using an LLM.
    
            :param llm: The LLM to use.
            :type llm: BaseLLM
            :param function_schemas: The function schemas to use.
            :type function_schemas: List[Dict[str, Any]]
            :param route_name: The name of the route.
            """
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
            formatted_schemas = "\n".join(
                [json.dumps(schema, indent=4) for schema in function_schemas]
            )
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
            You are tasked to generate a single JSON configuration for multiple function schemas. 
            Each function schema should contribute five example utterances. 
            Please follow the template below, no other tokens allowed:
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
                "name": "{route_name}",
    
                "utterances": [
                    "<example_utterance_1>",
                    "<example_utterance_2>",
                    "<example_utterance_3>",
                    "<example_utterance_4>",
                    "<example_utterance_5>"]
            }}
            </config>
    
            Only include the "name" and "utterances" keys in your answer.
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
            The "name" should match the provided route name and the "utterances"
            should comprise a list of 5 example phrases for each function schema that could be used to invoke
            the functions. Use real values instead of placeholders.
    
    Siraj R Aizlewood's avatar
    Siraj R Aizlewood committed
            Input schemas:
            {formatted_schemas}
    
    Ismail Ashraq's avatar
    Ismail Ashraq committed
            llm_input = [Message(role="user", content=prompt)]
            output = llm(llm_input)
    
    Simonas's avatar
    Simonas committed
                raise Exception("No output generated for dynamic route")
    
    
            route_config = cls._parse_route_config(config=output)
    
            if is_valid(route_config):
    
    Ismail Ashraq's avatar
    Ismail Ashraq committed
                route_config_dict = json.loads(route_config)
                route_config_dict["llm"] = llm
                return Route.from_dict(route_config_dict)
    
            raise Exception("No config generated")