From 26c19ba3e162278d0ae1a7f1b8f3f2d994f300a6 Mon Sep 17 00:00:00 2001 From: Ishaan Jaff Date: Tue, 12 Nov 2024 15:47:27 -0800 Subject: [PATCH] add BaseImageGenTest --- .../base_image_generation_test.py | 61 +++++++++++++++++++ 1 file changed, 61 insertions(+) create mode 100644 tests/image_gen_tests/base_image_generation_test.py diff --git a/tests/image_gen_tests/base_image_generation_test.py b/tests/image_gen_tests/base_image_generation_test.py new file mode 100644 index 000000000..56c4557cc --- /dev/null +++ b/tests/image_gen_tests/base_image_generation_test.py @@ -0,0 +1,61 @@ +import asyncio +import httpx +import json +import pytest +import sys +from typing import Any, Dict, List +from unittest.mock import MagicMock, Mock, patch +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm.exceptions import BadRequestError +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.utils import CustomStreamWrapper +from openai.types.image import Image + +# test_example.py +from abc import ABC, abstractmethod + + +class BaseImageGenTest(ABC): + """ + Abstract base test class that enforces a common test across all test classes. + """ + + @abstractmethod + def get_base_image_generation_call_args(self) -> dict: + """Must return the base image generation call args""" + pass + + @pytest.mark.asyncio(scope="module") + async def test_basic_image_generation(self): + """Test basic image generation""" + try: + base_image_generation_call_args = self.get_base_image_generation_call_args() + litellm.set_verbose = True + response = await litellm.aimage_generation( + **base_image_generation_call_args, prompt="A image of a otter" + ) + print(response) + + assert response._hidden_params["response_cost"] is not None + from openai.types.images_response import ImagesResponse + + ImagesResponse.model_validate(response.model_dump()) + + for d in response.data: + assert isinstance(d, Image) + print("data in response.data", d) + assert d.b64_json is not None or d.url is not None + except litellm.RateLimitError as e: + pass + except litellm.ContentPolicyViolationError: + pass # Azure randomly raises these errors - skip when they occur + except Exception as e: + if "Your task failed as a result of our safety system." in str(e): + pass + else: + pytest.fail(f"An exception occurred - {str(e)}")