test_basic_image_generation

This commit is contained in:
Ishaan Jaff 2024-11-12 16:34:47 -08:00
parent 1c2f7f792b
commit d2a1ac804b
2 changed files with 47 additions and 93 deletions

View file

@ -3,7 +3,7 @@ import httpx
import json import json
import pytest import pytest
import sys import sys
from typing import Any, Dict, List from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch from unittest.mock import MagicMock, Mock, patch
import os import os
@ -15,6 +15,19 @@ from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper from litellm.utils import CustomStreamWrapper
from openai.types.image import Image from openai.types.image import Image
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload
class TestCustomLogger(CustomLogger):
def __init__(self):
super().__init__()
self.standard_logging_payload: Optional[StandardLoggingPayload] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self.standard_logging_payload = kwargs.get("standard_logging_object")
pass
# test_example.py # test_example.py
from abc import ABC, abstractmethod from abc import ABC, abstractmethod
@ -34,6 +47,8 @@ class BaseImageGenTest(ABC):
async def test_basic_image_generation(self): async def test_basic_image_generation(self):
"""Test basic image generation""" """Test basic image generation"""
try: try:
custom_logger = TestCustomLogger()
litellm.callbacks = [custom_logger]
base_image_generation_call_args = self.get_base_image_generation_call_args() base_image_generation_call_args = self.get_base_image_generation_call_args()
litellm.set_verbose = True litellm.set_verbose = True
response = await litellm.aimage_generation( response = await litellm.aimage_generation(
@ -41,8 +56,18 @@ class BaseImageGenTest(ABC):
) )
print(response) print(response)
await asyncio.sleep(1)
assert response._hidden_params["response_cost"] is not None assert response._hidden_params["response_cost"] is not None
assert response._hidden_params["response_cost"] > 0
print("response_cost", response._hidden_params["response_cost"]) print("response_cost", response._hidden_params["response_cost"])
logged_standard_logging_payload = custom_logger.standard_logging_payload
print("logged_standard_logging_payload", logged_standard_logging_payload)
assert logged_standard_logging_payload is not None
assert logged_standard_logging_payload["response_cost"] is not None
assert logged_standard_logging_payload["response_cost"] > 0
from openai.types.images_response import ImagesResponse from openai.types.images_response import ImagesResponse
ImagesResponse.model_validate(response.model_dump()) ImagesResponse.model_validate(response.model_dump())

View file

@ -102,13 +102,28 @@ def load_vertex_ai_credentials():
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
# class TestBedrockSd3(BaseImageGenTest): class TestVertexImageGeneration(BaseImageGenTest):
# def get_base_image_generation_call_args(self) -> dict: def get_base_image_generation_call_args(self) -> dict:
# return {"model": "bedrock/stability.sd3-large-v1:0"} # load_vertex_ai_credentials()
litellm.in_memory_llm_clients_cache = {}
return {
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "adroit-crow-413218",
"vertex_ai_location": "us-central1",
"n": 1,
}
# class TestBedrockSd1(BaseImageGenTest):
# def get_base_image_generation_call_args(self) -> dict: class TestBedrockSd3(BaseImageGenTest):
# return {"model": "bedrock/stability.sd3-large-v1:0"} def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = {}
return {"model": "bedrock/stability.sd3-large-v1:0"}
class TestBedrockSd1(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = {}
return {"model": "bedrock/stability.sd3-large-v1:0"}
class TestOpenAIDalle3(BaseImageGenTest): class TestOpenAIDalle3(BaseImageGenTest):
@ -130,48 +145,6 @@ class TestAzureOpenAIDalle3(BaseImageGenTest):
} }
@pytest.mark.parametrize(
"sync_mode",
[
True,
], # False
) #
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_image_generation_azure(sync_mode):
try:
if sync_mode:
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="azure/",
api_version="2023-06-01-preview",
)
else:
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="azure/",
api_version="2023-06-01-preview",
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except litellm.InternalServerError:
pass
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
if "Connection error" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_azure()
@pytest.mark.flaky(retries=3, delay=1) @pytest.mark.flaky(retries=3, delay=1)
def test_image_generation_azure_dall_e_3(): def test_image_generation_azure_dall_e_3():
try: try:
@ -221,47 +194,3 @@ async def test_aimage_generation_bedrock_with_optional_params():
pass pass
else: else:
pytest.fail(f"An exception occurred - {str(e)}") pytest.fail(f"An exception occurred - {str(e)}")
from openai.types.image import Image
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_aimage_generation_vertex_ai(sync_mode):
litellm.set_verbose = True
load_vertex_ai_credentials()
data = {
"prompt": "An olympic size swimming pool",
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "adroit-crow-413218",
"vertex_ai_location": "us-central1",
"n": 1,
}
try:
if sync_mode:
response = litellm.image_generation(**data)
else:
response = await litellm.aimage_generation(**data)
assert response.data is not None
assert len(response.data) > 0
for d in response.data:
assert isinstance(d, Image)
print("data in response.data", d)
assert d.b64_json is not None
except litellm.ServiceUnavailableError as e:
pass
except litellm.RateLimitError as e:
pass
except litellm.InternalServerError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")