forked from phoenix/litellm-mirror
fix(proxy_server.py): fix model list returned for /model/info
when team has restricted access
This commit is contained in:
parent
25a2f00db6
commit
95566dc249
2 changed files with 46 additions and 16 deletions
|
@ -9395,12 +9395,31 @@ async def model_info_v1(
|
||||||
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
status_code=500, detail={"error": "LLM Model List not loaded in"}
|
||||||
)
|
)
|
||||||
|
|
||||||
if len(user_api_key_dict.models) > 0:
|
all_models: List[dict] = []
|
||||||
model_names = user_api_key_dict.models
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||||
|
if llm_model_list is None:
|
||||||
|
proxy_model_list = []
|
||||||
|
else:
|
||||||
|
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
||||||
|
key_models = get_key_models(
|
||||||
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||||
|
)
|
||||||
|
team_models = get_team_models(
|
||||||
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||||
|
)
|
||||||
|
all_models_str = get_complete_model_list(
|
||||||
|
key_models=key_models,
|
||||||
|
team_models=team_models,
|
||||||
|
proxy_model_list=proxy_model_list,
|
||||||
|
user_model=user_model,
|
||||||
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||||
|
)
|
||||||
|
|
||||||
|
if len(all_models_str) > 0:
|
||||||
|
model_names = all_models_str
|
||||||
_relevant_models = [m for m in llm_model_list if m["model_name"] in model_names]
|
_relevant_models = [m for m in llm_model_list if m["model_name"] in model_names]
|
||||||
all_models = copy.deepcopy(_relevant_models)
|
all_models = copy.deepcopy(_relevant_models)
|
||||||
else:
|
|
||||||
all_models = copy.deepcopy(llm_model_list)
|
|
||||||
for model in all_models:
|
for model in all_models:
|
||||||
# provided model_info in config.yaml
|
# provided model_info in config.yaml
|
||||||
model_info = model.get("model_info", {})
|
model_info = model.get("model_info", {})
|
||||||
|
|
|
@ -359,11 +359,11 @@ async def get_key_info(session, call_key, get_key=None):
|
||||||
return await response.json()
|
return await response.json()
|
||||||
|
|
||||||
|
|
||||||
async def get_model_list(session, call_key):
|
async def get_model_list(session, call_key, endpoint: str = "/v1/models"):
|
||||||
"""
|
"""
|
||||||
Make sure only models user has access to are returned
|
Make sure only models user has access to are returned
|
||||||
"""
|
"""
|
||||||
url = "http://0.0.0.0:4000/v1/models"
|
url = "http://0.0.0.0:4000" + endpoint
|
||||||
headers = {
|
headers = {
|
||||||
"Authorization": f"Bearer {call_key}",
|
"Authorization": f"Bearer {call_key}",
|
||||||
"Content-Type": "application/json",
|
"Content-Type": "application/json",
|
||||||
|
@ -749,8 +749,9 @@ async def test_key_delete_ui():
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
|
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
|
||||||
@pytest.mark.parametrize("model_access_level", ["key", "team"])
|
@pytest.mark.parametrize("model_access_level", ["key", "team"])
|
||||||
|
@pytest.mark.parametrize("model_endpoint", ["/v1/models", "/model/info"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_key_model_list(model_access, model_access_level):
|
async def test_key_model_list(model_access, model_access_level, model_endpoint):
|
||||||
"""
|
"""
|
||||||
Test if `/v1/models` works as expected.
|
Test if `/v1/models` works as expected.
|
||||||
"""
|
"""
|
||||||
|
@ -771,16 +772,26 @@ async def test_key_model_list(model_access, model_access_level):
|
||||||
key = key_gen["key"]
|
key = key_gen["key"]
|
||||||
print(f"key: {key}")
|
print(f"key: {key}")
|
||||||
|
|
||||||
model_list = await get_model_list(session=session, call_key=key)
|
model_list = await get_model_list(
|
||||||
|
session=session, call_key=key, endpoint=model_endpoint
|
||||||
|
)
|
||||||
print(f"model_list: {model_list}")
|
print(f"model_list: {model_list}")
|
||||||
|
|
||||||
if model_access == "all-team-models":
|
if model_access == "all-team-models":
|
||||||
|
if model_endpoint == "/v1/models":
|
||||||
assert not isinstance(model_list["data"][0]["id"], list)
|
assert not isinstance(model_list["data"][0]["id"], list)
|
||||||
assert isinstance(model_list["data"][0]["id"], str)
|
assert isinstance(model_list["data"][0]["id"], str)
|
||||||
|
elif model_endpoint == "/model/info":
|
||||||
|
assert isinstance(model_list["data"], list)
|
||||||
|
assert len(model_list["data"]) > 0
|
||||||
if model_access == "gpt-3.5-turbo":
|
if model_access == "gpt-3.5-turbo":
|
||||||
|
if model_endpoint == "/v1/models":
|
||||||
assert (
|
assert (
|
||||||
len(model_list["data"]) == 1
|
len(model_list["data"]) == 1
|
||||||
), "model_access={}, model_access_level={}".format(
|
), "model_access={}, model_access_level={}".format(
|
||||||
model_access, model_access_level
|
model_access, model_access_level
|
||||||
)
|
)
|
||||||
assert model_list["data"][0]["id"] == model_access
|
assert model_list["data"][0]["id"] == model_access
|
||||||
|
elif model_endpoint == "/model/info":
|
||||||
|
assert isinstance(model_list["data"], list)
|
||||||
|
assert len(model_list["data"]) == 1
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue