(feat) add cost tracking stable diffusion 3 on Bedrock (#6676)

* add cost tracking for sd3

* test_image_generation_bedrock

* fix get model info for image cost

* add cost_calculator for stability 1 models

* add unit testing for bedrock image cost calc

* test_cost_calculator_with_no_optional_params

* add test_cost_calculator_basic

* correctly allow size Optional

* fix cost_calculator

* sd3 unit tests cost calc
This commit is contained in:
Ishaan Jaff 2024-11-11 20:21:44 -08:00 committed by GitHub
parent e5051a93a8
commit 25bae4cc23
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 146 additions and 8 deletions

View file

@ -9,12 +9,14 @@ from openai.types.image import Image
logging.basicConfig(level=logging.DEBUG)
load_dotenv()
import asyncio
import os
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
from litellm.llms.bedrock.image.cost_calculator import cost_calculator
from litellm.types.utils import ImageResponse, ImageObject
import os
import litellm
from litellm.llms.bedrock.image.amazon_stability3_transformation import (
@ -27,7 +29,6 @@ from litellm.types.llms.bedrock import (
AmazonStability3TextToImageRequest,
AmazonStability3TextToImageResponse,
)
from litellm.types.utils import ImageResponse
from unittest.mock import MagicMock, patch
from litellm.llms.bedrock.image.image_handler import (
BedrockImageGeneration,
@ -149,7 +150,7 @@ def test_get_request_body_stability():
handler = BedrockImageGeneration()
prompt = "A beautiful sunset"
optional_params = {"cfg_scale": 7}
model = "stability.stable-diffusion-xl"
model = "stability.stable-diffusion-xl-v1"
result = handler._get_request_body(
model=model, prompt=prompt, optional_params=optional_params
@ -185,3 +186,80 @@ def test_transform_response_dict_to_openai_response_stability3():
assert len(result.data) == 2
assert all(hasattr(img, "b64_json") for img in result.data)
assert [img.b64_json for img in result.data] == ["base64_image_1", "base64_image_2"]
def test_cost_calculator_stability3():
# Mock image response
image_response = ImageResponse(
data=[
ImageObject(b64_json="base64_image_1"),
ImageObject(b64_json="base64_image_2"),
]
)
cost = cost_calculator(
model="stability.sd3-large-v1:0",
size="1024-x-1024",
image_response=image_response,
)
print("cost", cost)
# Assert cost is calculated correctly for 2 images
assert isinstance(cost, float)
assert cost > 0
def test_cost_calculator_stability1():
# Mock image response
image_response = ImageResponse(data=[ImageObject(b64_json="base64_image_1")])
# Test with different step configurations
cost_default_steps = cost_calculator(
model="stability.stable-diffusion-xl-v1",
size="1024-x-1024",
image_response=image_response,
optional_params={"steps": 50},
)
cost_max_steps = cost_calculator(
model="stability.stable-diffusion-xl-v1",
size="1024-x-1024",
image_response=image_response,
optional_params={"steps": 51},
)
# Assert costs are calculated correctly
assert isinstance(cost_default_steps, float)
assert isinstance(cost_max_steps, float)
assert cost_default_steps > 0
assert cost_max_steps > 0
# Max steps should be more expensive
assert cost_max_steps > cost_default_steps
def test_cost_calculator_with_no_optional_params():
image_response = ImageResponse(data=[ImageObject(b64_json="base64_image_1")])
cost = cost_calculator(
model="stability.stable-diffusion-xl-v0",
size="512-x-512",
image_response=image_response,
optional_params=None,
)
assert isinstance(cost, float)
assert cost > 0
def test_cost_calculator_basic():
image_response = ImageResponse(data=[ImageObject(b64_json="base64_image_1")])
cost = cost_calculator(
model="stability.stable-diffusion-xl-v1",
image_response=image_response,
optional_params=None,
)
assert isinstance(cost, float)
assert cost > 0

View file

@ -253,6 +253,9 @@ def test_image_generation_bedrock(model):
)
print(f"response: {response}")
print("response hidden params", response._hidden_params)
assert response._hidden_params["response_cost"] is not None
from openai.types.images_response import ImagesResponse
ImagesResponse.model_validate(response.model_dump())