diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 7ae67bdc6..fd85280dd 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -212,6 +212,12 @@ class KeyRequest(LiteLLMBase): keys: List[str] +class LiteLLM_ModelTable(LiteLLMBase): + model_aliases: Optional[str] = None # json dump the dict + created_by: str + updated_by: str + + class NewUserRequest(GenerateKeyRequest): max_budget: Optional[float] = None user_email: Optional[str] = None @@ -251,7 +257,7 @@ class Member(LiteLLMBase): return values -class NewTeamRequest(LiteLLMBase): +class TeamBase(LiteLLMBase): team_alias: Optional[str] = None team_id: Optional[str] = None organization_id: Optional[str] = None @@ -265,6 +271,10 @@ class NewTeamRequest(LiteLLMBase): models: list = [] +class NewTeamRequest(TeamBase): + model_aliases: Optional[dict] = None + + class GlobalEndUsersSpend(LiteLLMBase): api_key: Optional[str] = None @@ -299,11 +309,12 @@ class DeleteTeamRequest(LiteLLMBase): team_ids: List[str] # required -class LiteLLM_TeamTable(NewTeamRequest): +class LiteLLM_TeamTable(TeamBase): spend: Optional[float] = None max_parallel_requests: Optional[int] = None budget_duration: Optional[str] = None budget_reset_at: Optional[datetime] = None + model_id: Optional[int] = None @root_validator(pre=True) def set_model_info(cls, values): @@ -313,6 +324,7 @@ class LiteLLM_TeamTable(NewTeamRequest): "config", "permissions", "model_max_budget", + "model_aliases", ] for field in dict_fields: value = values.get(field) @@ -542,6 +554,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken): team_rpm_limit: Optional[int] = None team_max_budget: Optional[float] = None soft_budget: Optional[float] = None + team_model_aliases: Optional[Dict] = None class UserAPIKeyAuth( diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 409cf63d5..5ee3b751f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -405,7 +405,15 @@ async def user_api_key_auth( ) # request data, used across all checks. Making this easily available # Check 1. If token can call model - litellm.model_alias_map = valid_token.aliases + _model_alias_map = {} + if valid_token.team_model_aliases is not None: + _model_alias_map = { + **valid_token.aliases, + **valid_token.team_model_aliases, + } + else: + _model_alias_map = {**valid_token.aliases} + litellm.model_alias_map = _model_alias_map config = valid_token.config if config != {}: model_list = config.get("model_list", []) @@ -5020,11 +5028,27 @@ async def new_team( Member(role="admin", user_id=user_api_key_dict.user_id) ) + ## ADD TO MODEL TABLE + _model_id = None + if data.model_aliases is not None and isinstance(data.model_aliases, dict): + litellm_modeltable = LiteLLM_ModelTable( + model_aliases=json.dumps(data.model_aliases), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + model_dict = await prisma_client.db.litellm_modeltable.create( + {**litellm_modeltable.json(exclude_none=True)} # type: ignore + ) # type: ignore + + _model_id = model_dict.id + + ## ADD TO TEAM TABLE complete_team_data = LiteLLM_TeamTable( **data.json(), max_parallel_requests=user_api_key_dict.max_parallel_requests, budget_duration=user_api_key_dict.budget_duration, budget_reset_at=user_api_key_dict.budget_reset_at, + model_id=_model_id, ) team_row = await prisma_client.insert_data( @@ -5495,7 +5519,7 @@ async def new_organization( - `organization_alias`: *str* = The name of the organization. - `models`: *List* = The models the organization has access to. - `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization. - ### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ### + ### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ### - `max_budget`: *Optional[float]* = Max budget for org - `tpm_limit`: *Optional[int]* = Max tpm limit for org - `rpm_limit`: *Optional[int]* = Max rpm limit for org diff --git a/litellm/proxy/schema.prisma b/litellm/proxy/schema.prisma index 265bf32c0..d8c8faf16 100644 --- a/litellm/proxy/schema.prisma +++ b/litellm/proxy/schema.prisma @@ -42,6 +42,17 @@ model LiteLLM_OrganizationTable { teams LiteLLM_TeamTable[] } +// Model info for teams, just has model aliases for now. +model LiteLLM_ModelTable { + id Int @id @default(autoincrement()) + model_aliases Json? @map("aliases") + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String + team LiteLLM_TeamTable? +} + // Assign prod keys to groups, not individuals model LiteLLM_TeamTable { team_id String @id @default(uuid()) @@ -63,7 +74,9 @@ model LiteLLM_TeamTable { updated_at DateTime @default(now()) @updatedAt @map("updated_at") model_spend Json @default("{}") model_max_budget Json @default("{}") + model_id Int? @unique litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) + litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id]) } // Track spend, rate limit, budget Users diff --git a/litellm/proxy/utils.py b/litellm/proxy/utils.py index 1e701515e..ee5e323e8 100644 --- a/litellm/proxy/utils.py +++ b/litellm/proxy/utils.py @@ -965,12 +965,21 @@ class PrismaClient: ) sql_query = f""" - SELECT * - FROM "LiteLLM_VerificationTokenView" - WHERE token = '{token}' + SELECT + v.*, + t.spend AS team_spend, + t.max_budget AS team_max_budget, + t.tpm_limit AS team_tpm_limit, + t.rpm_limit AS team_rpm_limit, + m.aliases as team_model_aliases + FROM "LiteLLM_VerificationToken" AS v + INNER JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id + LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id + WHERE v.token = '{token}' """ response = await self.db.query_first(query=sql_query) + if response is not None: response = LiteLLM_VerificationTokenView(**response) # for prisma we need to cast the expires time to str diff --git a/schema.prisma b/schema.prisma index 265bf32c0..d8c8faf16 100644 --- a/schema.prisma +++ b/schema.prisma @@ -42,6 +42,17 @@ model LiteLLM_OrganizationTable { teams LiteLLM_TeamTable[] } +// Model info for teams, just has model aliases for now. +model LiteLLM_ModelTable { + id Int @id @default(autoincrement()) + model_aliases Json? @map("aliases") + created_at DateTime @default(now()) @map("created_at") + created_by String + updated_at DateTime @default(now()) @updatedAt @map("updated_at") + updated_by String + team LiteLLM_TeamTable? +} + // Assign prod keys to groups, not individuals model LiteLLM_TeamTable { team_id String @id @default(uuid()) @@ -63,7 +74,9 @@ model LiteLLM_TeamTable { updated_at DateTime @default(now()) @updatedAt @map("updated_at") model_spend Json @default("{}") model_max_budget Json @default("{}") + model_id Int? @unique litellm_organization_table LiteLLM_OrganizationTable? @relation(fields: [organization_id], references: [organization_id]) + litellm_model_table LiteLLM_ModelTable? @relation(fields: [model_id], references: [id]) } // Track spend, rate limit, budget Users