forked from phoenix/litellm-mirror
test_basic_image_generation
This commit is contained in:
parent
1c2f7f792b
commit
d2a1ac804b
2 changed files with 47 additions and 93 deletions
|
@ -3,7 +3,7 @@ import httpx
|
|||
import json
|
||||
import pytest
|
||||
import sys
|
||||
from typing import Any, Dict, List
|
||||
from typing import Any, Dict, List, Optional
|
||||
from unittest.mock import MagicMock, Mock, patch
|
||||
import os
|
||||
|
||||
|
@ -15,6 +15,19 @@ from litellm.exceptions import BadRequestError
|
|||
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||
from litellm.utils import CustomStreamWrapper
|
||||
from openai.types.image import Image
|
||||
from litellm.integrations.custom_logger import CustomLogger
|
||||
from litellm.types.utils import StandardLoggingPayload
|
||||
|
||||
|
||||
class TestCustomLogger(CustomLogger):
|
||||
def __init__(self):
|
||||
super().__init__()
|
||||
self.standard_logging_payload: Optional[StandardLoggingPayload] = None
|
||||
|
||||
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||
self.standard_logging_payload = kwargs.get("standard_logging_object")
|
||||
pass
|
||||
|
||||
|
||||
# test_example.py
|
||||
from abc import ABC, abstractmethod
|
||||
|
@ -34,6 +47,8 @@ class BaseImageGenTest(ABC):
|
|||
async def test_basic_image_generation(self):
|
||||
"""Test basic image generation"""
|
||||
try:
|
||||
custom_logger = TestCustomLogger()
|
||||
litellm.callbacks = [custom_logger]
|
||||
base_image_generation_call_args = self.get_base_image_generation_call_args()
|
||||
litellm.set_verbose = True
|
||||
response = await litellm.aimage_generation(
|
||||
|
@ -41,8 +56,18 @@ class BaseImageGenTest(ABC):
|
|||
)
|
||||
print(response)
|
||||
|
||||
await asyncio.sleep(1)
|
||||
|
||||
assert response._hidden_params["response_cost"] is not None
|
||||
assert response._hidden_params["response_cost"] > 0
|
||||
print("response_cost", response._hidden_params["response_cost"])
|
||||
|
||||
logged_standard_logging_payload = custom_logger.standard_logging_payload
|
||||
print("logged_standard_logging_payload", logged_standard_logging_payload)
|
||||
assert logged_standard_logging_payload is not None
|
||||
assert logged_standard_logging_payload["response_cost"] is not None
|
||||
assert logged_standard_logging_payload["response_cost"] > 0
|
||||
|
||||
from openai.types.images_response import ImagesResponse
|
||||
|
||||
ImagesResponse.model_validate(response.model_dump())
|
||||
|
|
|
@ -102,13 +102,28 @@ def load_vertex_ai_credentials():
|
|||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||
|
||||
|
||||
# class TestBedrockSd3(BaseImageGenTest):
|
||||
# def get_base_image_generation_call_args(self) -> dict:
|
||||
# return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||
class TestVertexImageGeneration(BaseImageGenTest):
|
||||
def get_base_image_generation_call_args(self) -> dict:
|
||||
# load_vertex_ai_credentials()
|
||||
litellm.in_memory_llm_clients_cache = {}
|
||||
return {
|
||||
"model": "vertex_ai/imagegeneration@006",
|
||||
"vertex_ai_project": "adroit-crow-413218",
|
||||
"vertex_ai_location": "us-central1",
|
||||
"n": 1,
|
||||
}
|
||||
|
||||
# class TestBedrockSd1(BaseImageGenTest):
|
||||
# def get_base_image_generation_call_args(self) -> dict:
|
||||
# return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||
|
||||
class TestBedrockSd3(BaseImageGenTest):
|
||||
def get_base_image_generation_call_args(self) -> dict:
|
||||
litellm.in_memory_llm_clients_cache = {}
|
||||
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||
|
||||
|
||||
class TestBedrockSd1(BaseImageGenTest):
|
||||
def get_base_image_generation_call_args(self) -> dict:
|
||||
litellm.in_memory_llm_clients_cache = {}
|
||||
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||
|
||||
|
||||
class TestOpenAIDalle3(BaseImageGenTest):
|
||||
|
@ -130,48 +145,6 @@ class TestAzureOpenAIDalle3(BaseImageGenTest):
|
|||
}
|
||||
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"sync_mode",
|
||||
[
|
||||
True,
|
||||
], # False
|
||||
) #
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
async def test_image_generation_azure(sync_mode):
|
||||
try:
|
||||
if sync_mode:
|
||||
response = litellm.image_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="azure/",
|
||||
api_version="2023-06-01-preview",
|
||||
)
|
||||
else:
|
||||
response = await litellm.aimage_generation(
|
||||
prompt="A cute baby sea otter",
|
||||
model="azure/",
|
||||
api_version="2023-06-01-preview",
|
||||
)
|
||||
print(f"response: {response}")
|
||||
assert len(response.data) > 0
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.ContentPolicyViolationError:
|
||||
pass # Azure randomly raises these errors - skip when they occur
|
||||
except litellm.InternalServerError:
|
||||
pass
|
||||
except Exception as e:
|
||||
if "Your task failed as a result of our safety system." in str(e):
|
||||
pass
|
||||
if "Connection error" in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
# test_image_generation_azure()
|
||||
|
||||
|
||||
@pytest.mark.flaky(retries=3, delay=1)
|
||||
def test_image_generation_azure_dall_e_3():
|
||||
try:
|
||||
|
@ -221,47 +194,3 @@ async def test_aimage_generation_bedrock_with_optional_params():
|
|||
pass
|
||||
else:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
||||
|
||||
from openai.types.image import Image
|
||||
|
||||
|
||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
||||
@pytest.mark.asyncio
|
||||
async def test_aimage_generation_vertex_ai(sync_mode):
|
||||
|
||||
litellm.set_verbose = True
|
||||
|
||||
load_vertex_ai_credentials()
|
||||
data = {
|
||||
"prompt": "An olympic size swimming pool",
|
||||
"model": "vertex_ai/imagegeneration@006",
|
||||
"vertex_ai_project": "adroit-crow-413218",
|
||||
"vertex_ai_location": "us-central1",
|
||||
"n": 1,
|
||||
}
|
||||
try:
|
||||
if sync_mode:
|
||||
response = litellm.image_generation(**data)
|
||||
else:
|
||||
response = await litellm.aimage_generation(**data)
|
||||
assert response.data is not None
|
||||
assert len(response.data) > 0
|
||||
|
||||
for d in response.data:
|
||||
assert isinstance(d, Image)
|
||||
print("data in response.data", d)
|
||||
assert d.b64_json is not None
|
||||
except litellm.ServiceUnavailableError as e:
|
||||
pass
|
||||
except litellm.RateLimitError as e:
|
||||
pass
|
||||
except litellm.InternalServerError as e:
|
||||
pass
|
||||
except litellm.ContentPolicyViolationError:
|
||||
pass # Azure randomly raises these errors - skip when they occur
|
||||
except Exception as e:
|
||||
if "Your task failed as a result of our safety system." in str(e):
|
||||
pass
|
||||
else:
|
||||
pytest.fail(f"An exception occurred - {str(e)}")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue