forked from phoenix/litellm-mirror
fix(proxy_server.py): fix model check for /v1/models
endpoint when team has restricted access
This commit is contained in:
parent
3c961136ea
commit
25a2f00db6
5 changed files with 134 additions and 36 deletions
|
@ -2,7 +2,7 @@
|
|||
## Tests /key endpoints.
|
||||
|
||||
import pytest
|
||||
import asyncio, time
|
||||
import asyncio, time, uuid
|
||||
import aiohttp
|
||||
from openai import AsyncOpenAI
|
||||
import sys, os
|
||||
|
@ -14,12 +14,14 @@ sys.path.insert(
|
|||
import litellm
|
||||
|
||||
|
||||
async def generate_team(session):
|
||||
async def generate_team(
|
||||
session, models: Optional[list] = None, team_id: Optional[str] = None
|
||||
):
|
||||
url = "http://0.0.0.0:4000/team/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {
|
||||
"team_id": "litellm-dashboard",
|
||||
}
|
||||
if team_id is None:
|
||||
team_id = "litellm-dashboard"
|
||||
data = {"team_id": team_id, "models": models}
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
status = response.status
|
||||
|
@ -746,19 +748,25 @@ async def test_key_delete_ui():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
|
||||
@pytest.mark.parametrize("model_access_level", ["key", "team"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_model_list(model_access):
|
||||
async def test_key_model_list(model_access, model_access_level):
|
||||
"""
|
||||
Test if `/v1/models` works as expected.
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
new_team = await generate_team(session=session)
|
||||
team_id = "litellm-dashboard"
|
||||
_models = [] if model_access == "all-team-models" else [model_access]
|
||||
team_id = "litellm_dashboard_{}".format(uuid.uuid4())
|
||||
new_team = await generate_team(
|
||||
session=session,
|
||||
models=_models if model_access_level == "team" else None,
|
||||
team_id=team_id,
|
||||
)
|
||||
key_gen = await generate_key(
|
||||
session=session,
|
||||
i=0,
|
||||
team_id=team_id,
|
||||
models=[] if model_access == "all-team-models" else [model_access],
|
||||
models=_models if model_access_level == "key" else [],
|
||||
)
|
||||
key = key_gen["key"]
|
||||
print(f"key: {key}")
|
||||
|
@ -770,5 +778,9 @@ async def test_key_model_list(model_access):
|
|||
assert not isinstance(model_list["data"][0]["id"], list)
|
||||
assert isinstance(model_list["data"][0]["id"], str)
|
||||
if model_access == "gpt-3.5-turbo":
|
||||
assert len(model_list["data"]) == 1
|
||||
assert (
|
||||
len(model_list["data"]) == 1
|
||||
), "model_access={}, model_access_level={}".format(
|
||||
model_access, model_access_level
|
||||
)
|
||||
assert model_list["data"][0]["id"] == model_access
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue