forked from phoenix/litellm-mirror
Merge pull request #2377 from BerriAI/litellm_team_level_model_groups
feat(proxy_server.py): team based model aliases
This commit is contained in:
commit
38612ddd34
4 changed files with 62 additions and 13 deletions
|
@ -212,6 +212,12 @@ class KeyRequest(LiteLLMBase):
|
||||||
keys: List[str]
|
keys: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class LiteLLM_ModelTable(LiteLLMBase):
|
||||||
|
model_aliases: Optional[str] = None # json dump the dict
|
||||||
|
created_by: str
|
||||||
|
updated_by: str
|
||||||
|
|
||||||
|
|
||||||
class NewUserRequest(GenerateKeyRequest):
|
class NewUserRequest(GenerateKeyRequest):
|
||||||
max_budget: Optional[float] = None
|
max_budget: Optional[float] = None
|
||||||
user_email: Optional[str] = None
|
user_email: Optional[str] = None
|
||||||
|
@ -251,7 +257,7 @@ class Member(LiteLLMBase):
|
||||||
return values
|
return values
|
||||||
|
|
||||||
|
|
||||||
class NewTeamRequest(LiteLLMBase):
|
class TeamBase(LiteLLMBase):
|
||||||
team_alias: Optional[str] = None
|
team_alias: Optional[str] = None
|
||||||
team_id: Optional[str] = None
|
team_id: Optional[str] = None
|
||||||
organization_id: Optional[str] = None
|
organization_id: Optional[str] = None
|
||||||
|
@ -265,6 +271,10 @@ class NewTeamRequest(LiteLLMBase):
|
||||||
models: list = []
|
models: list = []
|
||||||
|
|
||||||
|
|
||||||
|
class NewTeamRequest(TeamBase):
|
||||||
|
model_aliases: Optional[dict] = None
|
||||||
|
|
||||||
|
|
||||||
class GlobalEndUsersSpend(LiteLLMBase):
|
class GlobalEndUsersSpend(LiteLLMBase):
|
||||||
api_key: Optional[str] = None
|
api_key: Optional[str] = None
|
||||||
|
|
||||||
|
@ -299,11 +309,12 @@ class DeleteTeamRequest(LiteLLMBase):
|
||||||
team_ids: List[str] # required
|
team_ids: List[str] # required
|
||||||
|
|
||||||
|
|
||||||
class LiteLLM_TeamTable(NewTeamRequest):
|
class LiteLLM_TeamTable(TeamBase):
|
||||||
spend: Optional[float] = None
|
spend: Optional[float] = None
|
||||||
max_parallel_requests: Optional[int] = None
|
max_parallel_requests: Optional[int] = None
|
||||||
budget_duration: Optional[str] = None
|
budget_duration: Optional[str] = None
|
||||||
budget_reset_at: Optional[datetime] = None
|
budget_reset_at: Optional[datetime] = None
|
||||||
|
model_id: Optional[int] = None
|
||||||
|
|
||||||
@root_validator(pre=True)
|
@root_validator(pre=True)
|
||||||
def set_model_info(cls, values):
|
def set_model_info(cls, values):
|
||||||
|
@ -313,6 +324,7 @@ class LiteLLM_TeamTable(NewTeamRequest):
|
||||||
"config",
|
"config",
|
||||||
"permissions",
|
"permissions",
|
||||||
"model_max_budget",
|
"model_max_budget",
|
||||||
|
"model_aliases",
|
||||||
]
|
]
|
||||||
for field in dict_fields:
|
for field in dict_fields:
|
||||||
value = values.get(field)
|
value = values.get(field)
|
||||||
|
@ -542,6 +554,7 @@ class LiteLLM_VerificationTokenView(LiteLLM_VerificationToken):
|
||||||
team_rpm_limit: Optional[int] = None
|
team_rpm_limit: Optional[int] = None
|
||||||
team_max_budget: Optional[float] = None
|
team_max_budget: Optional[float] = None
|
||||||
soft_budget: Optional[float] = None
|
soft_budget: Optional[float] = None
|
||||||
|
team_model_aliases: Optional[Dict] = None
|
||||||
|
|
||||||
|
|
||||||
class UserAPIKeyAuth(
|
class UserAPIKeyAuth(
|
||||||
|
|
|
@ -406,7 +406,15 @@ async def user_api_key_auth(
|
||||||
) # request data, used across all checks. Making this easily available
|
) # request data, used across all checks. Making this easily available
|
||||||
|
|
||||||
# Check 1. If token can call model
|
# Check 1. If token can call model
|
||||||
litellm.model_alias_map = valid_token.aliases
|
_model_alias_map = {}
|
||||||
|
if valid_token.team_model_aliases is not None:
|
||||||
|
_model_alias_map = {
|
||||||
|
**valid_token.aliases,
|
||||||
|
**valid_token.team_model_aliases,
|
||||||
|
}
|
||||||
|
else:
|
||||||
|
_model_alias_map = {**valid_token.aliases}
|
||||||
|
litellm.model_alias_map = _model_alias_map
|
||||||
config = valid_token.config
|
config = valid_token.config
|
||||||
if config != {}:
|
if config != {}:
|
||||||
model_list = config.get("model_list", [])
|
model_list = config.get("model_list", [])
|
||||||
|
@ -824,7 +832,10 @@ async def user_api_key_auth(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
|
f"This key is made for LiteLLM UI, Tried to access route: {route}. Not allowed"
|
||||||
)
|
)
|
||||||
|
if valid_token_dict is not None:
|
||||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||||
|
else:
|
||||||
|
raise Exception()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}")
|
# verbose_proxy_logger.debug(f"An exception occurred - {traceback.format_exc()}")
|
||||||
traceback.print_exc()
|
traceback.print_exc()
|
||||||
|
@ -4923,11 +4934,27 @@ async def new_team(
|
||||||
Member(role="admin", user_id=user_api_key_dict.user_id)
|
Member(role="admin", user_id=user_api_key_dict.user_id)
|
||||||
)
|
)
|
||||||
|
|
||||||
|
## ADD TO MODEL TABLE
|
||||||
|
_model_id = None
|
||||||
|
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
|
||||||
|
litellm_modeltable = LiteLLM_ModelTable(
|
||||||
|
model_aliases=json.dumps(data.model_aliases),
|
||||||
|
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||||
|
)
|
||||||
|
model_dict = await prisma_client.db.litellm_modeltable.create(
|
||||||
|
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
|
||||||
|
) # type: ignore
|
||||||
|
|
||||||
|
_model_id = model_dict.id
|
||||||
|
|
||||||
|
## ADD TO TEAM TABLE
|
||||||
complete_team_data = LiteLLM_TeamTable(
|
complete_team_data = LiteLLM_TeamTable(
|
||||||
**data.json(),
|
**data.json(),
|
||||||
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
max_parallel_requests=user_api_key_dict.max_parallel_requests,
|
||||||
budget_duration=user_api_key_dict.budget_duration,
|
budget_duration=user_api_key_dict.budget_duration,
|
||||||
budget_reset_at=user_api_key_dict.budget_reset_at,
|
budget_reset_at=user_api_key_dict.budget_reset_at,
|
||||||
|
model_id=_model_id,
|
||||||
)
|
)
|
||||||
|
|
||||||
team_row = await prisma_client.insert_data(
|
team_row = await prisma_client.insert_data(
|
||||||
|
@ -5398,7 +5425,7 @@ async def new_organization(
|
||||||
- `organization_alias`: *str* = The name of the organization.
|
- `organization_alias`: *str* = The name of the organization.
|
||||||
- `models`: *List* = The models the organization has access to.
|
- `models`: *List* = The models the organization has access to.
|
||||||
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
|
- `budget_id`: *Optional[str]* = The id for a budget (tpm/rpm/max budget) for the organization.
|
||||||
### IF NO BUDGET - CREATE ONE WITH THESE PARAMS ###
|
### IF NO BUDGET ID - CREATE ONE WITH THESE PARAMS ###
|
||||||
- `max_budget`: *Optional[float]* = Max budget for org
|
- `max_budget`: *Optional[float]* = Max budget for org
|
||||||
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
|
- `tpm_limit`: *Optional[int]* = Max tpm limit for org
|
||||||
- `rpm_limit`: *Optional[int]* = Max rpm limit for org
|
- `rpm_limit`: *Optional[int]* = Max rpm limit for org
|
||||||
|
|
|
@ -965,12 +965,21 @@ class PrismaClient:
|
||||||
)
|
)
|
||||||
|
|
||||||
sql_query = f"""
|
sql_query = f"""
|
||||||
SELECT *
|
SELECT
|
||||||
FROM "LiteLLM_VerificationTokenView"
|
v.*,
|
||||||
WHERE token = '{token}'
|
t.spend AS team_spend,
|
||||||
|
t.max_budget AS team_max_budget,
|
||||||
|
t.tpm_limit AS team_tpm_limit,
|
||||||
|
t.rpm_limit AS team_rpm_limit,
|
||||||
|
m.aliases as team_model_aliases
|
||||||
|
FROM "LiteLLM_VerificationToken" AS v
|
||||||
|
LEFT JOIN "LiteLLM_TeamTable" AS t ON v.team_id = t.team_id
|
||||||
|
LEFT JOIN "LiteLLM_ModelTable" m ON t.model_id = m.id
|
||||||
|
WHERE v.token = '{token}'
|
||||||
"""
|
"""
|
||||||
|
|
||||||
response = await self.db.query_first(query=sql_query)
|
response = await self.db.query_first(query=sql_query)
|
||||||
|
|
||||||
if response is not None:
|
if response is not None:
|
||||||
response = LiteLLM_VerificationTokenView(**response)
|
response = LiteLLM_VerificationTokenView(**response)
|
||||||
# for prisma we need to cast the expires time to str
|
# for prisma we need to cast the expires time to str
|
||||||
|
|
|
@ -4043,7 +4043,7 @@ def get_optional_params_embeddings(
|
||||||
keys = list(non_default_params.keys())
|
keys = list(non_default_params.keys())
|
||||||
for k in keys:
|
for k in keys:
|
||||||
non_default_params.pop(k, None)
|
non_default_params.pop(k, None)
|
||||||
return non_default_params
|
else:
|
||||||
raise UnsupportedParamsError(
|
raise UnsupportedParamsError(
|
||||||
status_code=500,
|
status_code=500,
|
||||||
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
message=f"Setting user/encoding format is not supported by {custom_llm_provider}. To drop it from the call, set `litellm.drop_params = True`.",
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue