fix(proxy_server.py): fix model list returned for /model/info when team has restricted access

This commit is contained in:
Krrish Dholakia 2024-05-25 13:21:33 -07:00
parent 25a2f00db6
commit 95566dc249
2 changed files with 46 additions and 16 deletions

View file

@ -359,11 +359,11 @@ async def get_key_info(session, call_key, get_key=None):
return await response.json()
async def get_model_list(session, call_key):
async def get_model_list(session, call_key, endpoint: str = "/v1/models"):
"""
Make sure only models user has access to are returned
"""
url = "http://0.0.0.0:4000/v1/models"
url = "http://0.0.0.0:4000" + endpoint
headers = {
"Authorization": f"Bearer {call_key}",
"Content-Type": "application/json",
@ -749,8 +749,9 @@ async def test_key_delete_ui():
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
@pytest.mark.parametrize("model_access_level", ["key", "team"])
@pytest.mark.parametrize("model_endpoint", ["/v1/models", "/model/info"])
@pytest.mark.asyncio
async def test_key_model_list(model_access, model_access_level):
async def test_key_model_list(model_access, model_access_level, model_endpoint):
"""
Test if `/v1/models` works as expected.
"""
@ -771,16 +772,26 @@ async def test_key_model_list(model_access, model_access_level):
key = key_gen["key"]
print(f"key: {key}")
model_list = await get_model_list(session=session, call_key=key)
model_list = await get_model_list(
session=session, call_key=key, endpoint=model_endpoint
)
print(f"model_list: {model_list}")
if model_access == "all-team-models":
assert not isinstance(model_list["data"][0]["id"], list)
assert isinstance(model_list["data"][0]["id"], str)
if model_endpoint == "/v1/models":
assert not isinstance(model_list["data"][0]["id"], list)
assert isinstance(model_list["data"][0]["id"], str)
elif model_endpoint == "/model/info":
assert isinstance(model_list["data"], list)
assert len(model_list["data"]) > 0
if model_access == "gpt-3.5-turbo":
assert (
len(model_list["data"]) == 1
), "model_access={}, model_access_level={}".format(
model_access, model_access_level
)
assert model_list["data"][0]["id"] == model_access
if model_endpoint == "/v1/models":
assert (
len(model_list["data"]) == 1
), "model_access={}, model_access_level={}".format(
model_access, model_access_level
)
assert model_list["data"][0]["id"] == model_access
elif model_endpoint == "/model/info":
assert isinstance(model_list["data"], list)
assert len(model_list["data"]) == 1