forked from phoenix/litellm-mirror
fix(proxy_server.py): fix /v1/models
bug where it would return empty list
handle 'all-team-models' being set for a given key
This commit is contained in:
parent
8e5437c8e9
commit
f2766fddbf
2 changed files with 61 additions and 1 deletions
|
@ -227,6 +227,10 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
|
||||||
global_proxy_spend = 60
|
global_proxy_spend = 60
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialModelNames(enum.Enum):
|
||||||
|
all_team_models = "all-team-models"
|
||||||
|
|
||||||
|
|
||||||
@app.exception_handler(ProxyException)
|
@app.exception_handler(ProxyException)
|
||||||
async def openai_exception_handler(request: Request, exc: ProxyException):
|
async def openai_exception_handler(request: Request, exc: ProxyException):
|
||||||
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
|
# NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions
|
||||||
|
@ -3436,7 +3440,9 @@ def model_list(
|
||||||
all_models = []
|
all_models = []
|
||||||
if len(user_api_key_dict.models) > 0:
|
if len(user_api_key_dict.models) > 0:
|
||||||
all_models = user_api_key_dict.models
|
all_models = user_api_key_dict.models
|
||||||
else:
|
if SpecialModelNames.all_team_models.value in all_models:
|
||||||
|
all_models = user_api_key_dict.team_models
|
||||||
|
if len(all_models) == 0: # has all proxy models
|
||||||
## if no specific model access
|
## if no specific model access
|
||||||
if general_settings.get("infer_model_from_keys", False):
|
if general_settings.get("infer_model_from_keys", False):
|
||||||
all_models = litellm.utils.get_valid_models()
|
all_models = litellm.utils.get_valid_models()
|
||||||
|
|
|
@ -62,6 +62,7 @@ async def generate_key(
|
||||||
models=["azure-models", "gpt-4", "dall-e-3"],
|
models=["azure-models", "gpt-4", "dall-e-3"],
|
||||||
max_parallel_requests: Optional[int] = None,
|
max_parallel_requests: Optional[int] = None,
|
||||||
user_id: Optional[str] = None,
|
user_id: Optional[str] = None,
|
||||||
|
team_id: Optional[str] = None,
|
||||||
calling_key="sk-1234",
|
calling_key="sk-1234",
|
||||||
):
|
):
|
||||||
url = "http://0.0.0.0:4000/key/generate"
|
url = "http://0.0.0.0:4000/key/generate"
|
||||||
|
@ -77,6 +78,7 @@ async def generate_key(
|
||||||
"budget_duration": budget_duration,
|
"budget_duration": budget_duration,
|
||||||
"max_parallel_requests": max_parallel_requests,
|
"max_parallel_requests": max_parallel_requests,
|
||||||
"user_id": user_id,
|
"user_id": user_id,
|
||||||
|
"team_id": team_id,
|
||||||
}
|
}
|
||||||
|
|
||||||
print(f"data: {data}")
|
print(f"data: {data}")
|
||||||
|
@ -355,6 +357,29 @@ async def get_key_info(session, call_key, get_key=None):
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
|
async def get_model_list(session, call_key):
|
||||||
|
"""
|
||||||
|
Make sure only models user has access to are returned
|
||||||
|
"""
|
||||||
|
url = "http://0.0.0.0:4000/v1/models"
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {call_key}",
|
||||||
|
"Content-Type": "application/json",
|
||||||
|
}
|
||||||
|
|
||||||
|
async with session.get(url, headers=headers) as response:
|
||||||
|
status = response.status
|
||||||
|
response_text = await response.text()
|
||||||
|
print(response_text)
|
||||||
|
print()
|
||||||
|
|
||||||
|
if status != 200:
|
||||||
|
raise Exception(
|
||||||
|
f"Request did not return a 200 status code: {status}. Responses {response_text}"
|
||||||
|
)
|
||||||
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
async def get_model_info(session, call_key):
|
async def get_model_info(session, call_key):
|
||||||
"""
|
"""
|
||||||
Make sure only models user has access to are returned
|
Make sure only models user has access to are returned
|
||||||
|
@ -719,3 +744,32 @@ async def test_key_delete_ui():
|
||||||
get_key=key,
|
get_key=key,
|
||||||
auth_key=admin_ui_key["key"],
|
auth_key=admin_ui_key["key"],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_key_model_list(model_access):
|
||||||
|
"""
|
||||||
|
Test if `/v1/models` works as expected.
|
||||||
|
"""
|
||||||
|
async with aiohttp.ClientSession() as session:
|
||||||
|
new_team = await generate_team(session=session)
|
||||||
|
team_id = new_team["team_id"]
|
||||||
|
key_gen = await generate_key(
|
||||||
|
session=session,
|
||||||
|
i=0,
|
||||||
|
team_id=team_id,
|
||||||
|
models=[] if model_access == "all-team-models" else [model_access],
|
||||||
|
)
|
||||||
|
key = key_gen["key"]
|
||||||
|
print(f"key: {key}")
|
||||||
|
|
||||||
|
model_list = await get_model_list(session=session, call_key=key)
|
||||||
|
print(f"model_list: {model_list}")
|
||||||
|
|
||||||
|
if model_access == "all-team-models":
|
||||||
|
assert not isinstance(model_list["data"][0]["id"], list)
|
||||||
|
assert isinstance(model_list["data"][0]["id"], str)
|
||||||
|
if model_access == "gpt-3.5-turbo":
|
||||||
|
assert len(model_list["data"]) == 1
|
||||||
|
assert model_list["data"][0]["id"] == model_access
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue