fix(utils.py): enable cost tracking for image gen models on proxy

This commit is contained in:
Krrish Dholakia 2024-01-26 20:51:13 -08:00
parent 511510a1cc
commit a299ac2328
3 changed files with 43 additions and 42 deletions

View file

@ -3067,7 +3067,7 @@ def image_generation(
custom_llm_provider=custom_llm_provider,
**non_default_params,
)
logging = litellm_logging_obj
logging: Logging = litellm_logging_obj
logging.update_environment_variables(
model=model,
user=user,

View file

@ -819,44 +819,44 @@ async def test_async_embedding_azure_caching():
# Image Generation
# ## Test OpenAI + Sync
# def test_image_generation_openai():
# try:
# customHandler_success = CompletionCustomHandler()
# customHandler_failure = CompletionCustomHandler()
# litellm.callbacks = [customHandler_success]
## Test OpenAI + Sync
def test_image_generation_openai():
try:
customHandler_success = CompletionCustomHandler()
customHandler_failure = CompletionCustomHandler()
litellm.callbacks = [customHandler_success]
# litellm.set_verbose = True
litellm.set_verbose = True
# response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-3"
# )
response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-3"
)
# print(f"response: {response}")
# assert len(response.data) > 0
print(f"response: {response}")
assert len(response.data) > 0
# print(f"customHandler_success.errors: {customHandler_success.errors}")
# print(f"customHandler_success.states: {customHandler_success.states}")
# assert len(customHandler_success.errors) == 0
# assert len(customHandler_success.states) == 3 # pre, post, success
# # test failure callback
# litellm.callbacks = [customHandler_failure]
# try:
# response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-4"
# )
# except:
# pass
# print(f"customHandler_failure.errors: {customHandler_failure.errors}")
# print(f"customHandler_failure.states: {customHandler_failure.states}")
# assert len(customHandler_failure.errors) == 0
# assert len(customHandler_failure.states) == 3 # pre, post, failure
# except litellm.RateLimitError as e:
# pass
# except litellm.ContentPolicyViolationError:
# pass # OpenAI randomly raises these errors - skip when they occur
# except Exception as e:
# pytest.fail(f"An exception occurred - {str(e)}")
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback
litellm.callbacks = [customHandler_failure]
try:
response = litellm.image_generation(
prompt="A cute baby sea otter", model="dall-e-4"
)
except:
pass
print(f"customHandler_failure.errors: {customHandler_failure.errors}")
print(f"customHandler_failure.states: {customHandler_failure.states}")
assert len(customHandler_failure.errors) == 0
assert len(customHandler_failure.states) == 3 # pre, post, failure
except litellm.RateLimitError as e:
pass
except litellm.ContentPolicyViolationError:
pass # OpenAI randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_openai()

View file

@ -2029,14 +2029,15 @@ def client(original_function):
start_time=start_time,
)
## check if metadata is passed in
litellm_params = {}
if "metadata" in kwargs:
litellm_params = {"metadata": kwargs["metadata"]}
logging_obj.update_environment_variables(
model=model,
user="",
optional_params={},
litellm_params=litellm_params,
)
litellm_params["metadata"] = kwargs["metadata"]
logging_obj.update_environment_variables(
model=model,
user="",
optional_params={},
litellm_params=litellm_params,
)
return logging_obj
except Exception as e:
import logging