Compare commits

...
Sign in to create a new pull request.

10 commits

Author SHA1 Message Date
Ishaan Jaff
6f34169f17 add e2e tests for keys with regex patterns for /models and /model/info 2024-11-27 18:38:36 -08:00
Ishaan Jaff
a4491658e3 add "openai/*" in allowed models check 2024-11-25 15:50:58 -08:00
Ishaan Jaff
632283a371 add coverage 2024-11-25 15:48:40 -08:00
Ishaan Jaff
a18aeaa2fb add test_regex_pattern_matching_e2e_test 2024-11-25 15:39:42 -08:00
Ishaan Jaff
f76d0e41ef add unit testing for new model regex pattern matching 2024-11-25 15:35:53 -08:00
Ishaan Jaff
79d51e0bdd add doc string 2024-11-25 15:31:56 -08:00
Ishaan Jaff
76dec9a0e8 fix error msg 2024-11-25 14:32:26 -08:00
Ishaan Jaff
ba4e1a7a0d fix model auth checks 2024-11-25 14:02:31 -08:00
Ishaan Jaff
e9cdbff75a add check for _model_is_within_list_of_allowed_models 2024-11-25 13:58:08 -08:00
Ishaan Jaff
0e36333051 add check for model_matches_patterns 2024-11-25 13:50:33 -08:00
6 changed files with 225 additions and 39 deletions

View file

@ -9,6 +9,7 @@ Run checks for:
3. If end_user ('user' passed to /chat/completions, /embeddings endpoint) is in budget
"""
import re
import time
import traceback
from datetime import datetime
@ -34,6 +35,7 @@ from litellm.proxy._types import (
)
from litellm.proxy.auth.route_checks import RouteChecks
from litellm.proxy.utils import PrismaClient, ProxyLogging, log_db_metrics
from litellm.router_utils.pattern_match_deployments import PatternMatchRouter
from litellm.types.services import ServiceLoggerPayload, ServiceTypes
from .auth_checks_organization import organization_role_based_access_check
@ -48,8 +50,8 @@ else:
last_db_access_time = LimitedSizeOrderedDict(max_size=100)
db_cache_expiry = 5 # refresh every 5s
all_routes = LiteLLMRoutes.openai_routes.value + LiteLLMRoutes.management_routes.value
pattern_router = PatternMatchRouter()
def common_checks( # noqa: PLR0915
@ -90,17 +92,15 @@ def common_checks( # noqa: PLR0915
):
# this means the team has access to all models on the proxy
if (
"all-proxy-models" in team_object.models
or "*" in team_object.models
or "openai/*" in team_object.models
_model_is_within_list_of_allowed_models(
model=_model, allowed_models=team_object.models
)
is True
):
# this means the team has access to all models on the proxy
pass
# check if the team model is an access_group
elif model_in_access_group(_model, team_object.models) is True:
pass
elif _model and "*" in _model:
pass
else:
raise Exception(
f"Team={team_object.team_id} not allowed to call model={_model}. Allowed team models = {team_object.models}"
@ -828,7 +828,7 @@ async def can_key_call_model(
model: str, llm_model_list: Optional[list], valid_token: UserAPIKeyAuth
) -> Literal[True]:
"""
Checks if token can call a given model
Checks if token can call a given model, supporting regex/wildcard patterns
Returns:
- True: if token allowed to call model
@ -863,25 +863,81 @@ async def can_key_call_model(
# Filter out models that are access_groups
filtered_models = [m for m in valid_token.models if m not in access_groups]
filtered_models += models_in_current_access_groups
verbose_proxy_logger.debug(f"model: {model}; allowed_models: {filtered_models}")
all_model_access: bool = False
if (
len(filtered_models) == 0
or "*" in filtered_models
or "openai/*" in filtered_models
_model_is_within_list_of_allowed_models(
model=model, allowed_models=filtered_models
)
is False
):
all_model_access = True
if model is not None and model not in filtered_models and all_model_access is False:
raise ValueError(
f"API Key not allowed to access model. This token can only access models={valid_token.models}. Tried to access {model}"
)
valid_token.models = filtered_models
verbose_proxy_logger.debug(
f"filtered allowed_models: {filtered_models}; valid_token.models: {valid_token.models}"
)
return True
def _model_is_within_list_of_allowed_models(
model: str, allowed_models: List[str]
) -> bool:
"""
Checks if a model is within a list of allowed models (includes pattern matching checks)
Args:
model (str): The model to check (e.g., "custom_engine/model-123")
allowed_models (List[str]): List of allowed model patterns (e.g., ["custom_engine/*", "azure/gpt-4", "claude-sonnet"])
Returns:
bool: True if
- len(allowed_models) == 0
- "*" in allowed_models (means all models are allowed)
- "all-proxy-models" in allowed_models (means all models are allowed)
- model is in allowed_models list
- model matches any allowed pattern
"""
# Check for universal access patterns
if len(allowed_models) == 0:
return True
if "*" in allowed_models:
return True
if "all-proxy-models" in allowed_models:
return True
# Note: This is here to maintain backwards compatibility. This Used to be in our code and removing could impact existing customer code
if "openai/*" in allowed_models:
return True
if model in allowed_models:
return True
if _model_matches_patterns(model=model, allowed_models=allowed_models) is True:
return True
return False
def _model_matches_patterns(model: str, allowed_models: List[str]) -> bool:
"""
Helper function to check if a model matches any of the allowed model patterns.
Args:
model (str): The model to check (e.g., "custom_engine/model-123")
allowed_models (List[str]): List of allowed model patterns (e.g., ["custom_engine/*", "azure/gpt-4*"])
Returns:
bool: True if model matches any allowed pattern, False otherwise
"""
try:
# Create pattern router instance
for _model in allowed_models:
if "*" in _model:
regex_pattern = pattern_router._pattern_to_regex(_model)
if re.match(regex_pattern, model):
return True
return False
except Exception as e:
verbose_proxy_logger.exception(f"Error in model_matches_patterns: {str(e)}")
return False

View file

@ -1,24 +1,5 @@
model_list:
- model_name: gpt-4o
- model_name: custom_engine/*
litellm_params:
model: openai/gpt-4o
api_base: https://exampleopenaiendpoint-production.up.railway.app/
- model_name: fake-anthropic-endpoint
litellm_params:
model: anthropic/fake
api_base: https://exampleanthropicendpoint-production.up.railway.app/
router_settings:
provider_budget_config:
openai:
budget_limit: 0.3 # float of $ value budget for time period
time_period: 1d # can be 1d, 2d, 30d
anthropic:
budget_limit: 5
time_period: 1d
redis_host: os.environ/REDIS_HOST
redis_port: os.environ/REDIS_PORT
redis_password: os.environ/REDIS_PASSWORD
litellm_settings:
callbacks: ["prometheus"]
model: openai/custom_engine
api_base: https://exampleopenaiendpoint-production.up.railway.app/

View file

@ -85,6 +85,10 @@ model_list:
litellm_params:
model: openai/*
api_key: os.environ/OPENAI_API_KEY
- model_name: custom_engine/* # regex pattern matching test for /models and /model/info endpoints
litellm_params:
model: openai/custom_engine
api_base: https://exampleopenaiendpoint-production.up.railway.app/
# provider specific wildcard routing
@ -96,6 +100,11 @@ model_list:
litellm_params:
model: "groq/*"
api_key: os.environ/GROQ_API_KEY
- model_name: "custom-llm-engine/*"
litellm_params:
model: "openai/my-fake-model"
api_base: https://exampleopenaiendpoint-production.up.railway.app/
api_key: fake-key
- model_name: mistral-embed
litellm_params:
model: mistral/mistral-embed

View file

@ -15,6 +15,10 @@ from starlette.datastructures import URL
import litellm
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.auth.auth_checks import (
_model_is_within_list_of_allowed_models,
_model_matches_patterns,
)
class Request:
@ -387,3 +391,62 @@ def test_is_api_route_allowed(route, user_role, expected_result):
pass
else:
raise e
@pytest.mark.parametrize(
"model, allowed_models, expected_result",
[
# Empty allowed models list
("gpt-4", [], True),
# Universal access patterns
("gpt-4", ["*"], True),
("gpt-4", ["all-proxy-models"], True),
# Exact matches
("gpt-4", ["gpt-4"], True),
("gpt-4", ["gpt-3.5", "claude-2"], False),
# Pattern matching
("azure/gpt-4", ["azure/*"], True),
("custom_engine/model-123", ["custom_engine/*"], True),
("custom_engine/model-123", ["custom_engine/*", "azure/*"], True),
("custom-engine/model-123", ["custom_engine/*", "azure/*"], False),
("openai/gpt-4o", ["openai/*"], True),
("openai/gpt-12", ["openai/*"], True),
("gpt-4", ["gpt-*"], True),
("gpt-4", ["claude-*"], False),
# Mixed scenarios
("gpt-4", ["claude-2", "gpt-*", "palm-2"], True),
("anthropic/claude-instant-1", ["anthropic/*", "gpt-4"], True),
],
)
def test_model_is_within_list_of_allowed_models(model, allowed_models, expected_result):
result = _model_is_within_list_of_allowed_models(
model=model, allowed_models=allowed_models
)
assert result == expected_result
@pytest.mark.parametrize(
"model, allowed_models, expected_result",
[
# Basic pattern matching
("gpt-4", ["gpt-*"], True),
("azure/gpt-4", ["azure/*"], True),
("custom_engine/model-123", ["custom_engine/*"], True),
("openai/my-fake-model", ["openai/*"], True),
("custom_engine/model-123", ["custom_engine/*", "azure/*"], True),
("custom-engine/model-123", ["custom_engine/*", "azure/*"], False),
# Multiple patterns
("gpt-4", ["claude-*", "gpt-*"], True),
("anthropic/claude-instant-1", ["anthropic/*", "gpt-*"], True),
# No matches
("gpt-4", ["claude-*"], False),
("azure/gpt-4", ["anthropic/*"], False),
# Edge cases
("gpt-4", ["gpt-4"], False), # No wildcard, should return False
("model", ["*-suffix"], False),
("prefix-model", ["prefix-*"], True),
],
)
def test_model_matches_patterns(model, allowed_models, expected_result):
result = _model_matches_patterns(model=model, allowed_models=allowed_models)
assert result == expected_result

View file

@ -48,6 +48,8 @@ async def get_models(session, key):
if status != 200:
raise Exception(f"Request did not return a 200 status code: {status}")
return await response.json()
@pytest.mark.asyncio
async def test_get_models():
@ -362,3 +364,60 @@ async def test_add_model_run_health():
# cleanup
await delete_model(session=session, model_id=model_id)
@pytest.mark.asyncio
async def test_wildcard_model_access():
"""
Test key generation with wildcard model access pattern (custom_llm/*)
- Generate key with access to 'custom_llm/*'
- Call /models and /model/info to verify access
"""
async with aiohttp.ClientSession() as session:
# Generate key with wildcard access
key_gen = await generate_key(session=session, models=["custom_engine/*"])
key = key_gen["key"]
# Get models list
print("\nTesting /models endpoint with wildcard key")
models_response = await get_models(session=session, key=key)
# verify /models response
_data = models_response["data"]
found_custom_engine_model = False
for model in _data:
if model["id"] == "custom_engine/*":
found_custom_engine_model = True
assert model["object"] == "model", "Incorrect object type"
assert model["owned_by"] == "openai", "Incorrect owner"
break
assert (
found_custom_engine_model is True
), "custom_engine/* model not found in response"
# Get detailed model info
print("\nTesting /model/info endpoint with wildcard key")
model_info_response = await get_model_info(session=session, key=key)
print("Model info response:", model_info_response)
# Add assertions to verify response content
assert "data" in model_info_response
model_data = model_info_response["data"]
assert len(model_data) > 0
# Find and verify the custom_engine/* model
custom_engine_model = None
for model in model_data:
if model["model_name"] == "custom_engine/*":
custom_engine_model = model
break
assert (
custom_engine_model is not None
), "custom_engine/* model not found in response"
assert (
custom_engine_model["litellm_params"]["api_base"]
== "https://exampleopenaiendpoint-production.up.railway.app/"
)
assert custom_engine_model["litellm_params"]["model"] == "openai/custom_engine"
assert custom_engine_model["model_info"]["litellm_provider"] == "openai"

View file

@ -473,6 +473,24 @@ async def test_openai_wildcard_chat_completion():
await chat_completion(session=session, key=key, model="gpt-3.5-turbo-0125")
@pytest.mark.asyncio
async def test_regex_pattern_matching_e2e_test():
"""
- Create key for model = "custom-llm-engine/*" -> this has access to all models matching pattern = "custom-llm-engine/*"
- proxy_server_config.yaml has model = "custom-llm-engine/*"
- Make chat completion call
"""
async with aiohttp.ClientSession() as session:
key_gen = await generate_key(session=session, models=["custom-llm-engine/*"])
key = key_gen["key"]
# call chat/completions with a model that the key was not created for + the model is not on the config.yaml
await chat_completion(
session=session, key=key, model="custom-llm-engine/very-new-model"
)
@pytest.mark.asyncio
async def test_proxy_all_models():
"""