mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
feat(proxy_server.py): support retrieving models for a team, if user is a member - via /models?team_id
Allows user to see team models on UI when creating a key
This commit is contained in:
parent
621d193727
commit
26226d475f
4 changed files with 32 additions and 20 deletions
File diff suppressed because one or more lines are too long
|
@ -85,7 +85,7 @@ def get_key_models(
|
||||||
|
|
||||||
|
|
||||||
def get_team_models(
|
def get_team_models(
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
team_models: List[str],
|
||||||
proxy_model_list: List[str],
|
proxy_model_list: List[str],
|
||||||
model_access_groups: Dict[str, List[str]],
|
model_access_groups: Dict[str, List[str]],
|
||||||
) -> List[str]:
|
) -> List[str]:
|
||||||
|
@ -96,10 +96,10 @@ def get_team_models(
|
||||||
- If model_access_groups is provided, only return models that are in the access groups
|
- If model_access_groups is provided, only return models that are in the access groups
|
||||||
"""
|
"""
|
||||||
all_models = []
|
all_models = []
|
||||||
if len(user_api_key_dict.team_models) > 0:
|
if len(team_models) > 0:
|
||||||
all_models = user_api_key_dict.team_models
|
all_models = team_models
|
||||||
if SpecialModelNames.all_team_models.value in all_models:
|
if SpecialModelNames.all_team_models.value in all_models:
|
||||||
all_models = user_api_key_dict.team_models
|
all_models = team_models
|
||||||
if SpecialModelNames.all_proxy_models.value in all_models:
|
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||||
all_models = proxy_model_list
|
all_models = proxy_model_list
|
||||||
|
|
||||||
|
|
|
@ -1352,9 +1352,9 @@ async def team_info(
|
||||||
else:
|
else:
|
||||||
_team_info = LiteLLM_TeamTable()
|
_team_info = LiteLLM_TeamTable()
|
||||||
|
|
||||||
## UNFURL 'all-proxy-models' into the team_info.models list ##
|
# ## UNFURL 'all-proxy-models' into the team_info.models list ##
|
||||||
if llm_router is not None:
|
# if llm_router is not None:
|
||||||
_team_info = _unfurl_all_proxy_models(_team_info, llm_router)
|
# _team_info = _unfurl_all_proxy_models(_team_info, llm_router)
|
||||||
response_object = TeamInfoResponseObject(
|
response_object = TeamInfoResponseObject(
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
team_info=_team_info,
|
team_info=_team_info,
|
||||||
|
@ -1615,11 +1615,6 @@ async def list_team(
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# unfurl all-proxy-models
|
|
||||||
if llm_router is not None:
|
|
||||||
team = _unfurl_all_proxy_models(
|
|
||||||
LiteLLM_TeamTable(**team.model_dump()), llm_router
|
|
||||||
)
|
|
||||||
returned_responses.append(
|
returned_responses.append(
|
||||||
TeamListResponseObject(
|
TeamListResponseObject(
|
||||||
**team.model_dump(),
|
**team.model_dump(),
|
||||||
|
|
|
@ -122,7 +122,7 @@ from litellm.proxy.analytics_endpoints.analytics_endpoints import (
|
||||||
router as analytics_router,
|
router as analytics_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.anthropic_endpoints.endpoints import router as anthropic_router
|
from litellm.proxy.anthropic_endpoints.endpoints import router as anthropic_router
|
||||||
from litellm.proxy.auth.auth_checks import log_db_metrics
|
from litellm.proxy.auth.auth_checks import get_team_object, log_db_metrics
|
||||||
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
from litellm.proxy.auth.auth_utils import check_response_size_is_safe
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||||
|
@ -213,7 +213,10 @@ from litellm.proxy.management_endpoints.team_callback_endpoints import (
|
||||||
router as team_callback_router,
|
router as team_callback_router,
|
||||||
)
|
)
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
from litellm.proxy.management_endpoints.team_endpoints import router as team_router
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import update_team
|
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||||
|
update_team,
|
||||||
|
validate_membership,
|
||||||
|
)
|
||||||
from litellm.proxy.management_endpoints.ui_sso import (
|
from litellm.proxy.management_endpoints.ui_sso import (
|
||||||
get_disabled_non_admin_personal_key_creation,
|
get_disabled_non_admin_personal_key_creation,
|
||||||
)
|
)
|
||||||
|
@ -3380,13 +3383,14 @@ class ProxyStartupEvent:
|
||||||
async def model_list(
|
async def model_list(
|
||||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
return_wildcard_routes: Optional[bool] = False,
|
return_wildcard_routes: Optional[bool] = False,
|
||||||
|
team_id: Optional[str] = None,
|
||||||
):
|
):
|
||||||
"""
|
"""
|
||||||
Use `/model/info` - to get detailed model information, example - pricing, mode, etc.
|
Use `/model/info` - to get detailed model information, example - pricing, mode, etc.
|
||||||
|
|
||||||
This is just for compatibility with openai projects like aider.
|
This is just for compatibility with openai projects like aider.
|
||||||
"""
|
"""
|
||||||
global llm_model_list, general_settings, llm_router
|
global llm_model_list, general_settings, llm_router, prisma_client, user_api_key_cache, proxy_logging_obj
|
||||||
all_models = []
|
all_models = []
|
||||||
model_access_groups: Dict[str, List[str]] = defaultdict(list)
|
model_access_groups: Dict[str, List[str]] = defaultdict(list)
|
||||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||||
|
@ -3401,19 +3405,33 @@ async def model_list(
|
||||||
model_access_groups=model_access_groups,
|
model_access_groups=model_access_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
team_models: List[str] = user_api_key_dict.team_models
|
||||||
|
|
||||||
|
if team_id:
|
||||||
|
team_object = await get_team_object(
|
||||||
|
team_id=team_id,
|
||||||
|
prisma_client=prisma_client,
|
||||||
|
user_api_key_cache=user_api_key_cache,
|
||||||
|
proxy_logging_obj=proxy_logging_obj,
|
||||||
|
)
|
||||||
|
validate_membership(user_api_key_dict=user_api_key_dict, team_table=team_object)
|
||||||
|
team_models = team_object.models
|
||||||
|
|
||||||
team_models = get_team_models(
|
team_models = get_team_models(
|
||||||
user_api_key_dict=user_api_key_dict,
|
team_models=team_models,
|
||||||
proxy_model_list=proxy_model_list,
|
proxy_model_list=proxy_model_list,
|
||||||
model_access_groups=model_access_groups,
|
model_access_groups=model_access_groups,
|
||||||
)
|
)
|
||||||
|
|
||||||
all_models = get_complete_model_list(
|
all_models = get_complete_model_list(
|
||||||
key_models=key_models,
|
key_models=key_models if not team_models else [],
|
||||||
team_models=team_models,
|
team_models=team_models,
|
||||||
proxy_model_list=proxy_model_list,
|
proxy_model_list=proxy_model_list,
|
||||||
user_model=user_model,
|
user_model=user_model,
|
||||||
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||||
return_wildcard_routes=return_wildcard_routes,
|
return_wildcard_routes=return_wildcard_routes,
|
||||||
)
|
)
|
||||||
|
|
||||||
return dict(
|
return dict(
|
||||||
data=[
|
data=[
|
||||||
{
|
{
|
||||||
|
@ -6117,7 +6135,7 @@ async def model_info_v1( # noqa: PLR0915
|
||||||
model_access_groups=model_access_groups,
|
model_access_groups=model_access_groups,
|
||||||
)
|
)
|
||||||
team_models = get_team_models(
|
team_models = get_team_models(
|
||||||
user_api_key_dict=user_api_key_dict,
|
team_models=user_api_key_dict.team_models,
|
||||||
proxy_model_list=proxy_model_list,
|
proxy_model_list=proxy_model_list,
|
||||||
model_access_groups=model_access_groups,
|
model_access_groups=model_access_groups,
|
||||||
)
|
)
|
||||||
|
@ -6344,7 +6362,7 @@ async def model_group_info(
|
||||||
model_access_groups=model_access_groups,
|
model_access_groups=model_access_groups,
|
||||||
)
|
)
|
||||||
team_models = get_team_models(
|
team_models = get_team_models(
|
||||||
user_api_key_dict=user_api_key_dict,
|
team_models=user_api_key_dict.team_models,
|
||||||
proxy_model_list=proxy_model_list,
|
proxy_model_list=proxy_model_list,
|
||||||
model_access_groups=model_access_groups,
|
model_access_groups=model_access_groups,
|
||||||
)
|
)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue