Litellm ruff linting enforcement (#5992)

* ci(config.yml): add a 'check_code_quality' step

Addresses https://github.com/BerriAI/litellm/issues/5991

* ci(config.yml): check why circle ci doesn't pick up this test

* ci(config.yml): fix to run 'check_code_quality' tests

* fix(__init__.py): fix unprotected import

* fix(__init__.py): don't remove unused imports

* build(ruff.toml): update ruff.toml to ignore unused imports

* fix: fix: ruff + pyright - fix linting + type-checking errors

* fix: fix linting errors

* fix(lago.py): fix module init error

* fix: fix linting errors

* ci(config.yml): cd into correct dir for checks

* fix(proxy_server.py): fix linting error

* fix(utils.py): fix bare except

causes ruff linting errors

* fix: ruff - fix remaining linting errors

* fix(clickhouse.py): use standard logging object

* fix(__init__.py): fix unprotected import

* fix: ruff - fix linting errors

* fix: fix linting errors

* ci(config.yml): cleanup code qa step (formatting handled in local_testing)

* fix(_health_endpoints.py): fix ruff linting errors

* ci(config.yml): just use ruff in check_code_quality pipeline for now

* build(custom_guardrail.py): include missing file

* style(embedding_handler.py): fix ruff check
This commit is contained in:
Krish Dholakia 2024-10-01 16:44:20 -07:00 committed by GitHub
parent 4fa8991a90
commit 94a05ca5d0
263 changed files with 1687 additions and 3320 deletions

View file

@ -32,6 +32,7 @@ from openai import AsyncOpenAI
from typing_extensions import overload
import litellm
from litellm import get_secret_str
from litellm._logging import verbose_router_logger
from litellm.assistants.main import AssistantDeleted
from litellm.caching import DualCache, InMemoryCache, RedisCache
@ -599,7 +600,7 @@ class Router:
kwargs["model"] = model
kwargs["messages"] = messages
kwargs["original_function"] = self._completion
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
@ -1194,7 +1195,7 @@ class Router:
kwargs["prompt"] = prompt
kwargs["original_function"] = self._image_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
@ -1277,7 +1278,7 @@ class Router:
kwargs["prompt"] = prompt
kwargs["original_function"] = self._aimage_generation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -1410,7 +1411,7 @@ class Router:
kwargs["file"] = file
kwargs["original_function"] = self._atranscription
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -1564,7 +1565,7 @@ class Router:
)
kwargs["model_info"] = deployment.get("model_info", {})
data = deployment["litellm_params"].copy()
model_name = data["model"]
data["model"]
for k, v in self.default_litellm_params.items():
if (
k not in kwargs
@ -1583,9 +1584,9 @@ class Router:
and potential_model_client is not None
and dynamic_api_key != potential_model_client.api_key
):
model_client = None
pass
else:
model_client = potential_model_client
pass
response = await litellm.aspeech(**data, **kwargs)
@ -1607,7 +1608,7 @@ class Router:
kwargs["input"] = input
kwargs["original_function"] = self._amoderation
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -1707,7 +1708,7 @@ class Router:
kwargs["input"] = input
kwargs["original_function"] = self._arerank
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -1814,7 +1815,7 @@ class Router:
kwargs["prompt"] = prompt
kwargs["original_function"] = self._acompletion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
# pick the one that is available (lowest TPM/RPM)
@ -1858,7 +1859,7 @@ class Router:
kwargs["prompt"] = prompt
kwargs["original_function"] = self._atext_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -1977,7 +1978,7 @@ class Router:
kwargs["adapter_id"] = adapter_id
kwargs["original_function"] = self._aadapter_completion
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -2094,7 +2095,7 @@ class Router:
kwargs["input"] = input
kwargs["original_function"] = self._embedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = self.function_with_fallbacks(**kwargs)
return response
@ -2182,7 +2183,7 @@ class Router:
kwargs["input"] = input
kwargs["original_function"] = self._aembedding
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
return response
@ -2298,7 +2299,7 @@ class Router:
kwargs["model"] = model
kwargs["original_function"] = self._acreate_file
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -2424,7 +2425,7 @@ class Router:
kwargs["model"] = model
kwargs["original_function"] = self._acreate_batch
kwargs["num_retries"] = kwargs.get("num_retries", self.num_retries)
timeout = kwargs.get("request_timeout", self.timeout)
kwargs.get("request_timeout", self.timeout)
kwargs.setdefault("metadata", {}).update({"model_group": model})
response = await self.async_function_with_fallbacks(**kwargs)
@ -2624,7 +2625,7 @@ class Router:
return await litellm.alist_batches(
**{**model["litellm_params"], **kwargs}
)
except Exception as e:
except Exception:
return None
# Check all models in parallel
@ -3265,7 +3266,7 @@ class Router:
)
try:
if mock_testing_fallbacks is not None and mock_testing_fallbacks == True:
if mock_testing_fallbacks is not None and mock_testing_fallbacks is True:
raise Exception(
f"This is a mock exception for model={model_group}, to trigger a fallback. Fallbacks={fallbacks}"
)
@ -3615,25 +3616,12 @@ class Router:
):
try:
exception = kwargs.get("exception", None)
exception_type = type(exception)
exception_status = getattr(exception, "status_code", "")
exception_cause = getattr(exception, "__cause__", "")
exception_message = getattr(exception, "message", "")
exception_str = (
str(exception_type)
+ "Status: "
+ str(exception_status)
+ "Message: "
+ str(exception_cause)
+ str(exception_message)
+ "Full exception"
+ str(exception)
)
model_name = kwargs.get("model", None) # i.e. gpt35turbo
custom_llm_provider = kwargs.get("litellm_params", {}).get(
"custom_llm_provider", None
) # i.e. azure
metadata = kwargs.get("litellm_params", {}).get("metadata", None)
kwargs.get("litellm_params", {}).get("metadata", None)
_model_info = kwargs.get("litellm_params", {}).get("model_info", {})
exception_headers = litellm.utils._get_litellm_response_headers(
@ -3782,7 +3770,7 @@ class Router:
# should cool down for all other errors
return True
except:
except Exception:
# Catch all - if any exceptions default to cooling down
return True
@ -3826,9 +3814,9 @@ class Router:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
if isinstance(_all_deployments, dict):
return []
except:
except Exception:
pass
unhealthy_deployments = _get_cooldown_deployments(litellm_router_instance=self)
@ -3855,7 +3843,7 @@ class Router:
_, _all_deployments = self._common_checks_available_deployment( # type: ignore
model=model,
)
if type(_all_deployments) == dict:
if isinstance(_all_deployments, dict):
return [], _all_deployments
except Exception:
pass
@ -3885,7 +3873,7 @@ class Router:
"""
for _callback in litellm.callbacks:
if isinstance(_callback, CustomLogger):
response = _callback.pre_call_check(deployment)
_callback.pre_call_check(deployment)
async def async_routing_strategy_pre_call_checks(
self, deployment: dict, logging_obj: Optional[LiteLLMLogging] = None
@ -4022,10 +4010,10 @@ class Router:
- ValueError: If LITELLM_ENVIRONMENT is not set in .env or not one of the valid values
- ValueError: If supported_environments is not set in model_info or not one of the valid values
"""
litellm_environment = litellm.get_secret_str(secret_name="LITELLM_ENVIRONMENT")
litellm_environment = get_secret_str(secret_name="LITELLM_ENVIRONMENT")
if litellm_environment is None:
raise ValueError(
f"Set 'supported_environments' for model but not 'LITELLM_ENVIRONMENT' set in .env"
"Set 'supported_environments' for model but not 'LITELLM_ENVIRONMENT' set in .env"
)
if litellm_environment not in VALID_LITELLM_ENVIRONMENTS:
@ -4172,7 +4160,7 @@ class Router:
)
# set region (if azure model) ## PREVIEW FEATURE ##
if litellm.enable_preview_features == True:
if litellm.enable_preview_features is True:
print("Auto inferring region") # noqa
"""
Hiding behind a feature flag
@ -4277,7 +4265,7 @@ class Router:
return item
else:
return None
except:
except Exception:
return None
def get_deployment(self, model_id: str) -> Optional[Deployment]:
@ -4870,7 +4858,7 @@ class Router:
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
elif client_type == "async":
if kwargs.get("stream") == True:
if kwargs.get("stream") is True:
cache_key = f"{model_id}_stream_async_client"
client = self.cache.get_cache(key=cache_key, local_only=True)
if client is None:
@ -4891,7 +4879,7 @@ class Router:
client = self.cache.get_cache(key=cache_key, local_only=True)
return client
else:
if kwargs.get("stream") == True:
if kwargs.get("stream") is True:
cache_key = f"{model_id}_stream_client"
client = self.cache.get_cache(key=cache_key)
if client is None:
@ -5031,7 +5019,7 @@ class Router:
# check if in allowed_model_region
if (
_is_region_eu(litellm_params=LiteLLM_Params(**_litellm_params))
== False
is False
):
invalid_model_indices.append(idx)
continue
@ -5046,7 +5034,7 @@ class Router:
continue
## INVALID PARAMS ## -> catch 'gpt-3.5-turbo-16k' not supporting 'response_format' param
if request_kwargs is not None and litellm.drop_params == False:
if request_kwargs is not None and litellm.drop_params is False:
# get supported params
model, custom_llm_provider, _, _ = litellm.get_llm_provider(
model=model, litellm_params=LiteLLM_Params(**_litellm_params)
@ -5170,7 +5158,7 @@ class Router:
dep["litellm_params"]["model"] = model
provider_deployments.append(dep)
return model, provider_deployments
except:
except Exception:
# get_llm_provider raises exception when provider is unknown
pass
@ -5660,7 +5648,7 @@ class Router:
"num_successes": 1,
"avg_latency": response_ms,
}
if self.set_verbose == True and self.debug_level == "DEBUG":
if self.set_verbose is True and self.debug_level == "DEBUG":
from pprint import pformat
# Assuming self.deployment_stats is your dictionary