diff --git a/tests/image_gen_tests/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py index e94d62c1f..4e02b493c 100644 --- a/tests/image_gen_tests/test_image_generation.py +++ b/tests/image_gen_tests/test_image_generation.py @@ -22,6 +22,7 @@ import pytest import litellm import json import tempfile +from base_image_generation_test import BaseImageGenTest def get_vertex_ai_creds_json() -> dict: @@ -97,25 +98,27 @@ def load_vertex_ai_credentials(): os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) -def test_image_generation_openai(): - try: - litellm.set_verbose = True - response = litellm.image_generation( - prompt="A cute baby sea otter", model="dall-e-3" - ) - print(f"response: {response}") - assert len(response.data) > 0 - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # OpenAI randomly raises these errors - skip when they occur - except Exception as e: - if "Connection error" in str(e): - pass - pytest.fail(f"An exception occurred - {str(e)}") +# class TestBedrockSd3(BaseImageGenTest): +# def get_base_image_generation_call_args(self) -> dict: +# return {"model": "bedrock/stability.sd3-large-v1:0"} + +# class TestBedrockSd1(BaseImageGenTest): +# def get_base_image_generation_call_args(self) -> dict: +# return {"model": "bedrock/stability.sd3-large-v1:0"} -# test_image_generation_openai() +class TestOpenAIDalle3(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + return {"model": "dall-e-3"} + + +class TestAzureOpenAIDalle3(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + return { + "model": "azure/dall-e-3-test", + "api_version": "2023-09-01-preview", + "base_model": "dall-e-3", + } @pytest.mark.parametrize( @@ -188,88 +191,9 @@ def test_image_generation_azure_dall_e_3(): pytest.fail(f"An exception occurred - {str(e)}") -# test_image_generation_azure_dall_e_3() -@pytest.mark.asyncio -async def test_async_image_generation_openai(): - try: - response = litellm.image_generation( - prompt="A cute baby sea otter", model="dall-e-3" - ) - print(f"response: {response}") - assert len(response.data) > 0 - except litellm.APIError: - pass - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # openai randomly raises these errors - skip when they occur - except litellm.InternalServerError: - pass - except Exception as e: - if "Connection error" in str(e): - pass - pytest.fail(f"An exception occurred - {str(e)}") - - # asyncio.run(test_async_image_generation_openai()) -@pytest.mark.asyncio -async def test_async_image_generation_azure(): - try: - response = await litellm.aimage_generation( - prompt="A cute baby sea otter", - model="azure/dall-e-3-test", - api_version="2023-09-01-preview", - ) - print(f"response: {response}") - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # Azure randomly raises these errors - skip when they occur - except litellm.InternalServerError: - pass - except Exception as e: - if "Your task failed as a result of our safety system." in str(e): - pass - if "Connection error" in str(e): - pass - else: - pytest.fail(f"An exception occurred - {str(e)}") - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model", - ["bedrock/stability.sd3-large-v1:0", "bedrock/stability.stable-diffusion-xl-v1"], -) -def test_image_generation_bedrock(model): - try: - litellm.set_verbose = True - response = litellm.image_generation( - prompt="A cute baby sea otter", - model=model, - aws_region_name="us-west-2", - ) - - print(f"response: {response}") - print("response hidden params", response._hidden_params) - - assert response._hidden_params["response_cost"] is not None - from openai.types.images_response import ImagesResponse - - ImagesResponse.model_validate(response.model_dump()) - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # Azure randomly raises these errors - skip when they occur - except Exception as e: - if "Your task failed as a result of our safety system." in str(e): - pass - else: - pytest.fail(f"An exception occurred - {str(e)}") - - @pytest.mark.asyncio async def test_aimage_generation_bedrock_with_optional_params(): try: