diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index 85214b12d4..9882799fa6 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -2116,6 +2116,20 @@ class TeamMemberUpdateResponse(MemberUpdateResponse): max_budget_in_team: Optional[float] = None +class TeamModelAddRequest(BaseModel): + """Request to add models to a team""" + + team_id: str + models: List[str] + + +class TeamModelDeleteRequest(BaseModel): + """Request to delete models from a team""" + + team_id: str + models: List[str] + + # Organization Member Requests class OrganizationMemberAddRequest(OrgMemberAddRequest): organization_id: str diff --git a/litellm/proxy/management_endpoints/model_management_endpoints.py b/litellm/proxy/management_endpoints/model_management_endpoints.py index dbab85fbfc..e91937c0e6 100644 --- a/litellm/proxy/management_endpoints/model_management_endpoints.py +++ b/litellm/proxy/management_endpoints/model_management_endpoints.py @@ -25,18 +25,21 @@ from litellm.proxy._types import ( PrismaCompatibleUpdateDBModel, ProxyErrorTypes, ProxyException, + TeamModelAddRequest, UpdateTeamRequest, UserAPIKeyAuth, ) from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper -from litellm.proxy.management_endpoints.team_endpoints import update_team +from litellm.proxy.management_endpoints.team_endpoints import ( + team_model_add, + update_team, +) from litellm.proxy.utils import PrismaClient from litellm.types.router import ( Deployment, DeploymentTypedDict, LiteLLMParamsTypedDict, - ModelInfo, updateDeployment, ) from litellm.utils import get_utc_datetime @@ -88,16 +91,11 @@ def update_db_model( # update model info if updated_patch.model_info: - _updated_model_info_dict = updated_patch.model_info.model_dump( - exclude_unset=True - ) if "model_info" not in merged_deployment_dict: - merged_deployment_dict["model_info"] = ModelInfo() - _original_model_info_dict = merged_deployment_dict["model_info"].model_dump( - exclude_unset=True + merged_deployment_dict["model_info"] = {} + merged_deployment_dict["model_info"].update( + updated_patch.model_info.model_dump(exclude_none=True) ) - _original_model_info_dict.update(_updated_model_info_dict) - merged_deployment_dict["model_info"] = ModelInfo(**_original_model_info_dict) # convert to prisma compatible format @@ -294,6 +292,16 @@ async def _add_team_model_to_db( http_request=Request(scope={"type": "http"}), ) + # add model to team object + await team_model_add( + data=TeamModelAddRequest( + team_id=_team_id, + models=[original_model_name], + ), + http_request=Request(scope={"type": "http"}), + user_api_key_dict=user_api_key_dict, + ) + return model_response diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py index 57d0d0d957..c68aa11352 100644 --- a/litellm/proxy/management_endpoints/team_endpoints.py +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -47,6 +47,8 @@ from litellm.proxy._types import ( TeamMemberDeleteRequest, TeamMemberUpdateRequest, TeamMemberUpdateResponse, + TeamModelAddRequest, + TeamModelDeleteRequest, UpdateTeamRequest, UserAPIKeyAuth, ) @@ -1774,3 +1776,149 @@ async def ui_view_teams( except Exception as e: raise HTTPException(status_code=500, detail=f"Error searching teams: {str(e)}") + + +@router.post( + "/team/model/add", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def team_model_add( + data: TeamModelAddRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Add models to a team's allowed model list. Only proxy admin or team admin can add models. + + Parameters: + - team_id: str - Required. The team to add models to + - models: List[str] - Required. List of models to add to the team + + Example Request: + ``` + curl --location 'http://0.0.0.0:4000/team/model/add' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_id": "team-1234", + "models": ["gpt-4", "claude-2"] + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Get existing team + team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if team_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + team_obj = LiteLLM_TeamTable(**team_row.model_dump()) + + # Authorization check - only proxy admin or team admin can add models + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=team_obj + ) + ): + raise HTTPException( + status_code=403, + detail={"error": "Only proxy admin or team admin can modify team models"}, + ) + + # Get current models list + current_models = team_obj.models or [] + + # Add new models (avoid duplicates) + updated_models = list(set(current_models + data.models)) + + # Update team + updated_team = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"models": updated_models} + ) + + return updated_team + + +@router.post( + "/team/model/delete", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def team_model_delete( + data: TeamModelDeleteRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Remove models from a team's allowed model list. Only proxy admin or team admin can remove models. + + Parameters: + - team_id: str - Required. The team to remove models from + - models: List[str] - Required. List of models to remove from the team + + Example Request: + ``` + curl --location 'http://0.0.0.0:4000/team/model/delete' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "team_id": "team-1234", + "models": ["gpt-4"] + }' + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + # Get existing team + team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if team_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + team_obj = LiteLLM_TeamTable(**team_row.model_dump()) + + # Authorization check - only proxy admin or team admin can remove models + if ( + user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value + and not _is_user_team_admin( + user_api_key_dict=user_api_key_dict, team_obj=team_obj + ) + ): + raise HTTPException( + status_code=403, + detail={"error": "Only proxy admin or team admin can modify team models"}, + ) + + # Get current models list + current_models = team_obj.models or [] + + # Remove specified models + updated_models = [m for m in current_models if m not in data.models] + + # Update team + updated_team = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"models": updated_models} + ) + + return updated_team diff --git a/litellm/router.py b/litellm/router.py index 443d73ab5e..bd7d101c7a 100644 --- a/litellm/router.py +++ b/litellm/router.py @@ -4897,32 +4897,61 @@ class Router: return returned_models - def get_model_names(self) -> List[str]: + def get_model_names(self, team_id: Optional[str] = None) -> List[str]: """ - Returns all possible model names for router. + Returns all possible model names for the router, including models defined via model_group_alias. - Includes model_group_alias models too. + If a team_id is provided, only deployments configured with that team_id (i.e. team‐specific models) + will yield their team public name. """ - model_list = self.get_model_list() - if model_list is None: - return [] - + deployments = self.get_model_list() or [] model_names = [] - for m in model_list: - model_names.append(self._get_public_model_name(m)) + + for deployment in deployments: + model_info = deployment.get("model_info") + if self._is_team_specific_model(model_info): + team_model_name = self._get_team_specific_model( + deployment=deployment, team_id=team_id + ) + if team_model_name: + model_names.append(team_model_name) + else: + model_names.append(deployment.get("model_name", "")) + return model_names - def _get_public_model_name(self, deployment: DeploymentTypedDict) -> str: + def _get_team_specific_model( + self, deployment: DeploymentTypedDict, team_id: Optional[str] = None + ) -> Optional[str]: """ - Returns the user-friendly model name for public display (e.g., on /models endpoint). + Get the team-specific model name if team_id matches the deployment. - Prioritizes the team's public model name if available, otherwise falls back to the default model name. + Args: + deployment: DeploymentTypedDict - The model deployment + team_id: Optional[str] - If passed, will return router models set with a `team_id` matching the passed `team_id`. + + Returns: + str: The `team_public_model_name` if team_id matches + None: If team_id doesn't match or no team info exists """ - model_info = deployment.get("model_info") - if model_info and model_info.get("team_public_model_name"): - return model_info["team_public_model_name"] + model_info: Optional[Dict] = deployment.get("model_info") or {} + if model_info is None: + return None + if team_id == model_info.get("team_id"): + return model_info.get("team_public_model_name") + return None - return deployment["model_name"] + def _is_team_specific_model(self, model_info: Optional[Dict]) -> bool: + """ + Check if model info contains team-specific configuration. + + Args: + model_info: Model information dictionary + + Returns: + bool: True if model has team-specific configuration + """ + return bool(model_info and model_info.get("team_id")) def get_model_list_from_model_alias( self, model_name: Optional[str] = None diff --git a/litellm/router_strategy/budget_limiter.py b/litellm/router_strategy/budget_limiter.py index 3cb1bb965a..8061aab537 100644 --- a/litellm/router_strategy/budget_limiter.py +++ b/litellm/router_strategy/budget_limiter.py @@ -771,7 +771,7 @@ class RouterBudgetLimiting(CustomLogger): return for _model in model_list: _litellm_params = _model.get("litellm_params", {}) - _model_info: ModelInfo = _model.get("model_info") or ModelInfo() + _model_info: Dict = _model.get("model_info") or {} _model_id = _model_info.get("id") _max_budget = _litellm_params.get("max_budget") _budget_duration = _litellm_params.get("budget_duration") diff --git a/litellm/types/router.py b/litellm/types/router.py index c1c802e538..fc95bbc670 100644 --- a/litellm/types/router.py +++ b/litellm/types/router.py @@ -385,7 +385,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False): class DeploymentTypedDict(TypedDict, total=False): model_name: Required[str] litellm_params: Required[LiteLLMParamsTypedDict] - model_info: ModelInfo + model_info: dict SPECIAL_MODEL_INFO_PARAMS = [ diff --git a/tests/local_testing/test_router.py b/tests/local_testing/test_router.py index ce33ecc8f3..698f779b78 100644 --- a/tests/local_testing/test_router.py +++ b/tests/local_testing/test_router.py @@ -2601,24 +2601,64 @@ def test_model_group_alias(hidden): assert len(model_names) == len(_model_list) + 1 -def test_get_public_model_name(): +def test_get_team_specific_model(): """ - Test that the _get_public_model_name helper returns the `team_public_model_name` if it exists, otherwise it returns the `model_name`. + Test that _get_team_specific_model returns: + - team_public_model_name when team_id matches + - None when team_id doesn't match + - None when no team_id in model_info """ - _model_list = [ - { - "model_name": "model_name_12299393939_gms", - "litellm_params": {"model": "gpt-4o"}, - "model_info": {"team_public_model_name": "gpt-4o"}, - }, - ] - router = Router( - model_list=_model_list, + router = Router(model_list=[]) + + # Test 1: Matching team_id + deployment = DeploymentTypedDict( + model_name="model-x", + litellm_params={}, + model_info=ModelInfo(team_id="team1", team_public_model_name="public-model-x"), ) + assert router._get_team_specific_model(deployment, "team1") == "public-model-x" - models = router.get_model_list() + # Test 2: Non-matching team_id + assert router._get_team_specific_model(deployment, "team2") is None - assert router._get_public_model_name(models[0]) == "gpt-4o" + # Test 3: No team_id in model_info + deployment = DeploymentTypedDict( + model_name="model-y", + litellm_params={}, + model_info=ModelInfo(team_public_model_name="public-model-y"), + ) + assert router._get_team_specific_model(deployment, "team1") is None + + # Test 4: No model_info + deployment = DeploymentTypedDict( + model_name="model-z", litellm_params={}, model_info=ModelInfo() + ) + assert router._get_team_specific_model(deployment, "team1") is None + + +def test_is_team_specific_model(): + """ + Test that _is_team_specific_model returns: + - True when model_info contains team_id + - False when model_info doesn't contain team_id + - False when model_info is None + """ + router = Router(model_list=[]) + + # Test 1: With team_id + model_info = ModelInfo(team_id="team1", team_public_model_name="public-model-x") + assert router._is_team_specific_model(model_info) is True + + # Test 2: Without team_id + model_info = ModelInfo(team_public_model_name="public-model-y") + assert router._is_team_specific_model(model_info) is False + + # Test 3: Empty model_info + model_info = ModelInfo() + assert router._is_team_specific_model(model_info) is False + + # Test 4: None model_info + assert router._is_team_specific_model(None) is False # @pytest.mark.parametrize("on_error", [True, False]) diff --git a/tests/store_model_in_db_tests/test_team_alias.py b/tests/store_model_in_db_tests/test_team_alias.py deleted file mode 100644 index 11c65dfdc5..0000000000 --- a/tests/store_model_in_db_tests/test_team_alias.py +++ /dev/null @@ -1,86 +0,0 @@ -import pytest -import asyncio -import aiohttp -import json -from openai import AsyncOpenAI -import uuid -from httpx import AsyncClient -import uuid -import os - -TEST_MASTER_KEY = "sk-1234" -PROXY_BASE_URL = "http://0.0.0.0:4000" - - -@pytest.mark.asyncio -async def test_team_model_alias(): - """ - Test model alias functionality with teams: - 1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o" - 2. Create a new team - 3. Update team with model_alias mapping - 4. Generate key for team - 5. Make request with aliased model name - """ - client = AsyncClient(base_url=PROXY_BASE_URL) - headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"} - - # Add new model - model_response = await client.post( - "/model/new", - json={ - "model_name": "gpt-4o-team1", - "litellm_params": { - "model": "gpt-4o", - "api_key": os.getenv("OPENAI_API_KEY"), - }, - }, - headers=headers, - ) - assert model_response.status_code == 200 - - # Create new team - team_response = await client.post( - "/team/new", - json={ - "models": ["gpt-4o-team1"], - }, - headers=headers, - ) - assert team_response.status_code == 200 - team_data = team_response.json() - team_id = team_data["team_id"] - - # Update team with model alias - update_response = await client.post( - "/team/update", - json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}}, - headers=headers, - ) - assert update_response.status_code == 200 - - # Generate key for team - key_response = await client.post( - "/key/generate", json={"team_id": team_id}, headers=headers - ) - assert key_response.status_code == 200 - key = key_response.json()["key"] - - # Make request with model alias - openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1") - - response = await openai_client.chat.completions.create( - model="gpt-4o", - messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}], - ) - - assert response is not None, "Should get valid response when using model alias" - - # Cleanup - delete the model - model_id = model_response.json()["model_info"]["id"] - delete_response = await client.post( - "/model/delete", - json={"id": model_id}, - headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"}, - ) - assert delete_response.status_code == 200 diff --git a/tests/store_model_in_db_tests/test_team_models.py b/tests/store_model_in_db_tests/test_team_models.py new file mode 100644 index 0000000000..0faa01c8ee --- /dev/null +++ b/tests/store_model_in_db_tests/test_team_models.py @@ -0,0 +1,312 @@ +import pytest +import asyncio +import aiohttp +import json +from openai import AsyncOpenAI +import uuid +from httpx import AsyncClient +import uuid +import os + +TEST_MASTER_KEY = "sk-1234" +PROXY_BASE_URL = "http://0.0.0.0:4000" + + +@pytest.mark.asyncio +async def test_team_model_alias(): + """ + Test model alias functionality with teams: + 1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o" + 2. Create a new team + 3. Update team with model_alias mapping + 4. Generate key for team + 5. Make request with aliased model name + """ + client = AsyncClient(base_url=PROXY_BASE_URL) + headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"} + + # Add new model + model_response = await client.post( + "/model/new", + json={ + "model_name": "gpt-4o-team1", + "litellm_params": { + "model": "gpt-4o", + "api_key": os.getenv("OPENAI_API_KEY"), + }, + }, + headers=headers, + ) + assert model_response.status_code == 200 + + # Create new team + team_response = await client.post( + "/team/new", + json={ + "models": ["gpt-4o-team1"], + }, + headers=headers, + ) + assert team_response.status_code == 200 + team_data = team_response.json() + team_id = team_data["team_id"] + + # Update team with model alias + update_response = await client.post( + "/team/update", + json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}}, + headers=headers, + ) + assert update_response.status_code == 200 + + # Generate key for team + key_response = await client.post( + "/key/generate", json={"team_id": team_id}, headers=headers + ) + assert key_response.status_code == 200 + key = key_response.json()["key"] + + # Make request with model alias + openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1") + + response = await openai_client.chat.completions.create( + model="gpt-4o", + messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}], + ) + + assert response is not None, "Should get valid response when using model alias" + + # Cleanup - delete the model + model_id = model_response.json()["model_info"]["id"] + delete_response = await client.post( + "/model/delete", + json={"id": model_id}, + headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"}, + ) + assert delete_response.status_code == 200 + + +@pytest.mark.asyncio +async def test_team_model_association(): + """ + Test that models created with a team_id are properly associated with the team: + 1. Create a new team + 2. Add a model with team_id in model_info + 3. Verify the model appears in team info + """ + client = AsyncClient(base_url=PROXY_BASE_URL) + headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"} + + # Create new team + team_response = await client.post( + "/team/new", + json={ + "models": [], # Start with empty model list + }, + headers=headers, + ) + assert team_response.status_code == 200 + team_data = team_response.json() + team_id = team_data["team_id"] + + # Add new model with team_id + model_response = await client.post( + "/model/new", + json={ + "model_name": "gpt-4-team-test", + "litellm_params": { + "model": "gpt-4", + "custom_llm_provider": "openai", + "api_key": "fake_key", + }, + "model_info": {"team_id": team_id}, + }, + headers=headers, + ) + assert model_response.status_code == 200 + + # Get team info and verify model association + team_info_response = await client.get( + f"/team/info", + headers=headers, + params={"team_id": team_id}, + ) + assert team_info_response.status_code == 200 + team_info = team_info_response.json()["team_info"] + + print("team_info", json.dumps(team_info, indent=4)) + + # Verify the model is in team_models + assert ( + "gpt-4-team-test" in team_info["models"] + ), "Model should be associated with team" + + # Cleanup - delete the model + model_id = model_response.json()["model_info"]["id"] + delete_response = await client.post( + "/model/delete", + json={"id": model_id}, + headers=headers, + ) + assert delete_response.status_code == 200 + + +@pytest.mark.asyncio +async def test_team_model_visibility_in_models_endpoint(): + """ + Test that team-specific models are only visible to the correct team in /models endpoint: + 1. Create two teams + 2. Add a model associated with team1 + 3. Generate keys for both teams + 4. Verify team1's key can see the model in /models + 5. Verify team2's key cannot see the model in /models + """ + client = AsyncClient(base_url=PROXY_BASE_URL) + headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"} + + # Create team1 + team1_response = await client.post( + "/team/new", + json={"models": []}, + headers=headers, + ) + assert team1_response.status_code == 200 + team1_id = team1_response.json()["team_id"] + + # Create team2 + team2_response = await client.post( + "/team/new", + json={"models": []}, + headers=headers, + ) + assert team2_response.status_code == 200 + team2_id = team2_response.json()["team_id"] + + # Add model associated with team1 + model_response = await client.post( + "/model/new", + json={ + "model_name": "gpt-4-team-test", + "litellm_params": { + "model": "gpt-4", + "custom_llm_provider": "openai", + "api_key": "fake_key", + }, + "model_info": {"team_id": team1_id}, + }, + headers=headers, + ) + assert model_response.status_code == 200 + + # Generate keys for both teams + team1_key = ( + await client.post("/key/generate", json={"team_id": team1_id}, headers=headers) + ).json()["key"] + team2_key = ( + await client.post("/key/generate", json={"team_id": team2_id}, headers=headers) + ).json()["key"] + + # Check models visibility for team1's key + team1_models = await client.get( + "/models", headers={"Authorization": f"Bearer {team1_key}"} + ) + assert team1_models.status_code == 200 + print("team1_models", json.dumps(team1_models.json(), indent=4)) + assert any( + model["id"] == "gpt-4-team-test" for model in team1_models.json()["data"] + ), "Team1 should see their model" + + # Check models visibility for team2's key + team2_models = await client.get( + "/models", headers={"Authorization": f"Bearer {team2_key}"} + ) + assert team2_models.status_code == 200 + print("team2_models", json.dumps(team2_models.json(), indent=4)) + assert not any( + model["id"] == "gpt-4-team-test" for model in team2_models.json()["data"] + ), "Team2 should not see team1's model" + + # Cleanup + model_id = model_response.json()["model_info"]["id"] + await client.post("/model/delete", json={"id": model_id}, headers=headers) + + +@pytest.mark.asyncio +async def test_team_model_visibility_in_model_info_endpoint(): + """ + Test that team-specific models are visible to all users in /v2/model/info endpoint: + Note: /v2/model/info is used by the Admin UI to display model info + 1. Create a team + 2. Add a model associated with the team + 3. Generate a team key + 4. Verify both team key and non-team key can see the model in /v2/model/info + """ + client = AsyncClient(base_url=PROXY_BASE_URL) + headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"} + + # Create team + team_response = await client.post( + "/team/new", + json={"models": []}, + headers=headers, + ) + assert team_response.status_code == 200 + team_id = team_response.json()["team_id"] + + # Add model associated with team + model_response = await client.post( + "/model/new", + json={ + "model_name": "gpt-4-team-test", + "litellm_params": { + "model": "gpt-4", + "custom_llm_provider": "openai", + "api_key": "fake_key", + }, + "model_info": {"team_id": team_id}, + }, + headers=headers, + ) + assert model_response.status_code == 200 + + # Generate team key + team_key = ( + await client.post("/key/generate", json={"team_id": team_id}, headers=headers) + ).json()["key"] + + # Generate non-team key + non_team_key = ( + await client.post("/key/generate", json={}, headers=headers) + ).json()["key"] + + # Check model info visibility with team key + team_model_info = await client.get( + "/v2/model/info", + headers={"Authorization": f"Bearer {team_key}"}, + params={"model_name": "gpt-4-team-test"}, + ) + assert team_model_info.status_code == 200 + team_model_info = team_model_info.json() + print("Team 1 model info", json.dumps(team_model_info, indent=4)) + assert any( + model["model_info"].get("team_public_model_name") == "gpt-4-team-test" + for model in team_model_info["data"] + ), "Team1 should see their model" + + # Check model info visibility with non-team key + non_team_model_info = await client.get( + "/v2/model/info", + headers={"Authorization": f"Bearer {non_team_key}"}, + params={"model_name": "gpt-4-team-test"}, + ) + assert non_team_model_info.status_code == 200 + non_team_model_info = non_team_model_info.json() + print("Non-team model info", json.dumps(non_team_model_info, indent=4)) + assert any( + model["model_info"].get("team_public_model_name") == "gpt-4-team-test" + for model in non_team_model_info["data"] + ), "Non-team should see the model" + + # Cleanup + model_id = model_response.json()["model_info"]["id"] + await client.post("/model/delete", json={"id": model_id}, headers=headers) diff --git a/ui/litellm-dashboard/src/components/create_key_button.tsx b/ui/litellm-dashboard/src/components/create_key_button.tsx index 14c7fedce9..0464696cf7 100644 --- a/ui/litellm-dashboard/src/components/create_key_button.tsx +++ b/ui/litellm-dashboard/src/components/create_key_button.tsx @@ -293,7 +293,13 @@ const CreateKey: React.FC = ({ initialValue={team ? team.team_id : null} className="mt-8" > - + { + const selectedTeam = teams?.find(t => t.team_id === teamId) || null; + setSelectedCreateKeyTeam(selectedTeam); + }} + /> ; - availableModels: string[]; - onUpdate: () => void; -} - -const ModelAliasesCard: React.FC = ({ - teamId, - accessToken, - currentAliases, - availableModels, - onUpdate, -}) => { - const [isModalVisible, setIsModalVisible] = useState(false); - const [form] = Form.useForm(); - - const handleCreateAlias = async (values: any) => { - try { - if (!accessToken) return; - - const newAliases = { - ...currentAliases, - [values.alias_name]: values.original_model, - }; - - const updateData = { - team_id: teamId, - model_aliases: newAliases, - }; - - await teamUpdateCall(accessToken, updateData); - message.success("Model alias created successfully"); - setIsModalVisible(false); - form.resetFields(); - currentAliases[values.alias_name] = values.original_model; - } catch (error) { - message.error("Failed to create model alias"); - console.error("Error creating model alias:", error); - } - }; - - return ( -
- Team Aliases - - Allow a team to use an alias that points to a specific model deployment. - - - -
-
-
-
-
ALIAS
-
POINTS TO
-
-
- setIsModalVisible(true)} - > - Create Model Alias - -
- -
- {Object.entries(currentAliases).map(([aliasName, originalModel], index) => ( -
-
- - {aliasName} - -
-
- - {originalModel} - -
-
- ))} - {Object.keys(currentAliases).length === 0 && ( -
- No model aliases configured -
- )} -
-
- - { - setIsModalVisible(false); - form.resetFields(); - }} - footer={null} - width={500} - > -
- - - - - - - - -
- { - setIsModalVisible(false); - form.resetFields(); - }} - className="bg-white text-gray-700 border border-gray-300 hover:bg-gray-50" - > - Cancel - - - Create Alias - -
-
-
-
- ); -}; - -export default ModelAliasesCard; \ No newline at end of file diff --git a/ui/litellm-dashboard/src/components/team/team_info.tsx b/ui/litellm-dashboard/src/components/team/team_info.tsx index 3b6fc09e29..a8d13730a4 100644 --- a/ui/litellm-dashboard/src/components/team/team_info.tsx +++ b/ui/litellm-dashboard/src/components/team/team_info.tsx @@ -474,14 +474,6 @@ const TeamInfoView: React.FC = ({ )} - -