forked from phoenix/litellm-mirror
feat(proxy_server.py): team based model aliases
allow setting model aliases at a team level (e.g. route all 'gpt-3.5-turbo' requests from team-1 to model-deployment-group-2)
This commit is contained in:
parent
cdb960eb34
commit
ca97ea8acd
5 changed files with 79 additions and 7 deletions
|
@ -212,6 +212,12 @@ class KeyRequest(LiteLLMBase):
|
||||||
keys: List[str]
|
keys: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_ModelTable(LiteLLMBase):
|
||||||
|
model_aliases: Optional[str] = None # json dump the dict
|
||||||
|
created_by: str
|
||||||
|
updated_by: str
|
||||||
|
|
||||||
|
|
||||||
class NewUserRequest(GenerateKeyRequest):
|
class NewUserRequest(GenerateKeyRequest):
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
|
@ -251,7 +257,7 @@ class Member(LiteLLMBase):
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
class NewTeamRequest(LiteLLMBase):
|
class TeamBase(LiteLLMBase):
|
||||||
team_alias: Optional[str] = None
|
team_alias: Optional[str] = None
|
||||||
team_id: Optional[str] = None
|
team_id: Optional[str] = None
|
||||||
organization_id: Optional[str] = None
|
organization_id: Optional[str] = None
|
||||||
|
@ -265,6 +271,10 @@ class NewTeamRequest(LiteLLMBase):
|
||||||
models: list = []
|
models: list = []
|
||||||
|
|
||||||
|
|
||||||
|
class NewTeamRequest(TeamBase):
|
||||||
|
model_aliases: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class GlobalEndUsersSpend(LiteLLMBase):
|
class GlobalEndUsersSpend(LiteLLMBase):
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
@ -299,11 +309,12 @@ class DeleteTeamRequest(LiteLLMBase):
|
||||||
team_ids: List[str] # required
|
team_ids: List[str] # required
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_TeamTable(NewTeamRequest):
|
class LiteLLM_TeamTable(TeamBase):
|
||||||
spend: Optional[float] = None
|
spend: Optional[float] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
budget_reset_at: Optional[datetime] = None
|
budget_reset_at: Optional[datetime] = None
|
||||||
|
model_id: Optional[int] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
|
@ -313,6 +324,7 @@ class LiteLLM_TeamTable(NewTeamRequest):
|
||||||
"config",
|
"config",
|
||||||
"permissions",
|
"permissions",
|
||||||
"model_max_budget",
|
"model_max_budget",
|
||||||
|
"model_aliases",
|
||||||
]
|
]
|
||||||
for field in dict_fields:
|
for field in dict_fields:
|
||||||
value = values.get(field)
|
value = values.get(field)
|
||||||
|
@ -542,6 +554,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
team_rpm_limit: Optional[int] = None
|
team_rpm_limit: Optional[int] = None
|
||||||
team_max_budget: Optional[float] = None
|
team_max_budget: Optional[float] = None
|
||||||
soft_budget: Optional[float] = None
|
soft_budget: Optional[float] = None
|
||||||
|
team_model_aliases: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(
|
class UserAPIKeyAuth(
|
||||||
|
|
|
@ -405,7 +405,15 @@ async def user_api_key_auth(
|
||||||
) # request data, used across all checks. Making this easily available
|
) # request data, used across all checks. Making this easily available
|
||||||
|
|
||||||
# Check 1. If token can call model
|
# Check 1. If token can call model
|
||||||
litellm.model_alias_map = valid_token.aliases
|
_model_alias_map = {}
|
||||||
|
if valid_token.team_model_aliases is not None:
|
||||||
|
_model_alias_map = {
|
||||||
|
**valid_token.aliases,
|
||||||
|
**valid_token.team_model_aliases,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
_model_alias_map = {**valid_token.aliases}
|
||||||
|
litellm.model_alias_map = _model_alias_map
|
||||||
config = valid_token.config
|
config = valid_token.config
|
||||||
if config != {}:
|
if config != {}:
|
||||||
model_list = config.get("model_list", [])
|
model_list = config.get("model_list", [])
|
||||||
|
@ -5020,11 +5028,27 @@ async def new_team(
|
||||||
Member(role="admin", user_id=user_api_key_dict.user_id)
|
Member(role="admin", user_id=user_api_key_dict.user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## ADD TO MODEL TABLE
|
||||||
|
_model_id = None
|
||||||
|
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
|
||||||
|
litellm_modeltable = LiteLLM_ModelTable(
|
||||||
|
model_aliases=json.dumps(data.model_aliases),
|
||||||
|
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
)
|
||||||
|
model_dict = await prisma_client.db.litellm_modeltable.create(
|
||||||
|
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
_model_id = model_dict.id
|
||||||
|
|
||||||
|
## ADD TO TEAM TABLE
|
||||||
complete_team_data = LiteLLM_TeamTable(
|
complete_team_data = LiteLLM_TeamTable(
|
||||||
**data.json(),
|
**data.json(),
|
||||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||||
budget_duration=user_api_key_dict.budget_duration,
|
budget_duration=user_api_key_dict.budget_duration,
|
||||||
budget_reset_at=user_api_key_dict.budget_reset_at,
|
budget_reset_at=user_api_key_dict.budget_reset_at,
|
||||||
|
model_id=_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
team_row = await prisma_client.insert_data(
|
team_row = await prisma_client.insert_data(
|
||||||
|
@ -5495,7 +5519,7 @@ async def new_organization(
|
||||||
- `organization_alias`: *str* = The name of the organization.
|
- `organization_alias`: *str* = The name of the organization.
|
||||||
- `models`: *List* = The models the organization has access to.
|
- `models`: *List* = The models the organization has access to.
|
||||||
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
|
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
|
||||||
### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ###
|
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
|
||||||
- `max_budget`: *Optional[float]* = Max budget for org
|
- `max_budget`: *Optional[float]* = Max budget for org
|
||||||
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
|
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
|
||||||
- `rpm_limit`: *Optional[int]* = Max rpm limit for org
|
- `rpm_limit`: *Optional[int]* = Max rpm limit for org
|
||||||
|
|
|
@ -42,6 +42,17 @@ model LiteLLM_OrganizationTable {
|
||||||
teams LiteLLM_TeamTable[]
|
teams LiteLLM_TeamTable[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Model info for teams, just has model aliases for now.
|
||||||
|
model LiteLLM_ModelTable {
|
||||||
|
id Int @id @default(autoincrement())
|
||||||
|
model_aliases Json? @map("aliases")
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
team LiteLLM_TeamTable?
|
||||||
|
}
|
||||||
|
|
||||||
// Assign prod keys to groups, not individuals
|
// Assign prod keys to groups, not individuals
|
||||||
model LiteLLM_TeamTable {
|
model LiteLLM_TeamTable {
|
||||||
team_id String @id @default(uuid())
|
team_id String @id @default(uuid())
|
||||||
|
@ -63,7 +74,9 @@ model LiteLLM_TeamTable {
|
||||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
model_spend Json @default("{}")
|
model_spend Json @default("{}")
|
||||||
model_max_budget Json @default("{}")
|
model_max_budget Json @default("{}")
|
||||||
|
model_id Int? @unique
|
||||||
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
||||||
|
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track spend, rate limit, budget Users
|
// Track spend, rate limit, budget Users
|
||||||
|
|
|
@ -965,12 +965,21 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
|
|
||||||
sql_query = f"""
|
sql_query = f"""
|
||||||
SELECT *
|
SELECT
|
||||||
FROM "LiteLLM_VerificationTokenView"
|
v.*,
|
||||||
WHERE token = '{token}'
|
t.spend AS team_spend,
|
||||||
|
t.max_budget AS team_max_budget,
|
||||||
|
t.tpm_limit AS team_tpm_limit,
|
||||||
|
t.rpm_limit AS team_rpm_limit,
|
||||||
|
m.aliases as team_model_aliases
|
||||||
|
FROM "LiteLLM_VerificationToken" AS v
|
||||||
|
INNER JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
||||||
|
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
|
||||||
|
WHERE v.token = '{token}'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response = await self.db.query_first(query=sql_query)
|
response = await self.db.query_first(query=sql_query)
|
||||||
|
|
||||||
if response is not None:
|
if response is not None:
|
||||||
response = LiteLLM_VerificationTokenView(**response)
|
response = LiteLLM_VerificationTokenView(**response)
|
||||||
# for prisma we need to cast the expires time to str
|
# for prisma we need to cast the expires time to str
|
||||||
|
|
|
@ -42,6 +42,17 @@ model LiteLLM_OrganizationTable {
|
||||||
teams LiteLLM_TeamTable[]
|
teams LiteLLM_TeamTable[]
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Model info for teams, just has model aliases for now.
|
||||||
|
model LiteLLM_ModelTable {
|
||||||
|
id Int @id @default(autoincrement())
|
||||||
|
model_aliases Json? @map("aliases")
|
||||||
|
created_at DateTime @default(now()) @map("created_at")
|
||||||
|
created_by String
|
||||||
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
|
updated_by String
|
||||||
|
team LiteLLM_TeamTable?
|
||||||
|
}
|
||||||
|
|
||||||
// Assign prod keys to groups, not individuals
|
// Assign prod keys to groups, not individuals
|
||||||
model LiteLLM_TeamTable {
|
model LiteLLM_TeamTable {
|
||||||
team_id String @id @default(uuid())
|
team_id String @id @default(uuid())
|
||||||
|
@ -63,7 +74,9 @@ model LiteLLM_TeamTable {
|
||||||
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
updated_at DateTime @default(now()) @updatedAt @map("updated_at")
|
||||||
model_spend Json @default("{}")
|
model_spend Json @default("{}")
|
||||||
model_max_budget Json @default("{}")
|
model_max_budget Json @default("{}")
|
||||||
|
model_id Int? @unique
|
||||||
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id])
|
||||||
|
litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id])
|
||||||
}
|
}
|
||||||
|
|
||||||
// Track spend, rate limit, budget Users
|
// Track spend, rate limit, budget Users
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue