diff --git a/semantic_router/schema.py b/semantic_router/schema.py index 8d479ec901866cb2ff5c4d8ceb014b2e5d2a3519..c7912fa1c888f79ebf0b3c32af1836f6948698f4 100644 --- a/semantic_router/schema.py +++ b/semantic_router/schema.py @@ -55,12 +55,12 @@ class Message(BaseModel): role: str content: str - def to_openai(self) -> dict[str, str]: + def to_openai(self): if self.role.lower() not in ["user", "assistant", "system"]: raise ValueError("Role must be either 'user', 'assistant' or 'system'") return {"role": self.role, "content": self.content} - def to_cohere(self) -> dict[str, str]: + def to_cohere(self): return {"role": self.role, "message": self.content} @@ -68,12 +68,12 @@ class Conversation(BaseModel): messages: list[Message] def split_by_topic( - self, - encoder: BaseEncoder, - threshold: float = 0.5, - split_method: Literal[ - "consecutive_similarity_drop", "cumulative_similarity_drop" - ] = "consecutive_similarity_drop", + self, + encoder: BaseEncoder, + threshold: float = 0.5, + split_method: Literal[ + "consecutive_similarity_drop", "cumulative_similarity_drop" + ] = "consecutive_similarity_drop", ) -> dict[str, list[str]]: docs = [f"{m.role}: {m.content}" for m in self.messages] return semantic_splitter(