forked from phoenix/litellm-mirror
(feat) Add cost tracking for Azure Dall-e-3 Image Generation + use base class to ensure basic image generation tests pass (#6716)
* add BaseImageGenTest * use 1 class for unit testing * add debugging to BaseImageGenTest * TestAzureOpenAIDalle3 * fix response_cost_calculator * test_basic_image_generation * fix img gen basic test * fix _select_model_name_for_cost_calc * fix test_aimage_generation_bedrock_with_optional_params * fix undo changes cost tracking * fix response_cost_calculator * fix test_cost_azure_gpt_35
This commit is contained in:
parent
6d4cf2d908
commit
73c7b73aa0
3 changed files with 139 additions and 186 deletions
|
@ -171,7 +171,6 @@ def cost_per_token( # noqa: PLR0915
|
||||||
model_with_provider = model_with_provider_and_region
|
model_with_provider = model_with_provider_and_region
|
||||||
else:
|
else:
|
||||||
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
|
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
|
||||||
|
|
||||||
model_without_prefix = model
|
model_without_prefix = model
|
||||||
model_parts = model.split("/", 1)
|
model_parts = model.split("/", 1)
|
||||||
if len(model_parts) > 1:
|
if len(model_parts) > 1:
|
||||||
|
@ -454,7 +453,6 @@ def _select_model_name_for_cost_calc(
|
||||||
|
|
||||||
if base_model is not None:
|
if base_model is not None:
|
||||||
return base_model
|
return base_model
|
||||||
|
|
||||||
return_model = model
|
return_model = model
|
||||||
if isinstance(completion_response, str):
|
if isinstance(completion_response, str):
|
||||||
return return_model
|
return return_model
|
||||||
|
@ -620,7 +618,8 @@ def completion_cost( # noqa: PLR0915
|
||||||
f"completion_response response ms: {getattr(completion_response, '_response_ms', None)} "
|
f"completion_response response ms: {getattr(completion_response, '_response_ms', None)} "
|
||||||
)
|
)
|
||||||
model = _select_model_name_for_cost_calc(
|
model = _select_model_name_for_cost_calc(
|
||||||
model=model, completion_response=completion_response
|
model=model,
|
||||||
|
completion_response=completion_response,
|
||||||
)
|
)
|
||||||
hidden_params = getattr(completion_response, "_hidden_params", None)
|
hidden_params = getattr(completion_response, "_hidden_params", None)
|
||||||
if hidden_params is not None:
|
if hidden_params is not None:
|
||||||
|
@ -853,6 +852,8 @@ def response_cost_calculator(
|
||||||
if isinstance(response_object, BaseModel):
|
if isinstance(response_object, BaseModel):
|
||||||
response_object._hidden_params["optional_params"] = optional_params
|
response_object._hidden_params["optional_params"] = optional_params
|
||||||
if isinstance(response_object, ImageResponse):
|
if isinstance(response_object, ImageResponse):
|
||||||
|
if base_model is not None:
|
||||||
|
model = base_model
|
||||||
response_cost = completion_cost(
|
response_cost = completion_cost(
|
||||||
completion_response=response_object,
|
completion_response=response_object,
|
||||||
model=model,
|
model=model,
|
||||||
|
|
87
tests/image_gen_tests/base_image_generation_test.py
Normal file
87
tests/image_gen_tests/base_image_generation_test.py
Normal file
|
@ -0,0 +1,87 @@
|
||||||
|
import asyncio
|
||||||
|
import httpx
|
||||||
|
import json
|
||||||
|
import pytest
|
||||||
|
import sys
|
||||||
|
from typing import Any, Dict, List, Optional
|
||||||
|
from unittest.mock import MagicMock, Mock, patch
|
||||||
|
import os
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import litellm
|
||||||
|
from litellm.exceptions import BadRequestError
|
||||||
|
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
|
||||||
|
from litellm.utils import CustomStreamWrapper
|
||||||
|
from openai.types.image import Image
|
||||||
|
from litellm.integrations.custom_logger import CustomLogger
|
||||||
|
from litellm.types.utils import StandardLoggingPayload
|
||||||
|
|
||||||
|
|
||||||
|
class TestCustomLogger(CustomLogger):
|
||||||
|
def __init__(self):
|
||||||
|
super().__init__()
|
||||||
|
self.standard_logging_payload: Optional[StandardLoggingPayload] = None
|
||||||
|
|
||||||
|
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
|
||||||
|
self.standard_logging_payload = kwargs.get("standard_logging_object")
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
# test_example.py
|
||||||
|
from abc import ABC, abstractmethod
|
||||||
|
|
||||||
|
|
||||||
|
class BaseImageGenTest(ABC):
|
||||||
|
"""
|
||||||
|
Abstract base test class that enforces a common test across all test classes.
|
||||||
|
"""
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
|
"""Must return the base image generation call args"""
|
||||||
|
pass
|
||||||
|
|
||||||
|
@pytest.mark.asyncio(scope="module")
|
||||||
|
async def test_basic_image_generation(self):
|
||||||
|
"""Test basic image generation"""
|
||||||
|
try:
|
||||||
|
custom_logger = TestCustomLogger()
|
||||||
|
litellm.callbacks = [custom_logger]
|
||||||
|
base_image_generation_call_args = self.get_base_image_generation_call_args()
|
||||||
|
litellm.set_verbose = True
|
||||||
|
response = await litellm.aimage_generation(
|
||||||
|
**base_image_generation_call_args, prompt="A image of a otter"
|
||||||
|
)
|
||||||
|
print(response)
|
||||||
|
|
||||||
|
await asyncio.sleep(1)
|
||||||
|
|
||||||
|
assert response._hidden_params["response_cost"] is not None
|
||||||
|
assert response._hidden_params["response_cost"] > 0
|
||||||
|
print("response_cost", response._hidden_params["response_cost"])
|
||||||
|
|
||||||
|
logged_standard_logging_payload = custom_logger.standard_logging_payload
|
||||||
|
print("logged_standard_logging_payload", logged_standard_logging_payload)
|
||||||
|
assert logged_standard_logging_payload is not None
|
||||||
|
assert logged_standard_logging_payload["response_cost"] is not None
|
||||||
|
assert logged_standard_logging_payload["response_cost"] > 0
|
||||||
|
|
||||||
|
from openai.types.images_response import ImagesResponse
|
||||||
|
|
||||||
|
ImagesResponse.model_validate(response.model_dump())
|
||||||
|
|
||||||
|
for d in response.data:
|
||||||
|
assert isinstance(d, Image)
|
||||||
|
print("data in response.data", d)
|
||||||
|
assert d.b64_json is not None or d.url is not None
|
||||||
|
except litellm.RateLimitError as e:
|
||||||
|
pass
|
||||||
|
except litellm.ContentPolicyViolationError:
|
||||||
|
pass # Azure randomly raises these errors - skip when they occur
|
||||||
|
except Exception as e:
|
||||||
|
if "Your task failed as a result of our safety system." in str(e):
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
pytest.fail(f"An exception occurred - {str(e)}")
|
|
@ -22,6 +22,11 @@ import pytest
|
||||||
import litellm
|
import litellm
|
||||||
import json
|
import json
|
||||||
import tempfile
|
import tempfile
|
||||||
|
from base_image_generation_test import BaseImageGenTest
|
||||||
|
import logging
|
||||||
|
from litellm._logging import verbose_logger
|
||||||
|
|
||||||
|
verbose_logger.setLevel(logging.DEBUG)
|
||||||
|
|
||||||
|
|
||||||
def get_vertex_ai_creds_json() -> dict:
|
def get_vertex_ai_creds_json() -> dict:
|
||||||
|
@ -97,67 +102,49 @@ def load_vertex_ai_credentials():
|
||||||
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
|
||||||
|
|
||||||
|
|
||||||
def test_image_generation_openai():
|
class TestVertexImageGeneration(BaseImageGenTest):
|
||||||
try:
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
|
# comment this when running locally
|
||||||
|
load_vertex_ai_credentials()
|
||||||
|
|
||||||
|
litellm.in_memory_llm_clients_cache = {}
|
||||||
|
return {
|
||||||
|
"model": "vertex_ai/imagegeneration@006",
|
||||||
|
"vertex_ai_project": "adroit-crow-413218",
|
||||||
|
"vertex_ai_location": "us-central1",
|
||||||
|
"n": 1,
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class TestBedrockSd3(BaseImageGenTest):
|
||||||
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
|
litellm.in_memory_llm_clients_cache = {}
|
||||||
|
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestBedrockSd1(BaseImageGenTest):
|
||||||
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
|
litellm.in_memory_llm_clients_cache = {}
|
||||||
|
return {"model": "bedrock/stability.sd3-large-v1:0"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestOpenAIDalle3(BaseImageGenTest):
|
||||||
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
|
return {"model": "dall-e-3"}
|
||||||
|
|
||||||
|
|
||||||
|
class TestAzureOpenAIDalle3(BaseImageGenTest):
|
||||||
|
def get_base_image_generation_call_args(self) -> dict:
|
||||||
litellm.set_verbose = True
|
litellm.set_verbose = True
|
||||||
response = litellm.image_generation(
|
return {
|
||||||
prompt="A cute baby sea otter", model="dall-e-3"
|
"model": "azure/dall-e-3-test",
|
||||||
)
|
"api_version": "2023-09-01-preview",
|
||||||
print(f"response: {response}")
|
"metadata": {
|
||||||
assert len(response.data) > 0
|
"model_info": {
|
||||||
except litellm.RateLimitError as e:
|
"base_model": "dall-e-3",
|
||||||
pass
|
}
|
||||||
except litellm.ContentPolicyViolationError:
|
},
|
||||||
pass # OpenAI randomly raises these errors - skip when they occur
|
}
|
||||||
except Exception as e:
|
|
||||||
if "Connection error" in str(e):
|
|
||||||
pass
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_image_generation_openai()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"sync_mode",
|
|
||||||
[
|
|
||||||
True,
|
|
||||||
], # False
|
|
||||||
) #
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.flaky(retries=3, delay=1)
|
|
||||||
async def test_image_generation_azure(sync_mode):
|
|
||||||
try:
|
|
||||||
if sync_mode:
|
|
||||||
response = litellm.image_generation(
|
|
||||||
prompt="A cute baby sea otter",
|
|
||||||
model="azure/",
|
|
||||||
api_version="2023-06-01-preview",
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
response = await litellm.aimage_generation(
|
|
||||||
prompt="A cute baby sea otter",
|
|
||||||
model="azure/",
|
|
||||||
api_version="2023-06-01-preview",
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
assert len(response.data) > 0
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except litellm.ContentPolicyViolationError:
|
|
||||||
pass # Azure randomly raises these errors - skip when they occur
|
|
||||||
except litellm.InternalServerError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
if "Your task failed as a result of our safety system." in str(e):
|
|
||||||
pass
|
|
||||||
if "Connection error" in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
# test_image_generation_azure()
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.flaky(retries=3, delay=1)
|
@pytest.mark.flaky(retries=3, delay=1)
|
||||||
|
@ -188,91 +175,13 @@ def test_image_generation_azure_dall_e_3():
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
# test_image_generation_azure_dall_e_3()
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_image_generation_openai():
|
|
||||||
try:
|
|
||||||
response = litellm.image_generation(
|
|
||||||
prompt="A cute baby sea otter", model="dall-e-3"
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
assert len(response.data) > 0
|
|
||||||
except litellm.APIError:
|
|
||||||
pass
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except litellm.ContentPolicyViolationError:
|
|
||||||
pass # openai randomly raises these errors - skip when they occur
|
|
||||||
except litellm.InternalServerError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
if "Connection error" in str(e):
|
|
||||||
pass
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
# asyncio.run(test_async_image_generation_openai())
|
# asyncio.run(test_async_image_generation_openai())
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_async_image_generation_azure():
|
|
||||||
try:
|
|
||||||
response = await litellm.aimage_generation(
|
|
||||||
prompt="A cute baby sea otter",
|
|
||||||
model="azure/dall-e-3-test",
|
|
||||||
api_version="2023-09-01-preview",
|
|
||||||
)
|
|
||||||
print(f"response: {response}")
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except litellm.ContentPolicyViolationError:
|
|
||||||
pass # Azure randomly raises these errors - skip when they occur
|
|
||||||
except litellm.InternalServerError:
|
|
||||||
pass
|
|
||||||
except Exception as e:
|
|
||||||
if "Your task failed as a result of our safety system." in str(e):
|
|
||||||
pass
|
|
||||||
if "Connection error" in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
@pytest.mark.parametrize(
|
|
||||||
"model",
|
|
||||||
["bedrock/stability.sd3-large-v1:0", "bedrock/stability.stable-diffusion-xl-v1"],
|
|
||||||
)
|
|
||||||
def test_image_generation_bedrock(model):
|
|
||||||
try:
|
|
||||||
litellm.set_verbose = True
|
|
||||||
response = litellm.image_generation(
|
|
||||||
prompt="A cute baby sea otter",
|
|
||||||
model=model,
|
|
||||||
aws_region_name="us-west-2",
|
|
||||||
)
|
|
||||||
|
|
||||||
print(f"response: {response}")
|
|
||||||
print("response hidden params", response._hidden_params)
|
|
||||||
|
|
||||||
assert response._hidden_params["response_cost"] is not None
|
|
||||||
from openai.types.images_response import ImagesResponse
|
|
||||||
|
|
||||||
ImagesResponse.model_validate(response.model_dump())
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except litellm.ContentPolicyViolationError:
|
|
||||||
pass # Azure randomly raises these errors - skip when they occur
|
|
||||||
except Exception as e:
|
|
||||||
if "Your task failed as a result of our safety system." in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_aimage_generation_bedrock_with_optional_params():
|
async def test_aimage_generation_bedrock_with_optional_params():
|
||||||
try:
|
try:
|
||||||
|
litellm.in_memory_llm_clients_cache = {}
|
||||||
response = await litellm.aimage_generation(
|
response = await litellm.aimage_generation(
|
||||||
prompt="A cute baby sea otter",
|
prompt="A cute baby sea otter",
|
||||||
model="bedrock/stability.stable-diffusion-xl-v1",
|
model="bedrock/stability.stable-diffusion-xl-v1",
|
||||||
|
@ -288,47 +197,3 @@ async def test_aimage_generation_bedrock_with_optional_params():
|
||||||
pass
|
pass
|
||||||
else:
|
else:
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
pytest.fail(f"An exception occurred - {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
from openai.types.image import Image
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("sync_mode", [True, False])
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_aimage_generation_vertex_ai(sync_mode):
|
|
||||||
|
|
||||||
litellm.set_verbose = True
|
|
||||||
|
|
||||||
load_vertex_ai_credentials()
|
|
||||||
data = {
|
|
||||||
"prompt": "An olympic size swimming pool",
|
|
||||||
"model": "vertex_ai/imagegeneration@006",
|
|
||||||
"vertex_ai_project": "adroit-crow-413218",
|
|
||||||
"vertex_ai_location": "us-central1",
|
|
||||||
"n": 1,
|
|
||||||
}
|
|
||||||
try:
|
|
||||||
if sync_mode:
|
|
||||||
response = litellm.image_generation(**data)
|
|
||||||
else:
|
|
||||||
response = await litellm.aimage_generation(**data)
|
|
||||||
assert response.data is not None
|
|
||||||
assert len(response.data) > 0
|
|
||||||
|
|
||||||
for d in response.data:
|
|
||||||
assert isinstance(d, Image)
|
|
||||||
print("data in response.data", d)
|
|
||||||
assert d.b64_json is not None
|
|
||||||
except litellm.ServiceUnavailableError as e:
|
|
||||||
pass
|
|
||||||
except litellm.RateLimitError as e:
|
|
||||||
pass
|
|
||||||
except litellm.InternalServerError as e:
|
|
||||||
pass
|
|
||||||
except litellm.ContentPolicyViolationError:
|
|
||||||
pass # Azure randomly raises these errors - skip when they occur
|
|
||||||
except Exception as e:
|
|
||||||
if "Your task failed as a result of our safety system." in str(e):
|
|
||||||
pass
|
|
||||||
else:
|
|
||||||
pytest.fail(f"An exception occurred - {str(e)}")
|
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue