Skip to content
Snippets Groups Projects
Unverified Commit 5a18f786 authored by Andrei Fajardo's avatar Andrei Fajardo Committed by GitHub
Browse files

Add LogProb llm type and OpenAI logprobs (#11793)

* add logprob

* typo

* add logprobs to completion api

* bump version
parent 8cf92982
Branches
Tags
No related merge requests found
......@@ -48,6 +48,8 @@ from llama_index.core.types import BaseOutputParser, PydanticProgramMode
from llama_index.llms.openai.utils import (
create_retry_decorator,
from_openai_message,
from_openai_token_logprobs,
from_openai_completion_logprobs,
is_chat_model,
is_function_calling_model,
openai_modelname_to_contextsize,
......@@ -96,6 +98,15 @@ class OpenAI(LLM):
description="The maximum number of tokens to generate.",
gt=0,
)
logprobs: Optional[bool] = Field(
description="Whether to return logprobs per token."
)
top_logprobs: int = Field(
description="The number of top token log probs to return.",
default=0,
gte=0,
lte=20,
)
additional_kwargs: Dict[str, Any] = Field(
default_factory=dict, description="Additional kwargs for the OpenAI API."
)
......@@ -297,6 +308,12 @@ class OpenAI(LLM):
# https://platform.openai.com/docs/api-reference/chat
# https://platform.openai.com/docs/api-reference/completions
base_kwargs["max_tokens"] = self.max_tokens
if self.logprobs is not None and self.logprobs is True:
if self.metadata.is_chat_model:
base_kwargs["logprobs"] = self.logprobs
base_kwargs["top_logprobs"] = self.top_logprobs
else:
base_kwargs["logprobs"] = self.top_logprobs # int in this case
return {**base_kwargs, **self.additional_kwargs}
@llm_retry_decorator
......@@ -310,10 +327,15 @@ class OpenAI(LLM):
)
openai_message = response.choices[0].message
message = from_openai_message(openai_message)
openai_token_logprobs = response.choices[0].logprobs
logprobs = None
if openai_token_logprobs:
logprobs = from_openai_token_logprobs(openai_token_logprobs.content)
return ChatResponse(
message=message,
raw=response,
logprobs=logprobs,
additional_kwargs=self._get_response_token_counts(response),
)
......@@ -428,9 +450,16 @@ class OpenAI(LLM):
**all_kwargs,
)
text = response.choices[0].text
openai_completion_logprobs = response.choices[0].logprobs
logprobs = None
if openai_completion_logprobs:
logprobs = from_openai_completion_logprobs(openai_completion_logprobs)
return CompletionResponse(
text=text,
raw=response,
logprobs=logprobs,
additional_kwargs=self._get_response_token_counts(response),
)
......@@ -554,6 +583,7 @@ class OpenAI(LLM):
)
message_dict = response.choices[0].message
message = from_openai_message(message_dict)
logprobs_dict = response.choices[0].logprobs
return ChatResponse(
message=message,
......
......@@ -3,7 +3,7 @@ import os
from typing import Any, Callable, Dict, List, Optional, Sequence, Tuple, Type, Union
from deprecated import deprecated
from llama_index.core.base.llms.types import ChatMessage
from llama_index.core.base.llms.types import ChatMessage, LogProb, CompletionResponse
from llama_index.core.bridge.pydantic import BaseModel
from llama_index.core.base.llms.generic_utils import get_from_param_or_env
from tenacity import (
......@@ -21,6 +21,9 @@ import openai
from openai.types.chat import ChatCompletionMessageParam, ChatCompletionMessageToolCall
from openai.types.chat.chat_completion_chunk import ChoiceDeltaToolCall
from openai.types.chat.chat_completion_message import ChatCompletionMessage
from openai.types.chat.chat_completion_token_logprob import ChatCompletionTokenLogprob
from openai.types.completion_choice import Logprobs
from openai.types.completion import Completion
DEFAULT_OPENAI_API_TYPE = "open_ai"
DEFAULT_OPENAI_API_BASE = "https://api.openai.com/v1"
......@@ -260,6 +263,56 @@ def from_openai_message(openai_message: ChatCompletionMessage) -> ChatMessage:
return ChatMessage(role=role, content=content, additional_kwargs=additional_kwargs)
def from_openai_token_logprob(
openai_token_logprob: ChatCompletionTokenLogprob,
) -> List[LogProb]:
"""Convert a single openai token logprob to generic list of logprobs."""
try:
result = [
LogProb(token=el.token, logprob=el.logprob, bytes=el.bytes or [])
for el in openai_token_logprob.top_logprobs
]
except Exception as e:
print(openai_token_logprob)
raise
return result
def from_openai_token_logprobs(
openai_token_logprobs: Sequence[ChatCompletionTokenLogprob],
) -> List[List[LogProb]]:
"""Convert openai token logprobs to generic list of LogProb."""
return [
from_openai_token_logprob(token_logprob)
for token_logprob in openai_token_logprobs
]
def from_openai_completion_logprob(
openai_completion_logprob: Dict[str, float]
) -> List[LogProb]:
"""Convert openai completion logprobs to generic list of LogProb."""
return [
LogProb(token=t, logprob=v, bytes=[])
for t, v in openai_completion_logprob.items()
]
def from_openai_completion_logprobs(
openai_completion_logprobs: Logprobs,
) -> List[List[LogProb]]:
"""Convert openai completion logprobs to generic list of LogProb."""
return [
from_openai_completion_logprob(completion_logprob)
for completion_logprob in openai_completion_logprobs.top_logprobs
]
def from_openai_completion(openai_completion: Completion) -> CompletionResponse:
"""Convert openai completion to CompletionResponse."""
text = openai_completion.choices[0].text
def from_openai_messages(
openai_messages: Sequence[ChatCompletionMessage],
) -> List[ChatMessage]:
......
......@@ -29,11 +29,11 @@ exclude = ["**/BUILD"]
license = "MIT"
name = "llama-index-llms-openai"
readme = "README.md"
version = "0.1.8"
version = "0.1.9"
[tool.poetry.dependencies]
python = ">=3.8.1,<4.0"
llama-index-core = "^0.10.1"
llama-index-core = "^0.10.19"
[tool.poetry.group.dev.dependencies]
ipython = "8.10.0"
......
0% Loading or .
You are about to add 0 people to the discussion. Proceed with caution.
Please register or to comment