forked from phoenix/litellm-mirror
fix(proxy_server.py): fix model check for /v1/models
endpoint when team has restricted access
This commit is contained in:
parent
3c961136ea
commit
25a2f00db6
5 changed files with 134 additions and 36 deletions
|
@ -1112,3 +1112,8 @@ class WebhookEvent(CallInfo):
|
||||||
]
|
]
|
||||||
event_group: Literal["user", "key", "team", "proxy"]
|
event_group: Literal["user", "key", "team", "proxy"]
|
||||||
event_message: str # human-readable description of event
|
event_message: str # human-readable description of event
|
||||||
|
|
||||||
|
|
||||||
|
class SpecialModelNames(enum.Enum):
|
||||||
|
all_team_models = "all-team-models"
|
||||||
|
all_proxy_models = "all-proxy-models"
|
||||||
|
|
76
litellm/proxy/auth/model_checks.py
Normal file
76
litellm/proxy/auth/model_checks.py
Normal file
|
@ -0,0 +1,76 @@
|
||||||
|
# What is this?
|
||||||
|
## Common checks for /v1/models and `/model/info`
|
||||||
|
from typing import List, Optional
|
||||||
|
from litellm.proxy._types import UserAPIKeyAuth, SpecialModelNames
|
||||||
|
from litellm.utils import get_valid_models
|
||||||
|
from litellm._logging import verbose_proxy_logger
|
||||||
|
|
||||||
|
|
||||||
|
def get_key_models(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
- List of model name strings
|
||||||
|
- Empty list if no models set
|
||||||
|
"""
|
||||||
|
all_models = []
|
||||||
|
if len(user_api_key_dict.models) > 0:
|
||||||
|
all_models = user_api_key_dict.models
|
||||||
|
if SpecialModelNames.all_team_models.value in all_models:
|
||||||
|
all_models = user_api_key_dict.team_models
|
||||||
|
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||||
|
all_models = proxy_model_list
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("ALL KEY MODELS - {}".format(len(all_models)))
|
||||||
|
return all_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_team_models(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth, proxy_model_list: List[str]
|
||||||
|
) -> List[str]:
|
||||||
|
"""
|
||||||
|
Returns:
|
||||||
|
- List of model name strings
|
||||||
|
- Empty list if no models set
|
||||||
|
"""
|
||||||
|
all_models = []
|
||||||
|
if len(user_api_key_dict.team_models) > 0:
|
||||||
|
all_models = user_api_key_dict.team_models
|
||||||
|
if SpecialModelNames.all_team_models.value in all_models:
|
||||||
|
all_models = user_api_key_dict.team_models
|
||||||
|
if SpecialModelNames.all_proxy_models.value in all_models:
|
||||||
|
all_models = proxy_model_list
|
||||||
|
|
||||||
|
verbose_proxy_logger.debug("ALL TEAM MODELS - {}".format(len(all_models)))
|
||||||
|
return all_models
|
||||||
|
|
||||||
|
|
||||||
|
def get_complete_model_list(
|
||||||
|
key_models: List[str],
|
||||||
|
team_models: List[str],
|
||||||
|
proxy_model_list: List[str],
|
||||||
|
user_model: Optional[str],
|
||||||
|
infer_model_from_keys: Optional[bool],
|
||||||
|
) -> List[str]:
|
||||||
|
"""Logic for returning complete model list for a given key + team pair"""
|
||||||
|
|
||||||
|
"""
|
||||||
|
- If key list is empty -> defer to team list
|
||||||
|
- If team list is empty -> defer to proxy model list
|
||||||
|
"""
|
||||||
|
|
||||||
|
if len(key_models) > 0:
|
||||||
|
return key_models
|
||||||
|
|
||||||
|
if len(team_models) > 0:
|
||||||
|
return team_models
|
||||||
|
|
||||||
|
returned_models = proxy_model_list
|
||||||
|
if user_model is not None: # set via `litellm --model ollama/llam3`
|
||||||
|
returned_models.append(user_model)
|
||||||
|
|
||||||
|
if infer_model_from_keys is not None and infer_model_from_keys == True:
|
||||||
|
valid_models = get_valid_models()
|
||||||
|
returned_models.extend(valid_models)
|
||||||
|
return returned_models
|
|
@ -112,6 +112,11 @@ from litellm.router import ModelInfo as RouterModelInfo
|
||||||
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
from litellm._logging import verbose_router_logger, verbose_proxy_logger
|
||||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||||
from litellm.proxy.auth.litellm_license import LicenseCheck
|
from litellm.proxy.auth.litellm_license import LicenseCheck
|
||||||
|
from litellm.proxy.auth.model_checks import (
|
||||||
|
get_complete_model_list,
|
||||||
|
get_key_models,
|
||||||
|
get_team_models,
|
||||||
|
)
|
||||||
from litellm.proxy.hooks.prompt_injection_detection import (
|
from litellm.proxy.hooks.prompt_injection_detection import (
|
||||||
_OPTIONAL_PromptInjectionDetection,
|
_OPTIONAL_PromptInjectionDetection,
|
||||||
)
|
)
|
||||||
|
@ -266,10 +271,6 @@ class UserAPIKeyCacheTTLEnum(enum.Enum):
|
||||||
in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>`
|
in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: <your-value>`
|
||||||
|
|
||||||
|
|
||||||
class SpecialModelNames(enum.Enum):
|
|
||||||
all_team_models = "all-team-models"
|
|
||||||
|
|
||||||
|
|
||||||
class CommonProxyErrors(enum.Enum):
|
class CommonProxyErrors(enum.Enum):
|
||||||
db_not_connected_error = "DB not connected"
|
db_not_connected_error = "DB not connected"
|
||||||
no_llm_router = "No models configured on proxy"
|
no_llm_router = "No models configured on proxy"
|
||||||
|
@ -3778,21 +3779,24 @@ def model_list(
|
||||||
):
|
):
|
||||||
global llm_model_list, general_settings
|
global llm_model_list, general_settings
|
||||||
all_models = []
|
all_models = []
|
||||||
if len(user_api_key_dict.models) > 0:
|
## CHECK IF MODEL RESTRICTIONS ARE SET AT KEY/TEAM LEVEL ##
|
||||||
all_models = user_api_key_dict.models
|
if llm_model_list is None:
|
||||||
if SpecialModelNames.all_team_models.value in all_models:
|
proxy_model_list = []
|
||||||
all_models = user_api_key_dict.team_models
|
else:
|
||||||
if len(all_models) == 0: # has all proxy models
|
proxy_model_list = [m["model_name"] for m in llm_model_list]
|
||||||
## if no specific model access
|
key_models = get_key_models(
|
||||||
if general_settings.get("infer_model_from_keys", False):
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||||
all_models = litellm.utils.get_valid_models()
|
)
|
||||||
if llm_model_list:
|
team_models = get_team_models(
|
||||||
all_models = list(
|
user_api_key_dict=user_api_key_dict, proxy_model_list=proxy_model_list
|
||||||
set(all_models + [m["model_name"] for m in llm_model_list])
|
)
|
||||||
|
all_models = get_complete_model_list(
|
||||||
|
key_models=key_models,
|
||||||
|
team_models=team_models,
|
||||||
|
proxy_model_list=proxy_model_list,
|
||||||
|
user_model=user_model,
|
||||||
|
infer_model_from_keys=general_settings.get("infer_model_from_keys", False),
|
||||||
)
|
)
|
||||||
if user_model is not None:
|
|
||||||
all_models += [user_model]
|
|
||||||
verbose_proxy_logger.debug("all_models: %s", all_models)
|
|
||||||
return dict(
|
return dict(
|
||||||
data=[
|
data=[
|
||||||
{
|
{
|
||||||
|
|
|
@ -11441,11 +11441,8 @@ class CustomStreamWrapper:
|
||||||
self.response_id = original_chunk.id
|
self.response_id = original_chunk.id
|
||||||
if len(original_chunk.choices) > 0:
|
if len(original_chunk.choices) > 0:
|
||||||
delta = original_chunk.choices[0].delta
|
delta = original_chunk.choices[0].delta
|
||||||
if (
|
if delta is not None and (
|
||||||
delta is not None and (
|
delta.function_call is not None or delta.tool_calls is not None
|
||||||
delta.function_call is not None
|
|
||||||
or delta.tool_calls is not None
|
|
||||||
)
|
|
||||||
):
|
):
|
||||||
try:
|
try:
|
||||||
model_response.system_fingerprint = (
|
model_response.system_fingerprint = (
|
||||||
|
@ -11506,7 +11503,11 @@ class CustomStreamWrapper:
|
||||||
model_response.choices[0].delta = Delta()
|
model_response.choices[0].delta = Delta()
|
||||||
else:
|
else:
|
||||||
try:
|
try:
|
||||||
delta = dict() if original_chunk.choices[0].delta is None else dict(original_chunk.choices[0].delta)
|
delta = (
|
||||||
|
dict()
|
||||||
|
if original_chunk.choices[0].delta is None
|
||||||
|
else dict(original_chunk.choices[0].delta)
|
||||||
|
)
|
||||||
print_verbose(f"original delta: {delta}")
|
print_verbose(f"original delta: {delta}")
|
||||||
model_response.choices[0].delta = Delta(**delta)
|
model_response.choices[0].delta = Delta(**delta)
|
||||||
print_verbose(
|
print_verbose(
|
||||||
|
@ -12256,7 +12257,7 @@ def trim_messages(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def get_valid_models():
|
def get_valid_models() -> List[str]:
|
||||||
"""
|
"""
|
||||||
Returns a list of valid LLMs based on the set environment variables
|
Returns a list of valid LLMs based on the set environment variables
|
||||||
|
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
## Tests /key endpoints.
|
## Tests /key endpoints.
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
import asyncio, time
|
import asyncio, time, uuid
|
||||||
import aiohttp
|
import aiohttp
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
import sys, os
|
import sys, os
|
||||||
|
@ -14,12 +14,14 @@ sys.path.insert(
|
||||||
import litellm
|
import litellm
|
||||||
|
|
||||||
|
|
||||||
async def generate_team(session):
|
async def generate_team(
|
||||||
|
session, models: Optional[list] = None, team_id: Optional[str] = None
|
||||||
|
):
|
||||||
url = "http://0.0.0.0:4000/team/new"
|
url = "http://0.0.0.0:4000/team/new"
|
||||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||||
data = {
|
if team_id is None:
|
||||||
"team_id": "litellm-dashboard",
|
team_id = "litellm-dashboard"
|
||||||
}
|
data = {"team_id": team_id, "models": models}
|
||||||
|
|
||||||
async with session.post(url, headers=headers, json=data) as response:
|
async with session.post(url, headers=headers, json=data) as response:
|
||||||
status = response.status
|
status = response.status
|
||||||
|
@ -746,19 +748,25 @@ async def test_key_delete_ui():
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
|
@pytest.mark.parametrize("model_access", ["all-team-models", "gpt-3.5-turbo"])
|
||||||
|
@pytest.mark.parametrize("model_access_level", ["key", "team"])
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_key_model_list(model_access):
|
async def test_key_model_list(model_access, model_access_level):
|
||||||
"""
|
"""
|
||||||
Test if `/v1/models` works as expected.
|
Test if `/v1/models` works as expected.
|
||||||
"""
|
"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
new_team = await generate_team(session=session)
|
_models = [] if model_access == "all-team-models" else [model_access]
|
||||||
team_id = "litellm-dashboard"
|
team_id = "litellm_dashboard_{}".format(uuid.uuid4())
|
||||||
|
new_team = await generate_team(
|
||||||
|
session=session,
|
||||||
|
models=_models if model_access_level == "team" else None,
|
||||||
|
team_id=team_id,
|
||||||
|
)
|
||||||
key_gen = await generate_key(
|
key_gen = await generate_key(
|
||||||
session=session,
|
session=session,
|
||||||
i=0,
|
i=0,
|
||||||
team_id=team_id,
|
team_id=team_id,
|
||||||
models=[] if model_access == "all-team-models" else [model_access],
|
models=_models if model_access_level == "key" else [],
|
||||||
)
|
)
|
||||||
key = key_gen["key"]
|
key = key_gen["key"]
|
||||||
print(f"key: {key}")
|
print(f"key: {key}")
|
||||||
|
@ -770,5 +778,9 @@ async def test_key_model_list(model_access):
|
||||||
assert not isinstance(model_list["data"][0]["id"], list)
|
assert not isinstance(model_list["data"][0]["id"], list)
|
||||||
assert isinstance(model_list["data"][0]["id"], str)
|
assert isinstance(model_list["data"][0]["id"], str)
|
||||||
if model_access == "gpt-3.5-turbo":
|
if model_access == "gpt-3.5-turbo":
|
||||||
assert len(model_list["data"]) == 1
|
assert (
|
||||||
|
len(model_list["data"]) == 1
|
||||||
|
), "model_access={}, model_access_level={}".format(
|
||||||
|
model_access, model_access_level
|
||||||
|
)
|
||||||
assert model_list["data"][0]["id"] == model_access
|
assert model_list["data"][0]["id"] == model_access
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue