forked from phoenix/litellm-mirror
use 1 class for unit testing
This commit is contained in:
parent
26c19ba3e1
commit
c15359911a
1 changed files with 20 additions and 96 deletions
|
@ -22,6 +22,7 @@ import pytest
|
||||||
import litellm
|
import litellm
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from base_image_generation_test import BaseImageGenTest
|
||||||
|
|
||||||
|
|
||||||
def get_vertex_ai_creds_json() -> dict:
|
def get_vertex_ai_creds_json() -> dict:
|
||||||
|
@ -97,25 +98,27 @@ def load_vertex_ai_credentials():
|
||||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_openai():
|
# class TestBedrockSd3(BaseImageGenTest):
|
||||||
try:
|
# def get_base_image_generation_call_args(self) -> dict:
|
||||||
litellm.set_verbose = True
|
# return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||||
response = litellm.image_generation(
|
|
||||||
prompt="A cute baby sea otter", model="dall-e-3"
|
# class TestBedrockSd1(BaseImageGenTest):
|
||||||
)
|
# def get_base_image_generation_call_args(self) -> dict:
|
||||||
print(f"response: {response}")
|
# return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||||
assert len(response.data) > 0
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except litellm.ContentPolicyViolationError:
|
|
||||||
pass # OpenAI randomly raises these errors - skip when they occur
|
|
||||||
except Exception as e:
|
|
||||||
if "Connection error" in str(e):
|
|
||||||
pass
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_image_generation_openai()
|
class TestOpenAIDalle3(BaseImageGenTest):
|
||||||
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
|
return {"model": "dall-e-3"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureOpenAIDalle3(BaseImageGenTest):
|
||||||
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
|
return {
|
||||||
|
"model": "azure/dall-e-3-test",
|
||||||
|
"api_version": "2023-09-01-preview",
|
||||||
|
"base_model": "dall-e-3",
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
|
@ -188,88 +191,9 @@ def test_image_generation_azure_dall_e_3():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# test_image_generation_azure_dall_e_3()
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_image_generation_openai():
|
|
||||||
try:
|
|
||||||
response = litellm.image_generation(
|
|
||||||
prompt="A cute baby sea otter", model="dall-e-3"
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
assert len(response.data) > 0
|
|
||||||
except litellm.APIError:
|
|
||||||
pass
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except litellm.ContentPolicyViolationError:
|
|
||||||
pass # openai randomly raises these errors - skip when they occur
|
|
||||||
except litellm.InternalServerError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
if "Connection error" in str(e):
|
|
||||||
pass
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_async_image_generation_openai())
|
# asyncio.run(test_async_image_generation_openai())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_image_generation_azure():
|
|
||||||
try:
|
|
||||||
response = await litellm.aimage_generation(
|
|
||||||
prompt="A cute baby sea otter",
|
|
||||||
model="azure/dall-e-3-test",
|
|
||||||
api_version="2023-09-01-preview",
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except litellm.ContentPolicyViolationError:
|
|
||||||
pass # Azure randomly raises these errors - skip when they occur
|
|
||||||
except litellm.InternalServerError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
if "Your task failed as a result of our safety system." in str(e):
|
|
||||||
pass
|
|
||||||
if "Connection error" in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model",
|
|
||||||
["bedrock/stability.sd3-large-v1:0", "bedrock/stability.stable-diffusion-xl-v1"],
|
|
||||||
)
|
|
||||||
def test_image_generation_bedrock(model):
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
response = litellm.image_generation(
|
|
||||||
prompt="A cute baby sea otter",
|
|
||||||
model=model,
|
|
||||||
aws_region_name="us-west-2",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"response: {response}")
|
|
||||||
print("response hidden params", response._hidden_params)
|
|
||||||
|
|
||||||
assert response._hidden_params["response_cost"] is not None
|
|
||||||
from openai.types.images_response import ImagesResponse
|
|
||||||
|
|
||||||
ImagesResponse.model_validate(response.model_dump())
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except litellm.ContentPolicyViolationError:
|
|
||||||
pass # Azure randomly raises these errors - skip when they occur
|
|
||||||
except Exception as e:
|
|
||||||
if "Your task failed as a result of our safety system." in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_aimage_generation_bedrock_with_optional_params():
|
async def test_aimage_generation_bedrock_with_optional_params():
|
||||||
try:
|
try:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue