fix(utils.py): Break out of infinite streaming loop

Fixes https://github.com/BerriAI/litellm/issues/5158
This commit is contained in:
Krrish Dholakia 2024-08-12 14:00:23 -07:00
parent d0a68ab123
commit fdd9a07051
4 changed files with 190 additions and 29 deletions

View file

@ -8638,6 +8638,32 @@ class CustomStreamWrapper:
except Exception as e:
raise e
def safety_checker(self) -> None:
"""
Fixes - https://github.com/BerriAI/litellm/issues/5158
if the model enters a loop and starts repeating the same chunk again, break out of loop and raise an internalservererror - allows for retries.
Raises - InternalServerError, if LLM enters infinite loop while streaming
"""
if len(self.chunks) >= litellm.REPEATED_STREAMING_CHUNK_LIMIT:
# Get the last n chunks
last_chunks = self.chunks[-litellm.REPEATED_STREAMING_CHUNK_LIMIT :]
# Extract the relevant content from the chunks
last_contents = [chunk.choices[0].delta.content for chunk in last_chunks]
# Check if all extracted contents are identical
if all(content == last_contents[0] for content in last_contents):
# All last n chunks are identical
raise litellm.InternalServerError(
message="The model is repeating the same chunk = {}.".format(
last_contents[0]
),
model="",
llm_provider="",
)
def check_special_tokens(self, chunk: str, finish_reason: Optional[str]):
"""
Output parse <s> / </s> special tokens for sagemaker + hf streaming.
@ -10074,6 +10100,7 @@ class CustomStreamWrapper:
and len(completion_obj["tool_calls"]) > 0
)
): # cannot set content of an OpenAI Object to be an empty string
self.safety_checker()
hold, model_response_str = self.check_special_tokens(
chunk=completion_obj["content"],
finish_reason=model_response.choices[0].finish_reason,
@ -11257,6 +11284,34 @@ class ModelResponseIterator:
return self.model_response
class ModelResponseListIterator:
def __init__(self, model_responses):
self.model_responses = model_responses
self.index = 0
# Sync iterator
def __iter__(self):
return self
def __next__(self):
if self.index >= len(self.model_responses):
raise StopIteration
model_response = self.model_responses[self.index]
self.index += 1
return model_response
# Async iterator
def __aiter__(self):
return self
async def __anext__(self):
if self.index >= len(self.model_responses):
raise StopAsyncIteration
model_response = self.model_responses[self.index]
self.index += 1
return model_response
class CustomModelResponseIterator(Iterable):
def __init__(self) -> None:
super().__init__()