(feat) Add cost tracking for Azure Dall-e-3 Image Generation + use base class to ensure basic image generation tests pass (#6716)

* add BaseImageGenTest

* use 1 class for unit testing

* add debugging to BaseImageGenTest

* TestAzureOpenAIDalle3

* fix response_cost_calculator

* test_basic_image_generation

* fix img gen basic test

* fix _select_model_name_for_cost_calc

* fix test_aimage_generation_bedrock_with_optional_params

* fix undo changes cost tracking

* fix response_cost_calculator

* fix test_cost_azure_gpt_35
This commit is contained in:
Ishaan Jaff 2024-11-12 20:02:16 -08:00 committed by GitHub
parent 6d4cf2d908
commit 73c7b73aa0
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 139 additions and 186 deletions

View file

@ -22,6 +22,11 @@ import pytest
import litellm
import json
import tempfile
from base_image_generation_test import BaseImageGenTest
import logging
from litellm._logging import verbose_logger
verbose_logger.setLevel(logging.DEBUG)
def get_vertex_ai_creds_json() -> dict:
@ -97,67 +102,49 @@ def load_vertex_ai_credentials():
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
def test_image_generation_openai():
try:
class TestVertexImageGeneration(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
# comment this when running locally
load_vertex_ai_credentials()
litellm.in_memory_llm_clients_cache = {}
return {
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "adroit-crow-413218",
"vertex_ai_location": "us-central1",
"n": 1,
}
class TestBedrockSd3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = {}
return {"model": "bedrock/stability.sd3-large-v1:0"}
class TestBedrockSd1(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = {}
return {"model": "bedrock/stability.sd3-large-v1:0"}
class TestOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
return {"model": "dall-e-3"}
class TestAzureOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.set_verbose = True
response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-3"
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur
except Exception as e:
if "Connection error" in str(e):
pass
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_openai()
@pytest.mark.parametrize(
"sync_mode",
[
True,
], # False
) #
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_image_generation_azure(sync_mode):
try:
if sync_mode:
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="azure/",
api_version="2023-06-01-preview",
)
else:
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="azure/",
api_version="2023-06-01-preview",
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except litellm.InternalServerError:
pass
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
if "Connection error" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_azure()
return {
"model": "azure/dall-e-3-test",
"api_version": "2023-09-01-preview",
"metadata": {
"model_info": {
"base_model": "dall-e-3",
}
},
}
@pytest.mark.flaky(retries=3, delay=1)
@ -188,91 +175,13 @@ def test_image_generation_azure_dall_e_3():
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_azure_dall_e_3()
@pytest.mark.asyncio
async def test_async_image_generation_openai():
try:
response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-3"
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.APIError:
pass
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # openai randomly raises these errors - skip when they occur
except litellm.InternalServerError:
pass
except Exception as e:
if "Connection error" in str(e):
pass
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_image_generation_openai())
@pytest.mark.asyncio
async def test_async_image_generation_azure():
try:
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="azure/dall-e-3-test",
api_version="2023-09-01-preview",
)
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except litellm.InternalServerError:
pass
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
if "Connection error" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
["bedrock/stability.sd3-large-v1:0", "bedrock/stability.stable-diffusion-xl-v1"],
)
def test_image_generation_bedrock(model):
try:
litellm.set_verbose = True
response = litellm.image_generation(
prompt="A cute baby sea otter",
model=model,
aws_region_name="us-west-2",
)
print(f"response: {response}")
print("response hidden params", response._hidden_params)
assert response._hidden_params["response_cost"] is not None
from openai.types.images_response import ImagesResponse
ImagesResponse.model_validate(response.model_dump())
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio
async def test_aimage_generation_bedrock_with_optional_params():
try:
litellm.in_memory_llm_clients_cache = {}
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v1",
@ -288,47 +197,3 @@ async def test_aimage_generation_bedrock_with_optional_params():
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
from openai.types.image import Image
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_aimage_generation_vertex_ai(sync_mode):
litellm.set_verbose = True
load_vertex_ai_credentials()
data = {
"prompt": "An olympic size swimming pool",
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "adroit-crow-413218",
"vertex_ai_location": "us-central1",
"n": 1,
}
try:
if sync_mode:
response = litellm.image_generation(**data)
else:
response = await litellm.aimage_generation(**data)
assert response.data is not None
assert len(response.data) > 0
for d in response.data:
assert isinstance(d, Image)
print("data in response.data", d)
assert d.b64_json is not None
except litellm.ServiceUnavailableError as e:
pass
except litellm.RateLimitError as e:
pass
except litellm.InternalServerError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")