fix(proxy_server.py): fix /v1/models bug where it would return empty list

handle 'all-team-models' being set for a given key
This commit is contained in:
Krrish Dholakia 2024-05-07 13:42:58 -07:00
parent 8e5437c8e9
commit f2766fddbf
2 changed files with 61 additions and 1 deletions

View file

@ -62,6 +62,7 @@ async def generate_key(
models=["azure-models", "gpt-4", "dall-e-3"],
max_parallel_requests: Optional[int] = None,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
calling_key="sk-1234",
):
url = "http://0.0.0.0:4000/key/generate"
@ -77,6 +78,7 @@ async def generate_key(
"budget_duration": budget_duration,
"max_parallel_requests": max_parallel_requests,
"user_id": user_id,
"team_id": team_id,
}
print(f"data: {data}")
@ -355,6 +357,29 @@ async def get_key_info(session, call_key, get_key=None):
return await response.json()
async def get_model_list(session, call_key):
"""
Make sure only models user has access to are returned
"""
url = "http://0.0.0.0:4000/v1/models"
headers = {
"Authorization": f"Bearer {call_key}",
"Content-Type": "application/json",
}
async with session.get(url, headers=headers) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(
f"Request did not return a 200 status code: {status}. Responses {response_text}"
)
return await response.json()
async def get_model_info(session, call_key):
"""
Make sure only models user has access to are returned
@ -719,3 +744,32 @@ async def test_key_delete_ui():
get_key=key,
auth_key=admin_ui_key["key"],
)
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
@pytest.mark.asyncio
async def test_key_model_list(model_access):
"""
Test if `/v1/models` works as expected.
"""
async with aiohttp.ClientSession() as session:
new_team = await generate_team(session=session)
team_id = new_team["team_id"]
key_gen = await generate_key(
session=session,
i=0,
team_id=team_id,
models=[] if model_access == "all-team-models" else [model_access],
)
key = key_gen["key"]
print(f"key: {key}")
model_list = await get_model_list(session=session, call_key=key)
print(f"model_list: {model_list}")
if model_access == "all-team-models":
assert not isinstance(model_list["data"][0]["id"], list)
assert isinstance(model_list["data"][0]["id"], str)
if model_access == "gpt-3.5-turbo":
assert len(model_list["data"]) == 1
assert model_list["data"][0]["id"] == model_access