(Polish/Fixes) - Fixes for Adding Team Specific Models (#8645)

* refactor get model info for team models

* allow adding a model to a team when creating team specific model

* ui update selected Team on Team Dropdown

* test_team_model_association

* testing for team specific models

* test_get_team_specific_model

* test: skip on internal server error

* remove model alias card on teams page

* linting fix _get_team_specific_model

* fix DeploymentTypedDict

* fix linting error

* fix code quality

* fix model info checks

---------

Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
Ishaan Jaff 2025-02-18 21:11:57 -08:00 committed by GitHub
parent e08e8eda47
commit e5f29c3f7d
12 changed files with 599 additions and 310 deletions

View file

@ -4897,32 +4897,61 @@ class Router:
return returned_models
def get_model_names(self) -> List[str]:
def get_model_names(self, team_id: Optional[str] = None) -> List[str]:
"""
Returns all possible model names for router.
Returns all possible model names for the router, including models defined via model_group_alias.
Includes model_group_alias models too.
If a team_id is provided, only deployments configured with that team_id (i.e. teamspecific models)
will yield their team public name.
"""
model_list = self.get_model_list()
if model_list is None:
return []
deployments = self.get_model_list() or []
model_names = []
for m in model_list:
model_names.append(self._get_public_model_name(m))
for deployment in deployments:
model_info = deployment.get("model_info")
if self._is_team_specific_model(model_info):
team_model_name = self._get_team_specific_model(
deployment=deployment, team_id=team_id
)
if team_model_name:
model_names.append(team_model_name)
else:
model_names.append(deployment.get("model_name", ""))
return model_names
def _get_public_model_name(self, deployment: DeploymentTypedDict) -> str:
def _get_team_specific_model(
self, deployment: DeploymentTypedDict, team_id: Optional[str] = None
) -> Optional[str]:
"""
Returns the user-friendly model name for public display (e.g., on /models endpoint).
Get the team-specific model name if team_id matches the deployment.
Prioritizes the team's public model name if available, otherwise falls back to the default model name.
Args:
deployment: DeploymentTypedDict - The model deployment
team_id: Optional[str] - If passed, will return router models set with a `team_id` matching the passed `team_id`.
Returns:
str: The `team_public_model_name` if team_id matches
None: If team_id doesn't match or no team info exists
"""
model_info = deployment.get("model_info")
if model_info and model_info.get("team_public_model_name"):
return model_info["team_public_model_name"]
model_info: Optional[Dict] = deployment.get("model_info") or {}
if model_info is None:
return None
if team_id == model_info.get("team_id"):
return model_info.get("team_public_model_name")
return None
return deployment["model_name"]
def _is_team_specific_model(self, model_info: Optional[Dict]) -> bool:
"""
Check if model info contains team-specific configuration.
Args:
model_info: Model information dictionary
Returns:
bool: True if model has team-specific configuration
"""
return bool(model_info and model_info.get("team_id"))
def get_model_list_from_model_alias(
self, model_name: Optional[str] = None