From d63fec1c69a2e1e51bf884a805b9fd31ad8d1ee9 Mon Sep 17 00:00:00 2001 From: Andrei Fajardo <92402603+nerdai@users.noreply.github.com> Date: Sun, 10 Mar 2024 16:16:12 -0400 Subject: [PATCH] Add LogProb type to core.base.llms module (#11795) add LogProb --- llama-index-core/llama_index/core/base/llms/types.py | 11 ++++++++++- 1 file changed, 10 insertions(+), 1 deletion(-) diff --git a/llama-index-core/llama_index/core/base/llms/types.py b/llama-index-core/llama_index/core/base/llms/types.py index d771d75467..012c3c04bd 100644 --- a/llama-index-core/llama_index/core/base/llms/types.py +++ b/llama-index-core/llama_index/core/base/llms/types.py @@ -1,5 +1,5 @@ from enum import Enum -from typing import Any, AsyncGenerator, Generator, Optional, Union +from typing import Any, AsyncGenerator, Generator, Optional, Union, List from llama_index.core.bridge.pydantic import BaseModel, Field from llama_index.core.constants import DEFAULT_CONTEXT_WINDOW, DEFAULT_NUM_OUTPUTS @@ -40,6 +40,14 @@ class ChatMessage(BaseModel): return cls(role=role, content=content, **kwargs) +class LogProb(BaseModel): + """LogProb of a token.""" + + token: str = Field(default_factory=str) + logprob: float = Field(default_factory=float) + bytes: List[int] = Field(default_factory=list) + + # ===== Generic Model Output - Chat ===== class ChatResponse(BaseModel): """Chat response.""" @@ -47,6 +55,7 @@ class ChatResponse(BaseModel): message: ChatMessage raw: Optional[dict] = None delta: Optional[str] = None + logprobs: Optional[List[List[LogProb]]] = None additional_kwargs: dict = Field(default_factory=dict) def __str__(self) -> str: -- GitLab