mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-28 04:04:31 +00:00
fix(utils.py): return logprobs as an object not dict
This commit is contained in:
parent
157dd819f6
commit
fc75fe2d05
2 changed files with 69 additions and 3 deletions
|
@ -1291,6 +1291,7 @@ def test_completion_logprobs_stream():
|
||||||
for chunk in response:
|
for chunk in response:
|
||||||
# check if atleast one chunk has log probs
|
# check if atleast one chunk has log probs
|
||||||
print(chunk)
|
print(chunk)
|
||||||
|
print(f"chunk.choices[0]: {chunk.choices[0]}")
|
||||||
if "logprobs" in chunk.choices[0]:
|
if "logprobs" in chunk.choices[0]:
|
||||||
# assert we got a valid logprob in the choices
|
# assert we got a valid logprob in the choices
|
||||||
assert len(chunk.choices[0].logprobs.content[0].top_logprobs) == 3
|
assert len(chunk.choices[0].logprobs.content[0].top_logprobs) == 3
|
||||||
|
|
|
@ -220,6 +220,61 @@ def map_finish_reason(
|
||||||
return finish_reason
|
return finish_reason
|
||||||
|
|
||||||
|
|
||||||
|
class TopLogprob(OpenAIObject):
|
||||||
|
token: str
|
||||||
|
"""The token."""
|
||||||
|
|
||||||
|
bytes: Optional[List[int]] = None
|
||||||
|
"""A list of integers representing the UTF-8 bytes representation of the token.
|
||||||
|
|
||||||
|
Useful in instances where characters are represented by multiple tokens and
|
||||||
|
their byte representations must be combined to generate the correct text
|
||||||
|
representation. Can be `null` if there is no bytes representation for the token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logprob: float
|
||||||
|
"""The log probability of this token, if it is within the top 20 most likely
|
||||||
|
tokens.
|
||||||
|
|
||||||
|
Otherwise, the value `-9999.0` is used to signify that the token is very
|
||||||
|
unlikely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChatCompletionTokenLogprob(OpenAIObject):
|
||||||
|
token: str
|
||||||
|
"""The token."""
|
||||||
|
|
||||||
|
bytes: Optional[List[int]] = None
|
||||||
|
"""A list of integers representing the UTF-8 bytes representation of the token.
|
||||||
|
|
||||||
|
Useful in instances where characters are represented by multiple tokens and
|
||||||
|
their byte representations must be combined to generate the correct text
|
||||||
|
representation. Can be `null` if there is no bytes representation for the token.
|
||||||
|
"""
|
||||||
|
|
||||||
|
logprob: float
|
||||||
|
"""The log probability of this token, if it is within the top 20 most likely
|
||||||
|
tokens.
|
||||||
|
|
||||||
|
Otherwise, the value `-9999.0` is used to signify that the token is very
|
||||||
|
unlikely.
|
||||||
|
"""
|
||||||
|
|
||||||
|
top_logprobs: List[TopLogprob]
|
||||||
|
"""List of the most likely tokens and their log probability, at this token
|
||||||
|
position.
|
||||||
|
|
||||||
|
In rare cases, there may be fewer than the number of requested `top_logprobs`
|
||||||
|
returned.
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
class ChoiceLogprobs(OpenAIObject):
|
||||||
|
content: Optional[List[ChatCompletionTokenLogprob]] = None
|
||||||
|
"""A list of message content tokens with log probability information."""
|
||||||
|
|
||||||
|
|
||||||
class FunctionCall(OpenAIObject):
|
class FunctionCall(OpenAIObject):
|
||||||
arguments: str
|
arguments: str
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
|
@ -330,7 +385,7 @@ class Message(OpenAIObject):
|
||||||
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
|
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
|
||||||
|
|
||||||
if logprobs is not None:
|
if logprobs is not None:
|
||||||
self._logprobs = logprobs
|
self._logprobs = ChoiceLogprobs(**logprobs)
|
||||||
|
|
||||||
def get(self, key, default=None):
|
def get(self, key, default=None):
|
||||||
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
|
||||||
|
@ -354,11 +409,17 @@ class Message(OpenAIObject):
|
||||||
|
|
||||||
class Delta(OpenAIObject):
|
class Delta(OpenAIObject):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, content=None, role=None, function_call=None, tool_calls=None, **params
|
self,
|
||||||
|
content=None,
|
||||||
|
role=None,
|
||||||
|
function_call=None,
|
||||||
|
tool_calls=None,
|
||||||
|
**params,
|
||||||
):
|
):
|
||||||
super(Delta, self).__init__(**params)
|
super(Delta, self).__init__(**params)
|
||||||
self.content = content
|
self.content = content
|
||||||
self.role = role
|
self.role = role
|
||||||
|
|
||||||
if function_call is not None and isinstance(function_call, dict):
|
if function_call is not None and isinstance(function_call, dict):
|
||||||
self.function_call = FunctionCall(**function_call)
|
self.function_call = FunctionCall(**function_call)
|
||||||
else:
|
else:
|
||||||
|
@ -490,7 +551,11 @@ class StreamingChoices(OpenAIObject):
|
||||||
self.delta = Delta()
|
self.delta = Delta()
|
||||||
if enhancements is not None:
|
if enhancements is not None:
|
||||||
self.enhancements = enhancements
|
self.enhancements = enhancements
|
||||||
self.logprobs = logprobs
|
|
||||||
|
if logprobs is not None and isinstance(logprobs, dict):
|
||||||
|
self.logprobs = ChoiceLogprobs(**logprobs)
|
||||||
|
else:
|
||||||
|
self.logprobs = logprobs # type: ignore
|
||||||
|
|
||||||
def __contains__(self, key):
|
def __contains__(self, key):
|
||||||
# Define custom behavior for the 'in' operator
|
# Define custom behavior for the 'in' operator
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue