Merge pull request #3308 from BerriAI/litellm_fix_streaming_n

fix(utils.py): fix the response object returned when n>1 for stream=true
This commit is contained in:
Krish Dholakia 2024-04-25 18:36:54 -07:00 committed by GitHub
commit 69280177a3
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 384 additions and 60 deletions

View file

@ -1,51 +1,18 @@
environment_variables:
SLACK_WEBHOOK_URL: SQD2/FQHvDuj6Q9/Umyqi+EKLNKKLRCXETX2ncO0xCIQp6EHCKiYD7jPW0+1QdrsQ+pnEzhsfVY2r21SiQV901n/9iyJ2tSnEyWViP7FKQVtTvwutsAqSqbiVHxLHbpjPCu03fhS/idjZrtK7dJLbLBB3RgudjNjHg==
general_settings:
alerting:
- slack
alerting_threshold: 300
database_connection_pool_limit: 100
database_connection_timeout: 60
health_check_interval: 300
proxy_batch_write_at: 10
ui_access_mode: all
litellm_settings:
allowed_fails: 3
failure_callback:
- prometheus
fallbacks:
- gpt-3.5-turbo:
- fake-openai-endpoint
- gpt-4
num_retries: 3
service_callback:
- prometheus_system
success_callback:
- prometheus
model_list: model_list:
- litellm_params: - model_name: text-embedding-3-small
api_base: https://openai-function-calling-workers.tasslexyz.workers.dev/
api_key: my-fake-key
model: openai/my-fake-model
model_name: fake-openai-endpoint
- litellm_params:
model: gpt-3.5-turbo
model_name: gpt-3.5-turbo
- model_name: llama-3
litellm_params: litellm_params:
model: replicate/meta/meta-llama-3-8b-instruct model: text-embedding-3-small
router_settings: - model_name: whisper
allowed_fails: 3 litellm_params:
context_window_fallbacks: null model: azure/azure-whisper
cooldown_time: 1 api_version: 2024-02-15-preview
fallbacks: api_base: os.environ/AZURE_EUROPE_API_BASE
- gpt-3.5-turbo: api_key: os.environ/AZURE_EUROPE_API_KEY
- fake-openai-endpoint model_info:
- gpt-4 mode: audio_transcription
- gpt-3.5-turbo-3: - litellm_params:
- fake-openai-endpoint model: gpt-4
num_retries: 3 model_name: gpt-4
retry_after: 0
routing_strategy: simple-shuffle # litellm_settings:
routing_strategy_args: {} # cache: True
timeout: 6000

View file

@ -1291,6 +1291,7 @@ def test_completion_logprobs_stream():
for chunk in response: for chunk in response:
# check if atleast one chunk has log probs # check if atleast one chunk has log probs
print(chunk) print(chunk)
print(f"chunk.choices[0]: {chunk.choices[0]}")
if "logprobs" in chunk.choices[0]: if "logprobs" in chunk.choices[0]:
# assert we got a valid logprob in the choices # assert we got a valid logprob in the choices
assert len(chunk.choices[0].logprobs.content[0].top_logprobs) == 3 assert len(chunk.choices[0].logprobs.content[0].top_logprobs) == 3

View file

@ -2446,6 +2446,34 @@ class ModelResponseIterator:
return self.model_response return self.model_response
class ModelResponseListIterator:
def __init__(self, model_responses):
self.model_responses = model_responses
self.index = 0
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.index >= len(self.model_responses):
raise StopIteration
model_response = self.model_responses[self.index]
self.index += 1
return model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.index >= len(self.model_responses):
raise StopAsyncIteration
model_response = self.model_responses[self.index]
self.index += 1
return model_response
def test_unit_test_custom_stream_wrapper(): def test_unit_test_custom_stream_wrapper():
""" """
Test if last streaming chunk ends with '?', if the message repeats itself. Test if last streaming chunk ends with '?', if the message repeats itself.
@ -2486,3 +2514,259 @@ def test_unit_test_custom_stream_wrapper():
if "How are you?" in chunk.choices[0].delta.content: if "How are you?" in chunk.choices[0].delta.content:
freq += 1 freq += 1
assert freq == 1 assert freq == 1
chunks = [
{
"id": "chatcmpl-9HzZIMCtVq7CbTmdwEZrktiTeoiYe",
"object": "chat.completion.chunk",
"created": 1714075272,
"model": "gpt-4-0613",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"delta": {"content": "It"},
"logprobs": {
"content": [
{
"token": "It",
"logprob": -1.5952516,
"bytes": [73, 116],
"top_logprobs": [
{
"token": "Brown",
"logprob": -0.7358765,
"bytes": [66, 114, 111, 119, 110],
}
],
}
]
},
"finish_reason": None,
}
],
},
{
"id": "chatcmpl-9HzZIMCtVq7CbTmdwEZrktiTeoiYe",
"object": "chat.completion.chunk",
"created": 1714075272,
"model": "gpt-4-0613",
"system_fingerprint": None,
"choices": [
{
"index": 1,
"delta": {"content": "Brown"},
"logprobs": {
"content": [
{
"token": "Brown",
"logprob": -0.7358765,
"bytes": [66, 114, 111, 119, 110],
"top_logprobs": [
{
"token": "Brown",
"logprob": -0.7358765,
"bytes": [66, 114, 111, 119, 110],
}
],
}
]
},
"finish_reason": None,
}
],
},
{
"id": "chatcmpl-9HzZIMCtVq7CbTmdwEZrktiTeoiYe",
"object": "chat.completion.chunk",
"created": 1714075272,
"model": "gpt-4-0613",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"delta": {"content": "'s"},
"logprobs": {
"content": [
{
"token": "'s",
"logprob": -0.006786893,
"bytes": [39, 115],
"top_logprobs": [
{
"token": "'s",
"logprob": -0.006786893,
"bytes": [39, 115],
}
],
}
]
},
"finish_reason": None,
}
],
},
{
"id": "chatcmpl-9HzZIMCtVq7CbTmdwEZrktiTeoiYe",
"object": "chat.completion.chunk",
"created": 1714075272,
"model": "gpt-4-0613",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"delta": {"content": " impossible"},
"logprobs": {
"content": [
{
"token": " impossible",
"logprob": -0.06528423,
"bytes": [
32,
105,
109,
112,
111,
115,
115,
105,
98,
108,
101,
],
"top_logprobs": [
{
"token": " impossible",
"logprob": -0.06528423,
"bytes": [
32,
105,
109,
112,
111,
115,
115,
105,
98,
108,
101,
],
}
],
}
]
},
"finish_reason": None,
}
],
},
{
"id": "chatcmpl-9HzZIMCtVq7CbTmdwEZrktiTeoiYe",
"object": "chat.completion.chunk",
"created": 1714075272,
"model": "gpt-4-0613",
"system_fingerprint": None,
"choices": [
{
"index": 0,
"delta": {"content": "—even"},
"logprobs": {
"content": [
{
"token": "—even",
"logprob": -9999.0,
"bytes": [226, 128, 148, 101, 118, 101, 110],
"top_logprobs": [
{
"token": " to",
"logprob": -0.12302828,
"bytes": [32, 116, 111],
}
],
}
]
},
"finish_reason": None,
}
],
},
{
"id": "chatcmpl-9HzZIMCtVq7CbTmdwEZrktiTeoiYe",
"object": "chat.completion.chunk",
"created": 1714075272,
"model": "gpt-4-0613",
"system_fingerprint": None,
"choices": [
{"index": 0, "delta": {}, "logprobs": None, "finish_reason": "length"}
],
},
{
"id": "chatcmpl-9HzZIMCtVq7CbTmdwEZrktiTeoiYe",
"object": "chat.completion.chunk",
"created": 1714075272,
"model": "gpt-4-0613",
"system_fingerprint": None,
"choices": [
{"index": 1, "delta": {}, "logprobs": None, "finish_reason": "stop"}
],
},
]
def test_unit_test_custom_stream_wrapper_n():
"""
Test if the translated output maps exactly to the received openai input
Relevant issue: https://github.com/BerriAI/litellm/issues/3276
"""
litellm.set_verbose = False
chunk_list = []
for chunk in chunks:
_chunk = litellm.ModelResponse(**chunk, stream=True)
chunk_list.append(_chunk)
completion_stream = ModelResponseListIterator(model_responses=chunk_list)
response = litellm.CustomStreamWrapper(
completion_stream=completion_stream,
model="gpt-4-0613",
custom_llm_provider="cached_response",
logging_obj=litellm.Logging(
model="gpt-4-0613",
messages=[{"role": "user", "content": "Hey"}],
stream=True,
call_type="completion",
start_time=time.time(),
litellm_call_id="12345",
function_id="1245",
),
)
for idx, chunk in enumerate(response):
chunk_dict = {}
try:
chunk_dict = chunk.model_dump(exclude_none=True)
except:
chunk_dict = chunk.dict(exclude_none=True)
chunk_dict.pop("created")
chunks[idx].pop("created")
if chunks[idx]["system_fingerprint"] is None:
chunks[idx].pop("system_fingerprint", None)
if idx == 0:
for choice in chunk_dict["choices"]:
if "role" in choice["delta"]:
choice["delta"].pop("role")
for choice in chunks[idx]["choices"]:
# ignore finish reason None - since our pydantic object is set to exclude_none = true
if "finish_reason" in choice and choice["finish_reason"] is None:
choice.pop("finish_reason")
if "logprobs" in choice and choice["logprobs"] is None:
choice.pop("logprobs")
assert (
chunk_dict == chunks[idx]
), f"idx={idx} translated chunk = {chunk_dict} != openai chunk = {chunks[idx]}"

View file

@ -19,6 +19,7 @@ from functools import wraps
import datetime, time import datetime, time
import tiktoken import tiktoken
import uuid import uuid
from pydantic import BaseModel
import aiohttp import aiohttp
import textwrap import textwrap
import logging import logging
@ -219,6 +220,61 @@ def map_finish_reason(
return finish_reason return finish_reason
class TopLogprob(OpenAIObject):
token: str
"""The token."""
bytes: Optional[List[int]] = None
"""A list of integers representing the UTF-8 bytes representation of the token.
Useful in instances where characters are represented by multiple tokens and
their byte representations must be combined to generate the correct text
representation. Can be `null` if there is no bytes representation for the token.
"""
logprob: float
"""The log probability of this token, if it is within the top 20 most likely
tokens.
Otherwise, the value `-9999.0` is used to signify that the token is very
unlikely.
"""
class ChatCompletionTokenLogprob(OpenAIObject):
token: str
"""The token."""
bytes: Optional[List[int]] = None
"""A list of integers representing the UTF-8 bytes representation of the token.
Useful in instances where characters are represented by multiple tokens and
their byte representations must be combined to generate the correct text
representation. Can be `null` if there is no bytes representation for the token.
"""
logprob: float
"""The log probability of this token, if it is within the top 20 most likely
tokens.
Otherwise, the value `-9999.0` is used to signify that the token is very
unlikely.
"""
top_logprobs: List[TopLogprob]
"""List of the most likely tokens and their log probability, at this token
position.
In rare cases, there may be fewer than the number of requested `top_logprobs`
returned.
"""
class ChoiceLogprobs(OpenAIObject):
content: Optional[List[ChatCompletionTokenLogprob]] = None
"""A list of message content tokens with log probability information."""
class FunctionCall(OpenAIObject): class FunctionCall(OpenAIObject):
arguments: str arguments: str
name: Optional[str] = None name: Optional[str] = None
@ -329,7 +385,7 @@ class Message(OpenAIObject):
self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call)) self.tool_calls.append(ChatCompletionMessageToolCall(**tool_call))
if logprobs is not None: if logprobs is not None:
self._logprobs = logprobs self._logprobs = ChoiceLogprobs(**logprobs)
def get(self, key, default=None): def get(self, key, default=None):
# Custom .get() method to access attributes with a default value if the attribute doesn't exist # Custom .get() method to access attributes with a default value if the attribute doesn't exist
@ -353,11 +409,17 @@ class Message(OpenAIObject):
class Delta(OpenAIObject): class Delta(OpenAIObject):
def __init__( def __init__(
self, content=None, role=None, function_call=None, tool_calls=None, **params self,
content=None,
role=None,
function_call=None,
tool_calls=None,
**params,
): ):
super(Delta, self).__init__(**params) super(Delta, self).__init__(**params)
self.content = content self.content = content
self.role = role self.role = role
if function_call is not None and isinstance(function_call, dict): if function_call is not None and isinstance(function_call, dict):
self.function_call = FunctionCall(**function_call) self.function_call = FunctionCall(**function_call)
else: else:
@ -489,7 +551,11 @@ class StreamingChoices(OpenAIObject):
self.delta = Delta() self.delta = Delta()
if enhancements is not None: if enhancements is not None:
self.enhancements = enhancements self.enhancements = enhancements
self.logprobs = logprobs
if logprobs is not None and isinstance(logprobs, dict):
self.logprobs = ChoiceLogprobs(**logprobs)
else:
self.logprobs = logprobs # type: ignore
def __contains__(self, key): def __contains__(self, key):
# Define custom behavior for the 'in' operator # Define custom behavior for the 'in' operator
@ -10122,12 +10188,18 @@ class CustomStreamWrapper:
model_response.id = original_chunk.id model_response.id = original_chunk.id
self.response_id = original_chunk.id self.response_id = original_chunk.id
if len(original_chunk.choices) > 0: if len(original_chunk.choices) > 0:
choices = []
for idx, choice in enumerate(original_chunk.choices):
try: try:
delta = dict(original_chunk.choices[0].delta) if isinstance(choice, BaseModel):
print_verbose(f"original delta: {delta}") choice_json = choice.model_dump()
model_response.choices[0].delta = Delta(**delta) choice_json.pop(
"finish_reason", None
) # for mistral etc. which return a value in their last chunk (not-openai compatible).
choices.append(StreamingChoices(**choice_json))
except Exception as e: except Exception as e:
model_response.choices[0].delta = Delta() choices.append(StreamingChoices())
model_response.choices = choices
else: else:
return return
model_response.system_fingerprint = ( model_response.system_fingerprint = (
@ -10172,11 +10244,11 @@ class CustomStreamWrapper:
) )
self.holding_chunk = "" self.holding_chunk = ""
# if delta is None # if delta is None
is_delta_empty = self.is_delta_empty( _is_delta_empty = self.is_delta_empty(
delta=model_response.choices[0].delta delta=model_response.choices[0].delta
) )
if is_delta_empty: if _is_delta_empty:
# get any function call arguments # get any function call arguments
model_response.choices[0].finish_reason = map_finish_reason( model_response.choices[0].finish_reason = map_finish_reason(
finish_reason=self.received_finish_reason finish_reason=self.received_finish_reason