Merge pull request #1483 from BerriAI/litellm_model_access_groups_feature

feat(proxy_server.py): support model access groups
This commit is contained in:
Krish Dholakia 2024-01-17 18:16:53 -08:00 committed by GitHub
commit e9ac001005
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
7 changed files with 76 additions and 32 deletions

View file

@ -303,9 +303,12 @@ class LiteLLM_UserTable(LiteLLMBase):
max_budget: Optional[float] max_budget: Optional[float]
spend: float = 0.0 spend: float = 0.0
user_email: Optional[str] user_email: Optional[str]
models: list = []
@root_validator(pre=True) @root_validator(pre=True)
def set_model_info(cls, values): def set_model_info(cls, values):
if values.get("spend") is None: if values.get("spend") is None:
values.update({"spend": 0.0}) values.update({"spend": 0.0})
if values.get("models") is None:
values.update({"models", []})
return values return values

View file

@ -171,7 +171,7 @@ class DynamoDBWrapper(CustomDB):
if isinstance(v, datetime): if isinstance(v, datetime):
value[k] = v.isoformat() value[k] = v.isoformat()
await table.put_item(item=value) return await table.put_item(item=value, return_values=ReturnValues.all_old)
async def get_data(self, key: str, table_name: Literal["user", "key", "config"]): async def get_data(self, key: str, table_name: Literal["user", "key", "config"]):
from aiodynamo.client import Client from aiodynamo.client import Client

View file

@ -405,6 +405,7 @@ def run_server(
is_prisma_runnable = False is_prisma_runnable = False
if is_prisma_runnable: if is_prisma_runnable:
for _ in range(4):
# run prisma db push, before starting server # run prisma db push, before starting server
# Save the current working directory # Save the current working directory
original_dir = os.getcwd() original_dir = os.getcwd()
@ -413,9 +414,10 @@ def run_server(
dname = os.path.dirname(abspath) dname = os.path.dirname(abspath)
os.chdir(dname) os.chdir(dname)
try: try:
subprocess.run( subprocess.run(["prisma", "db", "push", "--accept-data-loss"])
["prisma", "db", "push", "--accept-data-loss"] break # Exit the loop if the subprocess succeeds
) # this looks like a weird edge case when prisma just wont start on render. we need to have the --accept-data-loss except subprocess.CalledProcessError as e:
print(f"Error: {e}")
finally: finally:
os.chdir(original_dir) os.chdir(original_dir)
else: else:

View file

@ -325,7 +325,28 @@ async def user_api_key_auth(
model = data.get("model", None) model = data.get("model", None)
if model in litellm.model_alias_map: if model in litellm.model_alias_map:
model = litellm.model_alias_map[model] model = litellm.model_alias_map[model]
if model and model not in valid_token.models:
## check if model in allowed model names
verbose_proxy_logger.debug(
f"LLM Model List pre access group check: {llm_model_list}"
)
access_groups = []
for m in llm_model_list:
for group in m.get("model_info", {}).get("access_groups", []):
access_groups.append((m["model_name"], group))
allowed_models = valid_token.models
if (
len(access_groups) > 0
): # check if token contains any model access groups
for m in valid_token.models:
for model_name, group in access_groups:
if m == group:
allowed_models.append(model_name)
verbose_proxy_logger.debug(
f"model: {model}; allowed_models: {allowed_models}"
)
if model is not None and model not in allowed_models:
raise ValueError( raise ValueError(
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}" f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
) )
@ -1057,6 +1078,7 @@ async def generate_key_helper_fn(
"user_email": user_email, "user_email": user_email,
"user_id": user_id, "user_id": user_id,
"spend": spend, "spend": spend,
"models": models,
} }
key_data = { key_data = {
"token": token, "token": token,
@ -1070,14 +1092,33 @@ async def generate_key_helper_fn(
"metadata": metadata_json, "metadata": metadata_json,
} }
if prisma_client is not None: if prisma_client is not None:
verification_token_data = dict(key_data) ## CREATE USER (If necessary)
verification_token_data.update(user_data) verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}")
verbose_proxy_logger.debug("PrismaClient: Before Insert Data") user_row = await prisma_client.insert_data(
await prisma_client.insert_data(data=verification_token_data) data=user_data, table_name="user"
)
## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models
## CREATE KEY
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}")
await prisma_client.insert_data(data=key_data, table_name="key")
elif custom_db_client is not None: elif custom_db_client is not None:
## CREATE USER (If necessary) ## CREATE USER (If necessary)
verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}") verbose_proxy_logger.debug(f"CustomDBClient: Creating User={user_data}")
await custom_db_client.insert_data(value=user_data, table_name="user") user_row = await custom_db_client.insert_data(
value=user_data, table_name="user"
)
if user_row is None:
# GET USER ROW
user_row = await custom_db_client.get_data(
key=user_id, table_name="user"
)
## use default user model list if no key-specific model list provided
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
key_data["models"] = user_row.models
## CREATE KEY ## CREATE KEY
verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}") verbose_proxy_logger.debug(f"CustomDBClient: Creating Key={key_data}")
await custom_db_client.insert_data(value=key_data, table_name="key") await custom_db_client.insert_data(value=key_data, table_name="key")

View file

@ -12,6 +12,7 @@ model LiteLLM_UserTable {
max_budget Float? max_budget Float?
spend Float @default(0.0) spend Float @default(0.0)
user_email String? user_email String?
models String[] @default([])
} }
// required for token gen // required for token gen
@ -19,7 +20,7 @@ model LiteLLM_VerificationToken {
token String @unique token String @unique
spend Float @default(0.0) spend Float @default(0.0)
expires DateTime? expires DateTime?
models String[] models String[] @default([])
aliases Json @default("{}") aliases Json @default("{}")
config Json @default("{}") config Json @default("{}")
user_id String? user_id String?

View file

@ -412,19 +412,17 @@ class PrismaClient:
on_backoff=on_backoff, # specifying the function to call on backoff on_backoff=on_backoff, # specifying the function to call on backoff
) )
async def insert_data( async def insert_data(
self, data: dict, table_name: Literal["user+key", "config"] = "user+key" self, data: dict, table_name: Literal["user", "key", "config"]
): ):
""" """
Add a key to the database. If it already exists, do nothing. Add a key to the database. If it already exists, do nothing.
""" """
try: try:
if table_name == "user+key": if table_name == "key":
token = data["token"] token = data["token"]
hashed_token = self.hash_token(token=token) hashed_token = self.hash_token(token=token)
db_data = self.jsonify_object(data=data) db_data = self.jsonify_object(data=data)
db_data["token"] = hashed_token db_data["token"] = hashed_token
max_budget = db_data.pop("max_budget", None)
user_email = db_data.pop("user_email", None)
print_verbose( print_verbose(
"PrismaClient: Before upsert into litellm_verificationtoken" "PrismaClient: Before upsert into litellm_verificationtoken"
) )
@ -437,19 +435,17 @@ class PrismaClient:
"update": {}, # don't do anything if it already exists "update": {}, # don't do anything if it already exists
}, },
) )
return new_verification_token
elif table_name == "user":
db_data = self.jsonify_object(data=data)
new_user_row = await self.db.litellm_usertable.upsert( new_user_row = await self.db.litellm_usertable.upsert(
where={"user_id": data["user_id"]}, where={"user_id": data["user_id"]},
data={ data={
"create": { "create": {**db_data}, # type: ignore
"user_id": data["user_id"],
"max_budget": max_budget,
"user_email": user_email,
},
"update": {}, # don't do anything if it already exists "update": {}, # don't do anything if it already exists
}, },
) )
return new_verification_token return new_user_row
elif table_name == "config": elif table_name == "config":
""" """
For each param, For each param,

View file

@ -12,6 +12,7 @@ model LiteLLM_UserTable {
max_budget Float? max_budget Float?
spend Float @default(0.0) spend Float @default(0.0)
user_email String? user_email String?
models String[]
} }
// required for token gen // required for token gen