diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py index d771d754674933201f7b3ae7d1d8b8900433c79d..012c3c04bd12cae1420b9272339b5e4b6eebf28a 100644 --- a/llama-index-core/llama_index/core/base/llms/types.py +++ b/llama-index-core/llama_index/core/base/llms/types.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, AsyncGenerator, Generator, Optional, Union +from typing import Any, AsyncGenerator, Generator, Optional, Union, List from llama_index.core.bridge.pydantic import BaseModel, Field from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS @@ -40,6 +40,14 @@ class ChatMessage(BaseModel): return cls(role=role, content=content, **kwargs) +class LogProb(BaseModel): + """LogProb of a token.""" + + token: str = Field(default_factory=str) + logprob: float = Field(default_factory=float) + bytes: List[int] = Field(default_factory=list) + + # ===== Generic Model Output - Chat ===== class ChatResponse(BaseModel): """Chat response.""" @@ -47,6 +55,7 @@ class ChatResponse(BaseModel): message: ChatMessage raw: Optional[dict] = None delta: Optional[str] = None + logprobs: Optional[List[List[LogProb]]] = None additional_kwargs: dict = Field(default_factory=dict) def __str__(self) -> str: