forked from phoenix/litellm-mirror
fix(utils.py): Break out of infinite streaming loop
Fixes https://github.com/BerriAI/litellm/issues/5158
This commit is contained in:
parent
d0a68ab123
commit
fdd9a07051
4 changed files with 190 additions and 29 deletions
|
@ -1,3 +1,6 @@
|
||||||
|
import Tabs from '@theme/Tabs';
|
||||||
|
import TabItem from '@theme/TabItem';
|
||||||
|
|
||||||
# Streaming + Async
|
# Streaming + Async
|
||||||
|
|
||||||
- [Streaming Responses](#streaming-responses)
|
- [Streaming Responses](#streaming-responses)
|
||||||
|
@ -74,3 +77,72 @@ async def completion_call():
|
||||||
|
|
||||||
asyncio.run(completion_call())
|
asyncio.run(completion_call())
|
||||||
```
|
```
|
||||||
|
|
||||||
|
## Error Handling - Infinite Loops
|
||||||
|
|
||||||
|
Sometimes a model might enter an infinite loop, and keep repeating the same chunks - [e.g. issue](https://github.com/BerriAI/litellm/issues/5158)
|
||||||
|
|
||||||
|
Break out of it with:
|
||||||
|
|
||||||
|
```python
|
||||||
|
litellm.REPEATED_STREAMING_CHUNK_LIMIT = 100 # # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
|
||||||
|
```
|
||||||
|
|
||||||
|
LiteLLM provides error handling for this, by checking if a chunk is repeated 'n' times (Default is 100). If it exceeds that limit, it will raise a `litellm.InternalServerError`, to allow retry logic to happen.
|
||||||
|
|
||||||
|
<Tabs>
|
||||||
|
<TabItem value="sdk" label="SDK">
|
||||||
|
|
||||||
|
```python
|
||||||
|
import litellm
|
||||||
|
import os
|
||||||
|
|
||||||
|
litellm.set_verbose = False
|
||||||
|
loop_amount = litellm.REPEATED_STREAMING_CHUNK_LIMIT + 1
|
||||||
|
chunks = [
|
||||||
|
litellm.ModelResponse(**{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1694268190,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"system_fingerprint": "fp_44709d6fcb",
|
||||||
|
"choices": [
|
||||||
|
{"index": 0, "delta": {"content": "How are you?"}, "finish_reason": "stop"}
|
||||||
|
],
|
||||||
|
}, stream=True)
|
||||||
|
] * loop_amount
|
||||||
|
completion_stream = litellm.ModelResponseListIterator(model_responses=chunks)
|
||||||
|
|
||||||
|
response = litellm.CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=litellm.Logging(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey"}],
|
||||||
|
stream=True,
|
||||||
|
call_type="completion",
|
||||||
|
start_time=time.time(),
|
||||||
|
litellm_call_id="12345",
|
||||||
|
function_id="1245",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
for chunk in response:
|
||||||
|
continue # expect to raise InternalServerError
|
||||||
|
```
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
<TabItem value="proxy" label="PROXY">
|
||||||
|
|
||||||
|
Define this on your config.yaml on the proxy.
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
litellm_settings:
|
||||||
|
REPEATED_STREAMING_CHUNK_LIMIT: 100 # this overrides the litellm default
|
||||||
|
```
|
||||||
|
|
||||||
|
The proxy uses the litellm SDK. To validate this works, try the 'SDK' code snippet.
|
||||||
|
|
||||||
|
</TabItem>
|
||||||
|
</Tabs>
|
|
@ -267,6 +267,7 @@ max_end_user_budget: Optional[float] = None
|
||||||
#### REQUEST PRIORITIZATION ####
|
#### REQUEST PRIORITIZATION ####
|
||||||
priority_reservation: Optional[Dict[str, float]] = None
|
priority_reservation: Optional[Dict[str, float]] = None
|
||||||
#### RELIABILITY ####
|
#### RELIABILITY ####
|
||||||
|
REPEATED_STREAMING_CHUNK_LIMIT = 100 # catch if model starts looping the same chunk while streaming. Uses high default to prevent false positives.
|
||||||
request_timeout: float = 6000
|
request_timeout: float = 6000
|
||||||
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
module_level_aclient = AsyncHTTPHandler(timeout=request_timeout)
|
||||||
module_level_client = HTTPHandler(timeout=request_timeout)
|
module_level_client = HTTPHandler(timeout=request_timeout)
|
||||||
|
@ -827,6 +828,7 @@ from .utils import (
|
||||||
TranscriptionResponse,
|
TranscriptionResponse,
|
||||||
TextCompletionResponse,
|
TextCompletionResponse,
|
||||||
get_provider_fields,
|
get_provider_fields,
|
||||||
|
ModelResponseListIterator,
|
||||||
)
|
)
|
||||||
|
|
||||||
ALL_LITELLM_RESPONSE_TYPES = [
|
ALL_LITELLM_RESPONSE_TYPES = [
|
||||||
|
|
|
@ -15,6 +15,7 @@ from pydantic import BaseModel
|
||||||
|
|
||||||
import litellm.litellm_core_utils
|
import litellm.litellm_core_utils
|
||||||
import litellm.litellm_core_utils.litellm_logging
|
import litellm.litellm_core_utils.litellm_logging
|
||||||
|
from litellm.utils import ModelResponseListIterator
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -3201,34 +3202,6 @@ class ModelResponseIterator:
|
||||||
return self.model_response
|
return self.model_response
|
||||||
|
|
||||||
|
|
||||||
class ModelResponseListIterator:
|
|
||||||
def __init__(self, model_responses):
|
|
||||||
self.model_responses = model_responses
|
|
||||||
self.index = 0
|
|
||||||
|
|
||||||
# Sync iterator
|
|
||||||
def __iter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
def __next__(self):
|
|
||||||
if self.index >= len(self.model_responses):
|
|
||||||
raise StopIteration
|
|
||||||
model_response = self.model_responses[self.index]
|
|
||||||
self.index += 1
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
# Async iterator
|
|
||||||
def __aiter__(self):
|
|
||||||
return self
|
|
||||||
|
|
||||||
async def __anext__(self):
|
|
||||||
if self.index >= len(self.model_responses):
|
|
||||||
raise StopAsyncIteration
|
|
||||||
model_response = self.model_responses[self.index]
|
|
||||||
self.index += 1
|
|
||||||
return model_response
|
|
||||||
|
|
||||||
|
|
||||||
def test_unit_test_custom_stream_wrapper():
|
def test_unit_test_custom_stream_wrapper():
|
||||||
"""
|
"""
|
||||||
Test if last streaming chunk ends with '?', if the message repeats itself.
|
Test if last streaming chunk ends with '?', if the message repeats itself.
|
||||||
|
@ -3271,6 +3244,65 @@ def test_unit_test_custom_stream_wrapper():
|
||||||
assert freq == 1
|
assert freq == 1
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"loop_amount",
|
||||||
|
[
|
||||||
|
litellm.REPEATED_STREAMING_CHUNK_LIMIT + 1,
|
||||||
|
litellm.REPEATED_STREAMING_CHUNK_LIMIT - 1,
|
||||||
|
],
|
||||||
|
)
|
||||||
|
def test_unit_test_custom_stream_wrapper_repeating_chunk(loop_amount):
|
||||||
|
"""
|
||||||
|
Test if InternalServerError raised if model enters infinite loop
|
||||||
|
|
||||||
|
Test if request passes if model loop is below accepted limit
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = False
|
||||||
|
chunks = [
|
||||||
|
litellm.ModelResponse(
|
||||||
|
**{
|
||||||
|
"id": "chatcmpl-123",
|
||||||
|
"object": "chat.completion.chunk",
|
||||||
|
"created": 1694268190,
|
||||||
|
"model": "gpt-3.5-turbo-0125",
|
||||||
|
"system_fingerprint": "fp_44709d6fcb",
|
||||||
|
"choices": [
|
||||||
|
{
|
||||||
|
"index": 0,
|
||||||
|
"delta": {"content": "How are you?"},
|
||||||
|
"finish_reason": "stop",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
] * loop_amount
|
||||||
|
completion_stream = ModelResponseListIterator(model_responses=chunks)
|
||||||
|
|
||||||
|
response = litellm.CustomStreamWrapper(
|
||||||
|
completion_stream=completion_stream,
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
custom_llm_provider="cached_response",
|
||||||
|
logging_obj=litellm.Logging(
|
||||||
|
model="gpt-3.5-turbo",
|
||||||
|
messages=[{"role": "user", "content": "Hey"}],
|
||||||
|
stream=True,
|
||||||
|
call_type="completion",
|
||||||
|
start_time=time.time(),
|
||||||
|
litellm_call_id="12345",
|
||||||
|
function_id="1245",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if loop_amount > litellm.REPEATED_STREAMING_CHUNK_LIMIT:
|
||||||
|
with pytest.raises(litellm.InternalServerError):
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
else:
|
||||||
|
for chunk in response:
|
||||||
|
continue
|
||||||
|
|
||||||
|
|
||||||
def test_unit_test_custom_stream_wrapper_openai():
|
def test_unit_test_custom_stream_wrapper_openai():
|
||||||
"""
|
"""
|
||||||
Test if last streaming chunk ends with '?', if the message repeats itself.
|
Test if last streaming chunk ends with '?', if the message repeats itself.
|
||||||
|
|
|
@ -8638,6 +8638,32 @@ class CustomStreamWrapper:
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise e
|
raise e
|
||||||
|
|
||||||
|
def safety_checker(self) -> None:
|
||||||
|
"""
|
||||||
|
Fixes - https://github.com/BerriAI/litellm/issues/5158
|
||||||
|
|
||||||
|
if the model enters a loop and starts repeating the same chunk again, break out of loop and raise an internalservererror - allows for retries.
|
||||||
|
|
||||||
|
Raises - InternalServerError, if LLM enters infinite loop while streaming
|
||||||
|
"""
|
||||||
|
if len(self.chunks) >= litellm.REPEATED_STREAMING_CHUNK_LIMIT:
|
||||||
|
# Get the last n chunks
|
||||||
|
last_chunks = self.chunks[-litellm.REPEATED_STREAMING_CHUNK_LIMIT :]
|
||||||
|
|
||||||
|
# Extract the relevant content from the chunks
|
||||||
|
last_contents = [chunk.choices[0].delta.content for chunk in last_chunks]
|
||||||
|
|
||||||
|
# Check if all extracted contents are identical
|
||||||
|
if all(content == last_contents[0] for content in last_contents):
|
||||||
|
# All last n chunks are identical
|
||||||
|
raise litellm.InternalServerError(
|
||||||
|
message="The model is repeating the same chunk = {}.".format(
|
||||||
|
last_contents[0]
|
||||||
|
),
|
||||||
|
model="",
|
||||||
|
llm_provider="",
|
||||||
|
)
|
||||||
|
|
||||||
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
|
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
|
||||||
"""
|
"""
|
||||||
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
|
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
|
||||||
|
@ -10074,6 +10100,7 @@ class CustomStreamWrapper:
|
||||||
and len(completion_obj["tool_calls"]) > 0
|
and len(completion_obj["tool_calls"]) > 0
|
||||||
)
|
)
|
||||||
): # cannot set content of an OpenAI Object to be an empty string
|
): # cannot set content of an OpenAI Object to be an empty string
|
||||||
|
self.safety_checker()
|
||||||
hold, model_response_str = self.check_special_tokens(
|
hold, model_response_str = self.check_special_tokens(
|
||||||
chunk=completion_obj["content"],
|
chunk=completion_obj["content"],
|
||||||
finish_reason=model_response.choices[0].finish_reason,
|
finish_reason=model_response.choices[0].finish_reason,
|
||||||
|
@ -11257,6 +11284,34 @@ class ModelResponseIterator:
|
||||||
return self.model_response
|
return self.model_response
|
||||||
|
|
||||||
|
|
||||||
|
class ModelResponseListIterator:
|
||||||
|
def __init__(self, model_responses):
|
||||||
|
self.model_responses = model_responses
|
||||||
|
self.index = 0
|
||||||
|
|
||||||
|
# Sync iterator
|
||||||
|
def __iter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
def __next__(self):
|
||||||
|
if self.index >= len(self.model_responses):
|
||||||
|
raise StopIteration
|
||||||
|
model_response = self.model_responses[self.index]
|
||||||
|
self.index += 1
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
# Async iterator
|
||||||
|
def __aiter__(self):
|
||||||
|
return self
|
||||||
|
|
||||||
|
async def __anext__(self):
|
||||||
|
if self.index >= len(self.model_responses):
|
||||||
|
raise StopAsyncIteration
|
||||||
|
model_response = self.model_responses[self.index]
|
||||||
|
self.index += 1
|
||||||
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
class CustomModelResponseIterator(Iterable):
|
class CustomModelResponseIterator(Iterable):
|
||||||
def __init__(self) -> None:
|
def __init__(self) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue