fix(utils.py): don't raise error on openai content filter during streaming - return as is

Fixes issue where we would raise an error vs. openai who return the chunk with finish reason as 'content_filter'
This commit is contained in:
Krrish Dholakia 2024-07-25 19:50:07 -07:00
parent 5bec2bf513
commit a2fd8459fc
2 changed files with 50 additions and 15 deletions

View file

@ -3248,6 +3248,56 @@ def test_unit_test_custom_stream_wrapper():
assert freq == 1
def test_unit_test_custom_stream_wrapper_openai():
"""
Test if last streaming chunk ends with '?', if the message repeats itself.
"""
litellm.set_verbose = False
chunk = {
"id": "chatcmpl-9mWtyDnikZZoB75DyfUzWUxiiE2Pi",
"choices": [
litellm.utils.StreamingChoices(
delta=litellm.utils.Delta(
content=None, function_call=None, role=None, tool_calls=None
),
finish_reason="content_filter",
index=0,
logprobs=None,
)
],
"created": 1721353246,
"model": "gpt-3.5-turbo-0613",
"object": "chat.completion.chunk",
"system_fingerprint": None,
"usage": None,
}
chunk = litellm.ModelResponse(**chunk, stream=True)
completion_stream = ModelResponseIterator(model_response=chunk)
response = litellm.CustomStreamWrapper(
completion_stream=completion_stream,
model="gpt-3.5-turbo",
custom_llm_provider="azure",
logging_obj=litellm.Logging(
model="gpt-3.5-turbo",
messages=[{"role": "user", "content": "Hey"}],
stream=True,
call_type="completion",
start_time=time.time(),
litellm_call_id="12345",
function_id="1245",
),
)
stream_finish_reason: Optional[str] = None
for chunk in response:
assert chunk.choices[0].delta.content is None
if chunk.choices[0].finish_reason is not None:
stream_finish_reason = chunk.choices[0].finish_reason
assert stream_finish_reason == "content_filter"
def test_aamazing_unit_test_custom_stream_wrapper_n():
"""
Test if the translated output maps exactly to the received openai input

View file

@ -8840,21 +8840,6 @@ class CustomStreamWrapper:
if str_line.choices[0].finish_reason:
is_finished = True
finish_reason = str_line.choices[0].finish_reason
if finish_reason == "content_filter":
if hasattr(str_line.choices[0], "content_filter_result"):
error_message = json.dumps(
str_line.choices[0].content_filter_result
)
else:
error_message = "{} Response={}".format(
self.custom_llm_provider, str(dict(str_line))
)
raise litellm.ContentPolicyViolationError(
message=error_message,
llm_provider=self.custom_llm_provider,
model=self.model,
)
# checking for logprobs
if (