forked from phoenix/litellm-mirror
fix(proxy_server.py): fix model check for /v1/models
endpoint when team has restricted access
This commit is contained in:
parent
3c961136ea
commit
25a2f00db6
5 changed files with 134 additions and 36 deletions
|
@ -1112,3 +1112,8 @@ class WebhookEvent(CallInfo):
|
|||
]
|
||||
event_group: Literal["user", "key", "team", "proxy"]
|
||||
event_message: str # human-readable description of event
|
||||
|
||||
|
||||
class SpecialModelNames(enum.Enum):
|
||||
all_team_models = "all-team-models"
|
||||
all_proxy_models = "all-proxy-models"
|
||||
|
|
76
litellm/proxy/auth/model_checks.py
Normal file
76
litellm/proxy/auth/model_checks.py
Normal file
|
@ -0,0 +1,76 @@
|
|||
# What is this?
|
||||
## Common checks for /v1/models and `/model/info`
|
||||
from typing import List, Optional
|
||||
from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames
|
||||
from litellm.utils import get_valid_models
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
||||
def get_key_models(
|
||||
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of model name strings
|
||||
- Empty list if no models set
|
||||
"""
|
||||
all_models = []
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
all_models = user_api_key_dict.models
|
||||
if SpecialModelNames.all_team_models.value in all_models:
|
||||
all_models = user_api_key_dict.team_models
|
||||
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||
all_models = proxy_model_list
|
||||
|
||||
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
|
||||
def get_team_models(
|
||||
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
|
||||
) -> List[str]:
|
||||
"""
|
||||
Returns:
|
||||
- List of model name strings
|
||||
- Empty list if no models set
|
||||
"""
|
||||
all_models = []
|
||||
if len(user_api_key_dict.team_models) > 0:
|
||||
all_models = user_api_key_dict.team_models
|
||||
if SpecialModelNames.all_team_models.value in all_models:
|
||||
all_models = user_api_key_dict.team_models
|
||||
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||
all_models = proxy_model_list
|
||||
|
||||
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
|
||||
return all_models
|
||||
|
||||
|
||||
def get_complete_model_list(
|
||||
key_models: List[str],
|
||||
team_models: List[str],
|
||||
proxy_model_list: List[str],
|
||||
user_model: Optional[str],
|
||||
infer_model_from_keys: Optional[bool],
|
||||
) -> List[str]:
|
||||
"""Logic for returning complete model list for a given key + team pair"""
|
||||
|
||||
"""
|
||||
- If key list is empty -> defer to team list
|
||||
- If team list is empty -> defer to proxy model list
|
||||
"""
|
||||
|
||||
if len(key_models) > 0:
|
||||
return key_models
|
||||
|
||||
if len(team_models) > 0:
|
||||
return team_models
|
||||
|
||||
returned_models = proxy_model_list
|
||||
if user_model is not None: # set via `litellm --model ollama/llam3`
|
||||
returned_models.append(user_model)
|
||||
|
||||
if infer_model_from_keys is not None and infer_model_from_keys == True:
|
||||
valid_models = get_valid_models()
|
||||
returned_models.extend(valid_models)
|
||||
return returned_models
|
|
@ -112,6 +112,11 @@ from litellm.router import ModelInfo as RouterModelInfo
|
|||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||
from litellm.proxy.auth.model_checks import (
|
||||
get_complete_model_list,
|
||||
get_key_models,
|
||||
get_team_models,
|
||||
)
|
||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||
_OPTIONAL_PromptInjectionDetection,
|
||||
)
|
||||
|
@ -266,10 +271,6 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
|
|||
in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>`
|
||||
|
||||
|
||||
class SpecialModelNames(enum.Enum):
|
||||
all_team_models = "all-team-models"
|
||||
|
||||
|
||||
class CommonProxyErrors(enum.Enum):
|
||||
db_not_connected_error = "DB not connected"
|
||||
no_llm_router = "No models configured on proxy"
|
||||
|
@ -3778,21 +3779,24 @@ def model_list(
|
|||
):
|
||||
global llm_model_list, general_settings
|
||||
all_models = []
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
all_models = user_api_key_dict.models
|
||||
if SpecialModelNames.all_team_models.value in all_models:
|
||||
all_models = user_api_key_dict.team_models
|
||||
if len(all_models) == 0: # has all proxy models
|
||||
## if no specific model access
|
||||
if general_settings.get("infer_model_from_keys", False):
|
||||
all_models = litellm.utils.get_valid_models()
|
||||
if llm_model_list:
|
||||
all_models = list(
|
||||
set(all_models + [m["model_name"] for m in llm_model_list])
|
||||
)
|
||||
if user_model is not None:
|
||||
all_models += [user_model]
|
||||
verbose_proxy_logger.debug("all_models: %s", all_models)
|
||||
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||
if llm_model_list is None:
|
||||
proxy_model_list = []
|
||||
else:
|
||||
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
||||
key_models = get_key_models(
|
||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||
)
|
||||
team_models = get_team_models(
|
||||
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||
)
|
||||
all_models = get_complete_model_list(
|
||||
key_models=key_models,
|
||||
team_models=team_models,
|
||||
proxy_model_list=proxy_model_list,
|
||||
user_model=user_model,
|
||||
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||
)
|
||||
return dict(
|
||||
data=[
|
||||
{
|
||||
|
|
|
@ -11441,11 +11441,8 @@ class CustomStreamWrapper:
|
|||
self.response_id = original_chunk.id
|
||||
if len(original_chunk.choices) > 0:
|
||||
delta = original_chunk.choices[0].delta
|
||||
if (
|
||||
delta is not None and (
|
||||
delta.function_call is not None
|
||||
or delta.tool_calls is not None
|
||||
)
|
||||
if delta is not None and (
|
||||
delta.function_call is not None or delta.tool_calls is not None
|
||||
):
|
||||
try:
|
||||
model_response.system_fingerprint = (
|
||||
|
@ -11506,7 +11503,11 @@ class CustomStreamWrapper:
|
|||
model_response.choices[0].delta = Delta()
|
||||
else:
|
||||
try:
|
||||
delta = dict() if original_chunk.choices[0].delta is None else dict(original_chunk.choices[0].delta)
|
||||
delta = (
|
||||
dict()
|
||||
if original_chunk.choices[0].delta is None
|
||||
else dict(original_chunk.choices[0].delta)
|
||||
)
|
||||
print_verbose(f"original delta: {delta}")
|
||||
model_response.choices[0].delta = Delta(**delta)
|
||||
print_verbose(
|
||||
|
@ -12256,7 +12257,7 @@ def trim_messages(
|
|||
return messages
|
||||
|
||||
|
||||
def get_valid_models():
|
||||
def get_valid_models() -> List[str]:
|
||||
"""
|
||||
Returns a list of valid LLMs based on the set environment variables
|
||||
|
||||
|
|
|
@ -2,7 +2,7 @@
|
|||
## Tests /key endpoints.
|
||||
|
||||
import pytest
|
||||
import asyncio, time
|
||||
import asyncio, time, uuid
|
||||
import aiohttp
|
||||
from openai import AsyncOpenAI
|
||||
import sys, os
|
||||
|
@ -14,12 +14,14 @@ sys.path.insert(
|
|||
import litellm
|
||||
|
||||
|
||||
async def generate_team(session):
|
||||
async def generate_team(
|
||||
session, models: Optional[list] = None, team_id: Optional[str] = None
|
||||
):
|
||||
url = "http://0.0.0.0:4000/team/new"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
data = {
|
||||
"team_id": "litellm-dashboard",
|
||||
}
|
||||
if team_id is None:
|
||||
team_id = "litellm-dashboard"
|
||||
data = {"team_id": team_id, "models": models}
|
||||
|
||||
async with session.post(url, headers=headers, json=data) as response:
|
||||
status = response.status
|
||||
|
@ -746,19 +748,25 @@ async def test_key_delete_ui():
|
|||
|
||||
|
||||
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
|
||||
@pytest.mark.parametrize("model_access_level", ["key", "team"])
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_model_list(model_access):
|
||||
async def test_key_model_list(model_access, model_access_level):
|
||||
"""
|
||||
Test if `/v1/models` works as expected.
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
new_team = await generate_team(session=session)
|
||||
team_id = "litellm-dashboard"
|
||||
_models = [] if model_access == "all-team-models" else [model_access]
|
||||
team_id = "litellm_dashboard_{}".format(uuid.uuid4())
|
||||
new_team = await generate_team(
|
||||
session=session,
|
||||
models=_models if model_access_level == "team" else None,
|
||||
team_id=team_id,
|
||||
)
|
||||
key_gen = await generate_key(
|
||||
session=session,
|
||||
i=0,
|
||||
team_id=team_id,
|
||||
models=[] if model_access == "all-team-models" else [model_access],
|
||||
models=_models if model_access_level == "key" else [],
|
||||
)
|
||||
key = key_gen["key"]
|
||||
print(f"key: {key}")
|
||||
|
@ -770,5 +778,9 @@ async def test_key_model_list(model_access):
|
|||
assert not isinstance(model_list["data"][0]["id"], list)
|
||||
assert isinstance(model_list["data"][0]["id"], str)
|
||||
if model_access == "gpt-3.5-turbo":
|
||||
assert len(model_list["data"]) == 1
|
||||
assert (
|
||||
len(model_list["data"]) == 1
|
||||
), "model_access={}, model_access_level={}".format(
|
||||
model_access, model_access_level
|
||||
)
|
||||
assert model_list["data"][0]["id"] == model_access
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue