mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
fix(user_api_key_auth.py): ensure user has access to fallback models
for client side fallbacks, checks if user has access to fallback models
This commit is contained in:
parent
14da2d5ade
commit
5729eb5168
4 changed files with 150 additions and 53 deletions
|
@ -47,6 +47,7 @@ from litellm._logging import verbose_logger, verbose_proxy_logger
|
|||
from litellm.proxy._types import *
|
||||
from litellm.proxy.auth.auth_checks import (
|
||||
allowed_routes_check,
|
||||
can_key_call_model,
|
||||
common_checks,
|
||||
get_actual_routes,
|
||||
get_end_user_object,
|
||||
|
@ -494,6 +495,7 @@ async def user_api_key_auth(
|
|||
# Got Valid Token from Cache, DB
|
||||
# Run checks for
|
||||
# 1. If token can call model
|
||||
## 1a. If token can call fallback models (if client-side fallbacks given)
|
||||
# 2. If user_id for this token is in budget
|
||||
# 3. If the user spend within their own team is within budget
|
||||
# 4. If 'user' passed to /chat/completions, /embeddings endpoint is in budget
|
||||
|
@ -540,55 +542,22 @@ async def user_api_key_auth(
|
|||
except json.JSONDecodeError:
|
||||
data = {} # Provide a default value, such as an empty dictionary
|
||||
model = data.get("model", None)
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
fallback_models: Optional[List[str]] = data.get("fallbacks", None)
|
||||
|
||||
## check if model in allowed model names
|
||||
verbose_proxy_logger.debug(
|
||||
f"LLM Model List pre access group check: {llm_model_list}"
|
||||
)
|
||||
from collections import defaultdict
|
||||
|
||||
access_groups = defaultdict(list)
|
||||
if llm_model_list is not None:
|
||||
for m in llm_model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
model_name = m["model_name"]
|
||||
access_groups[group].append(model_name)
|
||||
|
||||
models_in_current_access_groups = []
|
||||
if (
|
||||
len(access_groups) > 0
|
||||
): # check if token contains any model access groups
|
||||
for idx, m in enumerate(
|
||||
valid_token.models
|
||||
): # loop token models, if any of them are an access group add the access group
|
||||
if m in access_groups:
|
||||
# if it is an access group we need to remove it from valid_token.models
|
||||
models_in_group = access_groups[m]
|
||||
models_in_current_access_groups.extend(models_in_group)
|
||||
|
||||
# Filter out models that are access_groups
|
||||
filtered_models = [
|
||||
m for m in valid_token.models if m not in access_groups
|
||||
]
|
||||
|
||||
filtered_models += models_in_current_access_groups
|
||||
verbose_proxy_logger.debug(
|
||||
f"model: {model}; allowed_models: {filtered_models}"
|
||||
)
|
||||
if (
|
||||
model is not None
|
||||
and model not in filtered_models
|
||||
and "*" not in filtered_models
|
||||
):
|
||||
raise ValueError(
|
||||
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
||||
if model is not None:
|
||||
await can_key_call_model(
|
||||
model=model,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
valid_token.models = filtered_models
|
||||
verbose_proxy_logger.debug(
|
||||
f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
|
||||
)
|
||||
|
||||
if fallback_models is not None:
|
||||
for m in fallback_models:
|
||||
await can_key_call_model(
|
||||
model=m,
|
||||
llm_model_list=llm_model_list,
|
||||
valid_token=valid_token,
|
||||
)
|
||||
|
||||
# Check 2. If user_id for this token is in budget
|
||||
if valid_token.user_id is not None:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue