fix(utils.py): fix model name checking

This commit is contained in:
Krrish Dholakia 2024-03-09 18:22:26 -08:00
parent c333216f6e
commit 8d2d51b625
3 changed files with 21 additions and 5 deletions

View file

@ -753,6 +753,7 @@ class OpenAIChatCompletion(BaseLLM):
# return response # return response
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
except OpenAIError as e: except OpenAIError as e:
exception_mapping_worked = True exception_mapping_worked = True
## LOGGING ## LOGGING
logging_obj.post_call( logging_obj.post_call(

View file

@ -973,6 +973,7 @@ def test_image_generation_openai():
print(f"customHandler_success.errors: {customHandler_success.errors}") print(f"customHandler_success.errors: {customHandler_success.errors}")
print(f"customHandler_success.states: {customHandler_success.states}") print(f"customHandler_success.states: {customHandler_success.states}")
time.sleep(2)
assert len(customHandler_success.errors) == 0 assert len(customHandler_success.errors) == 0
assert len(customHandler_success.states) == 3 # pre, post, success assert len(customHandler_success.states) == 3 # pre, post, success
# test failure callback # test failure callback

View file

@ -1243,7 +1243,7 @@ class Logging:
) )
# print(f"original response in success handler: {self.model_call_details['original_response']}") # print(f"original response in success handler: {self.model_call_details['original_response']}")
try: try:
verbose_logger.debug(f"success callbacks: {litellm.success_callback}") print_verbose(f"success callbacks: {litellm.success_callback}")
## BUILD COMPLETE STREAMED RESPONSE ## BUILD COMPLETE STREAMED RESPONSE
complete_streaming_response = None complete_streaming_response = None
if self.stream and isinstance(result, ModelResponse): if self.stream and isinstance(result, ModelResponse):
@ -1266,7 +1266,7 @@ class Logging:
self.sync_streaming_chunks.append(result) self.sync_streaming_chunks.append(result)
if complete_streaming_response is not None: if complete_streaming_response is not None:
verbose_logger.debug( print_verbose(
f"Logging Details LiteLLM-Success Call streaming complete" f"Logging Details LiteLLM-Success Call streaming complete"
) )
self.model_call_details["complete_streaming_response"] = ( self.model_call_details["complete_streaming_response"] = (
@ -1613,6 +1613,14 @@ class Logging:
"aembedding", False "aembedding", False
) )
== False == False
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
): # custom logger class ): # custom logger class
if self.stream and complete_streaming_response is None: if self.stream and complete_streaming_response is None:
callback.log_stream_event( callback.log_stream_event(
@ -1645,6 +1653,14 @@ class Logging:
"aembedding", False "aembedding", False
) )
== False == False
and self.model_call_details.get("litellm_params", {}).get(
"aimage_generation", False
)
== False
and self.model_call_details.get("litellm_params", {}).get(
"atranscription", False
)
== False
): # custom logger functions ): # custom logger functions
print_verbose( print_verbose(
f"success callbacks: Running Custom Callback Function" f"success callbacks: Running Custom Callback Function"
@ -3728,7 +3744,6 @@ def completion_cost(
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path. - If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
""" """
try: try:
if ( if (
(call_type == "aimage_generation" or call_type == "image_generation") (call_type == "aimage_generation" or call_type == "image_generation")
and model is not None and model is not None
@ -3737,7 +3752,6 @@ def completion_cost(
and custom_llm_provider == "azure" and custom_llm_provider == "azure"
): ):
model = "dall-e-2" # for dall-e-2, azure expects an empty model name model = "dall-e-2" # for dall-e-2, azure expects an empty model name
# Handle Inputs to completion_cost # Handle Inputs to completion_cost
prompt_tokens = 0 prompt_tokens = 0
completion_tokens = 0 completion_tokens = 0
@ -3756,7 +3770,7 @@ def completion_cost(
"model", None "model", None
) # check if user passed an override for model, if it's none check completion_response['model'] ) # check if user passed an override for model, if it's none check completion_response['model']
if hasattr(completion_response, "_hidden_params"): if hasattr(completion_response, "_hidden_params"):
model = completion_response._hidden_params.get("model", model) model = model or completion_response._hidden_params.get("model", None)
custom_llm_provider = completion_response._hidden_params.get( custom_llm_provider = completion_response._hidden_params.get(
"custom_llm_provider", "" "custom_llm_provider", ""
) )