diff --git a/litellm/cost_calculator.py b/litellm/cost_calculator.py index 2aff3b04c..0aa8a8e36 100644 --- a/litellm/cost_calculator.py +++ b/litellm/cost_calculator.py @@ -171,7 +171,6 @@ def cost_per_token( # noqa: PLR0915 model_with_provider = model_with_provider_and_region else: _, custom_llm_provider, _, _ = litellm.get_llm_provider(model=model) - model_without_prefix = model model_parts = model.split("/", 1) if len(model_parts) > 1: @@ -454,7 +453,6 @@ def _select_model_name_for_cost_calc( if base_model is not None: return base_model - return_model = model if isinstance(completion_response, str): return return_model @@ -620,7 +618,8 @@ def completion_cost( # noqa: PLR0915 f"completion_response response ms: {getattr(completion_response, '_response_ms', None)} " ) model = _select_model_name_for_cost_calc( - model=model, completion_response=completion_response + model=model, + completion_response=completion_response, ) hidden_params = getattr(completion_response, "_hidden_params", None) if hidden_params is not None: @@ -853,6 +852,8 @@ def response_cost_calculator( if isinstance(response_object, BaseModel): response_object._hidden_params["optional_params"] = optional_params if isinstance(response_object, ImageResponse): + if base_model is not None: + model = base_model response_cost = completion_cost( completion_response=response_object, model=model, diff --git a/tests/image_gen_tests/base_image_generation_test.py b/tests/image_gen_tests/base_image_generation_test.py new file mode 100644 index 000000000..e0652114d --- /dev/null +++ b/tests/image_gen_tests/base_image_generation_test.py @@ -0,0 +1,87 @@ +import asyncio +import httpx +import json +import pytest +import sys +from typing import Any, Dict, List, Optional +from unittest.mock import MagicMock, Mock, patch +import os + +sys.path.insert( + 0, os.path.abspath("../..") +) # Adds the parent directory to the system path +import litellm +from litellm.exceptions import BadRequestError +from litellm.llms.custom_httpx.http_handler import AsyncHTTPHandler, HTTPHandler +from litellm.utils import CustomStreamWrapper +from openai.types.image import Image +from litellm.integrations.custom_logger import CustomLogger +from litellm.types.utils import StandardLoggingPayload + + +class TestCustomLogger(CustomLogger): + def __init__(self): + super().__init__() + self.standard_logging_payload: Optional[StandardLoggingPayload] = None + + async def async_log_success_event(self, kwargs, response_obj, start_time, end_time): + self.standard_logging_payload = kwargs.get("standard_logging_object") + pass + + +# test_example.py +from abc import ABC, abstractmethod + + +class BaseImageGenTest(ABC): + """ + Abstract base test class that enforces a common test across all test classes. + """ + + @abstractmethod + def get_base_image_generation_call_args(self) -> dict: + """Must return the base image generation call args""" + pass + + @pytest.mark.asyncio(scope="module") + async def test_basic_image_generation(self): + """Test basic image generation""" + try: + custom_logger = TestCustomLogger() + litellm.callbacks = [custom_logger] + base_image_generation_call_args = self.get_base_image_generation_call_args() + litellm.set_verbose = True + response = await litellm.aimage_generation( + **base_image_generation_call_args, prompt="A image of a otter" + ) + print(response) + + await asyncio.sleep(1) + + assert response._hidden_params["response_cost"] is not None + assert response._hidden_params["response_cost"] > 0 + print("response_cost", response._hidden_params["response_cost"]) + + logged_standard_logging_payload = custom_logger.standard_logging_payload + print("logged_standard_logging_payload", logged_standard_logging_payload) + assert logged_standard_logging_payload is not None + assert logged_standard_logging_payload["response_cost"] is not None + assert logged_standard_logging_payload["response_cost"] > 0 + + from openai.types.images_response import ImagesResponse + + ImagesResponse.model_validate(response.model_dump()) + + for d in response.data: + assert isinstance(d, Image) + print("data in response.data", d) + assert d.b64_json is not None or d.url is not None + except litellm.RateLimitError as e: + pass + except litellm.ContentPolicyViolationError: + pass # Azure randomly raises these errors - skip when they occur + except Exception as e: + if "Your task failed as a result of our safety system." in str(e): + pass + else: + pytest.fail(f"An exception occurred - {str(e)}") diff --git a/tests/image_gen_tests/test_image_generation.py b/tests/image_gen_tests/test_image_generation.py index e94d62c1f..692a0e4e9 100644 --- a/tests/image_gen_tests/test_image_generation.py +++ b/tests/image_gen_tests/test_image_generation.py @@ -22,6 +22,11 @@ import pytest import litellm import json import tempfile +from base_image_generation_test import BaseImageGenTest +import logging +from litellm._logging import verbose_logger + +verbose_logger.setLevel(logging.DEBUG) def get_vertex_ai_creds_json() -> dict: @@ -97,67 +102,49 @@ def load_vertex_ai_credentials(): os.environ["GOOGLE_APPLICATION_CREDENTIALS"] = os.path.abspath(temp_file.name) -def test_image_generation_openai(): - try: +class TestVertexImageGeneration(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + # comment this when running locally + load_vertex_ai_credentials() + + litellm.in_memory_llm_clients_cache = {} + return { + "model": "vertex_ai/imagegeneration@006", + "vertex_ai_project": "adroit-crow-413218", + "vertex_ai_location": "us-central1", + "n": 1, + } + + +class TestBedrockSd3(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + litellm.in_memory_llm_clients_cache = {} + return {"model": "bedrock/stability.sd3-large-v1:0"} + + +class TestBedrockSd1(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + litellm.in_memory_llm_clients_cache = {} + return {"model": "bedrock/stability.sd3-large-v1:0"} + + +class TestOpenAIDalle3(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: + return {"model": "dall-e-3"} + + +class TestAzureOpenAIDalle3(BaseImageGenTest): + def get_base_image_generation_call_args(self) -> dict: litellm.set_verbose = True - response = litellm.image_generation( - prompt="A cute baby sea otter", model="dall-e-3" - ) - print(f"response: {response}") - assert len(response.data) > 0 - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # OpenAI randomly raises these errors - skip when they occur - except Exception as e: - if "Connection error" in str(e): - pass - pytest.fail(f"An exception occurred - {str(e)}") - - -# test_image_generation_openai() - - -@pytest.mark.parametrize( - "sync_mode", - [ - True, - ], # False -) # -@pytest.mark.asyncio -@pytest.mark.flaky(retries=3, delay=1) -async def test_image_generation_azure(sync_mode): - try: - if sync_mode: - response = litellm.image_generation( - prompt="A cute baby sea otter", - model="azure/", - api_version="2023-06-01-preview", - ) - else: - response = await litellm.aimage_generation( - prompt="A cute baby sea otter", - model="azure/", - api_version="2023-06-01-preview", - ) - print(f"response: {response}") - assert len(response.data) > 0 - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # Azure randomly raises these errors - skip when they occur - except litellm.InternalServerError: - pass - except Exception as e: - if "Your task failed as a result of our safety system." in str(e): - pass - if "Connection error" in str(e): - pass - else: - pytest.fail(f"An exception occurred - {str(e)}") - - -# test_image_generation_azure() + return { + "model": "azure/dall-e-3-test", + "api_version": "2023-09-01-preview", + "metadata": { + "model_info": { + "base_model": "dall-e-3", + } + }, + } @pytest.mark.flaky(retries=3, delay=1) @@ -188,91 +175,13 @@ def test_image_generation_azure_dall_e_3(): pytest.fail(f"An exception occurred - {str(e)}") -# test_image_generation_azure_dall_e_3() -@pytest.mark.asyncio -async def test_async_image_generation_openai(): - try: - response = litellm.image_generation( - prompt="A cute baby sea otter", model="dall-e-3" - ) - print(f"response: {response}") - assert len(response.data) > 0 - except litellm.APIError: - pass - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # openai randomly raises these errors - skip when they occur - except litellm.InternalServerError: - pass - except Exception as e: - if "Connection error" in str(e): - pass - pytest.fail(f"An exception occurred - {str(e)}") - - # asyncio.run(test_async_image_generation_openai()) -@pytest.mark.asyncio -async def test_async_image_generation_azure(): - try: - response = await litellm.aimage_generation( - prompt="A cute baby sea otter", - model="azure/dall-e-3-test", - api_version="2023-09-01-preview", - ) - print(f"response: {response}") - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # Azure randomly raises these errors - skip when they occur - except litellm.InternalServerError: - pass - except Exception as e: - if "Your task failed as a result of our safety system." in str(e): - pass - if "Connection error" in str(e): - pass - else: - pytest.fail(f"An exception occurred - {str(e)}") - - -@pytest.mark.asyncio -@pytest.mark.parametrize( - "model", - ["bedrock/stability.sd3-large-v1:0", "bedrock/stability.stable-diffusion-xl-v1"], -) -def test_image_generation_bedrock(model): - try: - litellm.set_verbose = True - response = litellm.image_generation( - prompt="A cute baby sea otter", - model=model, - aws_region_name="us-west-2", - ) - - print(f"response: {response}") - print("response hidden params", response._hidden_params) - - assert response._hidden_params["response_cost"] is not None - from openai.types.images_response import ImagesResponse - - ImagesResponse.model_validate(response.model_dump()) - except litellm.RateLimitError as e: - pass - except litellm.ContentPolicyViolationError: - pass # Azure randomly raises these errors - skip when they occur - except Exception as e: - if "Your task failed as a result of our safety system." in str(e): - pass - else: - pytest.fail(f"An exception occurred - {str(e)}") - - @pytest.mark.asyncio async def test_aimage_generation_bedrock_with_optional_params(): try: + litellm.in_memory_llm_clients_cache = {} response = await litellm.aimage_generation( prompt="A cute baby sea otter", model="bedrock/stability.stable-diffusion-xl-v1", @@ -288,47 +197,3 @@ async def test_aimage_generation_bedrock_with_optional_params(): pass else: pytest.fail(f"An exception occurred - {str(e)}") - - -from openai.types.image import Image - - -@pytest.mark.parametrize("sync_mode", [True, False]) -@pytest.mark.asyncio -async def test_aimage_generation_vertex_ai(sync_mode): - - litellm.set_verbose = True - - load_vertex_ai_credentials() - data = { - "prompt": "An olympic size swimming pool", - "model": "vertex_ai/imagegeneration@006", - "vertex_ai_project": "adroit-crow-413218", - "vertex_ai_location": "us-central1", - "n": 1, - } - try: - if sync_mode: - response = litellm.image_generation(**data) - else: - response = await litellm.aimage_generation(**data) - assert response.data is not None - assert len(response.data) > 0 - - for d in response.data: - assert isinstance(d, Image) - print("data in response.data", d) - assert d.b64_json is not None - except litellm.ServiceUnavailableError as e: - pass - except litellm.RateLimitError as e: - pass - except litellm.InternalServerError as e: - pass - except litellm.ContentPolicyViolationError: - pass # Azure randomly raises these errors - skip when they occur - except Exception as e: - if "Your task failed as a result of our safety system." in str(e): - pass - else: - pytest.fail(f"An exception occurred - {str(e)}")