fix(proxy_server.py): pass llm router to get complete model list (#10176)

allows model auth to work
This commit is contained in:
Krish Dholakia 2025-04-19 22:27:49 -07:00 committed by GitHub
parent 5ca589f344
commit 738b7621dd
2 changed files with 31 additions and 28 deletions

View file

@ -804,9 +804,9 @@ model_max_budget_limiter = _PROXY_VirtualKeyModelMaxBudgetLimiter(
dual_cache=user_api_key_cache dual_cache=user_api_key_cache
) )
litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter) litellm.logging_callback_manager.add_litellm_callback(model_max_budget_limiter)
redis_usage_cache: Optional[RedisCache] = ( redis_usage_cache: Optional[
None # redis cache used for tracking spend, tpm/rpm limits RedisCache
) ] = None # redis cache used for tracking spend, tpm/rpm limits
user_custom_auth = None user_custom_auth = None
user_custom_key_generate = None user_custom_key_generate = None
user_custom_sso = None user_custom_sso = None
@ -1132,9 +1132,9 @@ async def update_cache( # noqa: PLR0915
_id = "team_id:{}".format(team_id) _id = "team_id:{}".format(team_id)
try: try:
# Fetch the existing cost for the given user # Fetch the existing cost for the given user
existing_spend_obj: Optional[LiteLLM_TeamTable] = ( existing_spend_obj: Optional[
await user_api_key_cache.async_get_cache(key=_id) LiteLLM_TeamTable
) ] = await user_api_key_cache.async_get_cache(key=_id)
if existing_spend_obj is None: if existing_spend_obj is None:
# do nothing if team not in api key cache # do nothing if team not in api key cache
return return
@ -2806,9 +2806,9 @@ async def initialize( # noqa: PLR0915
user_api_base = api_base user_api_base = api_base
dynamic_config[user_model]["api_base"] = api_base dynamic_config[user_model]["api_base"] = api_base
if api_version: if api_version:
os.environ["AZURE_API_VERSION"] = ( os.environ[
api_version # set this for azure - litellm can read this from the env "AZURE_API_VERSION"
) ] = api_version # set this for azure - litellm can read this from the env
if max_tokens: # model-specific param if max_tokens: # model-specific param
dynamic_config[user_model]["max_tokens"] = max_tokens dynamic_config[user_model]["max_tokens"] = max_tokens
if temperature: # model-specific param if temperature: # model-specific param
@ -6160,6 +6160,7 @@ async def model_info_v1( # noqa: PLR0915
proxy_model_list=proxy_model_list, proxy_model_list=proxy_model_list,
user_model=user_model, user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False), infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
llm_router=llm_router,
) )
if len(all_models_str) > 0: if len(all_models_str) > 0:
@ -6184,6 +6185,7 @@ def _get_model_group_info(
llm_router: Router, all_models_str: List[str], model_group: Optional[str] llm_router: Router, all_models_str: List[str], model_group: Optional[str]
) -> List[ModelGroupInfo]: ) -> List[ModelGroupInfo]:
model_groups: List[ModelGroupInfo] = [] model_groups: List[ModelGroupInfo] = []
for model in all_models_str: for model in all_models_str:
if model_group is not None and model_group != model: if model_group is not None and model_group != model:
continue continue
@ -6393,6 +6395,7 @@ async def model_group_info(
proxy_model_list=proxy_model_list, proxy_model_list=proxy_model_list,
user_model=user_model, user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False), infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
llm_router=llm_router,
) )
model_groups: List[ModelGroupInfo] = _get_model_group_info( model_groups: List[ModelGroupInfo] = _get_model_group_info(
llm_router=llm_router, all_models_str=all_models_str, model_group=model_group llm_router=llm_router, all_models_str=all_models_str, model_group=model_group
@ -7755,9 +7758,9 @@ async def get_config_list(
hasattr(sub_field_info, "description") hasattr(sub_field_info, "description")
and sub_field_info.description is not None and sub_field_info.description is not None
): ):
nested_fields[idx].field_description = ( nested_fields[
sub_field_info.description idx
) ].field_description = sub_field_info.description
idx += 1 idx += 1
_stored_in_db = None _stored_in_db = None

View file

@ -339,9 +339,9 @@ class Router:
) # names of models under litellm_params. ex. azure/chatgpt-v-2 ) # names of models under litellm_params. ex. azure/chatgpt-v-2
self.deployment_latency_map = {} self.deployment_latency_map = {}
### CACHING ### ### CACHING ###
cache_type: Literal["local", "redis", "redis-semantic", "s3", "disk"] = ( cache_type: Literal[
"local" # default to an in-memory cache "local", "redis", "redis-semantic", "s3", "disk"
) ] = "local" # default to an in-memory cache
redis_cache = None redis_cache = None
cache_config: Dict[str, Any] = {} cache_config: Dict[str, Any] = {}
@ -562,9 +562,9 @@ class Router:
) )
) )
self.model_group_retry_policy: Optional[Dict[str, RetryPolicy]] = ( self.model_group_retry_policy: Optional[
model_group_retry_policy Dict[str, RetryPolicy]
) ] = model_group_retry_policy
self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None self.allowed_fails_policy: Optional[AllowedFailsPolicy] = None
if allowed_fails_policy is not None: if allowed_fails_policy is not None:
@ -3247,11 +3247,11 @@ class Router:
if isinstance(e, litellm.ContextWindowExceededError): if isinstance(e, litellm.ContextWindowExceededError):
if context_window_fallbacks is not None: if context_window_fallbacks is not None:
fallback_model_group: Optional[List[str]] = ( fallback_model_group: Optional[
self._get_fallback_model_group_from_fallbacks( List[str]
fallbacks=context_window_fallbacks, ] = self._get_fallback_model_group_from_fallbacks(
model_group=model_group, fallbacks=context_window_fallbacks,
) model_group=model_group,
) )
if fallback_model_group is None: if fallback_model_group is None:
raise original_exception raise original_exception
@ -3283,11 +3283,11 @@ class Router:
e.message += "\n{}".format(error_message) e.message += "\n{}".format(error_message)
elif isinstance(e, litellm.ContentPolicyViolationError): elif isinstance(e, litellm.ContentPolicyViolationError):
if content_policy_fallbacks is not None: if content_policy_fallbacks is not None:
fallback_model_group: Optional[List[str]] = ( fallback_model_group: Optional[
self._get_fallback_model_group_from_fallbacks( List[str]
fallbacks=content_policy_fallbacks, ] = self._get_fallback_model_group_from_fallbacks(
model_group=model_group, fallbacks=content_policy_fallbacks,
) model_group=model_group,
) )
if fallback_model_group is None: if fallback_model_group is None:
raise original_exception raise original_exception