fix(proxy_server.py): fix model check for /v1/models endpoint when team has restricted access

This commit is contained in:
Krrish Dholakia 2024-05-25 13:02:03 -07:00
parent 3c961136ea
commit 25a2f00db6
5 changed files with 134 additions and 36 deletions

View file

@ -1112,3 +1112,8 @@ class WebhookEvent(CallInfo):
]
event_group: Literal["user", "key", "team", "proxy"]
event_message: str # human-readable description of event
class SpecialModelNames(enum.Enum):
all_team_models = "all-team-models"
all_proxy_models = "all-proxy-models"

View file

@ -0,0 +1,76 @@
# What is this?
## Common checks for /v1/models and `/model/info`
from typing import List, Optional
from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames
from litellm.utils import get_valid_models
from litellm._logging import verbose_proxy_logger
def get_key_models(
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
"""
all_models = []
if len(user_api_key_dict.models) > 0:
all_models = user_api_key_dict.models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
return all_models
def get_team_models(
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
"""
all_models = []
if len(user_api_key_dict.team_models) > 0:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
return all_models
def get_complete_model_list(
key_models: List[str],
team_models: List[str],
proxy_model_list: List[str],
user_model: Optional[str],
infer_model_from_keys: Optional[bool],
) -> List[str]:
"""Logic for returning complete model list for a given key + team pair"""
"""
- If key list is empty -> defer to team list
- If team list is empty -> defer to proxy model list
"""
if len(key_models) > 0:
return key_models
if len(team_models) > 0:
return team_models
returned_models = proxy_model_list
if user_model is not None: # set via `litellm --model ollama/llam3`
returned_models.append(user_model)
if infer_model_from_keys is not None and infer_model_from_keys == True:
valid_models = get_valid_models()
returned_models.extend(valid_models)
return returned_models

View file

@ -112,6 +112,11 @@ from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck
from litellm.proxy.auth.model_checks import (
get_complete_model_list,
get_key_models,
get_team_models,
)
from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection,
)
@ -266,10 +271,6 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>`
class SpecialModelNames(enum.Enum):
all_team_models = "all-team-models"
class CommonProxyErrors(enum.Enum):
db_not_connected_error = "DB not connected"
no_llm_router = "No models configured on proxy"
@ -3778,21 +3779,24 @@ def model_list(
):
global llm_model_list, general_settings
all_models = []
if len(user_api_key_dict.models) > 0:
all_models = user_api_key_dict.models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if len(all_models) == 0: # has all proxy models
## if no specific model access
if general_settings.get("infer_model_from_keys", False):
all_models = litellm.utils.get_valid_models()
if llm_model_list:
all_models = list(
set(all_models + [m["model_name"] for m in llm_model_list])
)
if user_model is not None:
all_models += [user_model]
verbose_proxy_logger.debug("all_models: %s", all_models)
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
team_models = get_team_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
all_models = get_complete_model_list(
key_models=key_models,
team_models=team_models,
proxy_model_list=proxy_model_list,
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
)
return dict(
data=[
{

View file

@ -11441,11 +11441,8 @@ class CustomStreamWrapper:
self.response_id = original_chunk.id
if len(original_chunk.choices) > 0:
delta = original_chunk.choices[0].delta
if (
delta is not None and (
delta.function_call is not None
or delta.tool_calls is not None
)
if delta is not None and (
delta.function_call is not None or delta.tool_calls is not None
):
try:
model_response.system_fingerprint = (
@ -11506,7 +11503,11 @@ class CustomStreamWrapper:
model_response.choices[0].delta = Delta()
else:
try:
delta = dict() if original_chunk.choices[0].delta is None else dict(original_chunk.choices[0].delta)
delta = (
dict()
if original_chunk.choices[0].delta is None
else dict(original_chunk.choices[0].delta)
)
print_verbose(f"original delta: {delta}")
model_response.choices[0].delta = Delta(**delta)
print_verbose(
@ -12256,7 +12257,7 @@ def trim_messages(
return messages
def get_valid_models():
def get_valid_models() -> List[str]:
"""
Returns a list of valid LLMs based on the set environment variables

View file

@ -2,7 +2,7 @@
## Tests /key endpoints.
import pytest
import asyncio, time
import asyncio, time, uuid
import aiohttp
from openai import AsyncOpenAI
import sys, os
@ -14,12 +14,14 @@ sys.path.insert(
import litellm
async def generate_team(session):
async def generate_team(
session, models: Optional[list] = None, team_id: Optional[str] = None
):
url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = {
"team_id": "litellm-dashboard",
}
if team_id is None:
team_id = "litellm-dashboard"
data = {"team_id": team_id, "models": models}
async with session.post(url, headers=headers, json=data) as response:
status = response.status
@ -746,19 +748,25 @@ async def test_key_delete_ui():
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
@pytest.mark.parametrize("model_access_level", ["key", "team"])
@pytest.mark.asyncio
async def test_key_model_list(model_access):
async def test_key_model_list(model_access, model_access_level):
"""
Test if `/v1/models` works as expected.
"""
async with aiohttp.ClientSession() as session:
new_team = await generate_team(session=session)
team_id = "litellm-dashboard"
_models = [] if model_access == "all-team-models" else [model_access]
team_id = "litellm_dashboard_{}".format(uuid.uuid4())
new_team = await generate_team(
session=session,
models=_models if model_access_level == "team" else None,
team_id=team_id,
)
key_gen = await generate_key(
session=session,
i=0,
team_id=team_id,
models=[] if model_access == "all-team-models" else [model_access],
models=_models if model_access_level == "key" else [],
)
key = key_gen["key"]
print(f"key: {key}")
@ -770,5 +778,9 @@ async def test_key_model_list(model_access):
assert not isinstance(model_list["data"][0]["id"], list)
assert isinstance(model_list["data"][0]["id"], str)
if model_access == "gpt-3.5-turbo":
assert len(model_list["data"]) == 1
assert (
len(model_list["data"]) == 1
), "model_access={}, model_access_level={}".format(
model_access, model_access_level
)
assert model_list["data"][0]["id"] == model_access