diff --git a/litellm/main.py b/litellm/main.py index f9f1139f6..c809f49d6 100644 --- a/litellm/main.py +++ b/litellm/main.py @@ -3067,7 +3067,7 @@ def image_generation( custom_llm_provider=custom_llm_provider, **non_default_params, ) - logging = litellm_logging_obj + logging: Logging = litellm_logging_obj logging.update_environment_variables( model=model, user=user, diff --git a/litellm/tests/test_custom_callback_input.py b/litellm/tests/test_custom_callback_input.py index a61cc843e..641343e7a 100644 --- a/litellm/tests/test_custom_callback_input.py +++ b/litellm/tests/test_custom_callback_input.py @@ -819,44 +819,44 @@ async def test_async_embedding_azure_caching(): # Image Generation -# ## Test OpenAI + Sync -# def test_image_generation_openai(): -# try: -# customHandler_success = CompletionCustomHandler() -# customHandler_failure = CompletionCustomHandler() -# litellm.callbacks = [customHandler_success] +## Test OpenAI + Sync +def test_image_generation_openai(): + try: + customHandler_success = CompletionCustomHandler() + customHandler_failure = CompletionCustomHandler() + litellm.callbacks = [customHandler_success] -# litellm.set_verbose = True + litellm.set_verbose = True -# response = litellm.image_generation( -# prompt="A cute baby sea otter", model="dall-e-3" -# ) + response = litellm.image_generation( + prompt="A cute baby sea otter", model="dall-e-3" + ) -# print(f"response: {response}") -# assert len(response.data) > 0 + print(f"response: {response}") + assert len(response.data) > 0 -# print(f"customHandler_success.errors: {customHandler_success.errors}") -# print(f"customHandler_success.states: {customHandler_success.states}") -# assert len(customHandler_success.errors) == 0 -# assert len(customHandler_success.states) == 3 # pre, post, success -# # test failure callback -# litellm.callbacks = [customHandler_failure] -# try: -# response = litellm.image_generation( -# prompt="A cute baby sea otter", model="dall-e-4" -# ) -# except: -# pass -# print(f"customHandler_failure.errors: {customHandler_failure.errors}") -# print(f"customHandler_failure.states: {customHandler_failure.states}") -# assert len(customHandler_failure.errors) == 0 -# assert len(customHandler_failure.states) == 3 # pre, post, failure -# except litellm.RateLimitError as e: -# pass -# except litellm.ContentPolicyViolationError: -# pass # OpenAI randomly raises these errors - skip when they occur -# except Exception as e: -# pytest.fail(f"An exception occurred - {str(e)}") + print(f"customHandler_success.errors: {customHandler_success.errors}") + print(f"customHandler_success.states: {customHandler_success.states}") + assert len(customHandler_success.errors) == 0 + assert len(customHandler_success.states) == 3 # pre, post, success + # test failure callback + litellm.callbacks = [customHandler_failure] + try: + response = litellm.image_generation( + prompt="A cute baby sea otter", model="dall-e-4" + ) + except: + pass + print(f"customHandler_failure.errors: {customHandler_failure.errors}") + print(f"customHandler_failure.states: {customHandler_failure.states}") + assert len(customHandler_failure.errors) == 0 + assert len(customHandler_failure.states) == 3 # pre, post, failure + except litellm.RateLimitError as e: + pass + except litellm.ContentPolicyViolationError: + pass # OpenAI randomly raises these errors - skip when they occur + except Exception as e: + pytest.fail(f"An exception occurred - {str(e)}") # test_image_generation_openai() diff --git a/litellm/utils.py b/litellm/utils.py index b0e48bbc6..613d9d90a 100644 --- a/litellm/utils.py +++ b/litellm/utils.py @@ -2029,14 +2029,15 @@ def client(original_function): start_time=start_time, ) ## check if metadata is passed in + litellm_params = {} if "metadata" in kwargs: - litellm_params = {"metadata": kwargs["metadata"]} - logging_obj.update_environment_variables( - model=model, - user="", - optional_params={}, - litellm_params=litellm_params, - ) + litellm_params["metadata"] = kwargs["metadata"] + logging_obj.update_environment_variables( + model=model, + user="", + optional_params={}, + litellm_params=litellm_params, + ) return logging_obj except Exception as e: import logging