diff --git a/tests/image_gen_tests/base_image_generation_test.py b/tests/image_gen_tests/base_image_generation_test.py index be0ecb1b4..e0652114d 100644 --- a/tests/image_gen_tests/base_image_generation_test.py +++ b/tests/image_gen_tests/base_image_generation_test.py @@ -3,7 +3,7 @@ import httpx import json import pytest import sys -from typing import Any, Dict, List +from typing import Any, Dict, List, Optional from unittest.mock import MagicMock, Mock, patch import os @@ -15,6 +15,19 @@ from litellm.exceptions import BadRequestError from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.utils import CustomStreamWrapper from openai.types.image import Image +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.utils import StandardLoggingPayload + + +class TestCustomLogger(CustomLogger): + def __init__(self): + super().__init__() + self.standard_logging_payload: Optional[StandardLoggingPayload] = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self.standard_logging_payload = kwargs.get("standard_logging_object") + pass + # test_example.py from abc import ABC, abstractmethod @@ -34,6 +47,8 @@ class BaseImageGenTest(ABC): async def test_basic_image_generation(self): """Test basic image generation""" try: + custom_logger = TestCustomLogger() + litellm.callbacks = [custom_logger] base_image_generation_call_args = self.get_base_image_generation_call_args() litellm.set_verbose = True response = await litellm.aimage_generation( @@ -41,8 +56,18 @@ class BaseImageGenTest(ABC): ) print(response) + await asyncio.sleep(1) + assert response._hidden_params["response_cost"] is not None + assert response._hidden_params["response_cost"] > 0 print("response_cost", response._hidden_params["response_cost"]) + + logged_standard_logging_payload = custom_logger.standard_logging_payload + print("logged_standard_logging_payload", logged_standard_logging_payload) + assert logged_standard_logging_payload is not None + assert logged_standard_logging_payload["response_cost"] is not None + assert logged_standard_logging_payload["response_cost"] > 0 + from openai.types.images_response import ImagesResponse ImagesResponse.model_validate(response.model_dump()) diff --git a/tests/image_gen_tests/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py index 411aab475..14493c3f0 100644 --- a/tests/image_gen_tests/test_image_generation.py +++ b/tests/image_gen_tests/test_image_generation.py @@ -102,13 +102,28 @@ def load_vertex_ai_credentials(): os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) -# class TestBedrockSd3(BaseImageGenTest): -# def get_base_image_generation_call_args(self) -> dict: -# return {"model": "bedrock/stability.sd3-large-v1:0"} +class TestVertexImageGeneration(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + # load_vertex_ai_credentials() + litellm.in_memory_llm_clients_cache = {} + return { + "model": "vertex_ai/imagegeneration@006", + "vertex_ai_project": "adroit-crow-413218", + "vertex_ai_location": "us-central1", + "n": 1, + } -# class TestBedrockSd1(BaseImageGenTest): -# def get_base_image_generation_call_args(self) -> dict: -# return {"model": "bedrock/stability.sd3-large-v1:0"} + +class TestBedrockSd3(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + litellm.in_memory_llm_clients_cache = {} + return {"model": "bedrock/stability.sd3-large-v1:0"} + + +class TestBedrockSd1(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + litellm.in_memory_llm_clients_cache = {} + return {"model": "bedrock/stability.sd3-large-v1:0"} class TestOpenAIDalle3(BaseImageGenTest): @@ -130,48 +145,6 @@ class TestAzureOpenAIDalle3(BaseImageGenTest): } -@pytest.mark.parametrize( - "sync_mode", - [ - True, - ], # False -) # -@pytest.mark.asyncio -@pytest.mark.flaky(retries=3, delay=1) -async def test_image_generation_azure(sync_mode): - try: - if sync_mode: - response = litellm.image_generation( - prompt="A cute baby sea otter", - model="azure/", - api_version="2023-06-01-preview", - ) - else: - response = await litellm.aimage_generation( - prompt="A cute baby sea otter", - model="azure/", - api_version="2023-06-01-preview", - ) - print(f"response: {response}") - assert len(response.data) > 0 - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # Azure randomly raises these errors - skip when they occur - except litellm.InternalServerError: - pass - except Exception as e: - if "Your task failed as a result of our safety system." in str(e): - pass - if "Connection error" in str(e): - pass - else: - pytest.fail(f"An exception occurred - {str(e)}") - - -# test_image_generation_azure() - - @pytest.mark.flaky(retries=3, delay=1) def test_image_generation_azure_dall_e_3(): try: @@ -221,47 +194,3 @@ async def test_aimage_generation_bedrock_with_optional_params(): pass else: pytest.fail(f"An exception occurred - {str(e)}") - - -from openai.types.image import Image - - -@pytest.mark.parametrize("sync_mode", [True, False]) -@pytest.mark.asyncio -async def test_aimage_generation_vertex_ai(sync_mode): - - litellm.set_verbose = True - - load_vertex_ai_credentials() - data = { - "prompt": "An olympic size swimming pool", - "model": "vertex_ai/imagegeneration@006", - "vertex_ai_project": "adroit-crow-413218", - "vertex_ai_location": "us-central1", - "n": 1, - } - try: - if sync_mode: - response = litellm.image_generation(**data) - else: - response = await litellm.aimage_generation(**data) - assert response.data is not None - assert len(response.data) > 0 - - for d in response.data: - assert isinstance(d, Image) - print("data in response.data", d) - assert d.b64_json is not None - except litellm.ServiceUnavailableError as e: - pass - except litellm.RateLimitError as e: - pass - except litellm.InternalServerError as e: - pass - except litellm.ContentPolicyViolationError: - pass # Azure randomly raises these errors - skip when they occur - except Exception as e: - if "Your task failed as a result of our safety system." in str(e): - pass - else: - pytest.fail(f"An exception occurred - {str(e)}")