mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
(Polish/Fixes) - Fixes for Adding Team Specific Models (#8645)
* refactor get model info for team models * allow adding a model to a team when creating team specific model * ui update selected Team on Team Dropdown * test_team_model_association * testing for team specific models * test_get_team_specific_model * test: skip on internal server error * remove model alias card on teams page * linting fix _get_team_specific_model * fix DeploymentTypedDict * fix linting error * fix code quality * fix model info checks --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
e08e8eda47
commit
e5f29c3f7d
12 changed files with 599 additions and 310 deletions
|
@ -2116,6 +2116,20 @@ class TeamMemberUpdateResponse(MemberUpdateResponse):
|
|||
max_budget_in_team: Optional[float] = None
|
||||
|
||||
|
||||
class TeamModelAddRequest(BaseModel):
|
||||
"""Request to add models to a team"""
|
||||
|
||||
team_id: str
|
||||
models: List[str]
|
||||
|
||||
|
||||
class TeamModelDeleteRequest(BaseModel):
|
||||
"""Request to delete models from a team"""
|
||||
|
||||
team_id: str
|
||||
models: List[str]
|
||||
|
||||
|
||||
# Organization Member Requests
|
||||
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
||||
organization_id: str
|
||||
|
|
|
@ -25,18 +25,21 @@ from litellm.proxy._types import (
|
|||
PrismaCompatibleUpdateDBModel,
|
||||
ProxyErrorTypes,
|
||||
ProxyException,
|
||||
TeamModelAddRequest,
|
||||
UpdateTeamRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
||||
from litellm.proxy.management_endpoints.team_endpoints import update_team
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
team_model_add,
|
||||
update_team,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
from litellm.types.router import (
|
||||
Deployment,
|
||||
DeploymentTypedDict,
|
||||
LiteLLMParamsTypedDict,
|
||||
ModelInfo,
|
||||
updateDeployment,
|
||||
)
|
||||
from litellm.utils import get_utc_datetime
|
||||
|
@ -88,16 +91,11 @@ def update_db_model(
|
|||
|
||||
# update model info
|
||||
if updated_patch.model_info:
|
||||
_updated_model_info_dict = updated_patch.model_info.model_dump(
|
||||
exclude_unset=True
|
||||
)
|
||||
if "model_info" not in merged_deployment_dict:
|
||||
merged_deployment_dict["model_info"] = ModelInfo()
|
||||
_original_model_info_dict = merged_deployment_dict["model_info"].model_dump(
|
||||
exclude_unset=True
|
||||
merged_deployment_dict["model_info"] = {}
|
||||
merged_deployment_dict["model_info"].update(
|
||||
updated_patch.model_info.model_dump(exclude_none=True)
|
||||
)
|
||||
_original_model_info_dict.update(_updated_model_info_dict)
|
||||
merged_deployment_dict["model_info"] = ModelInfo(**_original_model_info_dict)
|
||||
|
||||
# convert to prisma compatible format
|
||||
|
||||
|
@ -294,6 +292,16 @@ async def _add_team_model_to_db(
|
|||
http_request=Request(scope={"type": "http"}),
|
||||
)
|
||||
|
||||
# add model to team object
|
||||
await team_model_add(
|
||||
data=TeamModelAddRequest(
|
||||
team_id=_team_id,
|
||||
models=[original_model_name],
|
||||
),
|
||||
http_request=Request(scope={"type": "http"}),
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
)
|
||||
|
||||
return model_response
|
||||
|
||||
|
||||
|
|
|
@ -47,6 +47,8 @@ from litellm.proxy._types import (
|
|||
TeamMemberDeleteRequest,
|
||||
TeamMemberUpdateRequest,
|
||||
TeamMemberUpdateResponse,
|
||||
TeamModelAddRequest,
|
||||
TeamModelDeleteRequest,
|
||||
UpdateTeamRequest,
|
||||
UserAPIKeyAuth,
|
||||
)
|
||||
|
@ -1774,3 +1776,149 @@ async def ui_view_teams(
|
|||
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Error searching teams: {str(e)}")
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/model/add",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def team_model_add(
|
||||
data: TeamModelAddRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Add models to a team's allowed model list. Only proxy admin or team admin can add models.
|
||||
|
||||
Parameters:
|
||||
- team_id: str - Required. The team to add models to
|
||||
- models: List[str] - Required. List of models to add to the team
|
||||
|
||||
Example Request:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/team/model/add' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"team_id": "team-1234",
|
||||
"models": ["gpt-4", "claude-2"]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Get existing team
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
if team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||
)
|
||||
|
||||
team_obj = LiteLLM_TeamTable(**team_row.model_dump())
|
||||
|
||||
# Authorization check - only proxy admin or team admin can add models
|
||||
if (
|
||||
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=team_obj
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Only proxy admin or team admin can modify team models"},
|
||||
)
|
||||
|
||||
# Get current models list
|
||||
current_models = team_obj.models or []
|
||||
|
||||
# Add new models (avoid duplicates)
|
||||
updated_models = list(set(current_models + data.models))
|
||||
|
||||
# Update team
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id}, data={"models": updated_models}
|
||||
)
|
||||
|
||||
return updated_team
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/model/delete",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def team_model_delete(
|
||||
data: TeamModelDeleteRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Remove models from a team's allowed model list. Only proxy admin or team admin can remove models.
|
||||
|
||||
Parameters:
|
||||
- team_id: str - Required. The team to remove models from
|
||||
- models: List[str] - Required. List of models to remove from the team
|
||||
|
||||
Example Request:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/team/model/delete' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"team_id": "team-1234",
|
||||
"models": ["gpt-4"]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
# Get existing team
|
||||
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
if team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||
)
|
||||
|
||||
team_obj = LiteLLM_TeamTable(**team_row.model_dump())
|
||||
|
||||
# Authorization check - only proxy admin or team admin can remove models
|
||||
if (
|
||||
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||
and not _is_user_team_admin(
|
||||
user_api_key_dict=user_api_key_dict, team_obj=team_obj
|
||||
)
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=403,
|
||||
detail={"error": "Only proxy admin or team admin can modify team models"},
|
||||
)
|
||||
|
||||
# Get current models list
|
||||
current_models = team_obj.models or []
|
||||
|
||||
# Remove specified models
|
||||
updated_models = [m for m in current_models if m not in data.models]
|
||||
|
||||
# Update team
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id}, data={"models": updated_models}
|
||||
)
|
||||
|
||||
return updated_team
|
||||
|
|
|
@ -4897,32 +4897,61 @@ class Router:
|
|||
|
||||
return returned_models
|
||||
|
||||
def get_model_names(self) -> List[str]:
|
||||
def get_model_names(self, team_id: Optional[str] = None) -> List[str]:
|
||||
"""
|
||||
Returns all possible model names for router.
|
||||
Returns all possible model names for the router, including models defined via model_group_alias.
|
||||
|
||||
Includes model_group_alias models too.
|
||||
If a team_id is provided, only deployments configured with that team_id (i.e. team‐specific models)
|
||||
will yield their team public name.
|
||||
"""
|
||||
model_list = self.get_model_list()
|
||||
if model_list is None:
|
||||
return []
|
||||
|
||||
deployments = self.get_model_list() or []
|
||||
model_names = []
|
||||
for m in model_list:
|
||||
model_names.append(self._get_public_model_name(m))
|
||||
|
||||
for deployment in deployments:
|
||||
model_info = deployment.get("model_info")
|
||||
if self._is_team_specific_model(model_info):
|
||||
team_model_name = self._get_team_specific_model(
|
||||
deployment=deployment, team_id=team_id
|
||||
)
|
||||
if team_model_name:
|
||||
model_names.append(team_model_name)
|
||||
else:
|
||||
model_names.append(deployment.get("model_name", ""))
|
||||
|
||||
return model_names
|
||||
|
||||
def _get_public_model_name(self, deployment: DeploymentTypedDict) -> str:
|
||||
def _get_team_specific_model(
|
||||
self, deployment: DeploymentTypedDict, team_id: Optional[str] = None
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Returns the user-friendly model name for public display (e.g., on /models endpoint).
|
||||
Get the team-specific model name if team_id matches the deployment.
|
||||
|
||||
Prioritizes the team's public model name if available, otherwise falls back to the default model name.
|
||||
Args:
|
||||
deployment: DeploymentTypedDict - The model deployment
|
||||
team_id: Optional[str] - If passed, will return router models set with a `team_id` matching the passed `team_id`.
|
||||
|
||||
Returns:
|
||||
str: The `team_public_model_name` if team_id matches
|
||||
None: If team_id doesn't match or no team info exists
|
||||
"""
|
||||
model_info = deployment.get("model_info")
|
||||
if model_info and model_info.get("team_public_model_name"):
|
||||
return model_info["team_public_model_name"]
|
||||
model_info: Optional[Dict] = deployment.get("model_info") or {}
|
||||
if model_info is None:
|
||||
return None
|
||||
if team_id == model_info.get("team_id"):
|
||||
return model_info.get("team_public_model_name")
|
||||
return None
|
||||
|
||||
return deployment["model_name"]
|
||||
def _is_team_specific_model(self, model_info: Optional[Dict]) -> bool:
|
||||
"""
|
||||
Check if model info contains team-specific configuration.
|
||||
|
||||
Args:
|
||||
model_info: Model information dictionary
|
||||
|
||||
Returns:
|
||||
bool: True if model has team-specific configuration
|
||||
"""
|
||||
return bool(model_info and model_info.get("team_id"))
|
||||
|
||||
def get_model_list_from_model_alias(
|
||||
self, model_name: Optional[str] = None
|
||||
|
|
|
@ -771,7 +771,7 @@ class RouterBudgetLimiting(CustomLogger):
|
|||
return
|
||||
for _model in model_list:
|
||||
_litellm_params = _model.get("litellm_params", {})
|
||||
_model_info: ModelInfo = _model.get("model_info") or ModelInfo()
|
||||
_model_info: Dict = _model.get("model_info") or {}
|
||||
_model_id = _model_info.get("id")
|
||||
_max_budget = _litellm_params.get("max_budget")
|
||||
_budget_duration = _litellm_params.get("budget_duration")
|
||||
|
|
|
@ -385,7 +385,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
|||
class DeploymentTypedDict(TypedDict, total=False):
|
||||
model_name: Required[str]
|
||||
litellm_params: Required[LiteLLMParamsTypedDict]
|
||||
model_info: ModelInfo
|
||||
model_info: dict
|
||||
|
||||
|
||||
SPECIAL_MODEL_INFO_PARAMS = [
|
||||
|
|
|
@ -2601,24 +2601,64 @@ def test_model_group_alias(hidden):
|
|||
assert len(model_names) == len(_model_list) + 1
|
||||
|
||||
|
||||
def test_get_public_model_name():
|
||||
def test_get_team_specific_model():
|
||||
"""
|
||||
Test that the _get_public_model_name helper returns the `team_public_model_name` if it exists, otherwise it returns the `model_name`.
|
||||
Test that _get_team_specific_model returns:
|
||||
- team_public_model_name when team_id matches
|
||||
- None when team_id doesn't match
|
||||
- None when no team_id in model_info
|
||||
"""
|
||||
_model_list = [
|
||||
{
|
||||
"model_name": "model_name_12299393939_gms",
|
||||
"litellm_params": {"model": "gpt-4o"},
|
||||
"model_info": {"team_public_model_name": "gpt-4o"},
|
||||
},
|
||||
]
|
||||
router = Router(
|
||||
model_list=_model_list,
|
||||
router = Router(model_list=[])
|
||||
|
||||
# Test 1: Matching team_id
|
||||
deployment = DeploymentTypedDict(
|
||||
model_name="model-x",
|
||||
litellm_params={},
|
||||
model_info=ModelInfo(team_id="team1", team_public_model_name="public-model-x"),
|
||||
)
|
||||
assert router._get_team_specific_model(deployment, "team1") == "public-model-x"
|
||||
|
||||
models = router.get_model_list()
|
||||
# Test 2: Non-matching team_id
|
||||
assert router._get_team_specific_model(deployment, "team2") is None
|
||||
|
||||
assert router._get_public_model_name(models[0]) == "gpt-4o"
|
||||
# Test 3: No team_id in model_info
|
||||
deployment = DeploymentTypedDict(
|
||||
model_name="model-y",
|
||||
litellm_params={},
|
||||
model_info=ModelInfo(team_public_model_name="public-model-y"),
|
||||
)
|
||||
assert router._get_team_specific_model(deployment, "team1") is None
|
||||
|
||||
# Test 4: No model_info
|
||||
deployment = DeploymentTypedDict(
|
||||
model_name="model-z", litellm_params={}, model_info=ModelInfo()
|
||||
)
|
||||
assert router._get_team_specific_model(deployment, "team1") is None
|
||||
|
||||
|
||||
def test_is_team_specific_model():
|
||||
"""
|
||||
Test that _is_team_specific_model returns:
|
||||
- True when model_info contains team_id
|
||||
- False when model_info doesn't contain team_id
|
||||
- False when model_info is None
|
||||
"""
|
||||
router = Router(model_list=[])
|
||||
|
||||
# Test 1: With team_id
|
||||
model_info = ModelInfo(team_id="team1", team_public_model_name="public-model-x")
|
||||
assert router._is_team_specific_model(model_info) is True
|
||||
|
||||
# Test 2: Without team_id
|
||||
model_info = ModelInfo(team_public_model_name="public-model-y")
|
||||
assert router._is_team_specific_model(model_info) is False
|
||||
|
||||
# Test 3: Empty model_info
|
||||
model_info = ModelInfo()
|
||||
assert router._is_team_specific_model(model_info) is False
|
||||
|
||||
# Test 4: None model_info
|
||||
assert router._is_team_specific_model(None) is False
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("on_error", [True, False])
|
||||
|
|
|
@ -1,86 +0,0 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
import uuid
|
||||
from httpx import AsyncClient
|
||||
import uuid
|
||||
import os
|
||||
|
||||
TEST_MASTER_KEY = "sk-1234"
|
||||
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_alias():
|
||||
"""
|
||||
Test model alias functionality with teams:
|
||||
1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o"
|
||||
2. Create a new team
|
||||
3. Update team with model_alias mapping
|
||||
4. Generate key for team
|
||||
5. Make request with aliased model name
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Add new model
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4o-team1",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Create new team
|
||||
team_response = await client.post(
|
||||
"/team/new",
|
||||
json={
|
||||
"models": ["gpt-4o-team1"],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert team_response.status_code == 200
|
||||
team_data = team_response.json()
|
||||
team_id = team_data["team_id"]
|
||||
|
||||
# Update team with model alias
|
||||
update_response = await client.post(
|
||||
"/team/update",
|
||||
json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}},
|
||||
headers=headers,
|
||||
)
|
||||
assert update_response.status_code == 200
|
||||
|
||||
# Generate key for team
|
||||
key_response = await client.post(
|
||||
"/key/generate", json={"team_id": team_id}, headers=headers
|
||||
)
|
||||
assert key_response.status_code == 200
|
||||
key = key_response.json()["key"]
|
||||
|
||||
# Make request with model alias
|
||||
openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1")
|
||||
|
||||
response = await openai_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
|
||||
)
|
||||
|
||||
assert response is not None, "Should get valid response when using model alias"
|
||||
|
||||
# Cleanup - delete the model
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
delete_response = await client.post(
|
||||
"/model/delete",
|
||||
json={"id": model_id},
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
)
|
||||
assert delete_response.status_code == 200
|
312
tests/store_model_in_db_tests/test_team_models.py
Normal file
312
tests/store_model_in_db_tests/test_team_models.py
Normal file
|
@ -0,0 +1,312 @@
|
|||
import pytest
|
||||
import asyncio
|
||||
import aiohttp
|
||||
import json
|
||||
from openai import AsyncOpenAI
|
||||
import uuid
|
||||
from httpx import AsyncClient
|
||||
import uuid
|
||||
import os
|
||||
|
||||
TEST_MASTER_KEY = "sk-1234"
|
||||
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_alias():
|
||||
"""
|
||||
Test model alias functionality with teams:
|
||||
1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o"
|
||||
2. Create a new team
|
||||
3. Update team with model_alias mapping
|
||||
4. Generate key for team
|
||||
5. Make request with aliased model name
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Add new model
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4o-team1",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4o",
|
||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||
},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Create new team
|
||||
team_response = await client.post(
|
||||
"/team/new",
|
||||
json={
|
||||
"models": ["gpt-4o-team1"],
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert team_response.status_code == 200
|
||||
team_data = team_response.json()
|
||||
team_id = team_data["team_id"]
|
||||
|
||||
# Update team with model alias
|
||||
update_response = await client.post(
|
||||
"/team/update",
|
||||
json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}},
|
||||
headers=headers,
|
||||
)
|
||||
assert update_response.status_code == 200
|
||||
|
||||
# Generate key for team
|
||||
key_response = await client.post(
|
||||
"/key/generate", json={"team_id": team_id}, headers=headers
|
||||
)
|
||||
assert key_response.status_code == 200
|
||||
key = key_response.json()["key"]
|
||||
|
||||
# Make request with model alias
|
||||
openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1")
|
||||
|
||||
response = await openai_client.chat.completions.create(
|
||||
model="gpt-4o",
|
||||
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
|
||||
)
|
||||
|
||||
assert response is not None, "Should get valid response when using model alias"
|
||||
|
||||
# Cleanup - delete the model
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
delete_response = await client.post(
|
||||
"/model/delete",
|
||||
json={"id": model_id},
|
||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_association():
|
||||
"""
|
||||
Test that models created with a team_id are properly associated with the team:
|
||||
1. Create a new team
|
||||
2. Add a model with team_id in model_info
|
||||
3. Verify the model appears in team info
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Create new team
|
||||
team_response = await client.post(
|
||||
"/team/new",
|
||||
json={
|
||||
"models": [], # Start with empty model list
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert team_response.status_code == 200
|
||||
team_data = team_response.json()
|
||||
team_id = team_data["team_id"]
|
||||
|
||||
# Add new model with team_id
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4-team-test",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"custom_llm_provider": "openai",
|
||||
"api_key": "fake_key",
|
||||
},
|
||||
"model_info": {"team_id": team_id},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Get team info and verify model association
|
||||
team_info_response = await client.get(
|
||||
f"/team/info",
|
||||
headers=headers,
|
||||
params={"team_id": team_id},
|
||||
)
|
||||
assert team_info_response.status_code == 200
|
||||
team_info = team_info_response.json()["team_info"]
|
||||
|
||||
print("team_info", json.dumps(team_info, indent=4))
|
||||
|
||||
# Verify the model is in team_models
|
||||
assert (
|
||||
"gpt-4-team-test" in team_info["models"]
|
||||
), "Model should be associated with team"
|
||||
|
||||
# Cleanup - delete the model
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
delete_response = await client.post(
|
||||
"/model/delete",
|
||||
json={"id": model_id},
|
||||
headers=headers,
|
||||
)
|
||||
assert delete_response.status_code == 200
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_visibility_in_models_endpoint():
|
||||
"""
|
||||
Test that team-specific models are only visible to the correct team in /models endpoint:
|
||||
1. Create two teams
|
||||
2. Add a model associated with team1
|
||||
3. Generate keys for both teams
|
||||
4. Verify team1's key can see the model in /models
|
||||
5. Verify team2's key cannot see the model in /models
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Create team1
|
||||
team1_response = await client.post(
|
||||
"/team/new",
|
||||
json={"models": []},
|
||||
headers=headers,
|
||||
)
|
||||
assert team1_response.status_code == 200
|
||||
team1_id = team1_response.json()["team_id"]
|
||||
|
||||
# Create team2
|
||||
team2_response = await client.post(
|
||||
"/team/new",
|
||||
json={"models": []},
|
||||
headers=headers,
|
||||
)
|
||||
assert team2_response.status_code == 200
|
||||
team2_id = team2_response.json()["team_id"]
|
||||
|
||||
# Add model associated with team1
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4-team-test",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"custom_llm_provider": "openai",
|
||||
"api_key": "fake_key",
|
||||
},
|
||||
"model_info": {"team_id": team1_id},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Generate keys for both teams
|
||||
team1_key = (
|
||||
await client.post("/key/generate", json={"team_id": team1_id}, headers=headers)
|
||||
).json()["key"]
|
||||
team2_key = (
|
||||
await client.post("/key/generate", json={"team_id": team2_id}, headers=headers)
|
||||
).json()["key"]
|
||||
|
||||
# Check models visibility for team1's key
|
||||
team1_models = await client.get(
|
||||
"/models", headers={"Authorization": f"Bearer {team1_key}"}
|
||||
)
|
||||
assert team1_models.status_code == 200
|
||||
print("team1_models", json.dumps(team1_models.json(), indent=4))
|
||||
assert any(
|
||||
model["id"] == "gpt-4-team-test" for model in team1_models.json()["data"]
|
||||
), "Team1 should see their model"
|
||||
|
||||
# Check models visibility for team2's key
|
||||
team2_models = await client.get(
|
||||
"/models", headers={"Authorization": f"Bearer {team2_key}"}
|
||||
)
|
||||
assert team2_models.status_code == 200
|
||||
print("team2_models", json.dumps(team2_models.json(), indent=4))
|
||||
assert not any(
|
||||
model["id"] == "gpt-4-team-test" for model in team2_models.json()["data"]
|
||||
), "Team2 should not see team1's model"
|
||||
|
||||
# Cleanup
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
await client.post("/model/delete", json={"id": model_id}, headers=headers)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_visibility_in_model_info_endpoint():
|
||||
"""
|
||||
Test that team-specific models are visible to all users in /v2/model/info endpoint:
|
||||
Note: /v2/model/info is used by the Admin UI to display model info
|
||||
1. Create a team
|
||||
2. Add a model associated with the team
|
||||
3. Generate a team key
|
||||
4. Verify both team key and non-team key can see the model in /v2/model/info
|
||||
"""
|
||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||
|
||||
# Create team
|
||||
team_response = await client.post(
|
||||
"/team/new",
|
||||
json={"models": []},
|
||||
headers=headers,
|
||||
)
|
||||
assert team_response.status_code == 200
|
||||
team_id = team_response.json()["team_id"]
|
||||
|
||||
# Add model associated with team
|
||||
model_response = await client.post(
|
||||
"/model/new",
|
||||
json={
|
||||
"model_name": "gpt-4-team-test",
|
||||
"litellm_params": {
|
||||
"model": "gpt-4",
|
||||
"custom_llm_provider": "openai",
|
||||
"api_key": "fake_key",
|
||||
},
|
||||
"model_info": {"team_id": team_id},
|
||||
},
|
||||
headers=headers,
|
||||
)
|
||||
assert model_response.status_code == 200
|
||||
|
||||
# Generate team key
|
||||
team_key = (
|
||||
await client.post("/key/generate", json={"team_id": team_id}, headers=headers)
|
||||
).json()["key"]
|
||||
|
||||
# Generate non-team key
|
||||
non_team_key = (
|
||||
await client.post("/key/generate", json={}, headers=headers)
|
||||
).json()["key"]
|
||||
|
||||
# Check model info visibility with team key
|
||||
team_model_info = await client.get(
|
||||
"/v2/model/info",
|
||||
headers={"Authorization": f"Bearer {team_key}"},
|
||||
params={"model_name": "gpt-4-team-test"},
|
||||
)
|
||||
assert team_model_info.status_code == 200
|
||||
team_model_info = team_model_info.json()
|
||||
print("Team 1 model info", json.dumps(team_model_info, indent=4))
|
||||
assert any(
|
||||
model["model_info"].get("team_public_model_name") == "gpt-4-team-test"
|
||||
for model in team_model_info["data"]
|
||||
), "Team1 should see their model"
|
||||
|
||||
# Check model info visibility with non-team key
|
||||
non_team_model_info = await client.get(
|
||||
"/v2/model/info",
|
||||
headers={"Authorization": f"Bearer {non_team_key}"},
|
||||
params={"model_name": "gpt-4-team-test"},
|
||||
)
|
||||
assert non_team_model_info.status_code == 200
|
||||
non_team_model_info = non_team_model_info.json()
|
||||
print("Non-team model info", json.dumps(non_team_model_info, indent=4))
|
||||
assert any(
|
||||
model["model_info"].get("team_public_model_name") == "gpt-4-team-test"
|
||||
for model in non_team_model_info["data"]
|
||||
), "Non-team should see the model"
|
||||
|
||||
# Cleanup
|
||||
model_id = model_response.json()["model_info"]["id"]
|
||||
await client.post("/model/delete", json={"id": model_id}, headers=headers)
|
|
@ -293,7 +293,13 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
|||
initialValue={team ? team.team_id : null}
|
||||
className="mt-8"
|
||||
>
|
||||
<TeamDropdown teams={teams} />
|
||||
<TeamDropdown
|
||||
teams={teams}
|
||||
onChange={(teamId) => {
|
||||
const selectedTeam = teams?.find(t => t.team_id === teamId) || null;
|
||||
setSelectedCreateKeyTeam(selectedTeam);
|
||||
}}
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
|
|
|
@ -1,174 +0,0 @@
|
|||
import React, { useState } from "react";
|
||||
import {
|
||||
Card,
|
||||
Title,
|
||||
Text,
|
||||
Button as TremorButton,
|
||||
} from "@tremor/react";
|
||||
import { Modal, Form, Select, Input, message } from "antd";
|
||||
import { teamUpdateCall } from "@/components/networking";
|
||||
|
||||
interface ModelAliasesCardProps {
|
||||
teamId: string;
|
||||
accessToken: string | null;
|
||||
currentAliases: Record<string, string>;
|
||||
availableModels: string[];
|
||||
onUpdate: () => void;
|
||||
}
|
||||
|
||||
const ModelAliasesCard: React.FC<ModelAliasesCardProps> = ({
|
||||
teamId,
|
||||
accessToken,
|
||||
currentAliases,
|
||||
availableModels,
|
||||
onUpdate,
|
||||
}) => {
|
||||
const [isModalVisible, setIsModalVisible] = useState(false);
|
||||
const [form] = Form.useForm();
|
||||
|
||||
const handleCreateAlias = async (values: any) => {
|
||||
try {
|
||||
if (!accessToken) return;
|
||||
|
||||
const newAliases = {
|
||||
...currentAliases,
|
||||
[values.alias_name]: values.original_model,
|
||||
};
|
||||
|
||||
const updateData = {
|
||||
team_id: teamId,
|
||||
model_aliases: newAliases,
|
||||
};
|
||||
|
||||
await teamUpdateCall(accessToken, updateData);
|
||||
message.success("Model alias created successfully");
|
||||
setIsModalVisible(false);
|
||||
form.resetFields();
|
||||
currentAliases[values.alias_name] = values.original_model;
|
||||
} catch (error) {
|
||||
message.error("Failed to create model alias");
|
||||
console.error("Error creating model alias:", error);
|
||||
}
|
||||
};
|
||||
|
||||
return (
|
||||
<div className="mt-8">
|
||||
<Title>Team Aliases</Title>
|
||||
<Text className="text-gray-600 mb-4">
|
||||
Allow a team to use an alias that points to a specific model deployment.
|
||||
|
||||
</Text>
|
||||
|
||||
<div className="bg-white rounded-lg p-6 border border-gray-200">
|
||||
<div className="flex justify-between items-center mb-6">
|
||||
<div>
|
||||
<div className="flex space-x-4 text-gray-600">
|
||||
<div className="w-64">ALIAS</div>
|
||||
<div>POINTS TO</div>
|
||||
</div>
|
||||
</div>
|
||||
<TremorButton
|
||||
size="md"
|
||||
variant="primary"
|
||||
onClick={() => setIsModalVisible(true)}
|
||||
>
|
||||
Create Model Alias
|
||||
</TremorButton>
|
||||
</div>
|
||||
|
||||
<div className="space-y-4">
|
||||
{Object.entries(currentAliases).map(([aliasName, originalModel], index) => (
|
||||
<div key={index} className="flex space-x-4 border-t border-gray-200 pt-4">
|
||||
<div className="w-64">
|
||||
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
|
||||
{aliasName}
|
||||
</span>
|
||||
</div>
|
||||
<div>
|
||||
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
|
||||
{originalModel}
|
||||
</span>
|
||||
</div>
|
||||
</div>
|
||||
))}
|
||||
{Object.keys(currentAliases).length === 0 && (
|
||||
<div className="text-gray-500 text-center py-4 border-t border-gray-200">
|
||||
No model aliases configured
|
||||
</div>
|
||||
)}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<Modal
|
||||
title="Create Model Alias"
|
||||
open={isModalVisible}
|
||||
onCancel={() => {
|
||||
setIsModalVisible(false);
|
||||
form.resetFields();
|
||||
}}
|
||||
footer={null}
|
||||
width={500}
|
||||
>
|
||||
<Form
|
||||
form={form}
|
||||
onFinish={handleCreateAlias}
|
||||
layout="vertical"
|
||||
className="mt-4"
|
||||
>
|
||||
<Form.Item
|
||||
label="Alias Name"
|
||||
name="alias_name"
|
||||
rules={[{ required: true, message: "Please enter an alias name" }]}
|
||||
>
|
||||
<Input
|
||||
placeholder="Enter the model alias (e.g., gpt-4o)"
|
||||
type=""
|
||||
/>
|
||||
</Form.Item>
|
||||
|
||||
<Form.Item
|
||||
label="Points To"
|
||||
name="original_model"
|
||||
rules={[{ required: true, message: "Please select a model" }]}
|
||||
>
|
||||
<Select
|
||||
placeholder="Select model version"
|
||||
className="w-full font-mono"
|
||||
showSearch
|
||||
optionFilterProp="children"
|
||||
>
|
||||
{availableModels.map((model) => (
|
||||
<Select.Option key={model} value={model} className="font-mono">
|
||||
{model}
|
||||
</Select.Option>
|
||||
))}
|
||||
</Select>
|
||||
</Form.Item>
|
||||
|
||||
<div className="flex justify-end gap-2 mt-6">
|
||||
<TremorButton
|
||||
size="md"
|
||||
variant="secondary"
|
||||
onClick={() => {
|
||||
setIsModalVisible(false);
|
||||
form.resetFields();
|
||||
}}
|
||||
className="bg-white text-gray-700 border border-gray-300 hover:bg-gray-50"
|
||||
>
|
||||
Cancel
|
||||
</TremorButton>
|
||||
<TremorButton
|
||||
size="md"
|
||||
variant="secondary"
|
||||
type="submit"
|
||||
>
|
||||
Create Alias
|
||||
</TremorButton>
|
||||
</div>
|
||||
</Form>
|
||||
</Modal>
|
||||
</div>
|
||||
);
|
||||
};
|
||||
|
||||
export default ModelAliasesCard;
|
|
@ -474,14 +474,6 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
|||
</div>
|
||||
)}
|
||||
</Card>
|
||||
|
||||
<ModelAliasesCard
|
||||
teamId={teamId}
|
||||
accessToken={accessToken}
|
||||
currentAliases={teamData?.team_info?.litellm_model_table?.model_aliases || {}}
|
||||
availableModels={userModels}
|
||||
onUpdate={fetchTeamInfo}
|
||||
/>
|
||||
</TabPanel>
|
||||
</TabPanels>
|
||||
</TabGroup>
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue