Merge pull request #2377 from BerriAI/litellm_team_level_model_groups

feat(proxy_server.py): team based model aliases
This commit is contained in:
Krish Dholakia 2024-03-06 21:03:53 -08:00 committed by GitHub
commit 38612ddd34
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 62 additions and 13 deletions

View file

@ -406,7 +406,15 @@ async def user_api_key_auth(
) # request data, used across all checks. Making this easily available
# Check 1. If token can call model
litellm.model_alias_map = valid_token.aliases
_model_alias_map = {}
if valid_token.team_model_aliases is not None:
_model_alias_map = {
**valid_token.aliases,
**valid_token.team_model_aliases,
}
else:
_model_alias_map = {**valid_token.aliases}
litellm.model_alias_map = _model_alias_map
config = valid_token.config
if config != {}:
model_list = config.get("model_list", [])
@ -824,7 +832,10 @@ async def user_api_key_auth(
raise Exception(
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
)
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
if valid_token_dict is not None:
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
else:
raise Exception()
except Exception as e:
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}")
traceback.print_exc()
@ -4923,11 +4934,27 @@ async def new_team(
Member(role="admin", user_id=user_api_key_dict.user_id)
)
## ADD TO MODEL TABLE
_model_id = None
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
litellm_modeltable = LiteLLM_ModelTable(
model_aliases=json.dumps(data.model_aliases),
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
)
model_dict = await prisma_client.db.litellm_modeltable.create(
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
) # type: ignore
_model_id = model_dict.id
## ADD TO TEAM TABLE
complete_team_data = LiteLLM_TeamTable(
**data.json(),
max_parallel_requests=user_api_key_dict.max_parallel_requests,
budget_duration=user_api_key_dict.budget_duration,
budget_reset_at=user_api_key_dict.budget_reset_at,
model_id=_model_id,
)
team_row = await prisma_client.insert_data(
@ -5398,7 +5425,7 @@ async def new_organization(
- `organization_alias`: *str* = The name of the organization.
- `models`: *List* = The models the organization has access to.
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ###
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
- `max_budget`: *Optional[float]* = Max budget for org
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
- `rpm_limit`: *Optional[int]* = Max rpm limit for org