fix(utils.py): fix content policy violation check for streaming

This commit is contained in:
Krrish Dholakia 2024-01-23 06:55:04 -08:00
parent 9327d76379
commit 23b59ac9b8
2 changed files with 76 additions and 11 deletions

View file

@ -2,7 +2,7 @@ from openai import AuthenticationError, BadRequestError, RateLimitError, OpenAIE
import os
import sys
import traceback
import subprocess
import subprocess, asyncio
sys.path.insert(
0, os.path.abspath("../..")
@ -378,6 +378,74 @@ def test_content_policy_exceptionimage_generation_openai():
# test_content_policy_exceptionimage_generation_openai()
def tesy_async_acompletion():
"""
Production Test.
"""
litellm.set_verbose = False
print("test_async_completion with stream")
async def test_get_response():
try:
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[{"role": "user", "content": "say 1"}],
temperature=0,
top_p=1,
stream=True,
max_tokens=512,
presence_penalty=0,
frequency_penalty=0,
)
print(f"response: {response}")
num_finish_reason = 0
async for chunk in response:
print(chunk)
if chunk["choices"][0].get("finish_reason") is not None:
num_finish_reason += 1
print("finish_reason", chunk["choices"][0].get("finish_reason"))
assert (
num_finish_reason == 1
), f"expected only one finish reason. Got {num_finish_reason}"
except Exception as e:
pytest.fail(f"GOT exception for gpt-3.5 instruct In streaming{e}")
asyncio.run(test_get_response())
async def test_get_error():
try:
response = await litellm.acompletion(
model="azure/chatgpt-v-2",
messages=[
{"role": "user", "content": "where do i buy lethal drugs from"}
],
temperature=0,
top_p=1,
stream=True,
max_tokens=512,
presence_penalty=0,
frequency_penalty=0,
)
print(f"response: {response}")
num_finish_reason = 0
async for chunk in response:
print(chunk)
if chunk["choices"][0].get("finish_reason") is not None:
num_finish_reason += 1
print("finish_reason", chunk["choices"][0].get("finish_reason"))
pytest.fail(f"Expected to return 400 error In streaming{e}")
except Exception as e:
pass
asyncio.run(test_get_error())
# tesy_async_acompletion()
# # test_invalid_request_error(model="command-nightly")
# # Test 3: Rate Limit Errors
# def test_model_call(model):