fix(utils.py): fix formatting

This commit is contained in:
Krrish Dholakia 2024-06-11 15:49:20 -07:00
parent 0f1c40d698
commit caae69c18f
2 changed files with 39 additions and 44 deletions

View file

@ -1,8 +1,8 @@
repos: repos:
# - repo: https://github.com/psf/black - repo: https://github.com/psf/black
# rev: 24.2.0 rev: 24.2.0
# hooks: hooks:
# - id: black - id: black
- repo: https://github.com/pycqa/flake8 - repo: https://github.com/pycqa/flake8
rev: 7.0.0 # The version of flake8 to use rev: 7.0.0 # The version of flake8 to use
hooks: hooks:

View file

@ -938,7 +938,6 @@ class TextCompletionResponse(OpenAIObject):
object=None, object=None,
**params, **params,
): ):
if stream: if stream:
object = "text_completion.chunk" object = "text_completion.chunk"
choices = [TextChoices()] choices = [TextChoices()]
@ -947,7 +946,6 @@ class TextCompletionResponse(OpenAIObject):
if choices is not None and isinstance(choices, list): if choices is not None and isinstance(choices, list):
new_choices = [] new_choices = []
for choice in choices: for choice in choices:
if isinstance(choice, TextChoices): if isinstance(choice, TextChoices):
_new_choice = choice _new_choice = choice
elif isinstance(choice, dict): elif isinstance(choice, dict):
@ -1023,7 +1021,6 @@ class ImageObject(OpenAIObject):
revised_prompt: Optional[str] = None revised_prompt: Optional[str] = None
def __init__(self, b64_json=None, url=None, revised_prompt=None): def __init__(self, b64_json=None, url=None, revised_prompt=None):
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt) super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
def __contains__(self, key): def __contains__(self, key):
@ -1347,28 +1344,29 @@ class Logging:
) )
else: else:
verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n") verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n")
# log raw request to provider (like LangFuse) # log raw request to provider (like LangFuse) -- if opted in.
try: if litellm.log_raw_request_response is True:
# [Non-blocking Extra Debug Information in metadata] try:
_litellm_params = self.model_call_details.get("litellm_params", {}) # [Non-blocking Extra Debug Information in metadata]
_metadata = _litellm_params.get("metadata", {}) or {} _litellm_params = self.model_call_details.get("litellm_params", {})
if ( _metadata = _litellm_params.get("metadata", {}) or {}
litellm.turn_off_message_logging is not None if (
and litellm.turn_off_message_logging is True litellm.turn_off_message_logging is not None
): and litellm.turn_off_message_logging is True
):
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
_metadata["raw_request"] = str(curl_command)
except Exception as e:
_metadata["raw_request"] = ( _metadata["raw_request"] = (
"redacted by litellm. \ "Unable to Log \
'litellm.turn_off_message_logging=True'" raw request: {}".format(
str(e)
)
) )
else:
_metadata["raw_request"] = str(curl_command)
except Exception as e:
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
str(e)
)
)
if self.logger_fn and callable(self.logger_fn): if self.logger_fn and callable(self.logger_fn):
try: try:
self.logger_fn( self.logger_fn(
@ -1626,7 +1624,6 @@ class Logging:
end_time=end_time, end_time=end_time,
) )
except Exception as e: except Exception as e:
complete_streaming_response = None complete_streaming_response = None
else: else:
self.sync_streaming_chunks.append(result) self.sync_streaming_chunks.append(result)
@ -2396,7 +2393,6 @@ class Logging:
"async_complete_streaming_response" "async_complete_streaming_response"
in self.model_call_details in self.model_call_details
): ):
await customLogger.async_log_event( await customLogger.async_log_event(
kwargs=self.model_call_details, kwargs=self.model_call_details,
response_obj=self.model_call_details[ response_obj=self.model_call_details[
@ -2735,7 +2731,7 @@ class Logging:
only redacts when litellm.turn_off_message_logging == True only redacts when litellm.turn_off_message_logging == True
""" """
# check if user opted out of logging message/response to callbacks # check if user opted out of logging message/response to callbacks
if litellm.turn_off_message_logging == True: if litellm.turn_off_message_logging is True:
# remove messages, prompts, input, response from logging # remove messages, prompts, input, response from logging
self.model_call_details["messages"] = [ self.model_call_details["messages"] = [
{"role": "user", "content": "redacted-by-litellm"} {"role": "user", "content": "redacted-by-litellm"}
@ -6171,13 +6167,16 @@ def get_api_base(
if litellm.model_alias_map and model in litellm.model_alias_map: if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[model] model = litellm.model_alias_map[model]
try: try:
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = ( (
get_llm_provider( model,
model=model, custom_llm_provider,
custom_llm_provider=_optional_params.custom_llm_provider, dynamic_api_key,
api_base=_optional_params.api_base, dynamic_api_base,
api_key=_optional_params.api_key, ) = get_llm_provider(
) model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
) )
except Exception as e: except Exception as e:
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e))) verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
@ -6600,6 +6599,9 @@ def get_llm_provider(
or get_secret("MISTRAL_AZURE_API_KEY") # for Azure AI Mistral or get_secret("MISTRAL_AZURE_API_KEY") # for Azure AI Mistral
or get_secret("MISTRAL_API_KEY") or get_secret("MISTRAL_API_KEY")
) )
elif custom_llm_provider == "azure_ai":
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = get_secret("AZURE_AI_API_KEY")
elif custom_llm_provider == "voyage": elif custom_llm_provider == "voyage":
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1 # voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = "https://api.voyageai.com/v1" api_base = "https://api.voyageai.com/v1"
@ -6612,12 +6614,6 @@ def get_llm_provider(
or get_secret("TOGETHERAI_API_KEY") or get_secret("TOGETHERAI_API_KEY")
or get_secret("TOGETHER_AI_TOKEN") or get_secret("TOGETHER_AI_TOKEN")
) )
elif custom_llm_provider == "azure_ai":
api_base = (
api_base
or get_secret("AZURE_AI_API_BASE") # for Azure AI Mistral
) # type: ignore
dynamic_api_key = get_secret("AZURE_AI_API_KEY")
if api_base is not None and not isinstance(api_base, str): if api_base is not None and not isinstance(api_base, str):
raise Exception( raise Exception(
"api base needs to be a string. api_base={}".format(api_base) "api base needs to be a string. api_base={}".format(api_base)
@ -7459,7 +7455,6 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None): def set_callbacks(callback_list, function_id=None):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try: try: