mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
fix(utils.py): fix model name checking
This commit is contained in:
parent
c333216f6e
commit
8d2d51b625
3 changed files with 21 additions and 5 deletions
|
@ -753,6 +753,7 @@ class OpenAIChatCompletion(BaseLLM):
|
||||||
# return response
|
# return response
|
||||||
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
return convert_to_model_response_object(response_object=response, model_response_object=model_response, response_type="image_generation") # type: ignore
|
||||||
except OpenAIError as e:
|
except OpenAIError as e:
|
||||||
|
|
||||||
exception_mapping_worked = True
|
exception_mapping_worked = True
|
||||||
## LOGGING
|
## LOGGING
|
||||||
logging_obj.post_call(
|
logging_obj.post_call(
|
||||||
|
|
|
@ -973,6 +973,7 @@ def test_image_generation_openai():
|
||||||
|
|
||||||
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
print(f"customHandler_success.errors: {customHandler_success.errors}")
|
||||||
print(f"customHandler_success.states: {customHandler_success.states}")
|
print(f"customHandler_success.states: {customHandler_success.states}")
|
||||||
|
time.sleep(2)
|
||||||
assert len(customHandler_success.errors) == 0
|
assert len(customHandler_success.errors) == 0
|
||||||
assert len(customHandler_success.states) == 3 # pre, post, success
|
assert len(customHandler_success.states) == 3 # pre, post, success
|
||||||
# test failure callback
|
# test failure callback
|
||||||
|
|
|
@ -1243,7 +1243,7 @@ class Logging:
|
||||||
)
|
)
|
||||||
# print(f"original response in success handler: {self.model_call_details['original_response']}")
|
# print(f"original response in success handler: {self.model_call_details['original_response']}")
|
||||||
try:
|
try:
|
||||||
verbose_logger.debug(f"success callbacks: {litellm.success_callback}")
|
print_verbose(f"success callbacks: {litellm.success_callback}")
|
||||||
## BUILD COMPLETE STREAMED RESPONSE
|
## BUILD COMPLETE STREAMED RESPONSE
|
||||||
complete_streaming_response = None
|
complete_streaming_response = None
|
||||||
if self.stream and isinstance(result, ModelResponse):
|
if self.stream and isinstance(result, ModelResponse):
|
||||||
|
@ -1266,7 +1266,7 @@ class Logging:
|
||||||
self.sync_streaming_chunks.append(result)
|
self.sync_streaming_chunks.append(result)
|
||||||
|
|
||||||
if complete_streaming_response is not None:
|
if complete_streaming_response is not None:
|
||||||
verbose_logger.debug(
|
print_verbose(
|
||||||
f"Logging Details LiteLLM-Success Call streaming complete"
|
f"Logging Details LiteLLM-Success Call streaming complete"
|
||||||
)
|
)
|
||||||
self.model_call_details["complete_streaming_response"] = (
|
self.model_call_details["complete_streaming_response"] = (
|
||||||
|
@ -1613,6 +1613,14 @@ class Logging:
|
||||||
"aembedding", False
|
"aembedding", False
|
||||||
)
|
)
|
||||||
== False
|
== False
|
||||||
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
"aimage_generation", False
|
||||||
|
)
|
||||||
|
== False
|
||||||
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
"atranscription", False
|
||||||
|
)
|
||||||
|
== False
|
||||||
): # custom logger class
|
): # custom logger class
|
||||||
if self.stream and complete_streaming_response is None:
|
if self.stream and complete_streaming_response is None:
|
||||||
callback.log_stream_event(
|
callback.log_stream_event(
|
||||||
|
@ -1645,6 +1653,14 @@ class Logging:
|
||||||
"aembedding", False
|
"aembedding", False
|
||||||
)
|
)
|
||||||
== False
|
== False
|
||||||
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
"aimage_generation", False
|
||||||
|
)
|
||||||
|
== False
|
||||||
|
and self.model_call_details.get("litellm_params", {}).get(
|
||||||
|
"atranscription", False
|
||||||
|
)
|
||||||
|
== False
|
||||||
): # custom logger functions
|
): # custom logger functions
|
||||||
print_verbose(
|
print_verbose(
|
||||||
f"success callbacks: Running Custom Callback Function"
|
f"success callbacks: Running Custom Callback Function"
|
||||||
|
@ -3728,7 +3744,6 @@ def completion_cost(
|
||||||
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
|
- If an error occurs during execution, the function returns 0.0 without blocking the user's execution path.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
|
|
||||||
if (
|
if (
|
||||||
(call_type == "aimage_generation" or call_type == "image_generation")
|
(call_type == "aimage_generation" or call_type == "image_generation")
|
||||||
and model is not None
|
and model is not None
|
||||||
|
@ -3737,7 +3752,6 @@ def completion_cost(
|
||||||
and custom_llm_provider == "azure"
|
and custom_llm_provider == "azure"
|
||||||
):
|
):
|
||||||
model = "dall-e-2" # for dall-e-2, azure expects an empty model name
|
model = "dall-e-2" # for dall-e-2, azure expects an empty model name
|
||||||
|
|
||||||
# Handle Inputs to completion_cost
|
# Handle Inputs to completion_cost
|
||||||
prompt_tokens = 0
|
prompt_tokens = 0
|
||||||
completion_tokens = 0
|
completion_tokens = 0
|
||||||
|
@ -3756,7 +3770,7 @@ def completion_cost(
|
||||||
"model", None
|
"model", None
|
||||||
) # check if user passed an override for model, if it's none check completion_response['model']
|
) # check if user passed an override for model, if it's none check completion_response['model']
|
||||||
if hasattr(completion_response, "_hidden_params"):
|
if hasattr(completion_response, "_hidden_params"):
|
||||||
model = completion_response._hidden_params.get("model", model)
|
model = model or completion_response._hidden_params.get("model", None)
|
||||||
custom_llm_provider = completion_response._hidden_params.get(
|
custom_llm_provider = completion_response._hidden_params.get(
|
||||||
"custom_llm_provider", ""
|
"custom_llm_provider", ""
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue