Merge pull request #3308 from BerriAI/litellm_fix_streaming_n

fix(utils.py): fix the response object returned when n>1 for stream=true
This commit is contained in:
Krish Dholakia 2024-04-25 18:36:54 -07:00 committed by GitHub
commit 69280177a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 384 additions and 60 deletions

View file

@ -19,6 +19,7 @@ from functools import wraps
import datetime, time
import tiktoken
import uuid
from pydantic import BaseModel
import aiohttp
import textwrap
import logging
@ -219,6 +220,61 @@ def map_finish_reason(
return finish_reason
class TopLogprob(OpenAIObject):
token: str
"""The token."""
bytes: Optional[List[int]] = None
"""A list of integers representing the UTF-8 bytes representation of the token.
Useful in instances where characters are represented by multiple tokens and
their byte representations must be combined to generate the correct text
representation. Can be `null` if there is no bytes representation for the token.
"""
logprob: float
"""The log probability of this token, if it is within the top 20 most likely
tokens.
Otherwise, the value `-9999.0` is used to signify that the token is very
unlikely.
"""
class ChatCompletionTokenLogprob(OpenAIObject):
token: str
"""The token."""
bytes: Optional[List[int]] = None
"""A list of integers representing the UTF-8 bytes representation of the token.
Useful in instances where characters are represented by multiple tokens and
their byte representations must be combined to generate the correct text
representation. Can be `null` if there is no bytes representation for the token.
"""
logprob: float
"""The log probability of this token, if it is within the top 20 most likely
tokens.
Otherwise, the value `-9999.0` is used to signify that the token is very
unlikely.
"""
top_logprobs: List[TopLogprob]
"""List of the most likely tokens and their log probability, at this token
position.
In rare cases, there may be fewer than the number of requested `top_logprobs`
returned.
"""
class ChoiceLogprobs(OpenAIObject):
content: Optional[List[ChatCompletionTokenLogprob]] = None
"""A list of message content tokens with log probability information."""
class FunctionCall(OpenAIObject):
arguments: str
name: Optional[str] = None
@ -329,7 +385,7 @@ class Message(OpenAIObject):
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
if logprobs is not None:
self._logprobs = logprobs
self._logprobs = ChoiceLogprobs(**logprobs)
def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist
@ -353,11 +409,17 @@ class Message(OpenAIObject):
class Delta(OpenAIObject):
def __init__(
self, content=None, role=None, function_call=None, tool_calls=None, **params
self,
content=None,
role=None,
function_call=None,
tool_calls=None,
**params,
):
super(Delta, self).__init__(**params)
self.content = content
self.role = role
if function_call is not None and isinstance(function_call, dict):
self.function_call = FunctionCall(**function_call)
else:
@ -489,7 +551,11 @@ class StreamingChoices(OpenAIObject):
self.delta = Delta()
if enhancements is not None:
self.enhancements = enhancements
self.logprobs = logprobs
if logprobs is not None and isinstance(logprobs, dict):
self.logprobs = ChoiceLogprobs(**logprobs)
else:
self.logprobs = logprobs # type: ignore
def __contains__(self, key):
# Define custom behavior for the 'in' operator
@ -10122,12 +10188,18 @@ class CustomStreamWrapper:
model_response.id = original_chunk.id
self.response_id = original_chunk.id
if len(original_chunk.choices) > 0:
try:
delta = dict(original_chunk.choices[0].delta)
print_verbose(f"original delta: {delta}")
model_response.choices[0].delta = Delta(**delta)
except Exception as e:
model_response.choices[0].delta = Delta()
choices = []
for idx, choice in enumerate(original_chunk.choices):
try:
if isinstance(choice, BaseModel):
choice_json = choice.model_dump()
choice_json.pop(
"finish_reason", None
) # for mistral etc. which return a value in their last chunk (not-openai compatible).
choices.append(StreamingChoices(**choice_json))
except Exception as e:
choices.append(StreamingChoices())
model_response.choices = choices
else:
return
model_response.system_fingerprint = (
@ -10172,11 +10244,11 @@ class CustomStreamWrapper:
)
self.holding_chunk = ""
# if delta is None
is_delta_empty = self.is_delta_empty(
_is_delta_empty = self.is_delta_empty(
delta=model_response.choices[0].delta
)
if is_delta_empty:
if _is_delta_empty:
# get any function call arguments
model_response.choices[0].finish_reason = map_finish_reason(
finish_reason=self.received_finish_reason