fix(utils.py): fix model name checking

This commit is contained in:
Krrish Dholakia 2024-03-09 18:22:26 -08:00
parent c333216f6e
commit 8d2d51b625
3 changed files with 21 additions and 5 deletions

View file

@ -753,6 +753,7 @@ class OpenAIChatCompletion(BaseLLM):
# return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e:
exception_mapping_worked = True
## LOGGING
logging_obj.post_call(

View file

@ -973,6 +973,7 @@ def test_image_generation_openai():
print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}")
time.sleep(2)
assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback

View file

@ -1243,7 +1243,7 @@ class Logging:
)
# print(f"original response in success handler: {self.model_call_details['original_response']}")
try:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
print_verbose(f"success callbacks: {litellm.success_callback}")
## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None
if self.stream and isinstance(result, ModelResponse):
@ -1266,7 +1266,7 @@ class Logging:
self.sync_streaming_chunks.append(result)
if complete_streaming_response is not None:
verbose_logger.debug(
print_verbose(
f"Logging Details LiteLLM-Success Call streaming complete"
)
self.model_call_details["complete_streaming_response"] = (
@ -1613,6 +1613,14 @@ class Logging:
"aembedding", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
): # custom logger class
if self.stream and complete_streaming_response is None:
callback.log_stream_event(
@ -1645,6 +1653,14 @@ class Logging:
"aembedding", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
): # custom logger functions
print_verbose(
f"success callbacks: Running Custom Callback Function"
@ -3728,7 +3744,6 @@ def completion_cost(
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
"""
try:
if (
(call_type == "aimage_generation" or call_type == "image_generation")
and model is not None
@ -3737,7 +3752,6 @@ def completion_cost(
and custom_llm_provider == "azure"
):
model = "dall-e-2" # for dall-e-2, azure expects an empty model name
# Handle Inputs to completion_cost
prompt_tokens = 0
completion_tokens = 0
@ -3756,7 +3770,7 @@ def completion_cost(
"model", None
) # check if user passed an override for model, if it's none check completion_response['model']
if hasattr(completion_response, "_hidden_params"):
model = completion_response._hidden_params.get("model", model)
model = model or completion_response._hidden_params.get("model", None)
custom_llm_provider = completion_response._hidden_params.get(
"custom_llm_provider", ""
)