forked from phoenix/litellm-mirror
feat(proxy_server.py): support model access groups
This commit is contained in:
parent
70c4227f5e
commit
98b83fa780
6 changed files with 56 additions and 19 deletions
|
@ -307,9 +307,12 @@ class LiteLLM_UserTable(LiteLLMBase):
|
||||||
max_budget: Optional[float]
|
max_budget: Optional[float]
|
||||||
spend: float = 0.0
|
spend: float = 0.0
|
||||||
user_email: Optional[str]
|
user_email: Optional[str]
|
||||||
|
models: list = []
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
if values.get("spend") is None:
|
if values.get("spend") is None:
|
||||||
values.update({"spend": 0.0})
|
values.update({"spend": 0.0})
|
||||||
|
if values.get("models") is None:
|
||||||
|
values.update({"models", []})
|
||||||
return values
|
return values
|
||||||
|
|
|
@ -171,7 +171,7 @@ class DynamoDBWrapper(CustomDB):
|
||||||
if isinstance(v, datetime):
|
if isinstance(v, datetime):
|
||||||
value[k] = v.isoformat()
|
value[k] = v.isoformat()
|
||||||
|
|
||||||
await table.put_item(item=value)
|
return await table.put_item(item=value)
|
||||||
|
|
||||||
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
|
||||||
from aiodynamo.client import Client
|
from aiodynamo.client import Client
|
||||||
|
|
|
@ -325,7 +325,28 @@ async def user_api_key_auth(
|
||||||
model = data.get("model", None)
|
model = data.get("model", None)
|
||||||
if model in litellm.model_alias_map:
|
if model in litellm.model_alias_map:
|
||||||
model = litellm.model_alias_map[model]
|
model = litellm.model_alias_map[model]
|
||||||
if model and model not in valid_token.models:
|
|
||||||
|
## check if model in allowed model names
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"LLM Model List pre access group check: {llm_model_list}"
|
||||||
|
)
|
||||||
|
access_groups = []
|
||||||
|
for m in llm_model_list:
|
||||||
|
for group in m.get("model_info", {}).get("access_groups", []):
|
||||||
|
access_groups.append((m["model_name"], group))
|
||||||
|
|
||||||
|
allowed_models = valid_token.models
|
||||||
|
if (
|
||||||
|
len(access_groups) > 0
|
||||||
|
): # check if token contains any model access groups
|
||||||
|
for m in valid_token.models:
|
||||||
|
for model_name, group in access_groups:
|
||||||
|
if m == group:
|
||||||
|
allowed_models.append(model_name)
|
||||||
|
verbose_proxy_logger.debug(
|
||||||
|
f"model: {model}; allowed_models: {allowed_models}"
|
||||||
|
)
|
||||||
|
if model is not None and model not in allowed_models:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
|
||||||
)
|
)
|
||||||
|
@ -1057,6 +1078,7 @@ async def generate_key_helper_fn(
|
||||||
"user_email": user_email,
|
"user_email": user_email,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
"spend": spend,
|
"spend": spend,
|
||||||
|
"models": models,
|
||||||
}
|
}
|
||||||
key_data = {
|
key_data = {
|
||||||
"token": token,
|
"token": token,
|
||||||
|
@ -1070,14 +1092,28 @@ async def generate_key_helper_fn(
|
||||||
"metadata": metadata_json,
|
"metadata": metadata_json,
|
||||||
}
|
}
|
||||||
if prisma_client is not None:
|
if prisma_client is not None:
|
||||||
verification_token_data = dict(key_data)
|
## CREATE USER (If necessary)
|
||||||
verification_token_data.update(user_data)
|
verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}")
|
||||||
verbose_proxy_logger.debug("PrismaClient: Before Insert Data")
|
user_row = await prisma_client.insert_data(
|
||||||
await prisma_client.insert_data(data=verification_token_data)
|
data=user_data, table_name="user"
|
||||||
|
)
|
||||||
|
|
||||||
|
## use default user model list if no key-specific model list provided
|
||||||
|
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||||
|
key_data["models"] = user_row.models
|
||||||
|
## CREATE KEY
|
||||||
|
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}")
|
||||||
|
await prisma_client.insert_data(data=key_data, table_name="key")
|
||||||
elif custom_db_client is not None:
|
elif custom_db_client is not None:
|
||||||
## CREATE USER (If necessary)
|
## CREATE USER (If necessary)
|
||||||
verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}")
|
verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}")
|
||||||
await custom_db_client.insert_data(value=user_data, table_name="user")
|
user_row = await custom_db_client.insert_data(
|
||||||
|
value=user_data, table_name="user"
|
||||||
|
)
|
||||||
|
|
||||||
|
## use default user model list if no key-specific model list provided
|
||||||
|
if len(user_row["models"]) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||||
|
key_data["models"] = user_row["models"]
|
||||||
## CREATE KEY
|
## CREATE KEY
|
||||||
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}")
|
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}")
|
||||||
await custom_db_client.insert_data(value=key_data, table_name="key")
|
await custom_db_client.insert_data(value=key_data, table_name="key")
|
||||||
|
|
|
@ -12,6 +12,7 @@ model LiteLLM_UserTable {
|
||||||
max_budget Float?
|
max_budget Float?
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
user_email String?
|
user_email String?
|
||||||
|
models String[] @default([])
|
||||||
}
|
}
|
||||||
|
|
||||||
// required for token gen
|
// required for token gen
|
||||||
|
@ -19,7 +20,7 @@ model LiteLLM_VerificationToken {
|
||||||
token String @unique
|
token String @unique
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
expires DateTime?
|
expires DateTime?
|
||||||
models String[]
|
models String[] @default([])
|
||||||
aliases Json @default("{}")
|
aliases Json @default("{}")
|
||||||
config Json @default("{}")
|
config Json @default("{}")
|
||||||
user_id String?
|
user_id String?
|
||||||
|
|
|
@ -409,19 +409,17 @@ class PrismaClient:
|
||||||
on_backoff=on_backoff, # specifying the function to call on backoff
|
on_backoff=on_backoff, # specifying the function to call on backoff
|
||||||
)
|
)
|
||||||
async def insert_data(
|
async def insert_data(
|
||||||
self, data: dict, table_name: Literal["user+key", "config"] = "user+key"
|
self, data: dict, table_name: Literal["user", "key", "config"]
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Add a key to the database. If it already exists, do nothing.
|
Add a key to the database. If it already exists, do nothing.
|
||||||
"""
|
"""
|
||||||
try:
|
try:
|
||||||
if table_name == "user+key":
|
if table_name == "key":
|
||||||
token = data["token"]
|
token = data["token"]
|
||||||
hashed_token = self.hash_token(token=token)
|
hashed_token = self.hash_token(token=token)
|
||||||
db_data = self.jsonify_object(data=data)
|
db_data = self.jsonify_object(data=data)
|
||||||
db_data["token"] = hashed_token
|
db_data["token"] = hashed_token
|
||||||
max_budget = db_data.pop("max_budget", None)
|
|
||||||
user_email = db_data.pop("user_email", None)
|
|
||||||
print_verbose(
|
print_verbose(
|
||||||
"PrismaClient: Before upsert into litellm_verificationtoken"
|
"PrismaClient: Before upsert into litellm_verificationtoken"
|
||||||
)
|
)
|
||||||
|
@ -434,19 +432,17 @@ class PrismaClient:
|
||||||
"update": {}, # don't do anything if it already exists
|
"update": {}, # don't do anything if it already exists
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
return new_verification_token
|
||||||
|
elif table_name == "user":
|
||||||
|
db_data = self.jsonify_object(data=data)
|
||||||
new_user_row = await self.db.litellm_usertable.upsert(
|
new_user_row = await self.db.litellm_usertable.upsert(
|
||||||
where={"user_id": data["user_id"]},
|
where={"user_id": data["user_id"]},
|
||||||
data={
|
data={
|
||||||
"create": {
|
"create": {**db_data}, # type: ignore
|
||||||
"user_id": data["user_id"],
|
|
||||||
"max_budget": max_budget,
|
|
||||||
"user_email": user_email,
|
|
||||||
},
|
|
||||||
"update": {}, # don't do anything if it already exists
|
"update": {}, # don't do anything if it already exists
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
return new_verification_token
|
return new_user_row
|
||||||
elif table_name == "config":
|
elif table_name == "config":
|
||||||
"""
|
"""
|
||||||
For each param,
|
For each param,
|
||||||
|
|
|
@ -12,6 +12,7 @@ model LiteLLM_UserTable {
|
||||||
max_budget Float?
|
max_budget Float?
|
||||||
spend Float @default(0.0)
|
spend Float @default(0.0)
|
||||||
user_email String?
|
user_email String?
|
||||||
|
models String[]
|
||||||
}
|
}
|
||||||
|
|
||||||
// required for token gen
|
// required for token gen
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue