mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-24 18:24:20 +00:00
fix(model_management_endpoints.py): fix allowing team admins to update team models (#9697)
* fix(model_management_endpoints.py): fix allowing team admins to update their models * test(test_models.py): add e2e test to for team model flow ensure team admin can always add / edit / delete team models
This commit is contained in:
parent
3d0313b15b
commit
6c69ad4c89
3 changed files with 115 additions and 33 deletions
File diff suppressed because one or more lines are too long
|
@ -394,7 +394,7 @@ class ModelManagementAuthChecks:
|
|||
|
||||
@staticmethod
|
||||
async def can_user_make_model_call(
|
||||
model_params: Union[Deployment, updateDeployment],
|
||||
model_params: Deployment,
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
prisma_client: PrismaClient,
|
||||
premium_user: bool,
|
||||
|
@ -723,8 +723,38 @@ async def update_model(
|
|||
},
|
||||
)
|
||||
|
||||
_model_id = None
|
||||
_model_info = getattr(model_params, "model_info", None)
|
||||
if _model_info is None:
|
||||
raise Exception("model_info not provided")
|
||||
|
||||
_model_id = _model_info.id
|
||||
if _model_id is None:
|
||||
raise Exception("model_info.id not provided")
|
||||
|
||||
_existing_litellm_params = (
|
||||
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": _model_id}
|
||||
)
|
||||
)
|
||||
|
||||
if _existing_litellm_params is None:
|
||||
if (
|
||||
llm_router is not None
|
||||
and llm_router.get_deployment(model_id=_model_id) is not None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Can't edit model. Model in config. Store model in db via `/model/new`. to edit."
|
||||
},
|
||||
)
|
||||
else:
|
||||
raise Exception("model not found")
|
||||
deployment = Deployment(**_existing_litellm_params.model_dump())
|
||||
|
||||
await ModelManagementAuthChecks.can_user_make_model_call(
|
||||
model_params=model_params,
|
||||
model_params=deployment,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
prisma_client=prisma_client,
|
||||
premium_user=premium_user,
|
||||
|
@ -732,31 +762,6 @@ async def update_model(
|
|||
|
||||
# update DB
|
||||
if store_model_in_db is True:
|
||||
_model_id = None
|
||||
_model_info = getattr(model_params, "model_info", None)
|
||||
if _model_info is None:
|
||||
raise Exception("model_info not provided")
|
||||
|
||||
_model_id = _model_info.id
|
||||
if _model_id is None:
|
||||
raise Exception("model_info.id not provided")
|
||||
_existing_litellm_params = (
|
||||
await prisma_client.db.litellm_proxymodeltable.find_unique(
|
||||
where={"model_id": _model_id}
|
||||
)
|
||||
)
|
||||
if _existing_litellm_params is None:
|
||||
if (
|
||||
llm_router is not None
|
||||
and llm_router.get_deployment(model_id=_model_id) is not None
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": "Can't edit model. Model in config. Store model in db via `/model/new`. to edit."
|
||||
},
|
||||
)
|
||||
raise Exception("model not found")
|
||||
_existing_litellm_params_dict = dict(
|
||||
_existing_litellm_params.litellm_params
|
||||
)
|
||||
|
|
|
@ -10,7 +10,6 @@ from dotenv import load_dotenv
|
|||
|
||||
load_dotenv()
|
||||
|
||||
|
||||
async def generate_key(session, models=[]):
|
||||
url = "http://0.0.0.0:4000/key/generate"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
|
@ -58,10 +57,43 @@ async def test_get_models():
|
|||
await get_models(session=session, key=key)
|
||||
|
||||
|
||||
async def add_models(session, model_id="123", model_name="azure-gpt-3.5"):
|
||||
async def add_models(session, model_id="123", model_name="azure-gpt-3.5", key="sk-1234", team_id=None):
|
||||
url = "http://0.0.0.0:4000/model/new"
|
||||
headers = {
|
||||
"Authorization": f"Bearer sk-1234",
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
data = {
|
||||
"model_name": model_name,
|
||||
"litellm_params": {
|
||||
"model": "azure/chatgpt-v-2",
|
||||
"api_key": "os.environ/AZURE_API_KEY",
|
||||
"api_base": "https://openai-gpt-4-test-v-1.openai.azure.com/",
|
||||
"api_version": "2023-05-15",
|
||||
},
|
||||
"model_info": {"id": model_id},
|
||||
}
|
||||
|
||||
if team_id:
|
||||
data["model_info"]["team_id"] = team_id
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
status = response.status
|
||||
response_text = await response.text()
|
||||
print(f"Add models {response_text}")
|
||||
print()
|
||||
|
||||
if status != 200:
|
||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||
|
||||
response_json = await response.json()
|
||||
return response_json
|
||||
|
||||
async def update_model(session, model_id="123", model_name="azure-gpt-3.5", key="sk-1234"):
|
||||
url = "http://0.0.0.0:4000/model/update"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
|
@ -199,13 +231,13 @@ async def test_get_specific_model():
|
|||
)
|
||||
|
||||
|
||||
async def delete_model(session, model_id="123"):
|
||||
async def delete_model(session, model_id="123", key="sk-1234"):
|
||||
"""
|
||||
Make sure only models user has access to are returned
|
||||
"""
|
||||
url = "http://0.0.0.0:4000/model/delete"
|
||||
headers = {
|
||||
"Authorization": f"Bearer sk-1234",
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {"id": model_id}
|
||||
|
@ -441,3 +473,49 @@ async def test_model_group_info_e2e():
|
|||
has_anthropic_claude_3_opus = True
|
||||
|
||||
assert has_anthropic_claude_3_5_haiku and has_anthropic_claude_3_opus
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_team_model_e2e():
|
||||
"""
|
||||
Test team model e2e
|
||||
|
||||
- create team
|
||||
- create user
|
||||
- add user to team as admin
|
||||
- add model to team
|
||||
- update model
|
||||
- delete model
|
||||
"""
|
||||
from test_users import new_user
|
||||
from test_team import new_team
|
||||
import uuid
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Creat a user
|
||||
user_data = await new_user(session=session, i=0)
|
||||
user_id = user_data["user_id"]
|
||||
user_api_key = user_data["key"]
|
||||
|
||||
# Create a team
|
||||
member_list = [
|
||||
{"role": "admin", "user_id": user_id},
|
||||
]
|
||||
team_data = await new_team(session=session, member_list=member_list, i=0)
|
||||
team_id = team_data["team_id"]
|
||||
|
||||
model_id = str(uuid.uuid4())
|
||||
model_name = "my-test-model"
|
||||
# Add model to team
|
||||
model_data = await add_models(session=session, model_id=model_id, model_name=model_name, key=user_api_key, team_id=team_id)
|
||||
model_id = model_data["model_id"]
|
||||
|
||||
# Update model
|
||||
model_data = await update_model(session=session, model_id=model_id, model_name=model_name, key=user_api_key)
|
||||
model_id = model_data["model_id"]
|
||||
|
||||
# Delete model
|
||||
await delete_model(session=session, model_id=model_id, key=user_api_key)
|
||||
|
||||
|
||||
|
||||
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue