diff --git a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py index 598e29bad6fd468215c3815421546cd46c088fb9..a23d015090d4cb75ec9bdd24cef4bec267e84b4e 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/base.py @@ -48,6 +48,8 @@ from llama_index.core.types import BaseOutputParser, PydanticProgramMode from llama_index.llms.openai.utils import ( create_retry_decorator, from_openai_message, + from_openai_token_logprobs, + from_openai_completion_logprobs, is_chat_model, is_function_calling_model, openai_modelname_to_contextsize, @@ -96,6 +98,15 @@ class OpenAI(LLM): description="The maximum number of tokens to generate.", gt=0, ) + logprobs: Optional[bool] = Field( + description="Whether to return logprobs per token." + ) + top_logprobs: int = Field( + description="The number of top token log probs to return.", + default=0, + gte=0, + lte=20, + ) additional_kwargs: Dict[str, Any] = Field( default_factory=dict, description="Additional kwargs for the OpenAI API." ) @@ -297,6 +308,12 @@ class OpenAI(LLM): # https://platform.openai.com/docs/api-reference/chat # https://platform.openai.com/docs/api-reference/completions base_kwargs["max_tokens"] = self.max_tokens + if self.logprobs is not None and self.logprobs is True: + if self.metadata.is_chat_model: + base_kwargs["logprobs"] = self.logprobs + base_kwargs["top_logprobs"] = self.top_logprobs + else: + base_kwargs["logprobs"] = self.top_logprobs # int in this case return {**base_kwargs, **self.additional_kwargs} @llm_retry_decorator @@ -310,10 +327,15 @@ class OpenAI(LLM): ) openai_message = response.choices[0].message message = from_openai_message(openai_message) + openai_token_logprobs = response.choices[0].logprobs + logprobs = None + if openai_token_logprobs: + logprobs = from_openai_token_logprobs(openai_token_logprobs.content) return ChatResponse( message=message, raw=response, + logprobs=logprobs, additional_kwargs=self._get_response_token_counts(response), ) @@ -428,9 +450,16 @@ class OpenAI(LLM): **all_kwargs, ) text = response.choices[0].text + + openai_completion_logprobs = response.choices[0].logprobs + logprobs = None + if openai_completion_logprobs: + logprobs = from_openai_completion_logprobs(openai_completion_logprobs) + return CompletionResponse( text=text, raw=response, + logprobs=logprobs, additional_kwargs=self._get_response_token_counts(response), ) @@ -554,6 +583,7 @@ class OpenAI(LLM): ) message_dict = response.choices[0].message message = from_openai_message(message_dict) + logprobs_dict = response.choices[0].logprobs return ChatResponse( message=message, diff --git a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/utils.py b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/utils.py index 2adeee0082737fd92f1c418024e372d513b83b19..dc6025ceb2227e64f395ec76bcc533dae526ad71 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/utils.py +++ b/llama-index-integrations/llms/llama-index-llms-openai/llama_index/llms/openai/utils.py @@ -3,7 +3,7 @@ import os from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union from deprecated import deprecated -from llama_index.core.base.llms.types import ChatMessage +from llama_index.core.base.llms.types import ChatMessage, LogProb, CompletionResponse from llama_index.core.bridge.pydantic import BaseModel from llama_index.core.base.llms.generic_utils import get_from_param_or_env from tenacity import ( @@ -21,6 +21,9 @@ import openai from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessageToolCall from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall from openai.types.chat.chat_completion_message import ChatCompletionMessage +from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob +from openai.types.completion_choice import Logprobs +from openai.types.completion import Completion DEFAULT_OPENAI_API_TYPE = "open_ai" DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1" @@ -260,6 +263,56 @@ def from_openai_message(openai_message: ChatCompletionMessage) -> ChatMessage: return ChatMessage(role=role, content=content, additional_kwargs=additional_kwargs) +def from_openai_token_logprob( + openai_token_logprob: ChatCompletionTokenLogprob, +) -> List[LogProb]: + """Convert a single openai token logprob to generic list of logprobs.""" + try: + result = [ + LogProb(token=el.token, logprob=el.logprob, bytes=el.bytes or []) + for el in openai_token_logprob.top_logprobs + ] + except Exception as e: + print(openai_token_logprob) + raise + return result + + +def from_openai_token_logprobs( + openai_token_logprobs: Sequence[ChatCompletionTokenLogprob], +) -> List[List[LogProb]]: + """Convert openai token logprobs to generic list of LogProb.""" + return [ + from_openai_token_logprob(token_logprob) + for token_logprob in openai_token_logprobs + ] + + +def from_openai_completion_logprob( + openai_completion_logprob: Dict[str, float] +) -> List[LogProb]: + """Convert openai completion logprobs to generic list of LogProb.""" + return [ + LogProb(token=t, logprob=v, bytes=[]) + for t, v in openai_completion_logprob.items() + ] + + +def from_openai_completion_logprobs( + openai_completion_logprobs: Logprobs, +) -> List[List[LogProb]]: + """Convert openai completion logprobs to generic list of LogProb.""" + return [ + from_openai_completion_logprob(completion_logprob) + for completion_logprob in openai_completion_logprobs.top_logprobs + ] + + +def from_openai_completion(openai_completion: Completion) -> CompletionResponse: + """Convert openai completion to CompletionResponse.""" + text = openai_completion.choices[0].text + + def from_openai_messages( openai_messages: Sequence[ChatCompletionMessage], ) -> List[ChatMessage]: diff --git a/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml b/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml index 3612f5ca70078abe0852ecb50d128f459e181043..4a08f91b1cf269e8a02a5a55e6d90ce95d40a1f0 100644 --- a/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml +++ b/llama-index-integrations/llms/llama-index-llms-openai/pyproject.toml @@ -29,11 +29,11 @@ exclude = ["**/BUILD"] license = "MIT" name = "llama-index-llms-openai" readme = "README.md" -version = "0.1.8" +version = "0.1.9" [tool.poetry.dependencies] python = ">=3.8.1,<4.0" -llama-index-core = "^0.10.1" +llama-index-core = "^0.10.19" [tool.poetry.group.dev.dependencies] ipython = "8.10.0"