mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
Improved wildcard route handling on /models
and /model_group/info
(#8473)
* fix(model_checks.py): update returning known model from wildcard to filter based on given model prefix ensures wildcard route - `vertex_ai/gemini-*` just returns known vertex_ai/gemini- models * test(test_proxy_utils.py): add unit testing for new 'get_known_models_from_wildcard' helper * test(test_models.py): add e2e testing for `/model_group/info` endpoint * feat(prometheus.py): support tracking total requests by user_email on prometheus adds initial support for tracking total requests by user_email * test(test_prometheus.py): add testing to ensure user email is always tracked * test: update testing for new prometheus metric * test(test_prometheus_unit_tests.py): add user email to total proxy metric * test: update tests * test: fix spend tests * test: fix test * fix(pagerduty.py): fix linting error
This commit is contained in:
parent
5e58ae0347
commit
57e5ec07cc
15 changed files with 190 additions and 38 deletions
|
@ -118,6 +118,7 @@ class PagerDutyAlerting(SlackAlerting):
|
|||
user_api_key_user_id=_meta.get("user_api_key_user_id"),
|
||||
user_api_key_team_alias=_meta.get("user_api_key_team_alias"),
|
||||
user_api_key_end_user_id=_meta.get("user_api_key_end_user_id"),
|
||||
user_api_key_user_email=_meta.get("user_api_key_user_email"),
|
||||
)
|
||||
)
|
||||
|
||||
|
@ -195,6 +196,7 @@ class PagerDutyAlerting(SlackAlerting):
|
|||
user_api_key_user_id=user_api_key_dict.user_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
)
|
||||
)
|
||||
|
||||
|
|
|
@ -423,6 +423,7 @@ class PrometheusLogger(CustomLogger):
|
|||
team=user_api_team,
|
||||
team_alias=user_api_team_alias,
|
||||
user=user_id,
|
||||
user_email=standard_logging_payload["metadata"]["user_api_key_user_email"],
|
||||
status_code="200",
|
||||
model=model,
|
||||
litellm_model_name=model,
|
||||
|
@ -806,6 +807,7 @@ class PrometheusLogger(CustomLogger):
|
|||
enum_values = UserAPIKeyLabelValues(
|
||||
end_user=user_api_key_dict.end_user_id,
|
||||
user=user_api_key_dict.user_id,
|
||||
user_email=user_api_key_dict.user_email,
|
||||
hashed_api_key=user_api_key_dict.api_key,
|
||||
api_key_alias=user_api_key_dict.key_alias,
|
||||
team=user_api_key_dict.team_id,
|
||||
|
@ -853,6 +855,7 @@ class PrometheusLogger(CustomLogger):
|
|||
team=user_api_key_dict.team_id,
|
||||
team_alias=user_api_key_dict.team_alias,
|
||||
user=user_api_key_dict.user_id,
|
||||
user_email=user_api_key_dict.user_email,
|
||||
status_code="200",
|
||||
)
|
||||
_labels = prometheus_label_factory(
|
||||
|
|
|
@ -2894,6 +2894,7 @@ class StandardLoggingPayloadSetup:
|
|||
user_api_key_org_id=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_team_alias=None,
|
||||
user_api_key_user_email=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
requester_metadata=None,
|
||||
|
@ -3328,6 +3329,7 @@ def get_standard_logging_metadata(
|
|||
user_api_key_team_id=None,
|
||||
user_api_key_org_id=None,
|
||||
user_api_key_user_id=None,
|
||||
user_api_key_user_email=None,
|
||||
user_api_key_team_alias=None,
|
||||
spend_logs_metadata=None,
|
||||
requester_ip_address=None,
|
||||
|
|
|
@ -5,6 +5,11 @@ model_list:
|
|||
- model_name: gpt-4
|
||||
litellm_params:
|
||||
model: gpt-3.5-turbo
|
||||
- model_name: fake-openai-endpoint
|
||||
litellm_params:
|
||||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: azure-gpt-35-turbo
|
||||
litellm_params:
|
||||
model: azure/chatgpt-v-2
|
||||
|
@ -33,28 +38,14 @@ model_list:
|
|||
model: openai/fake
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app/
|
||||
- model_name: vertex_ai/gemini-*
|
||||
litellm_params:
|
||||
model: vertex_ai/gemini-*
|
||||
- model_name: fake-azure-endpoint
|
||||
litellm_params:
|
||||
model: openai/429
|
||||
api_key: fake-key
|
||||
api_base: https://exampleopenaiendpoint-production.up.railway.app
|
||||
|
||||
litellm_settings:
|
||||
cache: true
|
||||
|
||||
general_settings:
|
||||
enable_jwt_auth: True
|
||||
forward_openai_org_id: True
|
||||
litellm_jwtauth:
|
||||
user_id_jwt_field: "sub"
|
||||
team_ids_jwt_field: "groups"
|
||||
user_id_upsert: true # add user_id to the db if they don't exist
|
||||
enforce_team_based_model_access: true # don't allow users to access models unless the team has access
|
||||
|
||||
router_settings:
|
||||
redis_host: os.environ/REDIS_HOST
|
||||
redis_password: os.environ/REDIS_PASSWORD
|
||||
redis_port: os.environ/REDIS_PORT
|
||||
|
||||
guardrails:
|
||||
- guardrail_name: "aporia-pre-guard"
|
||||
litellm_params:
|
||||
guardrail: aporia # supported values: "aporia", "lakera"
|
||||
mode: "during_call"
|
||||
api_key: os.environ/APORIO_API_KEY
|
||||
api_base: os.environ/APORIO_API_BASE
|
||||
callbacks: ["prometheus"]
|
|
@ -1431,6 +1431,7 @@ class UserAPIKeyAuth(
|
|||
tpm_limit_per_model: Optional[Dict[str, int]] = None
|
||||
user_tpm_limit: Optional[int] = None
|
||||
user_rpm_limit: Optional[int] = None
|
||||
user_email: Optional[str] = None
|
||||
|
||||
model_config = ConfigDict(arbitrary_types_allowed=True)
|
||||
|
||||
|
|
|
@ -17,16 +17,8 @@ def _check_wildcard_routing(model: str) -> bool:
|
|||
- openai/*
|
||||
- *
|
||||
"""
|
||||
if model == "*":
|
||||
if "*" in model:
|
||||
return True
|
||||
|
||||
if "/" in model:
|
||||
llm_provider, potential_wildcard = model.split("/", 1)
|
||||
if (
|
||||
llm_provider in litellm.provider_list and potential_wildcard == "*"
|
||||
): # e.g. anthropic/*
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
|
||||
|
@ -156,6 +148,28 @@ def get_complete_model_list(
|
|||
return list(unique_models) + all_wildcard_models
|
||||
|
||||
|
||||
def get_known_models_from_wildcard(wildcard_model: str) -> List[str]:
|
||||
try:
|
||||
provider, model = wildcard_model.split("/", 1)
|
||||
except ValueError: # safely fail
|
||||
return []
|
||||
# get all known provider models
|
||||
wildcard_models = get_provider_models(provider=provider)
|
||||
if wildcard_models is None:
|
||||
return []
|
||||
if model == "*":
|
||||
return wildcard_models or []
|
||||
else:
|
||||
model_prefix = model.replace("*", "")
|
||||
filtered_wildcard_models = [
|
||||
wc_model
|
||||
for wc_model in wildcard_models
|
||||
if wc_model.split("/")[1].startswith(model_prefix)
|
||||
]
|
||||
|
||||
return filtered_wildcard_models
|
||||
|
||||
|
||||
def _get_wildcard_models(
|
||||
unique_models: Set[str], return_wildcard_routes: Optional[bool] = False
|
||||
) -> List[str]:
|
||||
|
@ -165,13 +179,13 @@ def _get_wildcard_models(
|
|||
if _check_wildcard_routing(model=model):
|
||||
|
||||
if (
|
||||
return_wildcard_routes is True
|
||||
return_wildcard_routes
|
||||
): # will add the wildcard route to the list eg: anthropic/*.
|
||||
all_wildcard_models.append(model)
|
||||
|
||||
provider = model.split("/")[0]
|
||||
# get all known provider models
|
||||
wildcard_models = get_provider_models(provider=provider)
|
||||
wildcard_models = get_known_models_from_wildcard(wildcard_model=model)
|
||||
|
||||
if wildcard_models is not None:
|
||||
models_to_remove.add(model)
|
||||
all_wildcard_models.extend(wildcard_models)
|
||||
|
|
|
@ -1196,6 +1196,7 @@ async def _return_user_api_key_auth_obj(
|
|||
user_api_key_kwargs.update(
|
||||
user_tpm_limit=user_obj.tpm_limit,
|
||||
user_rpm_limit=user_obj.rpm_limit,
|
||||
user_email=user_obj.user_email,
|
||||
)
|
||||
if user_obj is not None and _is_user_proxy_admin(user_obj=user_obj):
|
||||
user_api_key_kwargs.update(
|
||||
|
|
|
@ -312,6 +312,7 @@ class LiteLLMProxyRequestSetup:
|
|||
user_api_key_org_id=user_api_key_dict.org_id,
|
||||
user_api_key_team_alias=user_api_key_dict.team_alias,
|
||||
user_api_key_end_user_id=user_api_key_dict.end_user_id,
|
||||
user_api_key_user_email=user_api_key_dict.user_email,
|
||||
)
|
||||
return user_api_key_logged_metadata
|
||||
|
||||
|
|
|
@ -54,6 +54,7 @@ LATENCY_BUCKETS = (
|
|||
class UserAPIKeyLabelNames(Enum):
|
||||
END_USER = "end_user"
|
||||
USER = "user"
|
||||
USER_EMAIL = "user_email"
|
||||
API_KEY_HASH = "hashed_api_key"
|
||||
API_KEY_ALIAS = "api_key_alias"
|
||||
TEAM = "team"
|
||||
|
@ -123,6 +124,7 @@ class PrometheusMetricLabels:
|
|||
UserAPIKeyLabelNames.TEAM_ALIAS.value,
|
||||
UserAPIKeyLabelNames.USER.value,
|
||||
UserAPIKeyLabelNames.STATUS_CODE.value,
|
||||
UserAPIKeyLabelNames.USER_EMAIL.value,
|
||||
]
|
||||
|
||||
litellm_proxy_failed_requests_metric = [
|
||||
|
@ -156,6 +158,7 @@ class PrometheusMetricLabels:
|
|||
UserAPIKeyLabelNames.TEAM.value,
|
||||
UserAPIKeyLabelNames.TEAM_ALIAS.value,
|
||||
UserAPIKeyLabelNames.USER.value,
|
||||
UserAPIKeyLabelNames.USER_EMAIL.value,
|
||||
]
|
||||
|
||||
litellm_input_tokens_metric = [
|
||||
|
@ -240,6 +243,9 @@ class UserAPIKeyLabelValues(BaseModel):
|
|||
user: Annotated[
|
||||
Optional[str], Field(..., alias=UserAPIKeyLabelNames.USER.value)
|
||||
] = None
|
||||
user_email: Annotated[
|
||||
Optional[str], Field(..., alias=UserAPIKeyLabelNames.USER_EMAIL.value)
|
||||
] = None
|
||||
hashed_api_key: Annotated[
|
||||
Optional[str], Field(..., alias=UserAPIKeyLabelNames.API_KEY_HASH.value)
|
||||
] = None
|
||||
|
|
|
@ -1504,6 +1504,7 @@ class StandardLoggingUserAPIKeyMetadata(TypedDict):
|
|||
user_api_key_org_id: Optional[str]
|
||||
user_api_key_team_id: Optional[str]
|
||||
user_api_key_user_id: Optional[str]
|
||||
user_api_key_user_email: Optional[str]
|
||||
user_api_key_team_alias: Optional[str]
|
||||
user_api_key_end_user_id: Optional[str]
|
||||
|
||||
|
|
|
@ -272,6 +272,7 @@ def validate_redacted_message_span_attributes(span):
|
|||
"metadata.user_api_key_user_id",
|
||||
"metadata.user_api_key_org_id",
|
||||
"metadata.user_api_key_end_user_id",
|
||||
"metadata.user_api_key_user_email",
|
||||
"metadata.applied_guardrails",
|
||||
]
|
||||
|
||||
|
|
|
@ -73,6 +73,7 @@ def create_standard_logging_payload() -> StandardLoggingPayload:
|
|||
user_api_key_alias="test_alias",
|
||||
user_api_key_team_id="test_team",
|
||||
user_api_key_user_id="test_user",
|
||||
user_api_key_user_email="test@example.com",
|
||||
user_api_key_team_alias="test_team_alias",
|
||||
user_api_key_org_id=None,
|
||||
spend_logs_metadata=None,
|
||||
|
@ -475,6 +476,7 @@ def test_increment_top_level_request_and_spend_metrics(prometheus_logger):
|
|||
team="test_team",
|
||||
team_alias="test_team_alias",
|
||||
model="gpt-3.5-turbo",
|
||||
user_email=None,
|
||||
)
|
||||
prometheus_logger.litellm_requests_metric.labels().inc.assert_called_once()
|
||||
|
||||
|
@ -631,6 +633,7 @@ async def test_async_post_call_failure_hook(prometheus_logger):
|
|||
team_alias="test_team_alias",
|
||||
user="test_user",
|
||||
status_code="429",
|
||||
user_email=None,
|
||||
)
|
||||
prometheus_logger.litellm_proxy_total_requests_metric.labels().inc.assert_called_once()
|
||||
|
||||
|
@ -674,6 +677,7 @@ async def test_async_post_call_success_hook(prometheus_logger):
|
|||
team_alias="test_team_alias",
|
||||
user="test_user",
|
||||
status_code="200",
|
||||
user_email=None,
|
||||
)
|
||||
prometheus_logger.litellm_proxy_total_requests_metric.labels().inc.assert_called_once()
|
||||
|
||||
|
|
|
@ -111,12 +111,12 @@ async def test_proxy_failure_metrics():
|
|||
|
||||
assert (
|
||||
expected_metric in metrics
|
||||
), "Expected failure metric not found in /metrics"
|
||||
expected_llm_deployment_failure = 'litellm_deployment_failure_responses_total{api_base="https://exampleopenaiendpoint-production.up.railway.app",api_provider="openai",exception_class="RateLimitError",exception_status="429",litellm_model_name="429",model_id="7499d31f98cd518cf54486d5a00deda6894239ce16d13543398dc8abf870b15f",requested_model="fake-azure-endpoint"} 1.0'
|
||||
), "Expected failure metric not found in /metrics."
|
||||
expected_llm_deployment_failure = 'litellm_deployment_failure_responses_total{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",requested_model="fake-azure-endpoint",status_code="429",team="None",team_alias="None",user="default_user_id",user_email="None"} 1.0'
|
||||
assert expected_llm_deployment_failure
|
||||
|
||||
assert (
|
||||
'litellm_proxy_total_requests_metric_total{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",requested_model="fake-azure-endpoint",status_code="429",team="None",team_alias="None",user="default_user_id"} 1.0'
|
||||
'litellm_proxy_total_requests_metric_total{api_key_alias="None",end_user="None",hashed_api_key="88dc28d0f030c55ed4ab77ed8faf098196cb1c05df778539800c9f1243fe6b4b",requested_model="fake-azure-endpoint",status_code="429",team="None",team_alias="None",user="default_user_id",user_email="None"} 1.0'
|
||||
in metrics
|
||||
)
|
||||
|
||||
|
@ -258,6 +258,24 @@ async def create_test_team(
|
|||
return team_info["team_id"]
|
||||
|
||||
|
||||
async def create_test_user(
|
||||
session: aiohttp.ClientSession, user_data: Dict[str, Any]
|
||||
) -> str:
|
||||
"""Create a new user and return the user_id"""
|
||||
url = "http://0.0.0.0:4000/user/new"
|
||||
headers = {
|
||||
"Authorization": "Bearer sk-1234",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with session.post(url, headers=headers, json=user_data) as response:
|
||||
assert (
|
||||
response.status == 200
|
||||
), f"Failed to create user. Status: {response.status}"
|
||||
user_info = await response.json()
|
||||
return user_info
|
||||
|
||||
|
||||
async def get_prometheus_metrics(session: aiohttp.ClientSession) -> str:
|
||||
"""Fetch current prometheus metrics"""
|
||||
async with session.get("http://0.0.0.0:4000/metrics") as response:
|
||||
|
@ -526,3 +544,38 @@ async def test_key_budget_metrics():
|
|||
assert (
|
||||
abs(key_info_remaining_budget - first_budget["remaining"]) <= 0.00000
|
||||
), f"Spend mismatch: Prometheus={key_info_remaining_budget}, Key Info={first_budget['remaining']}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_user_email_metrics():
|
||||
"""
|
||||
Test user email tracking metrics:
|
||||
1. Create a user with user_email
|
||||
2. Make chat completion requests using OpenAI SDK with the user's email
|
||||
3. Verify user email is being tracked correctly in `litellm_user_email_metric`
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
# Create a user with user_email
|
||||
user_data = {
|
||||
"user_email": "test@example.com",
|
||||
}
|
||||
user_info = await create_test_user(session, user_data)
|
||||
key = user_info["key"]
|
||||
|
||||
# Initialize OpenAI client with the user's email
|
||||
client = AsyncOpenAI(base_url="http://0.0.0.0:4000", api_key=key)
|
||||
|
||||
# Make initial request and check budget
|
||||
await client.chat.completions.create(
|
||||
model="fake-openai-endpoint",
|
||||
messages=[{"role": "user", "content": f"Hello {uuid.uuid4()}"}],
|
||||
)
|
||||
|
||||
await asyncio.sleep(11) # Wait for metrics to update
|
||||
|
||||
# Get metrics after request
|
||||
metrics_after_first = await get_prometheus_metrics(session)
|
||||
print("metrics_after_first request", metrics_after_first)
|
||||
assert (
|
||||
"test@example.com" in metrics_after_first
|
||||
), "user_email should be tracked correctly"
|
||||
|
|
|
@ -1618,6 +1618,30 @@ def test_provider_specific_header():
|
|||
},
|
||||
}
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"wildcard_model, expected_models",
|
||||
[
|
||||
(
|
||||
"anthropic/*",
|
||||
["anthropic/claude-3-5-haiku-20241022", "anthropic/claude-3-opus-20240229"],
|
||||
),
|
||||
(
|
||||
"vertex_ai/gemini-*",
|
||||
["vertex_ai/gemini-1.5-flash", "vertex_ai/gemini-1.5-pro"],
|
||||
),
|
||||
],
|
||||
)
|
||||
def test_get_known_models_from_wildcard(wildcard_model, expected_models):
|
||||
from litellm.proxy.auth.model_checks import get_known_models_from_wildcard
|
||||
|
||||
wildcard_models = get_known_models_from_wildcard(wildcard_model=wildcard_model)
|
||||
# Check if all expected models are in the returned list
|
||||
print(f"wildcard_models: {wildcard_models}\n")
|
||||
for model in expected_models:
|
||||
if model not in wildcard_models:
|
||||
print(f"Missing expected model: {model}")
|
||||
|
||||
assert all(model in wildcard_models for model in expected_models)
|
||||
|
||||
@pytest.mark.parametrize(
|
||||
"data, user_api_key_dict, expected_model",
|
||||
|
@ -1667,3 +1691,4 @@ def test_update_model_if_team_alias_exists(data, user_api_key_dict, expected_mod
|
|||
|
||||
# Check if model was updated correctly
|
||||
assert test_data.get("model") == expected_model
|
||||
|
||||
|
|
|
@ -47,6 +47,7 @@ async def get_models(session, key):
|
|||
|
||||
if status != 200:
|
||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||
return await response.json()
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
@ -112,6 +113,24 @@ async def get_model_info(session, key, litellm_model_id=None):
|
|||
return await response.json()
|
||||
|
||||
|
||||
async def get_model_group_info(session, key):
|
||||
url = "http://0.0.0.0:4000/model_group/info"
|
||||
headers = {
|
||||
"Authorization": f"Bearer {key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
|
||||
async with session.get(url, headers=headers) as response:
|
||||
status = response.status
|
||||
response_text = await response.text()
|
||||
print(response_text)
|
||||
print()
|
||||
|
||||
if status != 200:
|
||||
raise Exception(f"Request did not return a 200 status code: {status}")
|
||||
return await response.json()
|
||||
|
||||
|
||||
async def chat_completion(session, key, model="azure-gpt-3.5"):
|
||||
url = "http://0.0.0.0:4000/chat/completions"
|
||||
headers = {
|
||||
|
@ -394,3 +413,31 @@ async def test_add_model_run_health():
|
|||
|
||||
# cleanup
|
||||
await delete_model(session=session, model_id=model_id)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_model_group_info_e2e():
|
||||
"""
|
||||
Test /model/group/info endpoint
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
models = await get_models(session=session, key="sk-1234")
|
||||
print(models)
|
||||
|
||||
expected_models = [
|
||||
"anthropic/claude-3-5-haiku-20241022",
|
||||
"anthropic/claude-3-opus-20240229",
|
||||
]
|
||||
|
||||
model_group_info = await get_model_group_info(session=session, key="sk-1234")
|
||||
print(model_group_info)
|
||||
|
||||
has_anthropic_claude_3_5_haiku = False
|
||||
has_anthropic_claude_3_opus = False
|
||||
for model in model_group_info["data"]:
|
||||
if model["model_group"] == "anthropic/claude-3-5-haiku-20241022":
|
||||
has_anthropic_claude_3_5_haiku = True
|
||||
if model["model_group"] == "anthropic/claude-3-opus-20240229":
|
||||
has_anthropic_claude_3_opus = True
|
||||
|
||||
assert has_anthropic_claude_3_5_haiku and has_anthropic_claude_3_opus
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue