Auth checks on invalid fallback models (#7871)

* fix(user_api_key_auth.py): handle clientside fallback model when item in list is dictionary

* fix(auth_checks.py): help user find invalid model names during dev

Ensure fallbacks work in prod

* fix(user_api_key_auth.py): fix linting check

* fix: cleanup unused variables

* fix: fix import

* fix(auth_checks.py): fix auth check
This commit is contained in:
Krish Dholakia 2025-01-19 21:28:10 -08:00 committed by GitHub
parent d6e85f7936
commit c306c2e0fc
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 82 additions and 11 deletions

View file

@ -10,7 +10,7 @@ Returns a UserAPIKeyAuth object if the API key is valid
import asyncio
import secrets
from datetime import datetime, timezone
from typing import Optional
from typing import Optional, cast
import fastapi
from fastapi import HTTPException, Request, WebSocket, status
@ -32,6 +32,7 @@ from litellm.proxy.auth.auth_checks import (
get_org_object,
get_team_object,
get_user_object,
is_valid_fallback_model,
)
from litellm.proxy.auth.auth_utils import (
_get_request_ip_address,
@ -880,10 +881,6 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
verbose_proxy_logger.debug(
f"\n new llm router model list {new_model_list}"
)
if (
len(valid_token.models) == 0
): # assume an empty model list means all models are allowed to be called
pass
elif (
isinstance(valid_token.models, list)
and "all-team-models" in valid_token.models
@ -893,8 +890,9 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
pass
else:
model = request_data.get("model", None)
fallback_models: Optional[List[str]] = request_data.get(
"fallbacks", None
fallback_models = cast(
Optional[List[ALL_FALLBACK_MODEL_VALUES]],
request_data.get("fallbacks", None),
)
if model is not None:
@ -908,11 +906,16 @@ async def _user_api_key_auth_builder( # noqa: PLR0915
if fallback_models is not None:
for m in fallback_models:
await can_key_call_model(
model=m,
model=m["model"] if isinstance(m, dict) else m,
llm_model_list=llm_model_list,
valid_token=valid_token,
llm_router=llm_router,
)
await is_valid_fallback_model(
model=m["model"] if isinstance(m, dict) else m,
llm_router=llm_router,
user_model=None,
)
# Check 2. If user_id for this token is in budget - done in common_checks()
if valid_token.user_id is not None: