Merge pull request #3839 from BerriAI/litellm_fix_models_endpoint

fix(proxy_server.py): fix model check for `/v1/models` + `/model/info` endpoint when team has restricted access
This commit is contained in:
Krish Dholakia 2024-05-25 14:23:01 -07:00 committed by GitHub
commit 02f2d67808
5 changed files with 174 additions and 46 deletions

View file

@ -1112,3 +1112,8 @@ class WebhookEvent(CallInfo):
] ]
event_group: Literal["user", "key", "team", "proxy"] event_group: Literal["user", "key", "team", "proxy"]
event_message: str # human-readable description of event event_message: str # human-readable description of event
class SpecialModelNames(enum.Enum):
all_team_models = "all-team-models"
all_proxy_models = "all-proxy-models"

View file

@ -0,0 +1,76 @@
# What is this?
## Common checks for /v1/models and `/model/info`
from typing import List, Optional
from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames
from litellm.utils import get_valid_models
from litellm._logging import verbose_proxy_logger
def get_key_models(
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
"""
all_models = []
if len(user_api_key_dict.models) > 0:
all_models = user_api_key_dict.models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
return all_models
def get_team_models(
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
"""
all_models = []
if len(user_api_key_dict.team_models) > 0:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
return all_models
def get_complete_model_list(
key_models: List[str],
team_models: List[str],
proxy_model_list: List[str],
user_model: Optional[str],
infer_model_from_keys: Optional[bool],
) -> List[str]:
"""Logic for returning complete model list for a given key + team pair"""
"""
- If key list is empty -> defer to team list
- If team list is empty -> defer to proxy model list
"""
if len(key_models) > 0:
return key_models
if len(team_models) > 0:
return team_models
returned_models = proxy_model_list
if user_model is not None: # set via `litellm --model ollama/llam3`
returned_models.append(user_model)
if infer_model_from_keys is not None and infer_model_from_keys == True:
valid_models = get_valid_models()
returned_models.extend(valid_models)
return returned_models

View file

@ -111,6 +111,11 @@ from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck from litellm.proxy.auth.litellm_license import LicenseCheck
from litellm.proxy.auth.model_checks import (
get_complete_model_list,
get_key_models,
get_team_models,
)
from litellm.proxy.hooks.prompt_injection_detection import ( from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection, _OPTIONAL_PromptInjectionDetection,
) )
@ -265,10 +270,6 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>` in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>`
class SpecialModelNames(enum.Enum):
all_team_models = "all-team-models"
class CommonProxyErrors(enum.Enum): class CommonProxyErrors(enum.Enum):
db_not_connected_error = "DB not connected" db_not_connected_error = "DB not connected"
no_llm_router = "No models configured on proxy" no_llm_router = "No models configured on proxy"
@ -3777,21 +3778,24 @@ def model_list(
): ):
global llm_model_list, general_settings global llm_model_list, general_settings
all_models = [] all_models = []
if len(user_api_key_dict.models) > 0: ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
all_models = user_api_key_dict.models if llm_model_list is None:
if SpecialModelNames.all_team_models.value in all_models: proxy_model_list = []
all_models = user_api_key_dict.team_models else:
if len(all_models) == 0: # has all proxy models proxy_model_list = [m["model_name"] for m in llm_model_list]
## if no specific model access key_models = get_key_models(
if general_settings.get("infer_model_from_keys", False): user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
all_models = litellm.utils.get_valid_models() )
if llm_model_list: team_models = get_team_models(
all_models = list( user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
set(all_models + [m["model_name"] for m in llm_model_list]) )
) all_models = get_complete_model_list(
if user_model is not None: key_models=key_models,
all_models += [user_model] team_models=team_models,
verbose_proxy_logger.debug("all_models: %s", all_models) proxy_model_list=proxy_model_list,
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
)
return dict( return dict(
data=[ data=[
{ {
@ -9640,12 +9644,31 @@ async def model_info_v1(
status_code=500, detail={"error": "LLM Model List not loaded in"} status_code=500, detail={"error": "LLM Model List not loaded in"}
) )
if len(user_api_key_dict.models) > 0: all_models: List[dict] = []
model_names = user_api_key_dict.models ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
if llm_model_list is None:
proxy_model_list = []
else:
proxy_model_list = [m["model_name"] for m in llm_model_list]
key_models = get_key_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
team_models = get_team_models(
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
)
all_models_str = get_complete_model_list(
key_models=key_models,
team_models=team_models,
proxy_model_list=proxy_model_list,
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
)
if len(all_models_str) > 0:
model_names = all_models_str
_relevant_models = [m for m in llm_model_list if m["model_name"] in model_names] _relevant_models = [m for m in llm_model_list if m["model_name"] in model_names]
all_models = copy.deepcopy(_relevant_models) all_models = copy.deepcopy(_relevant_models)
else:
all_models = copy.deepcopy(llm_model_list)
for model in all_models: for model in all_models:
# provided model_info in config.yaml # provided model_info in config.yaml
model_info = model.get("model_info", {}) model_info = model.get("model_info", {})

View file

@ -11441,11 +11441,8 @@ class CustomStreamWrapper:
self.response_id = original_chunk.id self.response_id = original_chunk.id
if len(original_chunk.choices) > 0: if len(original_chunk.choices) > 0:
delta = original_chunk.choices[0].delta delta = original_chunk.choices[0].delta
if ( if delta is not None and (
delta is not None and ( delta.function_call is not None or delta.tool_calls is not None
delta.function_call is not None
or delta.tool_calls is not None
)
): ):
try: try:
model_response.system_fingerprint = ( model_response.system_fingerprint = (
@ -11506,7 +11503,11 @@ class CustomStreamWrapper:
model_response.choices[0].delta = Delta() model_response.choices[0].delta = Delta()
else: else:
try: try:
delta = dict() if original_chunk.choices[0].delta is None else dict(original_chunk.choices[0].delta) delta = (
dict()
if original_chunk.choices[0].delta is None
else dict(original_chunk.choices[0].delta)
)
print_verbose(f"original delta: {delta}") print_verbose(f"original delta: {delta}")
model_response.choices[0].delta = Delta(**delta) model_response.choices[0].delta = Delta(**delta)
print_verbose( print_verbose(
@ -12256,7 +12257,7 @@ def trim_messages(
return messages return messages
def get_valid_models(): def get_valid_models() -> List[str]:
""" """
Returns a list of valid LLMs based on the set environment variables Returns a list of valid LLMs based on the set environment variables

View file

@ -2,7 +2,7 @@
## Tests /key endpoints. ## Tests /key endpoints.
import pytest import pytest
import asyncio, time import asyncio, time, uuid
import aiohttp import aiohttp
from openai import AsyncOpenAI from openai import AsyncOpenAI
import sys, os import sys, os
@ -14,12 +14,14 @@ sys.path.insert(
import litellm import litellm
async def generate_team(session): async def generate_team(
session, models: Optional[list] = None, team_id: Optional[str] = None
):
url = "http://0.0.0.0:4000/team/new" url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = { if team_id is None:
"team_id": "litellm-dashboard", team_id = "litellm-dashboard"
} data = {"team_id": team_id, "models": models}
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
@ -357,11 +359,11 @@ async def get_key_info(session, call_key, get_key=None):
return await response.json() return await response.json()
async def get_model_list(session, call_key): async def get_model_list(session, call_key, endpoint: str = "/v1/models"):
""" """
Make sure only models user has access to are returned Make sure only models user has access to are returned
""" """
url = "http://0.0.0.0:4000/v1/models" url = "http://0.0.0.0:4000" + endpoint
headers = { headers = {
"Authorization": f"Bearer {call_key}", "Authorization": f"Bearer {call_key}",
"Content-Type": "application/json", "Content-Type": "application/json",
@ -746,29 +748,50 @@ async def test_key_delete_ui():
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"]) @pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
@pytest.mark.parametrize("model_access_level", ["key", "team"])
@pytest.mark.parametrize("model_endpoint", ["/v1/models", "/model/info"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_key_model_list(model_access): async def test_key_model_list(model_access, model_access_level, model_endpoint):
""" """
Test if `/v1/models` works as expected. Test if `/v1/models` works as expected.
""" """
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
new_team = await generate_team(session=session) _models = [] if model_access == "all-team-models" else [model_access]
team_id = "litellm-dashboard" team_id = "litellm_dashboard_{}".format(uuid.uuid4())
new_team = await generate_team(
session=session,
models=_models if model_access_level == "team" else None,
team_id=team_id,
)
key_gen = await generate_key( key_gen = await generate_key(
session=session, session=session,
i=0, i=0,
team_id=team_id, team_id=team_id,
models=[] if model_access == "all-team-models" else [model_access], models=_models if model_access_level == "key" else [],
) )
key = key_gen["key"] key = key_gen["key"]
print(f"key: {key}") print(f"key: {key}")
model_list = await get_model_list(session=session, call_key=key) model_list = await get_model_list(
session=session, call_key=key, endpoint=model_endpoint
)
print(f"model_list: {model_list}") print(f"model_list: {model_list}")
if model_access == "all-team-models": if model_access == "all-team-models":
assert not isinstance(model_list["data"][0]["id"], list) if model_endpoint == "/v1/models":
assert isinstance(model_list["data"][0]["id"], str) assert not isinstance(model_list["data"][0]["id"], list)
assert isinstance(model_list["data"][0]["id"], str)
elif model_endpoint == "/model/info":
assert isinstance(model_list["data"], list)
assert len(model_list["data"]) > 0
if model_access == "gpt-3.5-turbo": if model_access == "gpt-3.5-turbo":
assert len(model_list["data"]) == 1 if model_endpoint == "/v1/models":
assert model_list["data"][0]["id"] == model_access assert (
len(model_list["data"]) == 1
), "model_access={}, model_access_level={}".format(
model_access, model_access_level
)
assert model_list["data"][0]["id"] == model_access
elif model_endpoint == "/model/info":
assert isinstance(model_list["data"], list)
assert len(model_list["data"]) == 1