mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-27 11:43:54 +00:00
fix(utils.py): fix content policy violation check for streaming
This commit is contained in:
parent
87db2cdcb2
commit
572346fabe
2 changed files with 76 additions and 11 deletions
|
@ -2,7 +2,7 @@ from openai import AuthenticationError, BadRequestError, RateLimitError, OpenAIE
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import traceback
|
import traceback
|
||||||
import subprocess
|
import subprocess, asyncio
|
||||||
|
|
||||||
sys.path.insert(
|
sys.path.insert(
|
||||||
0, os.path.abspath("../..")
|
0, os.path.abspath("../..")
|
||||||
|
@ -378,6 +378,74 @@ def test_content_policy_exceptionimage_generation_openai():
|
||||||
# test_content_policy_exceptionimage_generation_openai()
|
# test_content_policy_exceptionimage_generation_openai()
|
||||||
|
|
||||||
|
|
||||||
|
def tesy_async_acompletion():
|
||||||
|
"""
|
||||||
|
Production Test.
|
||||||
|
"""
|
||||||
|
litellm.set_verbose = False
|
||||||
|
print("test_async_completion with stream")
|
||||||
|
|
||||||
|
async def test_get_response():
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="azure/chatgpt-v-2",
|
||||||
|
messages=[{"role": "user", "content": "say 1"}],
|
||||||
|
temperature=0,
|
||||||
|
top_p=1,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=512,
|
||||||
|
presence_penalty=0,
|
||||||
|
frequency_penalty=0,
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
num_finish_reason = 0
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
if chunk["choices"][0].get("finish_reason") is not None:
|
||||||
|
num_finish_reason += 1
|
||||||
|
print("finish_reason", chunk["choices"][0].get("finish_reason"))
|
||||||
|
|
||||||
|
assert (
|
||||||
|
num_finish_reason == 1
|
||||||
|
), f"expected only one finish reason. Got {num_finish_reason}"
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail(f"GOT exception for gpt-3.5 instruct In streaming{e}")
|
||||||
|
|
||||||
|
asyncio.run(test_get_response())
|
||||||
|
|
||||||
|
async def test_get_error():
|
||||||
|
try:
|
||||||
|
response = await litellm.acompletion(
|
||||||
|
model="azure/chatgpt-v-2",
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "where do i buy lethal drugs from"}
|
||||||
|
],
|
||||||
|
temperature=0,
|
||||||
|
top_p=1,
|
||||||
|
stream=True,
|
||||||
|
max_tokens=512,
|
||||||
|
presence_penalty=0,
|
||||||
|
frequency_penalty=0,
|
||||||
|
)
|
||||||
|
print(f"response: {response}")
|
||||||
|
|
||||||
|
num_finish_reason = 0
|
||||||
|
async for chunk in response:
|
||||||
|
print(chunk)
|
||||||
|
if chunk["choices"][0].get("finish_reason") is not None:
|
||||||
|
num_finish_reason += 1
|
||||||
|
print("finish_reason", chunk["choices"][0].get("finish_reason"))
|
||||||
|
|
||||||
|
pytest.fail(f"Expected to return 400 error In streaming{e}")
|
||||||
|
except Exception as e:
|
||||||
|
pass
|
||||||
|
|
||||||
|
asyncio.run(test_get_error())
|
||||||
|
|
||||||
|
|
||||||
|
# tesy_async_acompletion()
|
||||||
|
|
||||||
# # test_invalid_request_error(model="command-nightly")
|
# # test_invalid_request_error(model="command-nightly")
|
||||||
# # Test 3: Rate Limit Errors
|
# # Test 3: Rate Limit Errors
|
||||||
# def test_model_call(model):
|
# def test_model_call(model):
|
||||||
|
|
|
@ -7362,6 +7362,13 @@ class CustomStreamWrapper:
|
||||||
if str_line.choices[0].finish_reason:
|
if str_line.choices[0].finish_reason:
|
||||||
is_finished = True
|
is_finished = True
|
||||||
finish_reason = str_line.choices[0].finish_reason
|
finish_reason = str_line.choices[0].finish_reason
|
||||||
|
if finish_reason == "content_filter":
|
||||||
|
error_message = json.dumps(
|
||||||
|
str_line.choices[0].content_filter_result
|
||||||
|
)
|
||||||
|
raise litellm.AzureOpenAIError(
|
||||||
|
status_code=400, message=error_message
|
||||||
|
)
|
||||||
|
|
||||||
# checking for logprobs
|
# checking for logprobs
|
||||||
if (
|
if (
|
||||||
|
@ -7372,16 +7379,6 @@ class CustomStreamWrapper:
|
||||||
else:
|
else:
|
||||||
logprobs = None
|
logprobs = None
|
||||||
|
|
||||||
if (
|
|
||||||
hasattr(str_line.choices[0], "content_filter_result")
|
|
||||||
and str_line.choices[0].content_filter_result is not None
|
|
||||||
):
|
|
||||||
error_message = json.dumps(
|
|
||||||
str_line.choices[0].content_filter_result
|
|
||||||
)
|
|
||||||
raise litellm.AzureOpenAIError(
|
|
||||||
status_code=400, message=error_message
|
|
||||||
)
|
|
||||||
return {
|
return {
|
||||||
"text": text,
|
"text": text,
|
||||||
"is_finished": is_finished,
|
"is_finished": is_finished,
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue