fix(model_management_endpoints.py): fix allowing team admins to update team models (#9697)
All checks were successful
Read Version from pyproject.toml / read-version (push) Successful in 17s
Helm unit test / unit-test (push) Successful in 22s

* fix(model_management_endpoints.py): fix allowing team admins to update their models

* test(test_models.py): add e2e test to for team model flow

ensure team admin can always add / edit / delete team models
This commit is contained in:
Krish Dholakia 2025-04-01 22:28:15 -07:00 committed by GitHub
parent 3d0313b15b
commit 6c69ad4c89
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
3 changed files with 115 additions and 33 deletions

File diff suppressed because one or more lines are too long

View file

@ -394,7 +394,7 @@ class ModelManagementAuthChecks:
@staticmethod
async def can_user_make_model_call(
model_params: Union[Deployment, updateDeployment],
model_params: Deployment,
user_api_key_dict: UserAPIKeyAuth,
prisma_client: PrismaClient,
premium_user: bool,
@ -723,8 +723,38 @@ async def update_model(
},
)
_model_id = None
_model_info = getattr(model_params, "model_info", None)
if _model_info is None:
raise Exception("model_info not provided")
_model_id = _model_info.id
if _model_id is None:
raise Exception("model_info.id not provided")
_existing_litellm_params = (
await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": _model_id}
)
)
if _existing_litellm_params is None:
if (
llm_router is not None
and llm_router.get_deployment(model_id=_model_id) is not None
):
raise HTTPException(
status_code=400,
detail={
"error": "Can't edit model. Model in config. Store model in db via `/model/new`. to edit."
},
)
else:
raise Exception("model not found")
deployment = Deployment(**_existing_litellm_params.model_dump())
await ModelManagementAuthChecks.can_user_make_model_call(
model_params=model_params,
model_params=deployment,
user_api_key_dict=user_api_key_dict,
prisma_client=prisma_client,
premium_user=premium_user,
@ -732,31 +762,6 @@ async def update_model(
# update DB
if store_model_in_db is True:
_model_id = None
_model_info = getattr(model_params, "model_info", None)
if _model_info is None:
raise Exception("model_info not provided")
_model_id = _model_info.id
if _model_id is None:
raise Exception("model_info.id not provided")
_existing_litellm_params = (
await prisma_client.db.litellm_proxymodeltable.find_unique(
where={"model_id": _model_id}
)
)
if _existing_litellm_params is None:
if (
llm_router is not None
and llm_router.get_deployment(model_id=_model_id) is not None
):
raise HTTPException(
status_code=400,
detail={
"error": "Can't edit model. Model in config. Store model in db via `/model/new`. to edit."
},
)
raise Exception("model not found")
_existing_litellm_params_dict = dict(
_existing_litellm_params.litellm_params
)

View file

@ -10,7 +10,6 @@ from dotenv import load_dotenv
load_dotenv()
async def generate_key(session, models=[]):
url = "http://0.0.0.0:4000/key/generate"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
@ -58,10 +57,43 @@ async def test_get_models():
await get_models(session=session, key=key)
async def add_models(session, model_id="123", model_name="azure-gpt-3.5"):
async def add_models(session, model_id="123", model_name="azure-gpt-3.5", key="sk-1234", team_id=None):
url = "http://0.0.0.0:4000/model/new"
headers = {
"Authorization": f"Bearer sk-1234",
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
data = {
"model_name": model_name,
"litellm_params": {
"model": "azure/chatgpt-v-2",
"api_key": "os.environ/AZURE_API_KEY",
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
"api_version": "2023-05-15",
},
"model_info": {"id": model_id},
}
if team_id:
data["model_info"]["team_id"] = team_id
async with session.post(url, headers=headers, json=data) as response:
status = response.status
response_text = await response.text()
print(f"Add models {response_text}")
print()
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
response_json = await response.json()
return response_json
async def update_model(session, model_id="123", model_name="azure-gpt-3.5", key="sk-1234"):
url = "http://0.0.0.0:4000/model/update"
headers = {
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
@ -199,13 +231,13 @@ async def test_get_specific_model():
)
async def delete_model(session, model_id="123"):
async def delete_model(session, model_id="123", key="sk-1234"):
"""
Make sure only models user has access to are returned
"""
url = "http://0.0.0.0:4000/model/delete"
headers = {
"Authorization": f"Bearer sk-1234",
"Authorization": f"Bearer {key}",
"Content-Type": "application/json",
}
data = {"id": model_id}
@ -441,3 +473,49 @@ async def test_model_group_info_e2e():
has_anthropic_claude_3_opus = True
assert has_anthropic_claude_3_5_haiku and has_anthropic_claude_3_opus
@pytest.mark.asyncio
async def test_team_model_e2e():
"""
Test team model e2e
- create team
- create user
- add user to team as admin
- add model to team
- update model
- delete model
"""
from test_users import new_user
from test_team import new_team
import uuid
async with aiohttp.ClientSession() as session:
# Creat a user
user_data = await new_user(session=session, i=0)
user_id = user_data["user_id"]
user_api_key = user_data["key"]
# Create a team
member_list = [
{"role": "admin", "user_id": user_id},
]
team_data = await new_team(session=session, member_list=member_list, i=0)
team_id = team_data["team_id"]
model_id = str(uuid.uuid4())
model_name = "my-test-model"
# Add model to team
model_data = await add_models(session=session, model_id=model_id, model_name=model_name, key=user_api_key, team_id=team_id)
model_id = model_data["model_id"]
# Update model
model_data = await update_model(session=session, model_id=model_id, model_name=model_name, key=user_api_key)
model_id = model_data["model_id"]
# Delete model
await delete_model(session=session, model_id=model_id, key=user_api_key)