import os
from typing import List, Optional, Any

import openai
from openai._types import NotGiven

from semantic_router.llms import BaseLLM
from semantic_router.schema import Message
from semantic_router.utils.defaults import EncoderDefault
from semantic_router.utils.logger import logger
import json
from semantic_router.utils.function_call import get_schema, convert_python_type_to_json_type
import inspect
from typing import Callable, Dict
import re

class OpenAILLM(BaseLLM):
    client: Optional[openai.OpenAI]
    temperature: Optional[float]
    max_tokens: Optional[int]

    def __init__(
        self,
        name: Optional[str] = None,
        openai_api_key: Optional[str] = None,
        temperature: float = 0.01,
        max_tokens: int = 200,
    ):
        if name is None:
            name = EncoderDefault.OPENAI.value["language_model"]
        super().__init__(name=name)
        api_key = openai_api_key or os.getenv("OPENAI_API_KEY")
        if api_key is None:
            raise ValueError("OpenAI API key cannot be 'None'.")
        try:
            self.client = openai.OpenAI(api_key=api_key)
        except Exception as e:
            raise ValueError(
                f"OpenAI API client failed to initialize. Error: {e}"
            ) from e
        self.temperature = temperature
        self.max_tokens = max_tokens

    def _extract_tool_calls_info(self, tool_calls: list[dict[str, Any]]) -> list[dict[str, Any]]:
        tool_calls_info = []
        for tool_call in tool_calls:
            if tool_call.function.arguments is None:
                raise ValueError(
                    "Invalid output, expected arguments to be specified for each tool call."
                )
            tool_calls_info.append({
                "function_name": tool_call.function.name,
                "arguments": json.loads(tool_call.function.arguments)
            })
        return tool_calls_info
    
    def __call__(
        self,
        messages: List[Message],
        function_schemas: Optional[list[dict[str, Any]]] = None,
    ) -> str:
        if self.client is None:
            raise ValueError("OpenAI client is not initialized.")
        try:
            if function_schemas:
                tools = function_schemas
            else:
                tools = NotGiven

            completion = self.client.chat.completions.create(
                model=self.name,
                messages=[m.to_openai() for m in messages],
                temperature=self.temperature,
                max_tokens=self.max_tokens,
                tools=tools,  
            )

            if function_schemas:
                tool_calls = completion.choices[0].message.tool_calls
                if tool_calls is None:
                    raise ValueError("Invalid output, expected a tool call.")
                if len(tool_calls) < 1:
                    raise ValueError(
                        "Invalid output, expected at least one tool to be specified."
                    )
                
                # Collecting multiple tool calls information
                output = self._extract_tool_calls_info(tool_calls)
            else:
                content = completion.choices[0].message.content
                if content is None:
                    raise ValueError("Invalid output, expected content.")
                output = str(content)  # str to keep MyPy happy.

            return output

        except Exception as e:
            logger.error(f"LLM error: {e}")
            raise Exception(f"LLM error: {e}") from e

    def extract_function_inputs(
        self, query: str, function_schemas: list[dict[str, Any]]
    ) -> dict:
        messages = []
        system_prompt = "You are an intelligent AI. Given a command or request from the user, call the function to complete the request."
        messages.append(Message(role="system", content=system_prompt))
        messages.append(Message(role="user", content=query))
        return self(messages=messages, function_schemas=function_schemas)

def get_schemas_openai(items: List[Callable]) -> List[Dict[str, Any]]:
    schemas = []
    for item in items:
        if not callable(item):
            raise ValueError("Provided item must be a callable function.")

        # Use the existing get_schema function to get the basic schema
        basic_schema = get_schema(item)

        # Initialize the function schema with basic details
        function_schema = {
            "name": basic_schema['name'],
            "description": basic_schema['description'],
            "parameters": {"type": "object", "properties": {}, "required": []}
        }

        # Extract parameter details from the signature
        signature = inspect.signature(item)
        docstring = inspect.getdoc(item)
        param_doc_regex = re.compile(r":param (\w+):(.*?)\n(?=:\w|$)", re.S)
        doc_params = param_doc_regex.findall(docstring) if docstring else []

        for param_name, param in signature.parameters.items():
            param_type = param.annotation.__name__ if param.annotation != inspect.Parameter.empty else "Any"
            param_description = "No description available."
            param_required = param.default is inspect.Parameter.empty

            # Find the parameter description in the docstring
            for doc_param_name, doc_param_desc in doc_params:
                if doc_param_name == param_name:
                    param_description = doc_param_desc.strip()
                    break

            function_schema["parameters"]["properties"][param_name] = {
                "type": convert_python_type_to_json_type(param_type),
                "description": param_description
            }

            if param_required:
                function_schema["parameters"]["required"].append(param_name)

        schemas.append({
            "type": "function",
            "function": function_schema
        })

    return schemas