fix(proxy_server.py): fix /v1/models bug where it would return empty list

handle 'all-team-models' being set for a given key
This commit is contained in:
Krrish Dholakia 2024-05-07 13:42:58 -07:00
parent 8e5437c8e9
commit f2766fddbf
2 changed files with 61 additions and 1 deletions

View file

@ -227,6 +227,10 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
global_proxy_spend = 60
class SpecialModelNames(enum.Enum):
all_team_models = "all-team-models"
@app.exception_handler(ProxyException)
async def openai_exception_handler(request: Request, exc: ProxyException):
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
@ -3436,7 +3440,9 @@ def model_list(
all_models = []
if len(user_api_key_dict.models) > 0:
all_models = user_api_key_dict.models
else:
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if len(all_models) == 0: # has all proxy models
## if no specific model access
if general_settings.get("infer_model_from_keys", False):
all_models = litellm.utils.get_valid_models()

View file

@ -62,6 +62,7 @@ async def generate_key(
models=["azure-models", "gpt-4", "dall-e-3"],
max_parallel_requests: Optional[int] = None,
user_id: Optional[str] = None,
team_id: Optional[str] = None,
calling_key="sk-1234",
):
url = "http://0.0.0.0:4000/key/generate"
@ -77,6 +78,7 @@ async def generate_key(
"budget_duration": budget_duration,
"max_parallel_requests": max_parallel_requests,
"user_id": user_id,
"team_id": team_id,
}
print(f"data: {data}")
@ -355,6 +357,29 @@ async def get_key_info(session, call_key, get_key=None):
return await response.json()
async def get_model_list(session, call_key):
"""
Make sure only models user has access to are returned
"""
url = "http://0.0.0.0:4000/v1/models"
headers = {
"Authorization": f"Bearer {call_key}",
"Content-Type": "application/json",
}
async with session.get(url, headers=headers) as response:
status = response.status
response_text = await response.text()
print(response_text)
print()
if status != 200:
raise Exception(
f"Request did not return a 200 status code: {status}. Responses {response_text}"
)
return await response.json()
async def get_model_info(session, call_key):
"""
Make sure only models user has access to are returned
@ -719,3 +744,32 @@ async def test_key_delete_ui():
get_key=key,
auth_key=admin_ui_key["key"],
)
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
@pytest.mark.asyncio
async def test_key_model_list(model_access):
"""
Test if `/v1/models` works as expected.
"""
async with aiohttp.ClientSession() as session:
new_team = await generate_team(session=session)
team_id = new_team["team_id"]
key_gen = await generate_key(
session=session,
i=0,
team_id=team_id,
models=[] if model_access == "all-team-models" else [model_access],
)
key = key_gen["key"]
print(f"key: {key}")
model_list = await get_model_list(session=session, call_key=key)
print(f"model_list: {model_list}")
if model_access == "all-team-models":
assert not isinstance(model_list["data"][0]["id"], list)
assert isinstance(model_list["data"][0]["id"], str)
if model_access == "gpt-3.5-turbo":
assert len(model_list["data"]) == 1
assert model_list["data"][0]["id"] == model_access