Compare commits

...
Sign in to create a new pull request.

12 commits

Author SHA1 Message Date
Ishaan Jaff
855761d041 fix test_cost_azure_gpt_35 2024-11-12 17:47:31 -08:00
Ishaan Jaff
96ec2a5a19 fix response_cost_calculator 2024-11-12 17:46:27 -08:00
Ishaan Jaff
0a73fa6d01 fix undo changes cost tracking 2024-11-12 17:45:12 -08:00
Ishaan Jaff
84603763cc fix test_aimage_generation_bedrock_with_optional_params 2024-11-12 16:58:07 -08:00
Ishaan Jaff
501ced09ba fix _select_model_name_for_cost_calc 2024-11-12 16:53:36 -08:00
Ishaan Jaff
e0f1f18339 fix img gen basic test 2024-11-12 16:46:21 -08:00
Ishaan Jaff
d2a1ac804b test_basic_image_generation 2024-11-12 16:34:47 -08:00
Ishaan Jaff
1c2f7f792b fix response_cost_calculator 2024-11-12 16:19:43 -08:00
Ishaan Jaff
2415e7d0a6 TestAzureOpenAIDalle3 2024-11-12 16:18:49 -08:00
Ishaan Jaff
a6ce4254f0 add debugging to BaseImageGenTest 2024-11-12 16:18:22 -08:00
Ishaan Jaff
c15359911a use 1 class for unit testing 2024-11-12 15:48:05 -08:00
Ishaan Jaff
26c19ba3e1 add BaseImageGenTest 2024-11-12 15:47:27 -08:00
3 changed files with 139 additions and 186 deletions

View file

@ -171,7 +171,6 @@ def cost_per_token( # noqa: PLR0915
model_with_provider = model_with_provider_and_region
else:
_, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model)
model_without_prefix = model
model_parts = model.split("/", 1)
if len(model_parts) > 1:
@ -454,7 +453,6 @@ def _select_model_name_for_cost_calc(
if base_model is not None:
return base_model
return_model = model
if isinstance(completion_response, str):
return return_model
@ -620,7 +618,8 @@ def completion_cost( # noqa: PLR0915
f"completion_response response ms: {getattr(completion_response, '_response_ms', None)} "
)
model = _select_model_name_for_cost_calc(
model=model, completion_response=completion_response
model=model,
completion_response=completion_response,
)
hidden_params = getattr(completion_response, "_hidden_params", None)
if hidden_params is not None:
@ -853,6 +852,8 @@ def response_cost_calculator(
if isinstance(response_object, BaseModel):
response_object._hidden_params["optional_params"] = optional_params
if isinstance(response_object, ImageResponse):
if base_model is not None:
model = base_model
response_cost = completion_cost(
completion_response=response_object,
model=model,

View file

@ -0,0 +1,87 @@
import asyncio
import httpx
import json
import pytest
import sys
from typing import Any, Dict, List, Optional
from unittest.mock import MagicMock, Mock, patch
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import litellm
from litellm.exceptions import BadRequestError
from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler
from litellm.utils import CustomStreamWrapper
from openai.types.image import Image
from litellm.integrations.custom_logger import CustomLogger
from litellm.types.utils import StandardLoggingPayload
class TestCustomLogger(CustomLogger):
def __init__(self):
super().__init__()
self.standard_logging_payload: Optional[StandardLoggingPayload] = None
async def async_log_success_event(self, kwargs, response_obj, start_time, end_time):
self.standard_logging_payload = kwargs.get("standard_logging_object")
pass
# test_example.py
from abc import ABC, abstractmethod
class BaseImageGenTest(ABC):
"""
Abstract base test class that enforces a common test across all test classes.
"""
@abstractmethod
def get_base_image_generation_call_args(self) -> dict:
"""Must return the base image generation call args"""
pass
@pytest.mark.asyncio(scope="module")
async def test_basic_image_generation(self):
"""Test basic image generation"""
try:
custom_logger = TestCustomLogger()
litellm.callbacks = [custom_logger]
base_image_generation_call_args = self.get_base_image_generation_call_args()
litellm.set_verbose = True
response = await litellm.aimage_generation(
**base_image_generation_call_args, prompt="A image of a otter"
)
print(response)
await asyncio.sleep(1)
assert response._hidden_params["response_cost"] is not None
assert response._hidden_params["response_cost"] > 0
print("response_cost", response._hidden_params["response_cost"])
logged_standard_logging_payload = custom_logger.standard_logging_payload
print("logged_standard_logging_payload", logged_standard_logging_payload)
assert logged_standard_logging_payload is not None
assert logged_standard_logging_payload["response_cost"] is not None
assert logged_standard_logging_payload["response_cost"] > 0
from openai.types.images_response import ImagesResponse
ImagesResponse.model_validate(response.model_dump())
for d in response.data:
assert isinstance(d, Image)
print("data in response.data", d)
assert d.b64_json is not None or d.url is not None
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")

View file

@ -22,6 +22,11 @@ import pytest
import litellm
import json
import tempfile
from base_image_generation_test import BaseImageGenTest
import logging
from litellm._logging import verbose_logger
verbose_logger.setLevel(logging.DEBUG)
def get_vertex_ai_creds_json() -> dict:
@ -97,67 +102,49 @@ def load_vertex_ai_credentials():
os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name)
def test_image_generation_openai():
try:
class TestVertexImageGeneration(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
# comment this when running locally
load_vertex_ai_credentials()
litellm.in_memory_llm_clients_cache = {}
return {
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "adroit-crow-413218",
"vertex_ai_location": "us-central1",
"n": 1,
}
class TestBedrockSd3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = {}
return {"model": "bedrock/stability.sd3-large-v1:0"}
class TestBedrockSd1(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.in_memory_llm_clients_cache = {}
return {"model": "bedrock/stability.sd3-large-v1:0"}
class TestOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
return {"model": "dall-e-3"}
class TestAzureOpenAIDalle3(BaseImageGenTest):
def get_base_image_generation_call_args(self) -> dict:
litellm.set_verbose = True
response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-3"
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur
except Exception as e:
if "Connection error" in str(e):
pass
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_openai()
@pytest.mark.parametrize(
"sync_mode",
[
True,
], # False
) #
@pytest.mark.asyncio
@pytest.mark.flaky(retries=3, delay=1)
async def test_image_generation_azure(sync_mode):
try:
if sync_mode:
response = litellm.image_generation(
prompt="A cute baby sea otter",
model="azure/",
api_version="2023-06-01-preview",
)
else:
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="azure/",
api_version="2023-06-01-preview",
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except litellm.InternalServerError:
pass
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
if "Connection error" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_azure()
return {
"model": "azure/dall-e-3-test",
"api_version": "2023-09-01-preview",
"metadata": {
"model_info": {
"base_model": "dall-e-3",
}
},
}
@pytest.mark.flaky(retries=3, delay=1)
@ -188,91 +175,13 @@ def test_image_generation_azure_dall_e_3():
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_azure_dall_e_3()
@pytest.mark.asyncio
async def test_async_image_generation_openai():
try:
response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-3"
)
print(f"response: {response}")
assert len(response.data) > 0
except litellm.APIError:
pass
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # openai randomly raises these errors - skip when they occur
except litellm.InternalServerError:
pass
except Exception as e:
if "Connection error" in str(e):
pass
pytest.fail(f"An exception occurred - {str(e)}")
# asyncio.run(test_async_image_generation_openai())
@pytest.mark.asyncio
async def test_async_image_generation_azure():
try:
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="azure/dall-e-3-test",
api_version="2023-09-01-preview",
)
print(f"response: {response}")
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except litellm.InternalServerError:
pass
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
if "Connection error" in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio
@pytest.mark.parametrize(
"model",
["bedrock/stability.sd3-large-v1:0", "bedrock/stability.stable-diffusion-xl-v1"],
)
def test_image_generation_bedrock(model):
try:
litellm.set_verbose = True
response = litellm.image_generation(
prompt="A cute baby sea otter",
model=model,
aws_region_name="us-west-2",
)
print(f"response: {response}")
print("response hidden params", response._hidden_params)
assert response._hidden_params["response_cost"] is not None
from openai.types.images_response import ImagesResponse
ImagesResponse.model_validate(response.model_dump())
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
@pytest.mark.asyncio
async def test_aimage_generation_bedrock_with_optional_params():
try:
litellm.in_memory_llm_clients_cache = {}
response = await litellm.aimage_generation(
prompt="A cute baby sea otter",
model="bedrock/stability.stable-diffusion-xl-v1",
@ -288,47 +197,3 @@ async def test_aimage_generation_bedrock_with_optional_params():
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")
from openai.types.image import Image
@pytest.mark.parametrize("sync_mode", [True, False])
@pytest.mark.asyncio
async def test_aimage_generation_vertex_ai(sync_mode):
litellm.set_verbose = True
load_vertex_ai_credentials()
data = {
"prompt": "An olympic size swimming pool",
"model": "vertex_ai/imagegeneration@006",
"vertex_ai_project": "adroit-crow-413218",
"vertex_ai_location": "us-central1",
"n": 1,
}
try:
if sync_mode:
response = litellm.image_generation(**data)
else:
response = await litellm.aimage_generation(**data)
assert response.data is not None
assert len(response.data) > 0
for d in response.data:
assert isinstance(d, Image)
print("data in response.data", d)
assert d.b64_json is not None
except litellm.ServiceUnavailableError as e:
pass
except litellm.RateLimitError as e:
pass
except litellm.InternalServerError as e:
pass
except litellm.ContentPolicyViolationError:
pass # Azure randomly raises these errors - skip when they occur
except Exception as e:
if "Your task failed as a result of our safety system." in str(e):
pass
else:
pytest.fail(f"An exception occurred - {str(e)}")