Merge pull request #1646 from BerriAI/litellm_image_gen_cost_tracking_proxy

Litellm image gen cost tracking proxy
This commit is contained in:
Krish Dholakia 2024-01-26 22:30:14 -08:00 committed by GitHub
commit ba4089824d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 60 additions and 43 deletions

View file

@ -718,8 +718,22 @@ class OpenAIChatCompletion(BaseLLM):
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e: except OpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
raise e raise e
except Exception as e: except Exception as e:
## LOGGING
logging_obj.post_call(
input=prompt,
api_key=api_key,
additional_args={"complete_input_dict": data},
original_response=str(e),
)
if hasattr(e, "status_code"): if hasattr(e, "status_code"):
raise OpenAIError(status_code=e.status_code, message=str(e)) raise OpenAIError(status_code=e.status_code, message=str(e))
else: else:

View file

@ -3076,7 +3076,7 @@ def image_generation(
custom_llm_provider=custom_llm_provider, custom_llm_provider=custom_llm_provider,
**non_default_params, **non_default_params,
) )
logging = litellm_logging_obj logging: Logging = litellm_logging_obj
logging.update_environment_variables( logging.update_environment_variables(
model=model, model=model,
user=user, user=user,

View file

@ -819,47 +819,49 @@ async def test_async_embedding_azure_caching():
# Image Generation # Image Generation
# ## Test OpenAI + Sync ## Test OpenAI + Sync
# def test_image_generation_openai(): def test_image_generation_openai():
# try: try:
# customHandler_success = CompletionCustomHandler() customHandler_success = CompletionCustomHandler()
# customHandler_failure = CompletionCustomHandler() customHandler_failure = CompletionCustomHandler()
# litellm.callbacks = [customHandler_success] # litellm.callbacks = [customHandler_success]
# litellm.set_verbose = True # litellm.set_verbose = True
# response = litellm.image_generation( # response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-3" # prompt="A cute baby sea otter", model="dall-e-3"
# ) # )
# print(f"response: {response}") # print(f"response: {response}")
# assert len(response.data) > 0 # assert len(response.data) > 0
# print(f"customHandler_success.errors: {customHandler_success.errors}") # print(f"customHandler_success.errors: {customHandler_success.errors}")
# print(f"customHandler_success.states: {customHandler_success.states}") # print(f"customHandler_success.states: {customHandler_success.states}")
# assert len(customHandler_success.errors) == 0 # assert len(customHandler_success.errors) == 0
# assert len(customHandler_success.states) == 3 # pre, post, success # assert len(customHandler_success.states) == 3 # pre, post, success
# # test failure callback # test failure callback
# litellm.callbacks = [customHandler_failure] litellm.callbacks = [customHandler_failure]
# try: try:
# response = litellm.image_generation( response = litellm.image_generation(
# prompt="A cute baby sea otter", model="dall-e-4" prompt="A cute baby sea otter",
# ) model="dall-e-2",
# except: api_key="my-bad-api-key",
# pass )
# print(f"customHandler_failure.errors: {customHandler_failure.errors}") except:
# print(f"customHandler_failure.states: {customHandler_failure.states}") pass
# assert len(customHandler_failure.errors) == 0 print(f"customHandler_failure.errors: {customHandler_failure.errors}")
# assert len(customHandler_failure.states) == 3 # pre, post, failure print(f"customHandler_failure.states: {customHandler_failure.states}")
# except litellm.RateLimitError as e: assert len(customHandler_failure.errors) == 0
# pass assert len(customHandler_failure.states) == 3 # pre, post, failure
# except litellm.ContentPolicyViolationError: except litellm.RateLimitError as e:
# pass # OpenAI randomly raises these errors - skip when they occur pass
# except Exception as e: except litellm.ContentPolicyViolationError:
# pytest.fail(f"An exception occurred - {str(e)}") pass # OpenAI randomly raises these errors - skip when they occur
except Exception as e:
pytest.fail(f"An exception occurred - {str(e)}")
# test_image_generation_openai() test_image_generation_openai()
## Test OpenAI + Async ## Test OpenAI + Async
## Test Azure + Sync ## Test Azure + Sync

View file

@ -2030,8 +2030,9 @@ def client(original_function):
start_time=start_time, start_time=start_time,
) )
## check if metadata is passed in ## check if metadata is passed in
litellm_params = {}
if "metadata" in kwargs: if "metadata" in kwargs:
litellm_params = {"metadata": kwargs["metadata"]} litellm_params["metadata"] = kwargs["metadata"]
logging_obj.update_environment_variables( logging_obj.update_environment_variables(
model=model, model=model,
user="", user="",