diff --git a/semantic_router/layer.py b/semantic_router/layer.py index ac55db983d57a96d0e2dff0d5b8b1b13ce5c6b27..81f690b594ba5e3ed5151e206720746de6a4bb0c 100644 --- a/semantic_router/layer.py +++ b/semantic_router/layer.py @@ -303,10 +303,9 @@ class RouteLayer: ) self.llm = OpenAILLM() route.llm = self.llm - return await route.llm.acall(text) # type: ignore else: route.llm = self.llm - return route(text) + return await route.acall(text) elif passed and route is not None and simulate_static: return RouteChoice( name=route.name, diff --git a/semantic_router/llms/openai.py b/semantic_router/llms/openai.py index 24a38ee3c11ac45f8725eff7d40d4d2c3865adf4..f22f409e8c59e830d9c1f20059c4be8d3bc10a27 100644 --- a/semantic_router/llms/openai.py +++ b/semantic_router/llms/openai.py @@ -66,6 +66,23 @@ class OpenAILLM(BaseLLM): ) return tool_calls_info + async def async_extract_tool_calls_info( + self, tool_calls: List[ChatCompletionMessageToolCall] + ) -> List[Dict[str, Any]]: + tool_calls_info = [] + for tool_call in tool_calls: + if tool_call.function.arguments is None: + raise ValueError( + "Invalid output, expected arguments to be specified for each tool call." + ) + tool_calls_info.append( + { + "function_name": tool_call.function.name, + "arguments": json.loads(tool_call.function.arguments), + } + ) + return tool_calls_info + def __call__( self, messages: List[Message], @@ -141,7 +158,7 @@ class OpenAILLM(BaseLLM): # Collecting multiple tool calls information output = str( - self._extract_tool_calls_info(tool_calls) + await self.async_extract_tool_calls_info(tool_calls) ) # str in keeping with base type. else: content = completion.choices[0].message.content @@ -168,6 +185,25 @@ class OpenAILLM(BaseLLM): output = output.replace("'", '"') function_inputs = json.loads(output) logger.info(f"Function inputs: {function_inputs}") + logger.info(f"function_schemas: {function_schemas}") + if not self._is_valid_inputs(function_inputs, function_schemas): + raise ValueError("Invalid inputs") + return function_inputs + + async def async_extract_function_inputs( + self, query: str, function_schemas: List[Dict[str, Any]] + ) -> List[Dict[str, Any]]: + system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request." + messages = [ + Message(role="system", content=system_prompt), + Message(role="user", content=query), + ] + output = await self.acall(messages=messages, function_schemas=function_schemas) + if not output: + raise Exception("No output generated for extract function input") + output = output.replace("'", '"') + function_inputs = json.loads(output) + logger.info(f"OpenAI => Function Inputs: {function_inputs}") if not self._is_valid_inputs(function_inputs, function_schemas): raise ValueError("Invalid inputs") return function_inputs diff --git a/semantic_router/route.py b/semantic_router/route.py index a32c778c8506217f542060683bb498989b560c67..66ffb6fb40197bebb6ec3dc41596f58d67aa380b 100644 --- a/semantic_router/route.py +++ b/semantic_router/route.py @@ -76,6 +76,28 @@ class Route(BaseModel): func_call = None return RouteChoice(name=self.name, function_call=func_call) + async def acall(self, query: Optional[str] = None) -> RouteChoice: + if self.function_schemas: + if not self.llm: + raise ValueError( + "LLM is required for dynamic routes. Please ensure the `llm` " + "attribute is set." + ) + elif query is None: + raise ValueError( + "Query is required for dynamic routes. Please ensure the `query` " + "argument is passed." + ) + # if a function schema is provided we generate the inputs + extracted_inputs = await self.llm.async_extract_function_inputs( + query=query, function_schemas=self.function_schemas + ) + func_call = extracted_inputs + else: + # otherwise we just pass None for the call + func_call = None + return RouteChoice(name=self.name, function_call=func_call) + # def to_dict(self) -> Dict[str, Any]: # return self.dict()