build(pyproject.toml): add new dev dependencies - for type checking (#9631)

* build(pyproject.toml): add new dev dependencies - for type checking

* build: reformat files to fit black

* ci: reformat to fit black

* ci(test-litellm.yml): make tests run clear

* build(pyproject.toml): add ruff

* fix: fix ruff checks

* build(mypy/): fix mypy linting errors

* fix(hashicorp_secret_manager.py): fix passing cert for tls auth

* build(mypy/): resolve all mypy errors

* test: update test

* fix: fix black formatting

* build(pre-commit-config.yaml): use poetry run black

* fix(proxy_server.py): fix linting error

* fix: fix ruff safe representation error
This commit is contained in:
Krish Dholakia 2025-03-29 11:02:13 -07:00 committed by GitHub
parent 72198737f8
commit d7b294dd0a
214 changed files with 1553 additions and 1433 deletions

View file

@ -148,7 +148,7 @@ from .router_utils.pattern_match_deployments import PatternMatchRouter
if TYPE_CHECKING:
from opentelemetry.trace import Span as _Span
Span = _Span
Span = Union[_Span, Any]
else:
Span = Any
@ -333,9 +333,9 @@ class Router:
) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {}
### CACHING ###
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = (
"local" # default to an in-memory cache
)
cache_type: Literal[
"local", "redis", "redis-semantic", "s3", "disk"
] = "local" # default to an in-memory cache
redis_cache = None
cache_config: Dict[str, Any] = {}
@ -556,9 +556,9 @@ class Router:
)
)
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = (
model_group_retry_policy
)
self.model_group_retry_policy: Optional[
Dict[str, RetryPolicy]
] = model_group_retry_policy
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
if allowed_fails_policy is not None:
@ -1093,9 +1093,9 @@ class Router:
"""
Adds default litellm params to kwargs, if set.
"""
self.default_litellm_params[metadata_variable_name] = (
self.default_litellm_params.pop("metadata", {})
)
self.default_litellm_params[
metadata_variable_name
] = self.default_litellm_params.pop("metadata", {})
for k, v in self.default_litellm_params.items():
if (
k not in kwargs and v is not None
@ -1678,14 +1678,16 @@ class Router:
f"Prompt variables is set but not a dictionary. Got={prompt_variables}, type={type(prompt_variables)}"
)
model, messages, optional_params = (
litellm_logging_object.get_chat_completion_prompt(
model=litellm_model,
messages=messages,
non_default_params=get_non_default_completion_params(kwargs=kwargs),
prompt_id=prompt_id,
prompt_variables=prompt_variables,
)
(
model,
messages,
optional_params,
) = litellm_logging_object.get_chat_completion_prompt(
model=litellm_model,
messages=messages,
non_default_params=get_non_default_completion_params(kwargs=kwargs),
prompt_id=prompt_id,
prompt_variables=prompt_variables,
)
kwargs = {**kwargs, **optional_params}
@ -2924,7 +2926,6 @@ class Router:
Future Improvement - cache the result.
"""
try:
filtered_model_list = self.get_model_list()
if filtered_model_list is None:
raise Exception("Router not yet initialized.")
@ -3211,11 +3212,11 @@ class Router:
if isinstance(e, litellm.ContextWindowExceededError):
if context_window_fallbacks is not None:
fallback_model_group: Optional[List[str]] = (
self._get_fallback_model_group_from_fallbacks(
fallbacks=context_window_fallbacks,
model_group=model_group,
)
fallback_model_group: Optional[
List[str]
] = self._get_fallback_model_group_from_fallbacks(
fallbacks=context_window_fallbacks,
model_group=model_group,
)
if fallback_model_group is None:
raise original_exception
@ -3247,11 +3248,11 @@ class Router:
e.message += "\n{}".format(error_message)
elif isinstance(e, litellm.ContentPolicyViolationError):
if content_policy_fallbacks is not None:
fallback_model_group: Optional[List[str]] = (
self._get_fallback_model_group_from_fallbacks(
fallbacks=content_policy_fallbacks,
model_group=model_group,
)
fallback_model_group: Optional[
List[str]
] = self._get_fallback_model_group_from_fallbacks(
fallbacks=content_policy_fallbacks,
model_group=model_group,
)
if fallback_model_group is None:
raise original_exception
@ -3282,11 +3283,12 @@ class Router:
e.message += "\n{}".format(error_message)
if fallbacks is not None and model_group is not None:
verbose_router_logger.debug(f"inside model fallbacks: {fallbacks}")
fallback_model_group, generic_fallback_idx = (
get_fallback_model_group(
fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}]
model_group=cast(str, model_group),
)
(
fallback_model_group,
generic_fallback_idx,
) = get_fallback_model_group(
fallbacks=fallbacks, # if fallbacks = [{"gpt-3.5-turbo": ["claude-3-haiku"]}]
model_group=cast(str, model_group),
)
## if none, check for generic fallback
if (
@ -3444,11 +3446,12 @@ class Router:
"""
Retry Logic
"""
_healthy_deployments, _all_deployments = (
await self._async_get_healthy_deployments(
model=kwargs.get("model") or "",
parent_otel_span=parent_otel_span,
)
(
_healthy_deployments,
_all_deployments,
) = await self._async_get_healthy_deployments(
model=kwargs.get("model") or "",
parent_otel_span=parent_otel_span,
)
# raises an exception if this error should not be retries
@ -3513,11 +3516,12 @@ class Router:
remaining_retries = num_retries - current_attempt
_model: Optional[str] = kwargs.get("model") # type: ignore
if _model is not None:
_healthy_deployments, _ = (
await self._async_get_healthy_deployments(
model=_model,
parent_otel_span=parent_otel_span,
)
(
_healthy_deployments,
_,
) = await self._async_get_healthy_deployments(
model=_model,
parent_otel_span=parent_otel_span,
)
else:
_healthy_deployments = []
@ -3884,7 +3888,6 @@ class Router:
)
if exception_headers is not None:
_time_to_cooldown = (
litellm.utils._get_retry_after_from_exception_header(
response_headers=exception_headers
@ -6131,7 +6134,6 @@ class Router:
try:
model_id = deployment.get("model_info", {}).get("id", None)
if response is None:
# update self.deployment_stats
if model_id is not None:
self._update_usage(