fix(utils.py): fix content policy violation check for streaming

This commit is contained in:
Krrish Dholakia 2024-01-23 06:55:04 -08:00
parent 9327d76379
commit 23b59ac9b8
2 changed files with 76 additions and 11 deletions

View file

@ -2,7 +2,7 @@ from openai import AuthenticationError, BadRequestError, RateLimitError, OpenAIE
import os import os
import sys import sys
import traceback import traceback
import subprocess import subprocess, asyncio
sys.path.insert( sys.path.insert(
0, os.path.abspath("../..") 0, os.path.abspath("../..")
@ -378,6 +378,74 @@ def test_content_policy_exceptionimage_generation_openai():
# test_content_policy_exceptionimage_generation_openai() # test_content_policy_exceptionimage_generation_openai()
def tesy_async_acompletion():
"""
Production Test.
"""
litellm.set_verbose = False
print("test_async_completion with stream")
async def test_get_response():
try:
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "say 1"}],
temperature=0,
top_p=1,
stream=True,
max_tokens=512,
presence_penalty=0,
frequency_penalty=0,
)
print(f"response: {response}")
num_finish_reason = 0
async for chunk in response:
print(chunk)
if chunk["choices"][0].get("finish_reason") is not None:
num_finish_reason += 1
print("finish_reason", chunk["choices"][0].get("finish_reason"))
assert (
num_finish_reason == 1
), f"expected only one finish reason. Got {num_finish_reason}"
except Exception as e:
pytest.fail(f"GOT exception for gpt-3.5 instruct In streaming{e}")
asyncio.run(test_get_response())
async def test_get_error():
try:
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": "where do i buy lethal drugs from"}
],
temperature=0,
top_p=1,
stream=True,
max_tokens=512,
presence_penalty=0,
frequency_penalty=0,
)
print(f"response: {response}")
num_finish_reason = 0
async for chunk in response:
print(chunk)
if chunk["choices"][0].get("finish_reason") is not None:
num_finish_reason += 1
print("finish_reason", chunk["choices"][0].get("finish_reason"))
pytest.fail(f"Expected to return 400 error In streaming{e}")
except Exception as e:
pass
asyncio.run(test_get_error())
# tesy_async_acompletion()
# # test_invalid_request_error(model="command-nightly") # # test_invalid_request_error(model="command-nightly")
# # Test 3: Rate Limit Errors # # Test 3: Rate Limit Errors
# def test_model_call(model): # def test_model_call(model):

View file

@ -7362,6 +7362,13 @@ class CustomStreamWrapper:
if str_line.choices[0].finish_reason: if str_line.choices[0].finish_reason:
is_finished = True is_finished = True
finish_reason = str_line.choices[0].finish_reason finish_reason = str_line.choices[0].finish_reason
if finish_reason == "content_filter":
error_message = json.dumps(
str_line.choices[0].content_filter_result
)
raise litellm.AzureOpenAIError(
status_code=400, message=error_message
)
# checking for logprobs # checking for logprobs
if ( if (
@ -7372,16 +7379,6 @@ class CustomStreamWrapper:
else: else:
logprobs = None logprobs = None
if (
hasattr(str_line.choices[0], "content_filter_result")
and str_line.choices[0].content_filter_result is not None
):
error_message = json.dumps(
str_line.choices[0].content_filter_result
)
raise litellm.AzureOpenAIError(
status_code=400, message=error_message
)
return { return {
"text": text, "text": text,
"is_finished": is_finished, "is_finished": is_finished,