From 738b7621dda9f3bc29d858a285c75445f16faba9 Mon Sep 17 00:00:00 2001 From: Krish Dholakia Date: Sat, 19 Apr 2025 22:27:49 -0700 Subject: [PATCH] fix(proxy_server.py): pass llm router to get complete model list (#10176) allows model auth to work --- litellm/proxy/proxy_server.py | 27 +++++++++++++++------------ litellm/router.py | 32 ++++++++++++++++---------------- 2 files changed, 31 insertions(+), 28 deletions(-) diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7523f4f87d..50662e69d5 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -804,9 +804,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter( dual_cache=user_api_key_cache ) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) -redis_usage_cache: Optional[RedisCache] = ( - None # redis cache used for tracking spend, tpm/rpm limits -) +redis_usage_cache: Optional[ + RedisCache +] = None # redis cache used for tracking spend, tpm/rpm limits user_custom_auth = None user_custom_key_generate = None user_custom_sso = None @@ -1132,9 +1132,9 @@ async def update_cache( # noqa: PLR0915 _id = "team_id:{}".format(team_id) try: # Fetch the existing cost for the given user - existing_spend_obj: Optional[LiteLLM_TeamTable] = ( - await user_api_key_cache.async_get_cache(key=_id) - ) + existing_spend_obj: Optional[ + LiteLLM_TeamTable + ] = await user_api_key_cache.async_get_cache(key=_id) if existing_spend_obj is None: # do nothing if team not in api key cache return @@ -2806,9 +2806,9 @@ async def initialize( # noqa: PLR0915 user_api_base = api_base dynamic_config[user_model]["api_base"] = api_base if api_version: - os.environ["AZURE_API_VERSION"] = ( - api_version # set this for azure - litellm can read this from the env - ) + os.environ[ + "AZURE_API_VERSION" + ] = api_version # set this for azure - litellm can read this from the env if max_tokens: # model-specific param dynamic_config[user_model]["max_tokens"] = max_tokens if temperature: # model-specific param @@ -6160,6 +6160,7 @@ async def model_info_v1( # noqa: PLR0915 proxy_model_list=proxy_model_list, user_model=user_model, infer_model_from_keys=general_settings.get("infer_model_from_keys", False), + llm_router=llm_router, ) if len(all_models_str) > 0: @@ -6184,6 +6185,7 @@ def _get_model_group_info( llm_router: Router, all_models_str: List[str], model_group: Optional[str] ) -> List[ModelGroupInfo]: model_groups: List[ModelGroupInfo] = [] + for model in all_models_str: if model_group is not None and model_group != model: continue @@ -6393,6 +6395,7 @@ async def model_group_info( proxy_model_list=proxy_model_list, user_model=user_model, infer_model_from_keys=general_settings.get("infer_model_from_keys", False), + llm_router=llm_router, ) model_groups: List[ModelGroupInfo] = _get_model_group_info( llm_router=llm_router, all_models_str=all_models_str, model_group=model_group @@ -7755,9 +7758,9 @@ async def get_config_list( hasattr(sub_field_info, "description") and sub_field_info.description is not None ): - nested_fields[idx].field_description = ( - sub_field_info.description - ) + nested_fields[ + idx + ].field_description = sub_field_info.description idx += 1 _stored_in_db = None diff --git a/litellm/router.py b/litellm/router.py index 92174adee3..bf1bca0a75 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -339,9 +339,9 @@ class Router: ) # names of models under litellm_params. ex. azure/chatgpt-v-2 self.deployment_latency_map = {} ### CACHING ### - cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = ( - "local" # default to an in-memory cache - ) + cache_type: Literal[ + "local", "redis", "redis-semantic", "s3", "disk" + ] = "local" # default to an in-memory cache redis_cache = None cache_config: Dict[str, Any] = {} @@ -562,9 +562,9 @@ class Router: ) ) - self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( - model_group_retry_policy - ) + self.model_group_retry_policy: Optional[ + Dict[str, RetryPolicy] + ] = model_group_retry_policy self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None if allowed_fails_policy is not None: @@ -3247,11 +3247,11 @@ class Router: if isinstance(e, litellm.ContextWindowExceededError): if context_window_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=context_window_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=context_window_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception @@ -3283,11 +3283,11 @@ class Router: e.message += "\n{}".format(error_message) elif isinstance(e, litellm.ContentPolicyViolationError): if content_policy_fallbacks is not None: - fallback_model_group: Optional[List[str]] = ( - self._get_fallback_model_group_from_fallbacks( - fallbacks=content_policy_fallbacks, - model_group=model_group, - ) + fallback_model_group: Optional[ + List[str] + ] = self._get_fallback_model_group_from_fallbacks( + fallbacks=content_policy_fallbacks, + model_group=model_group, ) if fallback_model_group is None: raise original_exception