fix(utils.py): fix formatting

This commit is contained in:
Krrish Dholakia 2024-06-11 15:49:20 -07:00
parent 2d8e4ddfa0
commit 9c7788ae48
2 changed files with 39 additions and 44 deletions

View file

@ -1,8 +1,8 @@
repos:
# - repo: https://github.com/psf/black
# rev: 24.2.0
# hooks:
# - id: black
- repo: https://github.com/psf/black
rev: 24.2.0
hooks:
- id: black
- repo: https://github.com/pycqa/flake8
rev: 7.0.0 # The version of flake8 to use
hooks:

View file

@ -938,7 +938,6 @@ class TextCompletionResponse(OpenAIObject):
object=None,
**params,
):
if stream:
object = "text_completion.chunk"
choices = [TextChoices()]
@ -947,7 +946,6 @@ class TextCompletionResponse(OpenAIObject):
if choices is not None and isinstance(choices, list):
new_choices = []
for choice in choices:
if isinstance(choice, TextChoices):
_new_choice = choice
elif isinstance(choice, dict):
@ -1023,7 +1021,6 @@ class ImageObject(OpenAIObject):
revised_prompt: Optional[str] = None
def __init__(self, b64_json=None, url=None, revised_prompt=None):
super().__init__(b64_json=b64_json, url=url, revised_prompt=revised_prompt)
def __contains__(self, key):
@ -1347,28 +1344,29 @@ class Logging:
)
else:
verbose_logger.debug(f"\033[92m{curl_command}\033[0m\n")
# log raw request to provider (like LangFuse)
try:
# [Non-blocking Extra Debug Information in metadata]
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
if (
litellm.turn_off_message_logging is not None
and litellm.turn_off_message_logging is True
):
# log raw request to provider (like LangFuse) -- if opted in.
if litellm.log_raw_request_response is True:
try:
# [Non-blocking Extra Debug Information in metadata]
_litellm_params = self.model_call_details.get("litellm_params", {})
_metadata = _litellm_params.get("metadata", {}) or {}
if (
litellm.turn_off_message_logging is not None
and litellm.turn_off_message_logging is True
):
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
)
else:
_metadata["raw_request"] = str(curl_command)
except Exception as e:
_metadata["raw_request"] = (
"redacted by litellm. \
'litellm.turn_off_message_logging=True'"
"Unable to Log \
raw request: {}".format(
str(e)
)
)
else:
_metadata["raw_request"] = str(curl_command)
except Exception as e:
_metadata["raw_request"] = (
"Unable to Log \
raw request: {}".format(
str(e)
)
)
if self.logger_fn and callable(self.logger_fn):
try:
self.logger_fn(
@ -1626,7 +1624,6 @@ class Logging:
end_time=end_time,
)
except Exception as e:
complete_streaming_response = None
else:
self.sync_streaming_chunks.append(result)
@ -2396,7 +2393,6 @@ class Logging:
"async_complete_streaming_response"
in self.model_call_details
):
await customLogger.async_log_event(
kwargs=self.model_call_details,
response_obj=self.model_call_details[
@ -2735,7 +2731,7 @@ class Logging:
only redacts when litellm.turn_off_message_logging == True
"""
# check if user opted out of logging message/response to callbacks
if litellm.turn_off_message_logging == True:
if litellm.turn_off_message_logging is True:
# remove messages, prompts, input, response from logging
self.model_call_details["messages"] = [
{"role": "user", "content": "redacted-by-litellm"}
@ -6171,13 +6167,16 @@ def get_api_base(
if litellm.model_alias_map and model in litellm.model_alias_map:
model = litellm.model_alias_map[model]
try:
model, custom_llm_provider, dynamic_api_key, dynamic_api_base = (
get_llm_provider(
model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
)
(
model,
custom_llm_provider,
dynamic_api_key,
dynamic_api_base,
) = get_llm_provider(
model=model,
custom_llm_provider=_optional_params.custom_llm_provider,
api_base=_optional_params.api_base,
api_key=_optional_params.api_key,
)
except Exception as e:
verbose_logger.debug("Error occurred in getting api base - {}".format(str(e)))
@ -6600,6 +6599,9 @@ def get_llm_provider(
or get_secret("MISTRAL_AZURE_API_KEY") # for Azure AI Mistral
or get_secret("MISTRAL_API_KEY")
)
elif custom_llm_provider == "azure_ai":
api_base = api_base or get_secret("AZURE_AI_API_BASE") # type: ignore
dynamic_api_key = get_secret("AZURE_AI_API_KEY")
elif custom_llm_provider == "voyage":
# voyage is openai compatible, we just need to set this to custom_openai and have the api_base be https://api.voyageai.com/v1
api_base = "https://api.voyageai.com/v1"
@ -6612,12 +6614,6 @@ def get_llm_provider(
or get_secret("TOGETHERAI_API_KEY")
or get_secret("TOGETHER_AI_TOKEN")
)
elif custom_llm_provider == "azure_ai":
api_base = (
api_base
or get_secret("AZURE_AI_API_BASE") # for Azure AI Mistral
) # type: ignore
dynamic_api_key = get_secret("AZURE_AI_API_KEY")
if api_base is not None and not isinstance(api_base, str):
raise Exception(
"api base needs to be a string. api_base={}".format(api_base)
@ -7459,7 +7455,6 @@ def validate_environment(model: Optional[str] = None) -> dict:
def set_callbacks(callback_list, function_id=None):
global sentry_sdk_instance, capture_exception, add_breadcrumb, posthog, slack_app, alerts_channel, traceloopLogger, athinaLogger, heliconeLogger, aispendLogger, berrispendLogger, supabaseClient, liteDebuggerClient, lunaryLogger, promptLayerLogger, langFuseLogger, customLogger, weightsBiasesLogger, langsmithLogger, logfireLogger, dynamoLogger, s3Logger, dataDogLogger, prometheusLogger, greenscaleLogger, openMeterLogger
try: