feat(proxy_server.py): team based model aliases

allow setting model aliases at a team level (e.g. route all 'gpt-3.5-turbo' requests from team-1 to model-deployment-group-2)
This commit is contained in:
Krrish Dholakia 2024-03-06 17:42:08 -08:00
parent cdb960eb34
commit ca97ea8acd
5 changed files with 79 additions and 7 deletions

View file

@ -212,6 +212,12 @@ class KeyRequest(LiteLLMBase):
keys: List[str] keys: List[str]
class LiteLLM_ModelTable(LiteLLMBase):
model_aliases: Optional[str] = None # json dump the dict
created_by: str
updated_by: str
class NewUserRequest(GenerateKeyRequest): class NewUserRequest(GenerateKeyRequest):
max_budget: Optional[float] = None max_budget: Optional[float] = None
user_email: Optional[str] = None user_email: Optional[str] = None
@ -251,7 +257,7 @@ class Member(LiteLLMBase):
return values return values
class NewTeamRequest(LiteLLMBase): class TeamBase(LiteLLMBase):
team_alias: Optional[str] = None team_alias: Optional[str] = None
team_id: Optional[str] = None team_id: Optional[str] = None
organization_id: Optional[str] = None organization_id: Optional[str] = None
@ -265,6 +271,10 @@ class NewTeamRequest(LiteLLMBase):
models: list = [] models: list = []
class NewTeamRequest(TeamBase):
model_aliases: Optional[dict] = None
class GlobalEndUsersSpend(LiteLLMBase): class GlobalEndUsersSpend(LiteLLMBase):
api_key: Optional[str] = None api_key: Optional[str] = None
@ -299,11 +309,12 @@ class DeleteTeamRequest(LiteLLMBase):
team_ids: List[str] # required team_ids: List[str] # required
class LiteLLM_TeamTable(NewTeamRequest): class LiteLLM_TeamTable(TeamBase):
spend: Optional[float] = None spend: Optional[float] = None
max_parallel_requests: Optional[int] = None max_parallel_requests: Optional[int] = None
budget_duration: Optional[str] = None budget_duration: Optional[str] = None
budget_reset_at: Optional[datetime] = None budget_reset_at: Optional[datetime] = None
model_id: Optional[int] = None
@root_validator(pre=True) @root_validator(pre=True)
def set_model_info(cls, values): def set_model_info(cls, values):
@ -313,6 +324,7 @@ class LiteLLM_TeamTable(NewTeamRequest):
"config", "config",
"permissions", "permissions",
"model_max_budget", "model_max_budget",
"model_aliases",
] ]
for field in dict_fields: for field in dict_fields:
value = values.get(field) value = values.get(field)
@ -542,6 +554,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
team_rpm_limit: Optional[int] = None team_rpm_limit: Optional[int] = None
team_max_budget: Optional[float] = None team_max_budget: Optional[float] = None
soft_budget: Optional[float] = None soft_budget: Optional[float] = None
team_model_aliases: Optional[Dict] = None
class UserAPIKeyAuth( class UserAPIKeyAuth(

View file

@ -405,7 +405,15 @@ async def user_api_key_auth(
) # request data, used across all checks. Making this easily available ) # request data, used across all checks. Making this easily available
# Check 1. If token can call model # Check 1. If token can call model
litellm.model_alias_map = valid_token.aliases _model_alias_map = {}
if valid_token.team_model_aliases is not None:
_model_alias_map = {
**valid_token.aliases,
**valid_token.team_model_aliases,
}
else:
_model_alias_map = {**valid_token.aliases}
litellm.model_alias_map = _model_alias_map
config = valid_token.config config = valid_token.config
if config != {}: if config != {}:
model_list = config.get("model_list", []) model_list = config.get("model_list", [])
@ -5020,11 +5028,27 @@ async def new_team(
Member(role="admin", user_id=user_api_key_dict.user_id) Member(role="admin", user_id=user_api_key_dict.user_id)
) )
## ADD TO MODEL TABLE
_model_id = None
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
litellm_modeltable = LiteLLM_ModelTable(
model_aliases=json.dumps(data.model_aliases),
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
model_dict = await prisma_client.db.litellm_modeltable.create(
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
) # type: ignore
_model_id = model_dict.id
## ADD TO TEAM TABLE
complete_team_data = LiteLLM_TeamTable( complete_team_data = LiteLLM_TeamTable(
**data.json(), **data.json(),
max_parallel_requests=user_api_key_dict.max_parallel_requests, max_parallel_requests=user_api_key_dict.max_parallel_requests,
budget_duration=user_api_key_dict.budget_duration, budget_duration=user_api_key_dict.budget_duration,
budget_reset_at=user_api_key_dict.budget_reset_at, budget_reset_at=user_api_key_dict.budget_reset_at,
model_id=_model_id,
) )
team_row = await prisma_client.insert_data( team_row = await prisma_client.insert_data(
@ -5495,7 +5519,7 @@ async def new_organization(
- `organization_alias`: *str* = The name of the organization. - `organization_alias`: *str* = The name of the organization.
- `models`: *List* = The models the organization has access to. - `models`: *List* = The models the organization has access to.
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization. - `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ### ### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
- `max_budget`: *Optional[float]* = Max budget for org - `max_budget`: *Optional[float]* = Max budget for org
- `tpm_limit`: *Optional[int]* = Max tpm limit for org - `tpm_limit`: *Optional[int]* = Max tpm limit for org
- `rpm_limit`: *Optional[int]* = Max rpm limit for org - `rpm_limit`: *Optional[int]* = Max rpm limit for org

View file

@ -42,6 +42,17 @@ model LiteLLM_OrganizationTable {
teams LiteLLM_TeamTable[] teams LiteLLM_TeamTable[]
} }
// Model info for teams, just has model aliases for now.
model LiteLLM_ModelTable {
id Int @id @default(autoincrement())
model_aliases Json? @map("aliases")
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
team LiteLLM_TeamTable?
}
// Assign prod keys to groups, not individuals // Assign prod keys to groups, not individuals
model LiteLLM_TeamTable { model LiteLLM_TeamTable {
team_id String @id @default(uuid()) team_id String @id @default(uuid())
@ -63,7 +74,9 @@ model LiteLLM_TeamTable {
updated_at DateTime @default(now()) @updatedAt @map("updated_at") updated_at DateTime @default(now()) @updatedAt @map("updated_at")
model_spend Json @default("{}") model_spend Json @default("{}")
model_max_budget Json @default("{}") model_max_budget Json @default("{}")
model_id Int? @unique
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
} }
// Track spend, rate limit, budget Users // Track spend, rate limit, budget Users

View file

@ -965,12 +965,21 @@ class PrismaClient:
) )
sql_query = f""" sql_query = f"""
SELECT * SELECT
FROM "LiteLLM_VerificationTokenView" v.*,
WHERE token = '{token}' t.spend AS team_spend,
t.max_budget AS team_max_budget,
t.tpm_limit AS team_tpm_limit,
t.rpm_limit AS team_rpm_limit,
m.aliases as team_model_aliases
FROM "LiteLLM_VerificationToken" AS v
INNER JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
WHERE v.token = '{token}'
""" """
response = await self.db.query_first(query=sql_query) response = await self.db.query_first(query=sql_query)
if response is not None: if response is not None:
response = LiteLLM_VerificationTokenView(**response) response = LiteLLM_VerificationTokenView(**response)
# for prisma we need to cast the expires time to str # for prisma we need to cast the expires time to str

View file

@ -42,6 +42,17 @@ model LiteLLM_OrganizationTable {
teams LiteLLM_TeamTable[] teams LiteLLM_TeamTable[]
} }
// Model info for teams, just has model aliases for now.
model LiteLLM_ModelTable {
id Int @id @default(autoincrement())
model_aliases Json? @map("aliases")
created_at DateTime @default(now()) @map("created_at")
created_by String
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
updated_by String
team LiteLLM_TeamTable?
}
// Assign prod keys to groups, not individuals // Assign prod keys to groups, not individuals
model LiteLLM_TeamTable { model LiteLLM_TeamTable {
team_id String @id @default(uuid()) team_id String @id @default(uuid())
@ -63,7 +74,9 @@ model LiteLLM_TeamTable {
updated_at DateTime @default(now()) @updatedAt @map("updated_at") updated_at DateTime @default(now()) @updatedAt @map("updated_at")
model_spend Json @default("{}") model_spend Json @default("{}")
model_max_budget Json @default("{}") model_max_budget Json @default("{}")
model_id Int? @unique
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
} }
// Track spend, rate limit, budget Users // Track spend, rate limit, budget Users