mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 10:44:24 +00:00
Merge pull request #1483 from BerriAI/litellm_model_access_groups_feature
feat(proxy_server.py): support model access groups
This commit is contained in:
commit
e9ac001005
7 changed files with 76 additions and 32 deletions
|
@ -325,7 +325,28 @@ async def user_api_key_auth(
|
|||
model = data.get("model", None)
|
||||
if model in litellm.model_alias_map:
|
||||
model = litellm.model_alias_map[model]
|
||||
if model and model not in valid_token.models:
|
||||
|
||||
## check if model in allowed model names
|
||||
verbose_proxy_logger.debug(
|
||||
f"LLM Model List pre access group check: {llm_model_list}"
|
||||
)
|
||||
access_groups = []
|
||||
for m in llm_model_list:
|
||||
for group in m.get("model_info", {}).get("access_groups", []):
|
||||
access_groups.append((m["model_name"], group))
|
||||
|
||||
allowed_models = valid_token.models
|
||||
if (
|
||||
len(access_groups) > 0
|
||||
): # check if token contains any model access groups
|
||||
for m in valid_token.models:
|
||||
for model_name, group in access_groups:
|
||||
if m == group:
|
||||
allowed_models.append(model_name)
|
||||
verbose_proxy_logger.debug(
|
||||
f"model: {model}; allowed_models: {allowed_models}"
|
||||
)
|
||||
if model is not None and model not in allowed_models:
|
||||
raise ValueError(
|
||||
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
||||
)
|
||||
|
@ -1057,6 +1078,7 @@ async def generate_key_helper_fn(
|
|||
"user_email": user_email,
|
||||
"user_id": user_id,
|
||||
"spend": spend,
|
||||
"models": models,
|
||||
}
|
||||
key_data = {
|
||||
"token": token,
|
||||
|
@ -1070,14 +1092,33 @@ async def generate_key_helper_fn(
|
|||
"metadata": metadata_json,
|
||||
}
|
||||
if prisma_client is not None:
|
||||
verification_token_data = dict(key_data)
|
||||
verification_token_data.update(user_data)
|
||||
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
|
||||
await prisma_client.insert_data(data=verification_token_data)
|
||||
## CREATE USER (If necessary)
|
||||
verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}")
|
||||
user_row = await prisma_client.insert_data(
|
||||
data=user_data, table_name="user"
|
||||
)
|
||||
|
||||
## use default user model list if no key-specific model list provided
|
||||
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||
key_data["models"] = user_row.models
|
||||
## CREATE KEY
|
||||
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}")
|
||||
await prisma_client.insert_data(data=key_data, table_name="key")
|
||||
elif custom_db_client is not None:
|
||||
## CREATE USER (If necessary)
|
||||
verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}")
|
||||
await custom_db_client.insert_data(value=user_data, table_name="user")
|
||||
user_row = await custom_db_client.insert_data(
|
||||
value=user_data, table_name="user"
|
||||
)
|
||||
if user_row is None:
|
||||
# GET USER ROW
|
||||
user_row = await custom_db_client.get_data(
|
||||
key=user_id, table_name="user"
|
||||
)
|
||||
|
||||
## use default user model list if no key-specific model list provided
|
||||
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||
key_data["models"] = user_row.models
|
||||
## CREATE KEY
|
||||
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}")
|
||||
await custom_db_client.insert_data(value=key_data, table_name="key")
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue