mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 11:14:04 +00:00
(Polish/Fixes) - Fixes for Adding Team Specific Models (#8645)
* refactor get model info for team models * allow adding a model to a team when creating team specific model * ui update selected Team on Team Dropdown * test_team_model_association * testing for team specific models * test_get_team_specific_model * test: skip on internal server error * remove model alias card on teams page * linting fix _get_team_specific_model * fix DeploymentTypedDict * fix linting error * fix code quality * fix model info checks --------- Co-authored-by: Krrish Dholakia <krrishdholakia@gmail.com>
This commit is contained in:
parent
e08e8eda47
commit
e5f29c3f7d
12 changed files with 599 additions and 310 deletions
|
@ -2116,6 +2116,20 @@ class TeamMemberUpdateResponse(MemberUpdateResponse):
|
||||||
max_budget_in_team: Optional[float] = None
|
max_budget_in_team: Optional[float] = None
|
||||||
|
|
||||||
|
|
||||||
|
class TeamModelAddRequest(BaseModel):
|
||||||
|
"""Request to add models to a team"""
|
||||||
|
|
||||||
|
team_id: str
|
||||||
|
models: List[str]
|
||||||
|
|
||||||
|
|
||||||
|
class TeamModelDeleteRequest(BaseModel):
|
||||||
|
"""Request to delete models from a team"""
|
||||||
|
|
||||||
|
team_id: str
|
||||||
|
models: List[str]
|
||||||
|
|
||||||
|
|
||||||
# Organization Member Requests
|
# Organization Member Requests
|
||||||
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
class OrganizationMemberAddRequest(OrgMemberAddRequest):
|
||||||
organization_id: str
|
organization_id: str
|
||||||
|
|
|
@ -25,18 +25,21 @@ from litellm.proxy._types import (
|
||||||
PrismaCompatibleUpdateDBModel,
|
PrismaCompatibleUpdateDBModel,
|
||||||
ProxyErrorTypes,
|
ProxyErrorTypes,
|
||||||
ProxyException,
|
ProxyException,
|
||||||
|
TeamModelAddRequest,
|
||||||
UpdateTeamRequest,
|
UpdateTeamRequest,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
from litellm.proxy.common_utils.encrypt_decrypt_utils import encrypt_value_helper
|
||||||
from litellm.proxy.management_endpoints.team_endpoints import update_team
|
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||||
|
team_model_add,
|
||||||
|
update_team,
|
||||||
|
)
|
||||||
from litellm.proxy.utils import PrismaClient
|
from litellm.proxy.utils import PrismaClient
|
||||||
from litellm.types.router import (
|
from litellm.types.router import (
|
||||||
Deployment,
|
Deployment,
|
||||||
DeploymentTypedDict,
|
DeploymentTypedDict,
|
||||||
LiteLLMParamsTypedDict,
|
LiteLLMParamsTypedDict,
|
||||||
ModelInfo,
|
|
||||||
updateDeployment,
|
updateDeployment,
|
||||||
)
|
)
|
||||||
from litellm.utils import get_utc_datetime
|
from litellm.utils import get_utc_datetime
|
||||||
|
@ -88,16 +91,11 @@ def update_db_model(
|
||||||
|
|
||||||
# update model info
|
# update model info
|
||||||
if updated_patch.model_info:
|
if updated_patch.model_info:
|
||||||
_updated_model_info_dict = updated_patch.model_info.model_dump(
|
|
||||||
exclude_unset=True
|
|
||||||
)
|
|
||||||
if "model_info" not in merged_deployment_dict:
|
if "model_info" not in merged_deployment_dict:
|
||||||
merged_deployment_dict["model_info"] = ModelInfo()
|
merged_deployment_dict["model_info"] = {}
|
||||||
_original_model_info_dict = merged_deployment_dict["model_info"].model_dump(
|
merged_deployment_dict["model_info"].update(
|
||||||
exclude_unset=True
|
updated_patch.model_info.model_dump(exclude_none=True)
|
||||||
)
|
)
|
||||||
_original_model_info_dict.update(_updated_model_info_dict)
|
|
||||||
merged_deployment_dict["model_info"] = ModelInfo(**_original_model_info_dict)
|
|
||||||
|
|
||||||
# convert to prisma compatible format
|
# convert to prisma compatible format
|
||||||
|
|
||||||
|
@ -294,6 +292,16 @@ async def _add_team_model_to_db(
|
||||||
http_request=Request(scope={"type": "http"}),
|
http_request=Request(scope={"type": "http"}),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# add model to team object
|
||||||
|
await team_model_add(
|
||||||
|
data=TeamModelAddRequest(
|
||||||
|
team_id=_team_id,
|
||||||
|
models=[original_model_name],
|
||||||
|
),
|
||||||
|
http_request=Request(scope={"type": "http"}),
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
)
|
||||||
|
|
||||||
return model_response
|
return model_response
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -47,6 +47,8 @@ from litellm.proxy._types import (
|
||||||
TeamMemberDeleteRequest,
|
TeamMemberDeleteRequest,
|
||||||
TeamMemberUpdateRequest,
|
TeamMemberUpdateRequest,
|
||||||
TeamMemberUpdateResponse,
|
TeamMemberUpdateResponse,
|
||||||
|
TeamModelAddRequest,
|
||||||
|
TeamModelDeleteRequest,
|
||||||
UpdateTeamRequest,
|
UpdateTeamRequest,
|
||||||
UserAPIKeyAuth,
|
UserAPIKeyAuth,
|
||||||
)
|
)
|
||||||
|
@ -1774,3 +1776,149 @@ async def ui_view_teams(
|
||||||
|
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
raise HTTPException(status_code=500, detail=f"Error searching teams: {str(e)}")
|
raise HTTPException(status_code=500, detail=f"Error searching teams: {str(e)}")
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/team/model/add",
|
||||||
|
tags=["team management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
@management_endpoint_wrapper
|
||||||
|
async def team_model_add(
|
||||||
|
data: TeamModelAddRequest,
|
||||||
|
http_request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Add models to a team's allowed model list. Only proxy admin or team admin can add models.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- team_id: str - Required. The team to add models to
|
||||||
|
- models: List[str] - Required. List of models to add to the team
|
||||||
|
|
||||||
|
Example Request:
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:4000/team/model/add' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"team_id": "team-1234",
|
||||||
|
"models": ["gpt-4", "claude-2"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||||
|
|
||||||
|
# Get existing team
|
||||||
|
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
|
where={"team_id": data.team_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if team_row is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
team_obj = LiteLLM_TeamTable(**team_row.model_dump())
|
||||||
|
|
||||||
|
# Authorization check - only proxy admin or team admin can add models
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and not _is_user_team_admin(
|
||||||
|
user_api_key_dict=user_api_key_dict, team_obj=team_obj
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={"error": "Only proxy admin or team admin can modify team models"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get current models list
|
||||||
|
current_models = team_obj.models or []
|
||||||
|
|
||||||
|
# Add new models (avoid duplicates)
|
||||||
|
updated_models = list(set(current_models + data.models))
|
||||||
|
|
||||||
|
# Update team
|
||||||
|
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||||
|
where={"team_id": data.team_id}, data={"models": updated_models}
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_team
|
||||||
|
|
||||||
|
|
||||||
|
@router.post(
|
||||||
|
"/team/model/delete",
|
||||||
|
tags=["team management"],
|
||||||
|
dependencies=[Depends(user_api_key_auth)],
|
||||||
|
)
|
||||||
|
@management_endpoint_wrapper
|
||||||
|
async def team_model_delete(
|
||||||
|
data: TeamModelDeleteRequest,
|
||||||
|
http_request: Request,
|
||||||
|
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Remove models from a team's allowed model list. Only proxy admin or team admin can remove models.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
- team_id: str - Required. The team to remove models from
|
||||||
|
- models: List[str] - Required. List of models to remove from the team
|
||||||
|
|
||||||
|
Example Request:
|
||||||
|
```
|
||||||
|
curl --location 'http://0.0.0.0:4000/team/model/delete' \
|
||||||
|
--header 'Authorization: Bearer sk-1234' \
|
||||||
|
--header 'Content-Type: application/json' \
|
||||||
|
--data '{
|
||||||
|
"team_id": "team-1234",
|
||||||
|
"models": ["gpt-4"]
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
"""
|
||||||
|
from litellm.proxy.proxy_server import prisma_client
|
||||||
|
|
||||||
|
if prisma_client is None:
|
||||||
|
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||||
|
|
||||||
|
# Get existing team
|
||||||
|
team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||||
|
where={"team_id": data.team_id}
|
||||||
|
)
|
||||||
|
|
||||||
|
if team_row is None:
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=404,
|
||||||
|
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||||
|
)
|
||||||
|
|
||||||
|
team_obj = LiteLLM_TeamTable(**team_row.model_dump())
|
||||||
|
|
||||||
|
# Authorization check - only proxy admin or team admin can remove models
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
and not _is_user_team_admin(
|
||||||
|
user_api_key_dict=user_api_key_dict, team_obj=team_obj
|
||||||
|
)
|
||||||
|
):
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=403,
|
||||||
|
detail={"error": "Only proxy admin or team admin can modify team models"},
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get current models list
|
||||||
|
current_models = team_obj.models or []
|
||||||
|
|
||||||
|
# Remove specified models
|
||||||
|
updated_models = [m for m in current_models if m not in data.models]
|
||||||
|
|
||||||
|
# Update team
|
||||||
|
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||||
|
where={"team_id": data.team_id}, data={"models": updated_models}
|
||||||
|
)
|
||||||
|
|
||||||
|
return updated_team
|
||||||
|
|
|
@ -4897,32 +4897,61 @@ class Router:
|
||||||
|
|
||||||
return returned_models
|
return returned_models
|
||||||
|
|
||||||
def get_model_names(self) -> List[str]:
|
def get_model_names(self, team_id: Optional[str] = None) -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns all possible model names for router.
|
Returns all possible model names for the router, including models defined via model_group_alias.
|
||||||
|
|
||||||
Includes model_group_alias models too.
|
If a team_id is provided, only deployments configured with that team_id (i.e. team‐specific models)
|
||||||
|
will yield their team public name.
|
||||||
"""
|
"""
|
||||||
model_list = self.get_model_list()
|
deployments = self.get_model_list() or []
|
||||||
if model_list is None:
|
|
||||||
return []
|
|
||||||
|
|
||||||
model_names = []
|
model_names = []
|
||||||
for m in model_list:
|
|
||||||
model_names.append(self._get_public_model_name(m))
|
for deployment in deployments:
|
||||||
|
model_info = deployment.get("model_info")
|
||||||
|
if self._is_team_specific_model(model_info):
|
||||||
|
team_model_name = self._get_team_specific_model(
|
||||||
|
deployment=deployment, team_id=team_id
|
||||||
|
)
|
||||||
|
if team_model_name:
|
||||||
|
model_names.append(team_model_name)
|
||||||
|
else:
|
||||||
|
model_names.append(deployment.get("model_name", ""))
|
||||||
|
|
||||||
return model_names
|
return model_names
|
||||||
|
|
||||||
def _get_public_model_name(self, deployment: DeploymentTypedDict) -> str:
|
def _get_team_specific_model(
|
||||||
|
self, deployment: DeploymentTypedDict, team_id: Optional[str] = None
|
||||||
|
) -> Optional[str]:
|
||||||
"""
|
"""
|
||||||
Returns the user-friendly model name for public display (e.g., on /models endpoint).
|
Get the team-specific model name if team_id matches the deployment.
|
||||||
|
|
||||||
Prioritizes the team's public model name if available, otherwise falls back to the default model name.
|
Args:
|
||||||
|
deployment: DeploymentTypedDict - The model deployment
|
||||||
|
team_id: Optional[str] - If passed, will return router models set with a `team_id` matching the passed `team_id`.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
str: The `team_public_model_name` if team_id matches
|
||||||
|
None: If team_id doesn't match or no team info exists
|
||||||
"""
|
"""
|
||||||
model_info = deployment.get("model_info")
|
model_info: Optional[Dict] = deployment.get("model_info") or {}
|
||||||
if model_info and model_info.get("team_public_model_name"):
|
if model_info is None:
|
||||||
return model_info["team_public_model_name"]
|
return None
|
||||||
|
if team_id == model_info.get("team_id"):
|
||||||
|
return model_info.get("team_public_model_name")
|
||||||
|
return None
|
||||||
|
|
||||||
return deployment["model_name"]
|
def _is_team_specific_model(self, model_info: Optional[Dict]) -> bool:
|
||||||
|
"""
|
||||||
|
Check if model info contains team-specific configuration.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
model_info: Model information dictionary
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
bool: True if model has team-specific configuration
|
||||||
|
"""
|
||||||
|
return bool(model_info and model_info.get("team_id"))
|
||||||
|
|
||||||
def get_model_list_from_model_alias(
|
def get_model_list_from_model_alias(
|
||||||
self, model_name: Optional[str] = None
|
self, model_name: Optional[str] = None
|
||||||
|
|
|
@ -771,7 +771,7 @@ class RouterBudgetLimiting(CustomLogger):
|
||||||
return
|
return
|
||||||
for _model in model_list:
|
for _model in model_list:
|
||||||
_litellm_params = _model.get("litellm_params", {})
|
_litellm_params = _model.get("litellm_params", {})
|
||||||
_model_info: ModelInfo = _model.get("model_info") or ModelInfo()
|
_model_info: Dict = _model.get("model_info") or {}
|
||||||
_model_id = _model_info.get("id")
|
_model_id = _model_info.get("id")
|
||||||
_max_budget = _litellm_params.get("max_budget")
|
_max_budget = _litellm_params.get("max_budget")
|
||||||
_budget_duration = _litellm_params.get("budget_duration")
|
_budget_duration = _litellm_params.get("budget_duration")
|
||||||
|
|
|
@ -385,7 +385,7 @@ class LiteLLMParamsTypedDict(TypedDict, total=False):
|
||||||
class DeploymentTypedDict(TypedDict, total=False):
|
class DeploymentTypedDict(TypedDict, total=False):
|
||||||
model_name: Required[str]
|
model_name: Required[str]
|
||||||
litellm_params: Required[LiteLLMParamsTypedDict]
|
litellm_params: Required[LiteLLMParamsTypedDict]
|
||||||
model_info: ModelInfo
|
model_info: dict
|
||||||
|
|
||||||
|
|
||||||
SPECIAL_MODEL_INFO_PARAMS = [
|
SPECIAL_MODEL_INFO_PARAMS = [
|
||||||
|
|
|
@ -2601,24 +2601,64 @@ def test_model_group_alias(hidden):
|
||||||
assert len(model_names) == len(_model_list) + 1
|
assert len(model_names) == len(_model_list) + 1
|
||||||
|
|
||||||
|
|
||||||
def test_get_public_model_name():
|
def test_get_team_specific_model():
|
||||||
"""
|
"""
|
||||||
Test that the _get_public_model_name helper returns the `team_public_model_name` if it exists, otherwise it returns the `model_name`.
|
Test that _get_team_specific_model returns:
|
||||||
|
- team_public_model_name when team_id matches
|
||||||
|
- None when team_id doesn't match
|
||||||
|
- None when no team_id in model_info
|
||||||
"""
|
"""
|
||||||
_model_list = [
|
router = Router(model_list=[])
|
||||||
{
|
|
||||||
"model_name": "model_name_12299393939_gms",
|
# Test 1: Matching team_id
|
||||||
"litellm_params": {"model": "gpt-4o"},
|
deployment = DeploymentTypedDict(
|
||||||
"model_info": {"team_public_model_name": "gpt-4o"},
|
model_name="model-x",
|
||||||
},
|
litellm_params={},
|
||||||
]
|
model_info=ModelInfo(team_id="team1", team_public_model_name="public-model-x"),
|
||||||
router = Router(
|
|
||||||
model_list=_model_list,
|
|
||||||
)
|
)
|
||||||
|
assert router._get_team_specific_model(deployment, "team1") == "public-model-x"
|
||||||
|
|
||||||
models = router.get_model_list()
|
# Test 2: Non-matching team_id
|
||||||
|
assert router._get_team_specific_model(deployment, "team2") is None
|
||||||
|
|
||||||
assert router._get_public_model_name(models[0]) == "gpt-4o"
|
# Test 3: No team_id in model_info
|
||||||
|
deployment = DeploymentTypedDict(
|
||||||
|
model_name="model-y",
|
||||||
|
litellm_params={},
|
||||||
|
model_info=ModelInfo(team_public_model_name="public-model-y"),
|
||||||
|
)
|
||||||
|
assert router._get_team_specific_model(deployment, "team1") is None
|
||||||
|
|
||||||
|
# Test 4: No model_info
|
||||||
|
deployment = DeploymentTypedDict(
|
||||||
|
model_name="model-z", litellm_params={}, model_info=ModelInfo()
|
||||||
|
)
|
||||||
|
assert router._get_team_specific_model(deployment, "team1") is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_is_team_specific_model():
|
||||||
|
"""
|
||||||
|
Test that _is_team_specific_model returns:
|
||||||
|
- True when model_info contains team_id
|
||||||
|
- False when model_info doesn't contain team_id
|
||||||
|
- False when model_info is None
|
||||||
|
"""
|
||||||
|
router = Router(model_list=[])
|
||||||
|
|
||||||
|
# Test 1: With team_id
|
||||||
|
model_info = ModelInfo(team_id="team1", team_public_model_name="public-model-x")
|
||||||
|
assert router._is_team_specific_model(model_info) is True
|
||||||
|
|
||||||
|
# Test 2: Without team_id
|
||||||
|
model_info = ModelInfo(team_public_model_name="public-model-y")
|
||||||
|
assert router._is_team_specific_model(model_info) is False
|
||||||
|
|
||||||
|
# Test 3: Empty model_info
|
||||||
|
model_info = ModelInfo()
|
||||||
|
assert router._is_team_specific_model(model_info) is False
|
||||||
|
|
||||||
|
# Test 4: None model_info
|
||||||
|
assert router._is_team_specific_model(None) is False
|
||||||
|
|
||||||
|
|
||||||
# @pytest.mark.parametrize("on_error", [True, False])
|
# @pytest.mark.parametrize("on_error", [True, False])
|
||||||
|
|
|
@ -1,86 +0,0 @@
|
||||||
import pytest
|
|
||||||
import asyncio
|
|
||||||
import aiohttp
|
|
||||||
import json
|
|
||||||
from openai import AsyncOpenAI
|
|
||||||
import uuid
|
|
||||||
from httpx import AsyncClient
|
|
||||||
import uuid
|
|
||||||
import os
|
|
||||||
|
|
||||||
TEST_MASTER_KEY = "sk-1234"
|
|
||||||
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
|
||||||
async def test_team_model_alias():
|
|
||||||
"""
|
|
||||||
Test model alias functionality with teams:
|
|
||||||
1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o"
|
|
||||||
2. Create a new team
|
|
||||||
3. Update team with model_alias mapping
|
|
||||||
4. Generate key for team
|
|
||||||
5. Make request with aliased model name
|
|
||||||
"""
|
|
||||||
client = AsyncClient(base_url=PROXY_BASE_URL)
|
|
||||||
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
|
||||||
|
|
||||||
# Add new model
|
|
||||||
model_response = await client.post(
|
|
||||||
"/model/new",
|
|
||||||
json={
|
|
||||||
"model_name": "gpt-4o-team1",
|
|
||||||
"litellm_params": {
|
|
||||||
"model": "gpt-4o",
|
|
||||||
"api_key": os.getenv("OPENAI_API_KEY"),
|
|
||||||
},
|
|
||||||
},
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
assert model_response.status_code == 200
|
|
||||||
|
|
||||||
# Create new team
|
|
||||||
team_response = await client.post(
|
|
||||||
"/team/new",
|
|
||||||
json={
|
|
||||||
"models": ["gpt-4o-team1"],
|
|
||||||
},
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
assert team_response.status_code == 200
|
|
||||||
team_data = team_response.json()
|
|
||||||
team_id = team_data["team_id"]
|
|
||||||
|
|
||||||
# Update team with model alias
|
|
||||||
update_response = await client.post(
|
|
||||||
"/team/update",
|
|
||||||
json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}},
|
|
||||||
headers=headers,
|
|
||||||
)
|
|
||||||
assert update_response.status_code == 200
|
|
||||||
|
|
||||||
# Generate key for team
|
|
||||||
key_response = await client.post(
|
|
||||||
"/key/generate", json={"team_id": team_id}, headers=headers
|
|
||||||
)
|
|
||||||
assert key_response.status_code == 200
|
|
||||||
key = key_response.json()["key"]
|
|
||||||
|
|
||||||
# Make request with model alias
|
|
||||||
openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1")
|
|
||||||
|
|
||||||
response = await openai_client.chat.completions.create(
|
|
||||||
model="gpt-4o",
|
|
||||||
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
|
|
||||||
)
|
|
||||||
|
|
||||||
assert response is not None, "Should get valid response when using model alias"
|
|
||||||
|
|
||||||
# Cleanup - delete the model
|
|
||||||
model_id = model_response.json()["model_info"]["id"]
|
|
||||||
delete_response = await client.post(
|
|
||||||
"/model/delete",
|
|
||||||
json={"id": model_id},
|
|
||||||
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
|
||||||
)
|
|
||||||
assert delete_response.status_code == 200
|
|
312
tests/store_model_in_db_tests/test_team_models.py
Normal file
312
tests/store_model_in_db_tests/test_team_models.py
Normal file
|
@ -0,0 +1,312 @@
|
||||||
|
import pytest
|
||||||
|
import asyncio
|
||||||
|
import aiohttp
|
||||||
|
import json
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
import uuid
|
||||||
|
from httpx import AsyncClient
|
||||||
|
import uuid
|
||||||
|
import os
|
||||||
|
|
||||||
|
TEST_MASTER_KEY = "sk-1234"
|
||||||
|
PROXY_BASE_URL = "http://0.0.0.0:4000"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_model_alias():
|
||||||
|
"""
|
||||||
|
Test model alias functionality with teams:
|
||||||
|
1. Add a new model with model_name="gpt-4-team1" and litellm_params.model="gpt-4o"
|
||||||
|
2. Create a new team
|
||||||
|
3. Update team with model_alias mapping
|
||||||
|
4. Generate key for team
|
||||||
|
5. Make request with aliased model name
|
||||||
|
"""
|
||||||
|
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||||
|
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||||
|
|
||||||
|
# Add new model
|
||||||
|
model_response = await client.post(
|
||||||
|
"/model/new",
|
||||||
|
json={
|
||||||
|
"model_name": "gpt-4o-team1",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4o",
|
||||||
|
"api_key": os.getenv("OPENAI_API_KEY"),
|
||||||
|
},
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert model_response.status_code == 200
|
||||||
|
|
||||||
|
# Create new team
|
||||||
|
team_response = await client.post(
|
||||||
|
"/team/new",
|
||||||
|
json={
|
||||||
|
"models": ["gpt-4o-team1"],
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert team_response.status_code == 200
|
||||||
|
team_data = team_response.json()
|
||||||
|
team_id = team_data["team_id"]
|
||||||
|
|
||||||
|
# Update team with model alias
|
||||||
|
update_response = await client.post(
|
||||||
|
"/team/update",
|
||||||
|
json={"team_id": team_id, "model_aliases": {"gpt-4o": "gpt-4o-team1"}},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert update_response.status_code == 200
|
||||||
|
|
||||||
|
# Generate key for team
|
||||||
|
key_response = await client.post(
|
||||||
|
"/key/generate", json={"team_id": team_id}, headers=headers
|
||||||
|
)
|
||||||
|
assert key_response.status_code == 200
|
||||||
|
key = key_response.json()["key"]
|
||||||
|
|
||||||
|
# Make request with model alias
|
||||||
|
openai_client = AsyncOpenAI(api_key=key, base_url=f"{PROXY_BASE_URL}/v1")
|
||||||
|
|
||||||
|
response = await openai_client.chat.completions.create(
|
||||||
|
model="gpt-4o",
|
||||||
|
messages=[{"role": "user", "content": f"Test message {uuid.uuid4()}"}],
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response is not None, "Should get valid response when using model alias"
|
||||||
|
|
||||||
|
# Cleanup - delete the model
|
||||||
|
model_id = model_response.json()["model_info"]["id"]
|
||||||
|
delete_response = await client.post(
|
||||||
|
"/model/delete",
|
||||||
|
json={"id": model_id},
|
||||||
|
headers={"Authorization": f"Bearer {TEST_MASTER_KEY}"},
|
||||||
|
)
|
||||||
|
assert delete_response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_model_association():
|
||||||
|
"""
|
||||||
|
Test that models created with a team_id are properly associated with the team:
|
||||||
|
1. Create a new team
|
||||||
|
2. Add a model with team_id in model_info
|
||||||
|
3. Verify the model appears in team info
|
||||||
|
"""
|
||||||
|
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||||
|
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||||
|
|
||||||
|
# Create new team
|
||||||
|
team_response = await client.post(
|
||||||
|
"/team/new",
|
||||||
|
json={
|
||||||
|
"models": [], # Start with empty model list
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert team_response.status_code == 200
|
||||||
|
team_data = team_response.json()
|
||||||
|
team_id = team_data["team_id"]
|
||||||
|
|
||||||
|
# Add new model with team_id
|
||||||
|
model_response = await client.post(
|
||||||
|
"/model/new",
|
||||||
|
json={
|
||||||
|
"model_name": "gpt-4-team-test",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4",
|
||||||
|
"custom_llm_provider": "openai",
|
||||||
|
"api_key": "fake_key",
|
||||||
|
},
|
||||||
|
"model_info": {"team_id": team_id},
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert model_response.status_code == 200
|
||||||
|
|
||||||
|
# Get team info and verify model association
|
||||||
|
team_info_response = await client.get(
|
||||||
|
f"/team/info",
|
||||||
|
headers=headers,
|
||||||
|
params={"team_id": team_id},
|
||||||
|
)
|
||||||
|
assert team_info_response.status_code == 200
|
||||||
|
team_info = team_info_response.json()["team_info"]
|
||||||
|
|
||||||
|
print("team_info", json.dumps(team_info, indent=4))
|
||||||
|
|
||||||
|
# Verify the model is in team_models
|
||||||
|
assert (
|
||||||
|
"gpt-4-team-test" in team_info["models"]
|
||||||
|
), "Model should be associated with team"
|
||||||
|
|
||||||
|
# Cleanup - delete the model
|
||||||
|
model_id = model_response.json()["model_info"]["id"]
|
||||||
|
delete_response = await client.post(
|
||||||
|
"/model/delete",
|
||||||
|
json={"id": model_id},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert delete_response.status_code == 200
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_model_visibility_in_models_endpoint():
|
||||||
|
"""
|
||||||
|
Test that team-specific models are only visible to the correct team in /models endpoint:
|
||||||
|
1. Create two teams
|
||||||
|
2. Add a model associated with team1
|
||||||
|
3. Generate keys for both teams
|
||||||
|
4. Verify team1's key can see the model in /models
|
||||||
|
5. Verify team2's key cannot see the model in /models
|
||||||
|
"""
|
||||||
|
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||||
|
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||||
|
|
||||||
|
# Create team1
|
||||||
|
team1_response = await client.post(
|
||||||
|
"/team/new",
|
||||||
|
json={"models": []},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert team1_response.status_code == 200
|
||||||
|
team1_id = team1_response.json()["team_id"]
|
||||||
|
|
||||||
|
# Create team2
|
||||||
|
team2_response = await client.post(
|
||||||
|
"/team/new",
|
||||||
|
json={"models": []},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert team2_response.status_code == 200
|
||||||
|
team2_id = team2_response.json()["team_id"]
|
||||||
|
|
||||||
|
# Add model associated with team1
|
||||||
|
model_response = await client.post(
|
||||||
|
"/model/new",
|
||||||
|
json={
|
||||||
|
"model_name": "gpt-4-team-test",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4",
|
||||||
|
"custom_llm_provider": "openai",
|
||||||
|
"api_key": "fake_key",
|
||||||
|
},
|
||||||
|
"model_info": {"team_id": team1_id},
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert model_response.status_code == 200
|
||||||
|
|
||||||
|
# Generate keys for both teams
|
||||||
|
team1_key = (
|
||||||
|
await client.post("/key/generate", json={"team_id": team1_id}, headers=headers)
|
||||||
|
).json()["key"]
|
||||||
|
team2_key = (
|
||||||
|
await client.post("/key/generate", json={"team_id": team2_id}, headers=headers)
|
||||||
|
).json()["key"]
|
||||||
|
|
||||||
|
# Check models visibility for team1's key
|
||||||
|
team1_models = await client.get(
|
||||||
|
"/models", headers={"Authorization": f"Bearer {team1_key}"}
|
||||||
|
)
|
||||||
|
assert team1_models.status_code == 200
|
||||||
|
print("team1_models", json.dumps(team1_models.json(), indent=4))
|
||||||
|
assert any(
|
||||||
|
model["id"] == "gpt-4-team-test" for model in team1_models.json()["data"]
|
||||||
|
), "Team1 should see their model"
|
||||||
|
|
||||||
|
# Check models visibility for team2's key
|
||||||
|
team2_models = await client.get(
|
||||||
|
"/models", headers={"Authorization": f"Bearer {team2_key}"}
|
||||||
|
)
|
||||||
|
assert team2_models.status_code == 200
|
||||||
|
print("team2_models", json.dumps(team2_models.json(), indent=4))
|
||||||
|
assert not any(
|
||||||
|
model["id"] == "gpt-4-team-test" for model in team2_models.json()["data"]
|
||||||
|
), "Team2 should not see team1's model"
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
model_id = model_response.json()["model_info"]["id"]
|
||||||
|
await client.post("/model/delete", json={"id": model_id}, headers=headers)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_team_model_visibility_in_model_info_endpoint():
|
||||||
|
"""
|
||||||
|
Test that team-specific models are visible to all users in /v2/model/info endpoint:
|
||||||
|
Note: /v2/model/info is used by the Admin UI to display model info
|
||||||
|
1. Create a team
|
||||||
|
2. Add a model associated with the team
|
||||||
|
3. Generate a team key
|
||||||
|
4. Verify both team key and non-team key can see the model in /v2/model/info
|
||||||
|
"""
|
||||||
|
client = AsyncClient(base_url=PROXY_BASE_URL)
|
||||||
|
headers = {"Authorization": f"Bearer {TEST_MASTER_KEY}"}
|
||||||
|
|
||||||
|
# Create team
|
||||||
|
team_response = await client.post(
|
||||||
|
"/team/new",
|
||||||
|
json={"models": []},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert team_response.status_code == 200
|
||||||
|
team_id = team_response.json()["team_id"]
|
||||||
|
|
||||||
|
# Add model associated with team
|
||||||
|
model_response = await client.post(
|
||||||
|
"/model/new",
|
||||||
|
json={
|
||||||
|
"model_name": "gpt-4-team-test",
|
||||||
|
"litellm_params": {
|
||||||
|
"model": "gpt-4",
|
||||||
|
"custom_llm_provider": "openai",
|
||||||
|
"api_key": "fake_key",
|
||||||
|
},
|
||||||
|
"model_info": {"team_id": team_id},
|
||||||
|
},
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
assert model_response.status_code == 200
|
||||||
|
|
||||||
|
# Generate team key
|
||||||
|
team_key = (
|
||||||
|
await client.post("/key/generate", json={"team_id": team_id}, headers=headers)
|
||||||
|
).json()["key"]
|
||||||
|
|
||||||
|
# Generate non-team key
|
||||||
|
non_team_key = (
|
||||||
|
await client.post("/key/generate", json={}, headers=headers)
|
||||||
|
).json()["key"]
|
||||||
|
|
||||||
|
# Check model info visibility with team key
|
||||||
|
team_model_info = await client.get(
|
||||||
|
"/v2/model/info",
|
||||||
|
headers={"Authorization": f"Bearer {team_key}"},
|
||||||
|
params={"model_name": "gpt-4-team-test"},
|
||||||
|
)
|
||||||
|
assert team_model_info.status_code == 200
|
||||||
|
team_model_info = team_model_info.json()
|
||||||
|
print("Team 1 model info", json.dumps(team_model_info, indent=4))
|
||||||
|
assert any(
|
||||||
|
model["model_info"].get("team_public_model_name") == "gpt-4-team-test"
|
||||||
|
for model in team_model_info["data"]
|
||||||
|
), "Team1 should see their model"
|
||||||
|
|
||||||
|
# Check model info visibility with non-team key
|
||||||
|
non_team_model_info = await client.get(
|
||||||
|
"/v2/model/info",
|
||||||
|
headers={"Authorization": f"Bearer {non_team_key}"},
|
||||||
|
params={"model_name": "gpt-4-team-test"},
|
||||||
|
)
|
||||||
|
assert non_team_model_info.status_code == 200
|
||||||
|
non_team_model_info = non_team_model_info.json()
|
||||||
|
print("Non-team model info", json.dumps(non_team_model_info, indent=4))
|
||||||
|
assert any(
|
||||||
|
model["model_info"].get("team_public_model_name") == "gpt-4-team-test"
|
||||||
|
for model in non_team_model_info["data"]
|
||||||
|
), "Non-team should see the model"
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
model_id = model_response.json()["model_info"]["id"]
|
||||||
|
await client.post("/model/delete", json={"id": model_id}, headers=headers)
|
|
@ -293,7 +293,13 @@ const CreateKey: React.FC<CreateKeyProps> = ({
|
||||||
initialValue={team ? team.team_id : null}
|
initialValue={team ? team.team_id : null}
|
||||||
className="mt-8"
|
className="mt-8"
|
||||||
>
|
>
|
||||||
<TeamDropdown teams={teams} />
|
<TeamDropdown
|
||||||
|
teams={teams}
|
||||||
|
onChange={(teamId) => {
|
||||||
|
const selectedTeam = teams?.find(t => t.team_id === teamId) || null;
|
||||||
|
setSelectedCreateKeyTeam(selectedTeam);
|
||||||
|
}}
|
||||||
|
/>
|
||||||
</Form.Item>
|
</Form.Item>
|
||||||
|
|
||||||
<Form.Item
|
<Form.Item
|
||||||
|
|
|
@ -1,174 +0,0 @@
|
||||||
import React, { useState } from "react";
|
|
||||||
import {
|
|
||||||
Card,
|
|
||||||
Title,
|
|
||||||
Text,
|
|
||||||
Button as TremorButton,
|
|
||||||
} from "@tremor/react";
|
|
||||||
import { Modal, Form, Select, Input, message } from "antd";
|
|
||||||
import { teamUpdateCall } from "@/components/networking";
|
|
||||||
|
|
||||||
interface ModelAliasesCardProps {
|
|
||||||
teamId: string;
|
|
||||||
accessToken: string | null;
|
|
||||||
currentAliases: Record<string, string>;
|
|
||||||
availableModels: string[];
|
|
||||||
onUpdate: () => void;
|
|
||||||
}
|
|
||||||
|
|
||||||
const ModelAliasesCard: React.FC<ModelAliasesCardProps> = ({
|
|
||||||
teamId,
|
|
||||||
accessToken,
|
|
||||||
currentAliases,
|
|
||||||
availableModels,
|
|
||||||
onUpdate,
|
|
||||||
}) => {
|
|
||||||
const [isModalVisible, setIsModalVisible] = useState(false);
|
|
||||||
const [form] = Form.useForm();
|
|
||||||
|
|
||||||
const handleCreateAlias = async (values: any) => {
|
|
||||||
try {
|
|
||||||
if (!accessToken) return;
|
|
||||||
|
|
||||||
const newAliases = {
|
|
||||||
...currentAliases,
|
|
||||||
[values.alias_name]: values.original_model,
|
|
||||||
};
|
|
||||||
|
|
||||||
const updateData = {
|
|
||||||
team_id: teamId,
|
|
||||||
model_aliases: newAliases,
|
|
||||||
};
|
|
||||||
|
|
||||||
await teamUpdateCall(accessToken, updateData);
|
|
||||||
message.success("Model alias created successfully");
|
|
||||||
setIsModalVisible(false);
|
|
||||||
form.resetFields();
|
|
||||||
currentAliases[values.alias_name] = values.original_model;
|
|
||||||
} catch (error) {
|
|
||||||
message.error("Failed to create model alias");
|
|
||||||
console.error("Error creating model alias:", error);
|
|
||||||
}
|
|
||||||
};
|
|
||||||
|
|
||||||
return (
|
|
||||||
<div className="mt-8">
|
|
||||||
<Title>Team Aliases</Title>
|
|
||||||
<Text className="text-gray-600 mb-4">
|
|
||||||
Allow a team to use an alias that points to a specific model deployment.
|
|
||||||
|
|
||||||
</Text>
|
|
||||||
|
|
||||||
<div className="bg-white rounded-lg p-6 border border-gray-200">
|
|
||||||
<div className="flex justify-between items-center mb-6">
|
|
||||||
<div>
|
|
||||||
<div className="flex space-x-4 text-gray-600">
|
|
||||||
<div className="w-64">ALIAS</div>
|
|
||||||
<div>POINTS TO</div>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
<TremorButton
|
|
||||||
size="md"
|
|
||||||
variant="primary"
|
|
||||||
onClick={() => setIsModalVisible(true)}
|
|
||||||
>
|
|
||||||
Create Model Alias
|
|
||||||
</TremorButton>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<div className="space-y-4">
|
|
||||||
{Object.entries(currentAliases).map(([aliasName, originalModel], index) => (
|
|
||||||
<div key={index} className="flex space-x-4 border-t border-gray-200 pt-4">
|
|
||||||
<div className="w-64">
|
|
||||||
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
|
|
||||||
{aliasName}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
<div>
|
|
||||||
<span className="bg-gray-100 px-2 py-1 rounded font-mono text-sm text-gray-700">
|
|
||||||
{originalModel}
|
|
||||||
</span>
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
))}
|
|
||||||
{Object.keys(currentAliases).length === 0 && (
|
|
||||||
<div className="text-gray-500 text-center py-4 border-t border-gray-200">
|
|
||||||
No model aliases configured
|
|
||||||
</div>
|
|
||||||
)}
|
|
||||||
</div>
|
|
||||||
</div>
|
|
||||||
|
|
||||||
<Modal
|
|
||||||
title="Create Model Alias"
|
|
||||||
open={isModalVisible}
|
|
||||||
onCancel={() => {
|
|
||||||
setIsModalVisible(false);
|
|
||||||
form.resetFields();
|
|
||||||
}}
|
|
||||||
footer={null}
|
|
||||||
width={500}
|
|
||||||
>
|
|
||||||
<Form
|
|
||||||
form={form}
|
|
||||||
onFinish={handleCreateAlias}
|
|
||||||
layout="vertical"
|
|
||||||
className="mt-4"
|
|
||||||
>
|
|
||||||
<Form.Item
|
|
||||||
label="Alias Name"
|
|
||||||
name="alias_name"
|
|
||||||
rules={[{ required: true, message: "Please enter an alias name" }]}
|
|
||||||
>
|
|
||||||
<Input
|
|
||||||
placeholder="Enter the model alias (e.g., gpt-4o)"
|
|
||||||
type=""
|
|
||||||
/>
|
|
||||||
</Form.Item>
|
|
||||||
|
|
||||||
<Form.Item
|
|
||||||
label="Points To"
|
|
||||||
name="original_model"
|
|
||||||
rules={[{ required: true, message: "Please select a model" }]}
|
|
||||||
>
|
|
||||||
<Select
|
|
||||||
placeholder="Select model version"
|
|
||||||
className="w-full font-mono"
|
|
||||||
showSearch
|
|
||||||
optionFilterProp="children"
|
|
||||||
>
|
|
||||||
{availableModels.map((model) => (
|
|
||||||
<Select.Option key={model} value={model} className="font-mono">
|
|
||||||
{model}
|
|
||||||
</Select.Option>
|
|
||||||
))}
|
|
||||||
</Select>
|
|
||||||
</Form.Item>
|
|
||||||
|
|
||||||
<div className="flex justify-end gap-2 mt-6">
|
|
||||||
<TremorButton
|
|
||||||
size="md"
|
|
||||||
variant="secondary"
|
|
||||||
onClick={() => {
|
|
||||||
setIsModalVisible(false);
|
|
||||||
form.resetFields();
|
|
||||||
}}
|
|
||||||
className="bg-white text-gray-700 border border-gray-300 hover:bg-gray-50"
|
|
||||||
>
|
|
||||||
Cancel
|
|
||||||
</TremorButton>
|
|
||||||
<TremorButton
|
|
||||||
size="md"
|
|
||||||
variant="secondary"
|
|
||||||
type="submit"
|
|
||||||
>
|
|
||||||
Create Alias
|
|
||||||
</TremorButton>
|
|
||||||
</div>
|
|
||||||
</Form>
|
|
||||||
</Modal>
|
|
||||||
</div>
|
|
||||||
);
|
|
||||||
};
|
|
||||||
|
|
||||||
export default ModelAliasesCard;
|
|
|
@ -474,14 +474,6 @@ const TeamInfoView: React.FC<TeamInfoProps> = ({
|
||||||
</div>
|
</div>
|
||||||
)}
|
)}
|
||||||
</Card>
|
</Card>
|
||||||
|
|
||||||
<ModelAliasesCard
|
|
||||||
teamId={teamId}
|
|
||||||
accessToken={accessToken}
|
|
||||||
currentAliases={teamData?.team_info?.litellm_model_table?.model_aliases || {}}
|
|
||||||
availableModels={userModels}
|
|
||||||
onUpdate={fetchTeamInfo}
|
|
||||||
/>
|
|
||||||
</TabPanel>
|
</TabPanel>
|
||||||
</TabPanels>
|
</TabPanels>
|
||||||
</TabGroup>
|
</TabGroup>
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue