fix(proxy_server.py): fix model check for /v1/models endpoint when team has restricted access

This commit is contained in:
Krrish Dholakia 2024-05-25 13:02:03 -07:00
parent 3c961136ea
commit 25a2f00db6
5 changed files with 134 additions and 36 deletions

View file

@ -1112,3 +1112,8 @@ class WebhookEvent(CallInfo):
] ]
event_group: Literal["user", "key", "team", "proxy"] event_group: Literal["user", "key", "team", "proxy"]
event_message: str # human-readable description of event event_message: str # human-readable description of event
class SpecialModelNames(enum.Enum):
all_team_models = "all-team-models"
all_proxy_models = "all-proxy-models"

View file

@ -0,0 +1,76 @@
# What is this?
## Common checks for /v1/models and `/model/info`
from typing import List, Optional
from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames
from litellm.utils import get_valid_models
from litellm._logging import verbose_proxy_logger
def get_key_models(
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
"""
all_models = []
if len(user_api_key_dict.models) > 0:
all_models = user_api_key_dict.models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
return all_models
def get_team_models(
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
) -> List[str]:
"""
Returns:
- List of model name strings
- Empty list if no models set
"""
all_models = []
if len(user_api_key_dict.team_models) > 0:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_team_models.value in all_models:
all_models = user_api_key_dict.team_models
if SpecialModelNames.all_proxy_models.value in all_models:
all_models = proxy_model_list
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
return all_models
def get_complete_model_list(
key_models: List[str],
team_models: List[str],
proxy_model_list: List[str],
user_model: Optional[str],
infer_model_from_keys: Optional[bool],
) -> List[str]:
"""Logic for returning complete model list for a given key + team pair"""
"""
- If key list is empty -> defer to team list
- If team list is empty -> defer to proxy model list
"""
if len(key_models) > 0:
return key_models
if len(team_models) > 0:
return team_models
returned_models = proxy_model_list
if user_model is not None: # set via `litellm --model ollama/llam3`
returned_models.append(user_model)
if infer_model_from_keys is not None and infer_model_from_keys == True:
valid_models = get_valid_models()
returned_models.extend(valid_models)
return returned_models

View file

@ -112,6 +112,11 @@ from litellm.router import ModelInfo as RouterModelInfo
from litellm._logging import verbose_router_logger, verbose_proxy_logger from litellm._logging import verbose_router_logger, verbose_proxy_logger
from litellm.proxy.auth.handle_jwt import JWTHandler from litellm.proxy.auth.handle_jwt import JWTHandler
from litellm.proxy.auth.litellm_license import LicenseCheck from litellm.proxy.auth.litellm_license import LicenseCheck
from litellm.proxy.auth.model_checks import (
get_complete_model_list,
get_key_models,
get_team_models,
)
from litellm.proxy.hooks.prompt_injection_detection import ( from litellm.proxy.hooks.prompt_injection_detection import (
_OPTIONAL_PromptInjectionDetection, _OPTIONAL_PromptInjectionDetection,
) )
@ -266,10 +271,6 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>` in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>`
class SpecialModelNames(enum.Enum):
all_team_models = "all-team-models"
class CommonProxyErrors(enum.Enum): class CommonProxyErrors(enum.Enum):
db_not_connected_error = "DB not connected" db_not_connected_error = "DB not connected"
no_llm_router = "No models configured on proxy" no_llm_router = "No models configured on proxy"
@ -3778,21 +3779,24 @@ def model_list(
): ):
global llm_model_list, general_settings global llm_model_list, general_settings
all_models = [] all_models = []
if len(user_api_key_dict.models) > 0: ## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
all_models = user_api_key_dict.models if llm_model_list is None:
if SpecialModelNames.all_team_models.value in all_models: proxy_model_list = []
all_models = user_api_key_dict.team_models else:
if len(all_models) == 0: # has all proxy models proxy_model_list = [m["model_name"] for m in llm_model_list]
## if no specific model access key_models = get_key_models(
if general_settings.get("infer_model_from_keys", False): user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
all_models = litellm.utils.get_valid_models() )
if llm_model_list: team_models = get_team_models(
all_models = list( user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
set(all_models + [m["model_name"] for m in llm_model_list]) )
) all_models = get_complete_model_list(
if user_model is not None: key_models=key_models,
all_models += [user_model] team_models=team_models,
verbose_proxy_logger.debug("all_models: %s", all_models) proxy_model_list=proxy_model_list,
user_model=user_model,
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
)
return dict( return dict(
data=[ data=[
{ {

View file

@ -11441,11 +11441,8 @@ class CustomStreamWrapper:
self.response_id = original_chunk.id self.response_id = original_chunk.id
if len(original_chunk.choices) > 0: if len(original_chunk.choices) > 0:
delta = original_chunk.choices[0].delta delta = original_chunk.choices[0].delta
if ( if delta is not None and (
delta is not None and ( delta.function_call is not None or delta.tool_calls is not None
delta.function_call is not None
or delta.tool_calls is not None
)
): ):
try: try:
model_response.system_fingerprint = ( model_response.system_fingerprint = (
@ -11506,7 +11503,11 @@ class CustomStreamWrapper:
model_response.choices[0].delta = Delta() model_response.choices[0].delta = Delta()
else: else:
try: try:
delta = dict() if original_chunk.choices[0].delta is None else dict(original_chunk.choices[0].delta) delta = (
dict()
if original_chunk.choices[0].delta is None
else dict(original_chunk.choices[0].delta)
)
print_verbose(f"original delta: {delta}") print_verbose(f"original delta: {delta}")
model_response.choices[0].delta = Delta(**delta) model_response.choices[0].delta = Delta(**delta)
print_verbose( print_verbose(
@ -12256,7 +12257,7 @@ def trim_messages(
return messages return messages
def get_valid_models(): def get_valid_models() -> List[str]:
""" """
Returns a list of valid LLMs based on the set environment variables Returns a list of valid LLMs based on the set environment variables

View file

@ -2,7 +2,7 @@
## Tests /key endpoints. ## Tests /key endpoints.
import pytest import pytest
import asyncio, time import asyncio, time, uuid
import aiohttp import aiohttp
from openai import AsyncOpenAI from openai import AsyncOpenAI
import sys, os import sys, os
@ -14,12 +14,14 @@ sys.path.insert(
import litellm import litellm
async def generate_team(session): async def generate_team(
session, models: Optional[list] = None, team_id: Optional[str] = None
):
url = "http://0.0.0.0:4000/team/new" url = "http://0.0.0.0:4000/team/new"
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"} headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
data = { if team_id is None:
"team_id": "litellm-dashboard", team_id = "litellm-dashboard"
} data = {"team_id": team_id, "models": models}
async with session.post(url, headers=headers, json=data) as response: async with session.post(url, headers=headers, json=data) as response:
status = response.status status = response.status
@ -746,19 +748,25 @@ async def test_key_delete_ui():
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"]) @pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
@pytest.mark.parametrize("model_access_level", ["key", "team"])
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_key_model_list(model_access): async def test_key_model_list(model_access, model_access_level):
""" """
Test if `/v1/models` works as expected. Test if `/v1/models` works as expected.
""" """
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
new_team = await generate_team(session=session) _models = [] if model_access == "all-team-models" else [model_access]
team_id = "litellm-dashboard" team_id = "litellm_dashboard_{}".format(uuid.uuid4())
new_team = await generate_team(
session=session,
models=_models if model_access_level == "team" else None,
team_id=team_id,
)
key_gen = await generate_key( key_gen = await generate_key(
session=session, session=session,
i=0, i=0,
team_id=team_id, team_id=team_id,
models=[] if model_access == "all-team-models" else [model_access], models=_models if model_access_level == "key" else [],
) )
key = key_gen["key"] key = key_gen["key"]
print(f"key: {key}") print(f"key: {key}")
@ -770,5 +778,9 @@ async def test_key_model_list(model_access):
assert not isinstance(model_list["data"][0]["id"], list) assert not isinstance(model_list["data"][0]["id"], list)
assert isinstance(model_list["data"][0]["id"], str) assert isinstance(model_list["data"][0]["id"], str)
if model_access == "gpt-3.5-turbo": if model_access == "gpt-3.5-turbo":
assert len(model_list["data"]) == 1 assert (
len(model_list["data"]) == 1
), "model_access={}, model_access_level={}".format(
model_access, model_access_level
)
assert model_list["data"][0]["id"] == model_access assert model_list["data"][0]["id"] == model_access