fix(proxy_server.py): fix model check for /v1/models endpoint when team has restricted access

This commit is contained in:
Krrish Dholakia 2024-05-25 13:02:03 -07:00
parent 3c961136ea
commit 25a2f00db6
5 changed files with 134 additions and 36 deletions

View file

@ -2,7 +2,7 @@
## Tests /key endpoints.
import pytest
import asyncio, time
import asyncio, time, uuid
import aiohttp
from openai import AsyncOpenAI
import sys, os
@ -14,12 +14,14 @@ sys.path.insert(
import litellm
async def generate_team(session):
async def generate_team(
session, models: Optional[list] = None, team_id: Optional[str] = None
):
url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"team_id": "litellm-dashboard",
}
if team_id is None:
team_id = "litellm-dashboard"
data = {"team_id": team_id, "models": models}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
@ -746,19 +748,25 @@ async def test_key_delete_ui():
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
@pytest.mark.parametrize("model_access_level", ["key", "team"])
@pytest.mark.asyncio
async def test_key_model_list(model_access):
async def test_key_model_list(model_access, model_access_level):
"""
Test if `/v1/models` works as expected.
"""
async with aiohttp.ClientSession() as session:
new_team = await generate_team(session=session)
team_id = "litellm-dashboard"
_models = [] if model_access == "all-team-models" else [model_access]
team_id = "litellm_dashboard_{}".format(uuid.uuid4())
new_team = await generate_team(
session=session,
models=_models if model_access_level == "team" else None,
team_id=team_id,
)
key_gen = await generate_key(
session=session,
i=0,
team_id=team_id,
models=[] if model_access == "all-team-models" else [model_access],
models=_models if model_access_level == "key" else [],
)
key = key_gen["key"]
print(f"key: {key}")
@ -770,5 +778,9 @@ async def test_key_model_list(model_access):
assert not isinstance(model_list["data"][0]["id"], list)
assert isinstance(model_list["data"][0]["id"], str)
if model_access == "gpt-3.5-turbo":
assert len(model_list["data"]) == 1
assert (
len(model_list["data"]) == 1
), "model_access={}, model_access_level={}".format(
model_access, model_access_level
)
assert model_list["data"][0]["id"] == model_access