forked from phoenix/litellm-mirror
Merge pull request #4217 from BerriAI/litellm_refactor_proxy_server
[Refactor-Proxy] Make proxy_server.py < 10K lines (move management, key, endpoints to their own files)
This commit is contained in:
commit
f84941bdc0
14 changed files with 4512 additions and 4475 deletions
|
@ -1609,3 +1609,10 @@ class ProxyException(Exception):
|
|||
"param": self.param,
|
||||
"code": self.code,
|
||||
}
|
||||
|
||||
|
||||
class CommonProxyErrors(enum.Enum):
|
||||
db_not_connected_error = "DB not connected"
|
||||
no_llm_router = "No models configured on proxy"
|
||||
not_allowed_access = "Admin-only endpoint. Not allowed to access this."
|
||||
not_premium_user = "You must be a LiteLLM Enterprise user to use this feature. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat"
|
||||
|
|
194
litellm/proxy/caching_routes.py
Normal file
194
litellm/proxy/caching_routes.py
Normal file
|
@ -0,0 +1,194 @@
|
|||
from typing import Optional
|
||||
from fastapi import Depends, Request, APIRouter
|
||||
from fastapi import HTTPException
|
||||
import copy
|
||||
import litellm
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
|
||||
|
||||
router = APIRouter(
|
||||
prefix="/cache",
|
||||
tags=["caching"],
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/ping",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def cache_ping():
|
||||
"""
|
||||
Endpoint for checking if cache can be pinged
|
||||
"""
|
||||
try:
|
||||
litellm_cache_params = {}
|
||||
specific_cache_params = {}
|
||||
|
||||
if litellm.cache is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
|
||||
for k, v in vars(litellm.cache).items():
|
||||
try:
|
||||
if k == "cache":
|
||||
continue
|
||||
litellm_cache_params[k] = str(copy.deepcopy(v))
|
||||
except Exception:
|
||||
litellm_cache_params[k] = "<unable to copy or convert>"
|
||||
for k, v in vars(litellm.cache.cache).items():
|
||||
try:
|
||||
specific_cache_params[k] = str(v)
|
||||
except Exception:
|
||||
specific_cache_params[k] = "<unable to copy or convert>"
|
||||
if litellm.cache.type == "redis":
|
||||
# ping the redis cache
|
||||
ping_response = await litellm.cache.ping()
|
||||
verbose_proxy_logger.debug(
|
||||
"/cache/ping: ping_response: " + str(ping_response)
|
||||
)
|
||||
# making a set cache call
|
||||
# add cache does not return anything
|
||||
await litellm.cache.async_add_cache(
|
||||
result="test_key",
|
||||
model="test-model",
|
||||
messages=[{"role": "user", "content": "test from litellm"}],
|
||||
)
|
||||
verbose_proxy_logger.debug("/cache/ping: done with set_cache()")
|
||||
return {
|
||||
"status": "healthy",
|
||||
"cache_type": litellm.cache.type,
|
||||
"ping_response": True,
|
||||
"set_cache_response": "success",
|
||||
"litellm_cache_params": litellm_cache_params,
|
||||
"redis_cache_params": specific_cache_params,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"cache_type": litellm.cache.type,
|
||||
"litellm_cache_params": litellm_cache_params,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Service Unhealthy ({str(e)}).Cache parameters: {litellm_cache_params}.specific_cache_params: {specific_cache_params}",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/delete",
|
||||
tags=["caching"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def cache_delete(request: Request):
|
||||
"""
|
||||
Endpoint for deleting a key from the cache. All responses from litellm proxy have `x-litellm-cache-key` in the headers
|
||||
|
||||
Parameters:
|
||||
- **keys**: *Optional[List[str]]* - A list of keys to delete from the cache. Example {"keys": ["key1", "key2"]}
|
||||
|
||||
```shell
|
||||
curl -X POST "http://0.0.0.0:4000/cache/delete" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d '{"keys": ["key1", "key2"]}'
|
||||
```
|
||||
|
||||
"""
|
||||
try:
|
||||
if litellm.cache is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
|
||||
request_data = await request.json()
|
||||
keys = request_data.get("keys", None)
|
||||
|
||||
if litellm.cache.type == "redis":
|
||||
await litellm.cache.delete_cache_keys(keys=keys)
|
||||
return {
|
||||
"status": "success",
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cache type {litellm.cache.type} does not support deleting a key. only `redis` is supported",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cache Delete Failed({str(e)})",
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/redis/info",
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def cache_redis_info():
|
||||
"""
|
||||
Endpoint for getting /redis/info
|
||||
"""
|
||||
try:
|
||||
if litellm.cache is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
if litellm.cache.type == "redis":
|
||||
client_list = litellm.cache.cache.client_list()
|
||||
redis_info = litellm.cache.cache.info()
|
||||
num_clients = len(client_list)
|
||||
return {
|
||||
"num_clients": num_clients,
|
||||
"clients": client_list,
|
||||
"info": redis_info,
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cache type {litellm.cache.type} does not support flushing",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Service Unhealthy ({str(e)})",
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/flushall",
|
||||
tags=["caching"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def cache_flushall():
|
||||
"""
|
||||
A function to flush all items from the cache. (All items will be deleted from the cache with this)
|
||||
Raises HTTPException if the cache is not initialized or if the cache type does not support flushing.
|
||||
Returns a dictionary with the status of the operation.
|
||||
|
||||
Usage:
|
||||
```
|
||||
curl -X POST http://0.0.0.0:4000/cache/flushall -H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
"""
|
||||
try:
|
||||
if litellm.cache is None:
|
||||
raise HTTPException(
|
||||
status_code=503, detail="Cache not initialized. litellm.cache is None"
|
||||
)
|
||||
if litellm.cache.type == "redis":
|
||||
litellm.cache.cache.flushall()
|
||||
return {
|
||||
"status": "success",
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=500,
|
||||
detail=f"Cache type {litellm.cache.type} does not support flushing",
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=503,
|
||||
detail=f"Service Unhealthy ({str(e)})",
|
||||
)
|
|
@ -1,113 +0,0 @@
|
|||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from litellm.proxy._types import UserAPIKeyAuth, ManagementEndpointLoggingPayload
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm._logging import verbose_logger
|
||||
from fastapi import Request
|
||||
|
||||
|
||||
def management_endpoint_wrapper(func):
|
||||
"""
|
||||
This wrapper does the following:
|
||||
|
||||
1. Log I/O, Exceptions to OTEL
|
||||
2. Create an Audit log for success calls
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
end_time = datetime.now()
|
||||
try:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
user_api_key_dict: UserAPIKeyAuth = (
|
||||
kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
|
||||
)
|
||||
_http_request: Request = kwargs.get("http_request")
|
||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||
if parent_otel_span is not None:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None:
|
||||
if _http_request:
|
||||
_route = _http_request.url.path
|
||||
_request_body: dict = await _read_request_body(
|
||||
request=_http_request
|
||||
)
|
||||
_response = dict(result) if result is not None else None
|
||||
|
||||
logging_payload = ManagementEndpointLoggingPayload(
|
||||
route=_route,
|
||||
request_data=_request_body,
|
||||
response=_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
await open_telemetry_logger.async_management_endpoint_success_hook(
|
||||
logging_payload=logging_payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if _http_request:
|
||||
_route = _http_request.url.path
|
||||
# Flush user_api_key cache if this was an update/delete call to /key, /team, or /user
|
||||
if _route in [
|
||||
"/key/update",
|
||||
"/key/delete",
|
||||
"/team/update",
|
||||
"/team/delete",
|
||||
"/user/update",
|
||||
"/user/delete",
|
||||
"/customer/update",
|
||||
"/customer/delete",
|
||||
]:
|
||||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
user_api_key_cache.flush_cache()
|
||||
except Exception as e:
|
||||
# Non-Blocking Exception
|
||||
verbose_logger.debug("Error in management endpoint wrapper: %s", str(e))
|
||||
pass
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
end_time = datetime.now()
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
user_api_key_dict: UserAPIKeyAuth = (
|
||||
kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
|
||||
)
|
||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||
if parent_otel_span is not None:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None:
|
||||
_http_request: Request = kwargs.get("http_request")
|
||||
if _http_request:
|
||||
_route = _http_request.url.path
|
||||
_request_body: dict = await _read_request_body(
|
||||
request=_http_request
|
||||
)
|
||||
logging_payload = ManagementEndpointLoggingPayload(
|
||||
route=_route,
|
||||
request_data=_request_body,
|
||||
response=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
exception=e,
|
||||
)
|
||||
|
||||
await open_telemetry_logger.async_management_endpoint_failure_hook(
|
||||
logging_payload=logging_payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
return wrapper
|
478
litellm/proxy/health_endpoints/_health_endpoints.py
Normal file
478
litellm/proxy/health_endpoints/_health_endpoints.py
Normal file
|
@ -0,0 +1,478 @@
|
|||
from typing import Optional, Literal
|
||||
import litellm
|
||||
import os
|
||||
import asyncio
|
||||
import fastapi
|
||||
import traceback
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import Depends, Request, APIRouter, Header, status
|
||||
from litellm.proxy.health_check import perform_health_check
|
||||
from fastapi import HTTPException
|
||||
import copy
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
ProxyException,
|
||||
WebhookEvent,
|
||||
CallInfo,
|
||||
)
|
||||
|
||||
#### Health ENDPOINTS ####
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.get(
|
||||
"/test",
|
||||
tags=["health"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def test_endpoint(request: Request):
|
||||
"""
|
||||
[DEPRECATED] use `/health/liveliness` instead.
|
||||
|
||||
A test endpoint that pings the proxy server to check if it's healthy.
|
||||
|
||||
Parameters:
|
||||
request (Request): The incoming request.
|
||||
|
||||
Returns:
|
||||
dict: A dictionary containing the route of the request URL.
|
||||
"""
|
||||
# ping the proxy server to check if its healthy
|
||||
return {"route": request.url.path}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health/services",
|
||||
tags=["health"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
include_in_schema=False,
|
||||
)
|
||||
async def health_services_endpoint(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
service: Literal[
|
||||
"slack_budget_alerts", "langfuse", "slack", "openmeter", "webhook", "email"
|
||||
] = fastapi.Query(description="Specify the service being hit."),
|
||||
):
|
||||
"""
|
||||
Hidden endpoint.
|
||||
|
||||
Used by the UI to let user check if slack alerting is working as expected.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import (
|
||||
proxy_logging_obj,
|
||||
prisma_client,
|
||||
general_settings,
|
||||
)
|
||||
|
||||
if service is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "Service must be specified."}
|
||||
)
|
||||
if service not in [
|
||||
"slack_budget_alerts",
|
||||
"email",
|
||||
"langfuse",
|
||||
"slack",
|
||||
"openmeter",
|
||||
"webhook",
|
||||
]:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Service must be in list. Service={service}. List={['slack_budget_alerts']}"
|
||||
},
|
||||
)
|
||||
|
||||
if service == "openmeter":
|
||||
_ = await litellm.acompletion(
|
||||
model="openai/litellm-mock-response-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
user="litellm:/health/services",
|
||||
mock_response="This is a mock response",
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Mock LLM request made - check openmeter.",
|
||||
}
|
||||
|
||||
if service == "langfuse":
|
||||
from litellm.integrations.langfuse import LangFuseLogger
|
||||
|
||||
langfuse_logger = LangFuseLogger()
|
||||
langfuse_logger.Langfuse.auth_check()
|
||||
_ = litellm.completion(
|
||||
model="openai/litellm-mock-response-model",
|
||||
messages=[{"role": "user", "content": "Hey, how's it going?"}],
|
||||
user="litellm:/health/services",
|
||||
mock_response="This is a mock response",
|
||||
)
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Mock LLM request made - check langfuse.",
|
||||
}
|
||||
|
||||
if service == "webhook":
|
||||
user_info = CallInfo(
|
||||
token=user_api_key_dict.token or "",
|
||||
spend=1,
|
||||
max_budget=0,
|
||||
user_id=user_api_key_dict.user_id,
|
||||
key_alias=user_api_key_dict.key_alias,
|
||||
team_id=user_api_key_dict.team_id,
|
||||
)
|
||||
await proxy_logging_obj.budget_alerts(
|
||||
type="user_budget",
|
||||
user_info=user_info,
|
||||
)
|
||||
|
||||
if service == "slack" or service == "slack_budget_alerts":
|
||||
if "slack" in general_settings.get("alerting", []):
|
||||
# test_message = f"""\n🚨 `ProjectedLimitExceededError` 💸\n\n`Key Alias:` litellm-ui-test-alert \n`Expected Day of Error`: 28th March \n`Current Spend`: $100.00 \n`Projected Spend at end of month`: $1000.00 \n`Soft Limit`: $700"""
|
||||
# check if user has opted into unique_alert_webhooks
|
||||
if (
|
||||
proxy_logging_obj.slack_alerting_instance.alert_to_webhook_url
|
||||
is not None
|
||||
):
|
||||
for (
|
||||
alert_type
|
||||
) in proxy_logging_obj.slack_alerting_instance.alert_to_webhook_url:
|
||||
"""
|
||||
"llm_exceptions",
|
||||
"llm_too_slow",
|
||||
"llm_requests_hanging",
|
||||
"budget_alerts",
|
||||
"db_exceptions",
|
||||
"""
|
||||
# only test alert if it's in active alert types
|
||||
if (
|
||||
proxy_logging_obj.slack_alerting_instance.alert_types
|
||||
is not None
|
||||
and alert_type
|
||||
not in proxy_logging_obj.slack_alerting_instance.alert_types
|
||||
):
|
||||
continue
|
||||
test_message = "default test message"
|
||||
if alert_type == "llm_exceptions":
|
||||
test_message = f"LLM Exception test alert"
|
||||
elif alert_type == "llm_too_slow":
|
||||
test_message = f"LLM Too Slow test alert"
|
||||
elif alert_type == "llm_requests_hanging":
|
||||
test_message = f"LLM Requests Hanging test alert"
|
||||
elif alert_type == "budget_alerts":
|
||||
test_message = f"Budget Alert test alert"
|
||||
elif alert_type == "db_exceptions":
|
||||
test_message = f"DB Exception test alert"
|
||||
elif alert_type == "outage_alerts":
|
||||
test_message = f"Outage Alert Exception test alert"
|
||||
elif alert_type == "daily_reports":
|
||||
test_message = f"Daily Reports test alert"
|
||||
|
||||
await proxy_logging_obj.alerting_handler(
|
||||
message=test_message, level="Low", alert_type=alert_type
|
||||
)
|
||||
else:
|
||||
await proxy_logging_obj.alerting_handler(
|
||||
message="This is a test slack alert message",
|
||||
level="Low",
|
||||
alert_type="budget_alerts",
|
||||
)
|
||||
|
||||
if prisma_client is not None:
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_monthly_spend_report()
|
||||
)
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_weekly_spend_report()
|
||||
)
|
||||
|
||||
alert_types = (
|
||||
proxy_logging_obj.slack_alerting_instance.alert_types or []
|
||||
)
|
||||
alert_types = list(alert_types)
|
||||
return {
|
||||
"status": "success",
|
||||
"alert_types": alert_types,
|
||||
"message": "Mock Slack Alert sent, verify Slack Alert Received on your channel",
|
||||
}
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=422,
|
||||
detail={
|
||||
"error": '"{}" not in proxy config: general_settings. Unable to test this.'.format(
|
||||
service
|
||||
)
|
||||
},
|
||||
)
|
||||
if service == "email":
|
||||
webhook_event = WebhookEvent(
|
||||
event="key_created",
|
||||
event_group="key",
|
||||
event_message="Test Email Alert",
|
||||
token=user_api_key_dict.token or "",
|
||||
key_alias="Email Test key (This is only a test alert key. DO NOT USE THIS IN PRODUCTION.)",
|
||||
spend=0,
|
||||
max_budget=0,
|
||||
user_id=user_api_key_dict.user_id,
|
||||
user_email=os.getenv("TEST_EMAIL_ADDRESS"),
|
||||
team_id=user_api_key_dict.team_id,
|
||||
)
|
||||
|
||||
# use create task - this can take 10 seconds. don't keep ui users waiting for notification to check their email
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
||||
webhook_event=webhook_event
|
||||
)
|
||||
)
|
||||
|
||||
return {
|
||||
"status": "success",
|
||||
"message": "Mock Email Alert sent, verify Email Alert Received",
|
||||
}
|
||||
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.health_services_endpoint(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
)
|
||||
|
||||
|
||||
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def health_endpoint(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
model: Optional[str] = fastapi.Query(
|
||||
None, description="Specify the model name (optional)"
|
||||
),
|
||||
):
|
||||
"""
|
||||
🚨 USE `/health/liveliness` to health check the proxy 🚨
|
||||
|
||||
See more 👉 https://docs.litellm.ai/docs/proxy/health
|
||||
|
||||
|
||||
Check the health of all the endpoints in config.yaml
|
||||
|
||||
To run health checks in the background, add this to config.yaml:
|
||||
```
|
||||
general_settings:
|
||||
# ... other settings
|
||||
background_health_checks: True
|
||||
```
|
||||
else, the health checks will be run on models when /health is called.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
health_check_results,
|
||||
use_background_health_checks,
|
||||
user_model,
|
||||
llm_model_list,
|
||||
)
|
||||
|
||||
try:
|
||||
if llm_model_list is None:
|
||||
# if no router set, check if user set a model using litellm --model ollama/llama2
|
||||
if user_model is not None:
|
||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
||||
model_list=[], cli_model=user_model
|
||||
)
|
||||
return {
|
||||
"healthy_endpoints": healthy_endpoints,
|
||||
"unhealthy_endpoints": unhealthy_endpoints,
|
||||
"healthy_count": len(healthy_endpoints),
|
||||
"unhealthy_count": len(unhealthy_endpoints),
|
||||
}
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": "Model list not initialized"},
|
||||
)
|
||||
_llm_model_list = copy.deepcopy(llm_model_list)
|
||||
### FILTER MODELS FOR ONLY THOSE USER HAS ACCESS TO ###
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
allowed_model_names = user_api_key_dict.models
|
||||
else:
|
||||
allowed_model_names = [] #
|
||||
if use_background_health_checks:
|
||||
return health_check_results
|
||||
else:
|
||||
healthy_endpoints, unhealthy_endpoints = await perform_health_check(
|
||||
_llm_model_list, model
|
||||
)
|
||||
|
||||
return {
|
||||
"healthy_endpoints": healthy_endpoints,
|
||||
"unhealthy_endpoints": unhealthy_endpoints,
|
||||
"healthy_count": len(healthy_endpoints),
|
||||
"unhealthy_count": len(unhealthy_endpoints),
|
||||
}
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.py::health_endpoint(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
|
||||
|
||||
db_health_cache = {"status": "unknown", "last_updated": datetime.now()}
|
||||
|
||||
|
||||
def _db_health_readiness_check():
|
||||
from litellm.proxy.proxy_server import prisma_client
|
||||
|
||||
global db_health_cache
|
||||
|
||||
# Note - Intentionally don't try/except this so it raises an exception when it fails
|
||||
|
||||
# if timedelta is less than 2 minutes return DB Status
|
||||
time_diff = datetime.now() - db_health_cache["last_updated"]
|
||||
if db_health_cache["status"] != "unknown" and time_diff < timedelta(minutes=2):
|
||||
return db_health_cache
|
||||
prisma_client.health_check()
|
||||
db_health_cache = {"status": "connected", "last_updated": datetime.now()}
|
||||
return db_health_cache
|
||||
|
||||
|
||||
@router.get(
|
||||
"/active/callbacks",
|
||||
tags=["health"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def active_callbacks():
|
||||
"""
|
||||
Returns a list of active callbacks on litellm.callbacks, litellm.input_callback, litellm.failure_callback, litellm.success_callback
|
||||
"""
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, general_settings
|
||||
|
||||
_alerting = str(general_settings.get("alerting"))
|
||||
# get success callbacks
|
||||
|
||||
litellm_callbacks = [str(x) for x in litellm.callbacks]
|
||||
litellm_input_callbacks = [str(x) for x in litellm.input_callback]
|
||||
litellm_failure_callbacks = [str(x) for x in litellm.failure_callback]
|
||||
litellm_success_callbacks = [str(x) for x in litellm.success_callback]
|
||||
litellm_async_success_callbacks = [str(x) for x in litellm._async_success_callback]
|
||||
litellm_async_failure_callbacks = [str(x) for x in litellm._async_failure_callback]
|
||||
litellm_async_input_callbacks = [str(x) for x in litellm._async_input_callback]
|
||||
|
||||
all_litellm_callbacks = (
|
||||
litellm_callbacks
|
||||
+ litellm_input_callbacks
|
||||
+ litellm_failure_callbacks
|
||||
+ litellm_success_callbacks
|
||||
+ litellm_async_success_callbacks
|
||||
+ litellm_async_failure_callbacks
|
||||
+ litellm_async_input_callbacks
|
||||
)
|
||||
|
||||
alerting = proxy_logging_obj.alerting
|
||||
_num_alerting = 0
|
||||
if alerting and isinstance(alerting, list):
|
||||
_num_alerting = len(alerting)
|
||||
|
||||
return {
|
||||
"alerting": _alerting,
|
||||
"litellm.callbacks": litellm_callbacks,
|
||||
"litellm.input_callback": litellm_input_callbacks,
|
||||
"litellm.failure_callback": litellm_failure_callbacks,
|
||||
"litellm.success_callback": litellm_success_callbacks,
|
||||
"litellm._async_success_callback": litellm_async_success_callbacks,
|
||||
"litellm._async_failure_callback": litellm_async_failure_callbacks,
|
||||
"litellm._async_input_callback": litellm_async_input_callbacks,
|
||||
"all_litellm_callbacks": all_litellm_callbacks,
|
||||
"num_callbacks": len(all_litellm_callbacks),
|
||||
"num_alerting": _num_alerting,
|
||||
}
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health/readiness",
|
||||
tags=["health"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def health_readiness():
|
||||
"""
|
||||
Unprotected endpoint for checking if worker can receive requests
|
||||
"""
|
||||
from litellm.proxy.proxy_server import proxy_logging_obj, prisma_client, version
|
||||
|
||||
try:
|
||||
# get success callback
|
||||
success_callback_names = []
|
||||
|
||||
try:
|
||||
# this was returning a JSON of the values in some of the callbacks
|
||||
# all we need is the callback name, hence we do str(callback)
|
||||
success_callback_names = [str(x) for x in litellm.success_callback]
|
||||
except:
|
||||
# don't let this block the /health/readiness response, if we can't convert to str -> return litellm.success_callback
|
||||
success_callback_names = litellm.success_callback
|
||||
|
||||
# check Cache
|
||||
cache_type = None
|
||||
if litellm.cache is not None:
|
||||
from litellm.caching import RedisSemanticCache
|
||||
|
||||
cache_type = litellm.cache.type
|
||||
|
||||
if isinstance(litellm.cache.cache, RedisSemanticCache):
|
||||
# ping the cache
|
||||
# TODO: @ishaan-jaff - we should probably not ping the cache on every /health/readiness check
|
||||
try:
|
||||
index_info = await litellm.cache.cache._index_info()
|
||||
except Exception as e:
|
||||
index_info = "index does not exist - error: " + str(e)
|
||||
cache_type = {"type": cache_type, "index_info": index_info}
|
||||
|
||||
# check DB
|
||||
if prisma_client is not None: # if db passed in, check if it's connected
|
||||
db_health_status = _db_health_readiness_check()
|
||||
return {
|
||||
"status": "healthy",
|
||||
"db": "connected",
|
||||
"cache": cache_type,
|
||||
"litellm_version": version,
|
||||
"success_callbacks": success_callback_names,
|
||||
**db_health_status,
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"status": "healthy",
|
||||
"db": "Not connected",
|
||||
"cache": cache_type,
|
||||
"litellm_version": version,
|
||||
"success_callbacks": success_callback_names,
|
||||
}
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=503, detail=f"Service Unhealthy ({str(e)})")
|
||||
|
||||
|
||||
@router.get(
|
||||
"/health/liveliness",
|
||||
tags=["health"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def health_liveliness():
|
||||
"""
|
||||
Unprotected endpoint for checking if worker is alive
|
||||
"""
|
||||
return "I'm alive!"
|
926
litellm/proxy/management_endpoints/key_management_endpoints.py
Normal file
926
litellm/proxy/management_endpoints/key_management_endpoints.py
Normal file
|
@ -0,0 +1,926 @@
|
|||
"""
|
||||
KEY MANAGEMENT
|
||||
|
||||
All /key management endpoints
|
||||
|
||||
/key/generate
|
||||
/key/info
|
||||
/key/update
|
||||
/key/delete
|
||||
"""
|
||||
|
||||
import copy
|
||||
import json
|
||||
import uuid
|
||||
import re
|
||||
import traceback
|
||||
import asyncio
|
||||
import secrets
|
||||
from typing import Optional, List
|
||||
import fastapi
|
||||
from fastapi import Depends, Request, APIRouter, Header, status
|
||||
from fastapi import HTTPException
|
||||
import litellm
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy._types import *
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/generate",
|
||||
tags=["key management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=GenerateKeyResponse,
|
||||
)
|
||||
async def generate_key_fn(
|
||||
data: GenerateKeyRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Generate an API key based on the provided data.
|
||||
|
||||
Docs: https://docs.litellm.ai/docs/proxy/virtual_keys
|
||||
|
||||
Parameters:
|
||||
- duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d").
|
||||
- key_alias: Optional[str] - User defined key alias
|
||||
- team_id: Optional[str] - The team id of the key
|
||||
- user_id: Optional[str] - The user id of the key
|
||||
- models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models)
|
||||
- aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models
|
||||
- config: Optional[dict] - any key-specific configs, overrides config in config.yaml
|
||||
- spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend
|
||||
- send_invite_email: Optional[bool] - Whether to send an invite email to the user_id, with the generate key
|
||||
- max_budget: Optional[float] - Specify max budget for a given key.
|
||||
- max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x.
|
||||
- metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
||||
- permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false}
|
||||
- model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget.
|
||||
|
||||
Examples:
|
||||
|
||||
1. Allow users to turn on/off pii masking
|
||||
|
||||
```bash
|
||||
curl --location 'http://0.0.0.0:8000/key/generate' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"permissions": {"allow_pii_controls": true}
|
||||
}'
|
||||
```
|
||||
|
||||
Returns:
|
||||
- key: (str) The generated api key
|
||||
- expires: (datetime) Datetime object for when key expires.
|
||||
- user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import (
|
||||
user_custom_key_generate,
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
general_settings,
|
||||
proxy_logging_obj,
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
|
||||
verbose_proxy_logger.debug("entered /key/generate")
|
||||
|
||||
if user_custom_key_generate is not None:
|
||||
result = await user_custom_key_generate(data)
|
||||
decision = result.get("decision", True)
|
||||
message = result.get("message", "Authentication Failed - Custom Auth Rule")
|
||||
if not decision:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN, detail=message
|
||||
)
|
||||
# check if user set default key/generate params on config.yaml
|
||||
if litellm.default_key_generate_params is not None:
|
||||
for elem in data:
|
||||
key, value = elem
|
||||
if value is None and key in [
|
||||
"max_budget",
|
||||
"user_id",
|
||||
"team_id",
|
||||
"max_parallel_requests",
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
"budget_duration",
|
||||
]:
|
||||
setattr(
|
||||
data, key, litellm.default_key_generate_params.get(key, None)
|
||||
)
|
||||
elif key == "models" and value == []:
|
||||
setattr(data, key, litellm.default_key_generate_params.get(key, []))
|
||||
elif key == "metadata" and value == {}:
|
||||
setattr(data, key, litellm.default_key_generate_params.get(key, {}))
|
||||
|
||||
# check if user set default key/generate params on config.yaml
|
||||
if litellm.upperbound_key_generate_params is not None:
|
||||
for elem in data:
|
||||
# if key in litellm.upperbound_key_generate_params, use the min of value and litellm.upperbound_key_generate_params[key]
|
||||
key, value = elem
|
||||
if (
|
||||
value is not None
|
||||
and getattr(litellm.upperbound_key_generate_params, key, None)
|
||||
is not None
|
||||
):
|
||||
# if value is float/int
|
||||
if key in [
|
||||
"max_budget",
|
||||
"max_parallel_requests",
|
||||
"tpm_limit",
|
||||
"rpm_limit",
|
||||
]:
|
||||
if value > getattr(litellm.upperbound_key_generate_params, key):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"{key} is over max limit set in config - user_value={value}; max_value={getattr(litellm.upperbound_key_generate_params, key)}"
|
||||
},
|
||||
)
|
||||
elif key == "budget_duration":
|
||||
# budgets are in 1s, 1m, 1h, 1d, 1m (30s, 30m, 30h, 30d, 30m)
|
||||
# compare the duration in seconds and max duration in seconds
|
||||
upperbound_budget_duration = _duration_in_seconds(
|
||||
duration=getattr(
|
||||
litellm.upperbound_key_generate_params, key
|
||||
)
|
||||
)
|
||||
user_set_budget_duration = _duration_in_seconds(duration=value)
|
||||
if user_set_budget_duration > upperbound_budget_duration:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Budget duration is over max limit set in config - user_value={user_set_budget_duration}; max_value={upperbound_budget_duration}"
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable
|
||||
_budget_id = None
|
||||
if prisma_client is not None and data.soft_budget is not None:
|
||||
# create the Budget Row for the LiteLLM Verification Token
|
||||
budget_row = LiteLLM_BudgetTable(
|
||||
soft_budget=data.soft_budget,
|
||||
model_max_budget=data.model_max_budget or {},
|
||||
)
|
||||
new_budget = prisma_client.jsonify_object(
|
||||
budget_row.json(exclude_none=True)
|
||||
)
|
||||
|
||||
_budget = await prisma_client.db.litellm_budgettable.create(
|
||||
data={
|
||||
**new_budget, # type: ignore
|
||||
"created_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
"updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
}
|
||||
)
|
||||
_budget_id = getattr(_budget, "budget_id", None)
|
||||
data_json = data.json() # type: ignore
|
||||
# if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users
|
||||
if "max_budget" in data_json:
|
||||
data_json["key_max_budget"] = data_json.pop("max_budget", None)
|
||||
if _budget_id is not None:
|
||||
data_json["budget_id"] = _budget_id
|
||||
|
||||
if "budget_duration" in data_json:
|
||||
data_json["key_budget_duration"] = data_json.pop("budget_duration", None)
|
||||
|
||||
response = await generate_key_helper_fn(
|
||||
request_type="key", **data_json, table_name="key"
|
||||
)
|
||||
|
||||
response["soft_budget"] = (
|
||||
data.soft_budget
|
||||
) # include the user-input soft budget in the response
|
||||
|
||||
if data.send_invite_email is True:
|
||||
if "email" not in general_settings.get("alerting", []):
|
||||
raise ValueError(
|
||||
"Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`"
|
||||
)
|
||||
event = WebhookEvent(
|
||||
event="key_created",
|
||||
event_group="key",
|
||||
event_message=f"API Key Created",
|
||||
token=response.get("token", ""),
|
||||
spend=response.get("spend", 0.0),
|
||||
max_budget=response.get("max_budget", 0.0),
|
||||
user_id=response.get("user_id", None),
|
||||
team_id=response.get("team_id", "Default Team"),
|
||||
key_alias=response.get("key_alias", None),
|
||||
)
|
||||
|
||||
# If user configured email alerting - send an Email letting their end-user know the key was created
|
||||
asyncio.create_task(
|
||||
proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email(
|
||||
webhook_event=event,
|
||||
)
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = json.dumps(response, default=str)
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=response.get("token_id", ""),
|
||||
action="created",
|
||||
updated_values=_updated_values,
|
||||
before_value=None,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return GenerateKeyResponse(**response)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.generate_key_fn(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
async def update_key_fn(
|
||||
request: Request,
|
||||
data: UpdateKeyRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Update an existing key
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
user_custom_key_generate,
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
general_settings,
|
||||
proxy_logging_obj,
|
||||
create_audit_log_for_update,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
try:
|
||||
data_json: dict = data.json()
|
||||
key = data_json.pop("key")
|
||||
# get the row from db
|
||||
if prisma_client is None:
|
||||
raise Exception("Not connected to DB!")
|
||||
|
||||
existing_key_row = await prisma_client.get_data(
|
||||
token=data.key, table_name="key", query_type="find_unique"
|
||||
)
|
||||
|
||||
if existing_key_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||
)
|
||||
|
||||
# get non default values for key
|
||||
non_default_values = {}
|
||||
for k, v in data_json.items():
|
||||
if v is not None and v not in (
|
||||
[],
|
||||
{},
|
||||
0,
|
||||
): # models default to [], spend defaults to 0, we should not reset these values
|
||||
non_default_values[k] = v
|
||||
|
||||
if "duration" in non_default_values:
|
||||
duration = non_default_values.pop("duration")
|
||||
duration_s = _duration_in_seconds(duration=duration)
|
||||
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
non_default_values["expires"] = expires
|
||||
|
||||
response = await prisma_client.update_data(
|
||||
token=key, data={**non_default_values, "token": key}
|
||||
)
|
||||
|
||||
# Delete - key from cache, since it's been updated!
|
||||
# key updated - a new model could have been added to this key. it should not block requests after this is done
|
||||
user_api_key_cache.delete_cache(key)
|
||||
hashed_token = hash_token(key)
|
||||
user_api_key_cache.delete_cache(hashed_token)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = json.dumps(data_json, default=str)
|
||||
|
||||
_before_value = existing_key_row.json(exclude_none=True)
|
||||
_before_value = json.dumps(_before_value, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=data.key,
|
||||
action="updated",
|
||||
updated_values=_updated_values,
|
||||
before_value=_before_value,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return {"key": key, **response["data"]}
|
||||
# update based on remaining passed in values
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
async def delete_key_fn(
|
||||
data: KeyRequest,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Delete a key from the key management system.
|
||||
|
||||
Parameters::
|
||||
- keys (List[str]): A list of keys or hashed keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]}
|
||||
|
||||
Returns:
|
||||
- deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]}
|
||||
|
||||
|
||||
Raises:
|
||||
HTTPException: If an error occurs during key deletion.
|
||||
"""
|
||||
try:
|
||||
from litellm.proxy.proxy_server import (
|
||||
user_custom_key_generate,
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
general_settings,
|
||||
proxy_logging_obj,
|
||||
create_audit_log_for_update,
|
||||
user_api_key_cache,
|
||||
)
|
||||
|
||||
keys = data.keys
|
||||
if len(keys) == 0:
|
||||
raise ProxyException(
|
||||
message=f"No keys provided, passed in: keys={keys}",
|
||||
type="auth_error",
|
||||
param="keys",
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
## only allow user to delete keys they own
|
||||
user_id = user_api_key_dict.user_id
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_api_key_dict.user_role: {user_api_key_dict.user_role}"
|
||||
)
|
||||
if (
|
||||
user_api_key_dict.user_role is not None
|
||||
and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN
|
||||
):
|
||||
user_id = None # unless they're admin
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||
if litellm.store_audit_logs is True:
|
||||
# make an audit log for each team deleted
|
||||
for key in data.keys:
|
||||
key_row = await prisma_client.get_data( # type: ignore
|
||||
token=key, table_name="key", query_type="find_unique"
|
||||
)
|
||||
|
||||
key_row = key_row.json(exclude_none=True)
|
||||
_key_row = json.dumps(key_row, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.KEY_TABLE_NAME,
|
||||
object_id=key,
|
||||
action="deleted",
|
||||
updated_values="{}",
|
||||
before_value=_key_row,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
number_deleted_keys = await delete_verification_token(
|
||||
tokens=keys, user_id=user_id
|
||||
)
|
||||
verbose_proxy_logger.debug(
|
||||
f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}"
|
||||
)
|
||||
|
||||
try:
|
||||
assert len(keys) == number_deleted_keys["deleted_keys"]
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={len(keys)}, Deleted keys ={number_deleted_keys['deleted_keys']}"
|
||||
},
|
||||
)
|
||||
|
||||
for key in keys:
|
||||
user_api_key_cache.delete_cache(key)
|
||||
# remove hash token from cache
|
||||
hashed_token = hash_token(key)
|
||||
user_api_key_cache.delete_cache(hashed_token)
|
||||
|
||||
verbose_proxy_logger.debug(
|
||||
f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}"
|
||||
)
|
||||
|
||||
return {"deleted_keys": keys}
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/v2/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
async def info_key_fn_v2(
|
||||
data: Optional[KeyRequest] = None,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Retrieve information about a list of keys.
|
||||
|
||||
**New endpoint**. Currently admin only.
|
||||
Parameters:
|
||||
keys: Optional[list] = body parameter representing the key(s) in the request
|
||||
user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key
|
||||
Returns:
|
||||
Dict containing the key and its associated information
|
||||
|
||||
Example Curl:
|
||||
```
|
||||
curl -X GET "http://0.0.0.0:8000/key/info" \
|
||||
-H "Authorization: Bearer sk-1234" \
|
||||
-d {"keys": ["sk-1", "sk-2", "sk-3"]}
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
user_custom_key_generate,
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
general_settings,
|
||||
proxy_logging_obj,
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise Exception(
|
||||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
if data is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail={"message": "Malformed request. No keys passed in."},
|
||||
)
|
||||
|
||||
key_info = await prisma_client.get_data(
|
||||
token=data.keys, table_name="key", query_type="find_all"
|
||||
)
|
||||
filtered_key_info = []
|
||||
for k in key_info:
|
||||
try:
|
||||
k = k.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
k = k.dict()
|
||||
filtered_key_info.append(k)
|
||||
return {"key": data.keys, "info": filtered_key_info}
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@router.get(
|
||||
"/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
async def info_key_fn(
|
||||
key: Optional[str] = fastapi.Query(
|
||||
default=None, description="Key in the request parameters"
|
||||
),
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Retrieve information about a key.
|
||||
Parameters:
|
||||
key: Optional[str] = Query parameter representing the key in the request
|
||||
user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key
|
||||
Returns:
|
||||
Dict containing the key and its associated information
|
||||
|
||||
Example Curl:
|
||||
```
|
||||
curl -X GET "http://0.0.0.0:8000/key/info?key=sk-02Wr4IAlN3NvPXvL5JVvDA" \
|
||||
-H "Authorization: Bearer sk-1234"
|
||||
```
|
||||
|
||||
Example Curl - if no key is passed, it will use the Key Passed in Authorization Header
|
||||
```
|
||||
curl -X GET "http://0.0.0.0:8000/key/info" \
|
||||
-H "Authorization: Bearer sk-02Wr4IAlN3NvPXvL5JVvDA"
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
user_custom_key_generate,
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
general_settings,
|
||||
proxy_logging_obj,
|
||||
create_audit_log_for_update,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise Exception(
|
||||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
if key == None:
|
||||
key = user_api_key_dict.api_key
|
||||
key_info = await prisma_client.get_data(token=key)
|
||||
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
||||
try:
|
||||
key_info = key_info.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
key_info = key_info.dict()
|
||||
key_info.pop("token")
|
||||
return {"key": key, "info": key_info}
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
def _duration_in_seconds(duration: str):
|
||||
match = re.match(r"(\d+)([smhd]?)", duration)
|
||||
if not match:
|
||||
raise ValueError("Invalid duration format")
|
||||
|
||||
value, unit = match.groups()
|
||||
value = int(value)
|
||||
|
||||
if unit == "s":
|
||||
return value
|
||||
elif unit == "m":
|
||||
return value * 60
|
||||
elif unit == "h":
|
||||
return value * 3600
|
||||
elif unit == "d":
|
||||
return value * 86400
|
||||
else:
|
||||
raise ValueError("Unsupported duration unit")
|
||||
|
||||
|
||||
async def generate_key_helper_fn(
|
||||
request_type: Literal[
|
||||
"user", "key"
|
||||
], # identifies if this request is from /user/new or /key/generate
|
||||
duration: Optional[str],
|
||||
models: list,
|
||||
aliases: dict,
|
||||
config: dict,
|
||||
spend: float,
|
||||
key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key
|
||||
key_budget_duration: Optional[str] = None,
|
||||
budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable
|
||||
soft_budget: Optional[
|
||||
float
|
||||
] = None, # soft_budget is used to set soft Budgets Per user
|
||||
max_budget: Optional[float] = None, # max_budget is used to Budget Per user
|
||||
budget_duration: Optional[str] = None, # max_budget is used to Budget Per user
|
||||
token: Optional[str] = None,
|
||||
user_id: Optional[str] = None,
|
||||
team_id: Optional[str] = None,
|
||||
user_email: Optional[str] = None,
|
||||
user_role: Optional[str] = None,
|
||||
max_parallel_requests: Optional[int] = None,
|
||||
metadata: Optional[dict] = {},
|
||||
tpm_limit: Optional[int] = None,
|
||||
rpm_limit: Optional[int] = None,
|
||||
query_type: Literal["insert_data", "update_data"] = "insert_data",
|
||||
update_key_values: Optional[dict] = None,
|
||||
key_alias: Optional[str] = None,
|
||||
allowed_cache_controls: Optional[list] = [],
|
||||
permissions: Optional[dict] = {},
|
||||
model_max_budget: Optional[dict] = {},
|
||||
teams: Optional[list] = None,
|
||||
organization_id: Optional[str] = None,
|
||||
table_name: Optional[Literal["key", "user"]] = None,
|
||||
send_invite_email: Optional[bool] = None,
|
||||
):
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
custom_db_client,
|
||||
litellm_proxy_budget_name,
|
||||
premium_user,
|
||||
)
|
||||
|
||||
if prisma_client is None and custom_db_client is None:
|
||||
raise Exception(
|
||||
f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys "
|
||||
)
|
||||
|
||||
if token is None:
|
||||
token = f"sk-{secrets.token_urlsafe(16)}"
|
||||
|
||||
if duration is None: # allow tokens that never expire
|
||||
expires = None
|
||||
else:
|
||||
duration_s = _duration_in_seconds(duration=duration)
|
||||
expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
|
||||
if key_budget_duration is None: # one-time budget
|
||||
key_reset_at = None
|
||||
else:
|
||||
duration_s = _duration_in_seconds(duration=key_budget_duration)
|
||||
key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
|
||||
if budget_duration is None: # one-time budget
|
||||
reset_at = None
|
||||
else:
|
||||
duration_s = _duration_in_seconds(duration=budget_duration)
|
||||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
|
||||
aliases_json = json.dumps(aliases)
|
||||
config_json = json.dumps(config)
|
||||
permissions_json = json.dumps(permissions)
|
||||
metadata_json = json.dumps(metadata)
|
||||
model_max_budget_json = json.dumps(model_max_budget)
|
||||
user_role = user_role
|
||||
tpm_limit = tpm_limit
|
||||
rpm_limit = rpm_limit
|
||||
allowed_cache_controls = allowed_cache_controls
|
||||
|
||||
try:
|
||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||
user_data = {
|
||||
"max_budget": max_budget,
|
||||
"user_email": user_email,
|
||||
"user_id": user_id,
|
||||
"team_id": team_id,
|
||||
"organization_id": organization_id,
|
||||
"user_role": user_role,
|
||||
"spend": spend,
|
||||
"models": models,
|
||||
"max_parallel_requests": max_parallel_requests,
|
||||
"tpm_limit": tpm_limit,
|
||||
"rpm_limit": rpm_limit,
|
||||
"budget_duration": budget_duration,
|
||||
"budget_reset_at": reset_at,
|
||||
"allowed_cache_controls": allowed_cache_controls,
|
||||
}
|
||||
if teams is not None:
|
||||
user_data["teams"] = teams
|
||||
key_data = {
|
||||
"token": token,
|
||||
"key_alias": key_alias,
|
||||
"expires": expires,
|
||||
"models": models,
|
||||
"aliases": aliases_json,
|
||||
"config": config_json,
|
||||
"spend": spend,
|
||||
"max_budget": key_max_budget,
|
||||
"user_id": user_id,
|
||||
"team_id": team_id,
|
||||
"max_parallel_requests": max_parallel_requests,
|
||||
"metadata": metadata_json,
|
||||
"tpm_limit": tpm_limit,
|
||||
"rpm_limit": rpm_limit,
|
||||
"budget_duration": key_budget_duration,
|
||||
"budget_reset_at": key_reset_at,
|
||||
"allowed_cache_controls": allowed_cache_controls,
|
||||
"permissions": permissions_json,
|
||||
"model_max_budget": model_max_budget_json,
|
||||
"budget_id": budget_id,
|
||||
}
|
||||
|
||||
if (
|
||||
litellm.get_secret("DISABLE_KEY_NAME", False) == True
|
||||
): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much)
|
||||
pass
|
||||
else:
|
||||
key_data["key_name"] = f"sk-...{token[-4:]}"
|
||||
saved_token = copy.deepcopy(key_data)
|
||||
if isinstance(saved_token["aliases"], str):
|
||||
saved_token["aliases"] = json.loads(saved_token["aliases"])
|
||||
if isinstance(saved_token["config"], str):
|
||||
saved_token["config"] = json.loads(saved_token["config"])
|
||||
if isinstance(saved_token["metadata"], str):
|
||||
saved_token["metadata"] = json.loads(saved_token["metadata"])
|
||||
if isinstance(saved_token["permissions"], str):
|
||||
if (
|
||||
"get_spend_routes" in saved_token["permissions"]
|
||||
and premium_user != True
|
||||
):
|
||||
raise ValueError(
|
||||
"get_spend_routes permission is only available for LiteLLM Enterprise users"
|
||||
)
|
||||
|
||||
saved_token["permissions"] = json.loads(saved_token["permissions"])
|
||||
if isinstance(saved_token["model_max_budget"], str):
|
||||
saved_token["model_max_budget"] = json.loads(
|
||||
saved_token["model_max_budget"]
|
||||
)
|
||||
|
||||
if saved_token.get("expires", None) is not None and isinstance(
|
||||
saved_token["expires"], datetime
|
||||
):
|
||||
saved_token["expires"] = saved_token["expires"].isoformat()
|
||||
if prisma_client is not None:
|
||||
if (
|
||||
table_name is None or table_name == "user"
|
||||
): # do not auto-create users for `/key/generate`
|
||||
## CREATE USER (If necessary)
|
||||
if query_type == "insert_data":
|
||||
user_row = await prisma_client.insert_data(
|
||||
data=user_data, table_name="user"
|
||||
)
|
||||
## use default user model list if no key-specific model list provided
|
||||
if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore
|
||||
key_data["models"] = user_row.models
|
||||
elif query_type == "update_data":
|
||||
user_row = await prisma_client.update_data(
|
||||
data=user_data,
|
||||
table_name="user",
|
||||
update_key_values=update_key_values,
|
||||
)
|
||||
if user_id == litellm_proxy_budget_name or (
|
||||
table_name is not None and table_name == "user"
|
||||
):
|
||||
# do not create a key for litellm_proxy_budget_name or if table name is set to just 'user'
|
||||
# we only need to ensure this exists in the user table
|
||||
# the LiteLLM_VerificationToken table will increase in size if we don't do this check
|
||||
return user_data
|
||||
|
||||
## CREATE KEY
|
||||
verbose_proxy_logger.debug("prisma_client: Creating Key= %s", key_data)
|
||||
create_key_response = await prisma_client.insert_data(
|
||||
data=key_data, table_name="key"
|
||||
)
|
||||
key_data["token_id"] = getattr(create_key_response, "token", None)
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.generate_key_helper_fn(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
if isinstance(e, HTTPException):
|
||||
raise e
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={"error": "Internal Server Error."},
|
||||
)
|
||||
|
||||
# Add budget related info in key_data - this ensures it's returned
|
||||
key_data["budget_id"] = budget_id
|
||||
|
||||
if request_type == "user":
|
||||
# if this is a /user/new request update the key_date with user_data fields
|
||||
key_data.update(user_data)
|
||||
return key_data
|
||||
|
||||
|
||||
async def delete_verification_token(tokens: List, user_id: Optional[str] = None):
|
||||
from litellm.proxy.proxy_server import prisma_client, litellm_proxy_admin_name
|
||||
|
||||
try:
|
||||
if prisma_client:
|
||||
# Assuming 'db' is your Prisma Client instance
|
||||
# check if admin making request - don't filter by user-id
|
||||
if user_id == litellm_proxy_admin_name:
|
||||
deleted_tokens = await prisma_client.delete_data(tokens=tokens)
|
||||
# else
|
||||
else:
|
||||
deleted_tokens = await prisma_client.delete_data(
|
||||
tokens=tokens, user_id=user_id
|
||||
)
|
||||
_num_deleted_tokens = deleted_tokens.get("deleted_keys", 0)
|
||||
if _num_deleted_tokens != len(tokens):
|
||||
raise Exception(
|
||||
"Failed to delete all tokens. Tried to delete tokens that don't belong to user: "
|
||||
+ str(user_id)
|
||||
)
|
||||
else:
|
||||
raise Exception("DB not connected. prisma_client is None")
|
||||
except Exception as e:
|
||||
verbose_proxy_logger.error(
|
||||
"litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format(
|
||||
str(e)
|
||||
)
|
||||
)
|
||||
verbose_proxy_logger.debug(traceback.format_exc())
|
||||
raise e
|
||||
return deleted_tokens
|
899
litellm/proxy/management_endpoints/team_endpoints.py
Normal file
899
litellm/proxy/management_endpoints/team_endpoints.py
Normal file
|
@ -0,0 +1,899 @@
|
|||
from typing import Optional, List
|
||||
import fastapi
|
||||
from fastapi import Depends, Request, APIRouter, Header, status
|
||||
from fastapi import HTTPException
|
||||
import copy
|
||||
import json
|
||||
import uuid
|
||||
import litellm
|
||||
import asyncio
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy._types import (
|
||||
UserAPIKeyAuth,
|
||||
LiteLLM_TeamTable,
|
||||
LiteLLM_ModelTable,
|
||||
LitellmUserRoles,
|
||||
NewTeamRequest,
|
||||
TeamMemberAddRequest,
|
||||
UpdateTeamRequest,
|
||||
BlockTeamRequest,
|
||||
DeleteTeamRequest,
|
||||
Member,
|
||||
LitellmTableNames,
|
||||
LiteLLM_AuditLogs,
|
||||
TeamMemberDeleteRequest,
|
||||
ProxyException,
|
||||
CommonProxyErrors,
|
||||
)
|
||||
from litellm.proxy.management_helpers.utils import (
|
||||
add_new_member,
|
||||
management_endpoint_wrapper,
|
||||
)
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
#### TEAM MANAGEMENT ####
|
||||
@router.post(
|
||||
"/team/new",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
response_model=LiteLLM_TeamTable,
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def new_team(
|
||||
data: NewTeamRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Allow users to create a new team. Apply user permissions to their team.
|
||||
|
||||
👉 [Detailed Doc on setting team budgets](https://docs.litellm.ai/docs/proxy/team_budgets)
|
||||
|
||||
|
||||
Parameters:
|
||||
- team_alias: Optional[str] - User defined team alias
|
||||
- team_id: Optional[str] - The team id of the user. If none passed, we'll generate it.
|
||||
- members_with_roles: List[{"role": "admin" or "user", "user_id": "<user-id>"}] - A list of users and their roles in the team. Get user_id when making a new user via `/user/new`.
|
||||
- metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"extra_info": "some info"}
|
||||
- tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit
|
||||
- rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit
|
||||
- max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget
|
||||
- budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets)
|
||||
- models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed.
|
||||
- blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id.
|
||||
|
||||
Returns:
|
||||
- team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id.
|
||||
|
||||
_deprecated_params:
|
||||
- admins: list - A list of user_id's for the admin role
|
||||
- users: list - A list of user_id's for the user role
|
||||
|
||||
Example Request:
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/team/new' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
--header 'Content-Type: application/json' \
|
||||
|
||||
--data '{
|
||||
"team_alias": "my-new-team_2",
|
||||
"members_with_roles": [{"role": "admin", "user_id": "user-1234"},
|
||||
{"role": "user", "user_id": "user-2434"}]
|
||||
}'
|
||||
|
||||
```
|
||||
|
||||
```
|
||||
curl --location 'http://0.0.0.0:4000/team/new' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
--header 'Content-Type: application/json' \
|
||||
|
||||
--data '{
|
||||
"team_alias": "QA Prod Bot",
|
||||
"max_budget": 0.000000001,
|
||||
"budget_duration": "1d"
|
||||
}'
|
||||
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
create_audit_log_for_update,
|
||||
_duration_in_seconds,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if data.team_id is None:
|
||||
data.team_id = str(uuid.uuid4())
|
||||
else:
|
||||
# Check if team_id exists already
|
||||
_existing_team_id = await prisma_client.get_data(
|
||||
team_id=data.team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if _existing_team_id is not None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Team id = {data.team_id} already exists. Please use a different team id."
|
||||
},
|
||||
)
|
||||
|
||||
if (
|
||||
user_api_key_dict.user_role is None
|
||||
or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN
|
||||
): # don't restrict proxy admin
|
||||
if (
|
||||
data.tpm_limit is not None
|
||||
and user_api_key_dict.tpm_limit is not None
|
||||
and data.tpm_limit > user_api_key_dict.tpm_limit
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}"
|
||||
},
|
||||
)
|
||||
|
||||
if (
|
||||
data.rpm_limit is not None
|
||||
and user_api_key_dict.rpm_limit is not None
|
||||
and data.rpm_limit > user_api_key_dict.rpm_limit
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}"
|
||||
},
|
||||
)
|
||||
|
||||
if (
|
||||
data.max_budget is not None
|
||||
and user_api_key_dict.max_budget is not None
|
||||
and data.max_budget > user_api_key_dict.max_budget
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}"
|
||||
},
|
||||
)
|
||||
|
||||
if data.models is not None and len(user_api_key_dict.models) > 0:
|
||||
for m in data.models:
|
||||
if m not in user_api_key_dict.models:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}"
|
||||
},
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_id is not None:
|
||||
creating_user_in_list = False
|
||||
for member in data.members_with_roles:
|
||||
if member.user_id == user_api_key_dict.user_id:
|
||||
creating_user_in_list = True
|
||||
|
||||
if creating_user_in_list == False:
|
||||
data.members_with_roles.append(
|
||||
Member(role="admin", user_id=user_api_key_dict.user_id)
|
||||
)
|
||||
|
||||
## ADD TO MODEL TABLE
|
||||
_model_id = None
|
||||
if data.model_aliases is not None and isinstance(data.model_aliases, dict):
|
||||
litellm_modeltable = LiteLLM_ModelTable(
|
||||
model_aliases=json.dumps(data.model_aliases),
|
||||
created_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name,
|
||||
)
|
||||
model_dict = await prisma_client.db.litellm_modeltable.create(
|
||||
{**litellm_modeltable.json(exclude_none=True)} # type: ignore
|
||||
) # type: ignore
|
||||
|
||||
_model_id = model_dict.id
|
||||
|
||||
## ADD TO TEAM TABLE
|
||||
complete_team_data = LiteLLM_TeamTable(
|
||||
**data.json(),
|
||||
model_id=_model_id,
|
||||
)
|
||||
|
||||
# If budget_duration is set, set `budget_reset_at`
|
||||
if complete_team_data.budget_duration is not None:
|
||||
duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration)
|
||||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
complete_team_data.budget_reset_at = reset_at
|
||||
|
||||
team_row = await prisma_client.insert_data(
|
||||
data=complete_team_data.json(exclude_none=True), table_name="team"
|
||||
)
|
||||
|
||||
## ADD TEAM ID TO USER TABLE ##
|
||||
for user in complete_team_data.members_with_roles:
|
||||
## add team id to user row ##
|
||||
await prisma_client.update_data(
|
||||
user_id=user.user_id,
|
||||
data={"user_id": user.user_id, "teams": [team_row.team_id]},
|
||||
update_key_values_custom_query={
|
||||
"teams": {
|
||||
"push ": [team_row.team_id],
|
||||
}
|
||||
},
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_updated_values = complete_team_data.json(exclude_none=True)
|
||||
|
||||
_updated_values = json.dumps(_updated_values, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.TEAM_TABLE_NAME,
|
||||
object_id=data.team_id,
|
||||
action="created",
|
||||
updated_values=_updated_values,
|
||||
before_value=None,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
try:
|
||||
return team_row.model_dump()
|
||||
except Exception as e:
|
||||
return team_row.dict()
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def update_team(
|
||||
data: UpdateTeamRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
Use `/team/member_add` AND `/team/member/delete` to add/remove new team members
|
||||
|
||||
You can now update team budget / rate limits via /team/update
|
||||
|
||||
Parameters:
|
||||
- team_id: str - The team id of the user. Required param.
|
||||
- team_alias: Optional[str] - User defined team alias
|
||||
- metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" }
|
||||
- tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit
|
||||
- rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit
|
||||
- max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget
|
||||
- budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets)
|
||||
- models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed.
|
||||
- blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id.
|
||||
|
||||
Example - update team TPM Limit
|
||||
|
||||
```
|
||||
curl --location 'http://0.0.0.0:8000/team/update' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
--header 'Content-Type: application/json' \
|
||||
|
||||
--data-raw '{
|
||||
"team_id": "litellm-test-client-id-new",
|
||||
"tpm_limit": 100
|
||||
}'
|
||||
```
|
||||
|
||||
Example - Update Team `max_budget` budget
|
||||
```
|
||||
curl --location 'http://0.0.0.0:8000/team/update' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
--header 'Content-Type: application/json' \
|
||||
|
||||
--data-raw '{
|
||||
"team_id": "litellm-test-client-id-new",
|
||||
"max_budget": 10
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
create_audit_log_for_update,
|
||||
_duration_in_seconds,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if data.team_id is None:
|
||||
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
||||
verbose_proxy_logger.debug("/team/update - %s", data)
|
||||
|
||||
existing_team_row = await prisma_client.get_data(
|
||||
team_id=data.team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if existing_team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team not found, passed team_id={data.team_id}"},
|
||||
)
|
||||
|
||||
updated_kv = data.json(exclude_none=True)
|
||||
|
||||
# Check budget_duration and budget_reset_at
|
||||
if data.budget_duration is not None:
|
||||
duration_s = _duration_in_seconds(duration=data.budget_duration)
|
||||
reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s)
|
||||
|
||||
# set the budget_reset_at in DB
|
||||
updated_kv["budget_reset_at"] = reset_at
|
||||
|
||||
team_row = await prisma_client.update_data(
|
||||
update_key_values=updated_kv,
|
||||
data=updated_kv,
|
||||
table_name="team",
|
||||
team_id=data.team_id,
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
if litellm.store_audit_logs is True:
|
||||
_before_value = existing_team_row.json(exclude_none=True)
|
||||
_before_value = json.dumps(_before_value, default=str)
|
||||
_after_value: str = json.dumps(updated_kv, default=str)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.TEAM_TABLE_NAME,
|
||||
object_id=data.team_id,
|
||||
action="updated",
|
||||
updated_values=_after_value,
|
||||
before_value=_before_value,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
return team_row
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/member_add",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def team_member_add(
|
||||
data: TeamMemberAddRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[BETA]
|
||||
|
||||
Add new members (either via user_email or user_id) to a team
|
||||
|
||||
If user doesn't exist, new user row will also be added to User Table
|
||||
|
||||
```
|
||||
|
||||
curl -X POST 'http://0.0.0.0:4000/team/member_add' \
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
-H 'Content-Type: application/json' \
|
||||
-d '{"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", "member": {"role": "user", "user_id": "krrish247652@berri.ai"}}'
|
||||
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
create_audit_log_for_update,
|
||||
_duration_in_seconds,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if data.team_id is None:
|
||||
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
||||
|
||||
if data.member is None:
|
||||
raise HTTPException(
|
||||
status_code=400, detail={"error": "No member/members passed in"}
|
||||
)
|
||||
|
||||
existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
if existing_team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={
|
||||
"error": f"Team not found for team_id={getattr(data, 'team_id', None)}"
|
||||
},
|
||||
)
|
||||
|
||||
complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump())
|
||||
|
||||
if isinstance(data.member, Member):
|
||||
# add to team db
|
||||
new_member = data.member
|
||||
|
||||
complete_team_data.members_with_roles.append(new_member)
|
||||
|
||||
elif isinstance(data.member, List):
|
||||
# add to team db
|
||||
new_members = data.member
|
||||
|
||||
complete_team_data.members_with_roles.extend(new_members)
|
||||
|
||||
# ADD MEMBER TO TEAM
|
||||
_db_team_members = [m.model_dump() for m in complete_team_data.members_with_roles]
|
||||
updated_team = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id},
|
||||
data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore
|
||||
)
|
||||
|
||||
if isinstance(data.member, Member):
|
||||
await add_new_member(
|
||||
new_member=data.member,
|
||||
max_budget_in_team=data.max_budget_in_team,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
team_id=data.team_id,
|
||||
)
|
||||
elif isinstance(data.member, List):
|
||||
tasks: List = []
|
||||
for m in data.member:
|
||||
await add_new_member(
|
||||
new_member=m,
|
||||
max_budget_in_team=data.max_budget_in_team,
|
||||
prisma_client=prisma_client,
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
litellm_proxy_admin_name=litellm_proxy_admin_name,
|
||||
team_id=data.team_id,
|
||||
)
|
||||
await asyncio.gather(*tasks)
|
||||
|
||||
return updated_team
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/member_delete",
|
||||
tags=["team management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def team_member_delete(
|
||||
data: TeamMemberDeleteRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[BETA]
|
||||
|
||||
delete members (either via user_email or user_id) from a team
|
||||
|
||||
If user doesn't exist, an exception will be raised
|
||||
```
|
||||
curl -X POST 'http://0.0.0.0:8000/team/update' \
|
||||
|
||||
-H 'Authorization: Bearer sk-1234' \
|
||||
|
||||
-H 'Content-Type: application/json' \
|
||||
|
||||
-D '{
|
||||
"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849",
|
||||
"user_id": "krrish247652@berri.ai"
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
create_audit_log_for_update,
|
||||
_duration_in_seconds,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if data.team_id is None:
|
||||
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
||||
|
||||
if data.user_id is None and data.user_email is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Either user_id or user_email needs to be passed in"},
|
||||
)
|
||||
|
||||
_existing_team_row = await prisma_client.db.litellm_teamtable.find_unique(
|
||||
where={"team_id": data.team_id}
|
||||
)
|
||||
|
||||
if _existing_team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": "Team id={} does not exist in db".format(data.team_id)},
|
||||
)
|
||||
existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump())
|
||||
|
||||
## DELETE MEMBER FROM TEAM
|
||||
new_team_members: List[Member] = []
|
||||
for m in existing_team_row.members_with_roles:
|
||||
if (
|
||||
data.user_id is not None
|
||||
and m.user_id is not None
|
||||
and data.user_id == m.user_id
|
||||
):
|
||||
continue
|
||||
elif (
|
||||
data.user_email is not None
|
||||
and m.user_email is not None
|
||||
and data.user_email == m.user_email
|
||||
):
|
||||
continue
|
||||
new_team_members.append(m)
|
||||
existing_team_row.members_with_roles = new_team_members
|
||||
|
||||
_db_new_team_members: List[dict] = [m.model_dump() for m in new_team_members]
|
||||
|
||||
_ = await prisma_client.db.litellm_teamtable.update(
|
||||
where={
|
||||
"team_id": data.team_id,
|
||||
},
|
||||
data={"members_with_roles": json.dumps(_db_new_team_members)}, # type: ignore
|
||||
)
|
||||
|
||||
## DELETE TEAM ID from USER ROW, IF EXISTS ##
|
||||
# get user row
|
||||
key_val = {}
|
||||
if data.user_id is not None:
|
||||
key_val["user_id"] = data.user_id
|
||||
elif data.user_email is not None:
|
||||
key_val["user_email"] = data.user_email
|
||||
existing_user_rows = await prisma_client.db.litellm_usertable.find_many(
|
||||
where=key_val # type: ignore
|
||||
)
|
||||
|
||||
if existing_user_rows is not None and (
|
||||
isinstance(existing_user_rows, list) and len(existing_user_rows) > 0
|
||||
):
|
||||
for existing_user in existing_user_rows:
|
||||
team_list = []
|
||||
if data.team_id in existing_user.teams:
|
||||
team_list = existing_user.teams
|
||||
team_list.remove(data.team_id)
|
||||
await prisma_client.db.litellm_usertable.update(
|
||||
where={
|
||||
"user_id": existing_user.user_id,
|
||||
},
|
||||
data={"teams": {"set": team_list}},
|
||||
)
|
||||
|
||||
return existing_team_row
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def delete_team(
|
||||
data: DeleteTeamRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
litellm_changed_by: Optional[str] = Header(
|
||||
None,
|
||||
description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability",
|
||||
),
|
||||
):
|
||||
"""
|
||||
delete team and associated team keys
|
||||
|
||||
```
|
||||
curl --location 'http://0.0.0.0:8000/team/delete' \
|
||||
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
|
||||
--header 'Content-Type: application/json' \
|
||||
|
||||
--data-raw '{
|
||||
"team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
create_audit_log_for_update,
|
||||
_duration_in_seconds,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(status_code=500, detail={"error": "No db connected"})
|
||||
|
||||
if data.team_ids is None:
|
||||
raise HTTPException(status_code=400, detail={"error": "No team id passed in"})
|
||||
|
||||
# check that all teams passed exist
|
||||
for team_id in data.team_ids:
|
||||
team_row = await prisma_client.get_data( # type: ignore
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if team_row is None:
|
||||
raise HTTPException(
|
||||
status_code=404,
|
||||
detail={"error": f"Team not found, passed team_id={team_id}"},
|
||||
)
|
||||
|
||||
# Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True
|
||||
# we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes
|
||||
if litellm.store_audit_logs is True:
|
||||
# make an audit log for each team deleted
|
||||
for team_id in data.team_ids:
|
||||
team_row = await prisma_client.get_data( # type: ignore
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
|
||||
_team_row = team_row.json(exclude_none=True)
|
||||
|
||||
asyncio.create_task(
|
||||
create_audit_log_for_update(
|
||||
request_data=LiteLLM_AuditLogs(
|
||||
id=str(uuid.uuid4()),
|
||||
updated_at=datetime.now(timezone.utc),
|
||||
changed_by=litellm_changed_by
|
||||
or user_api_key_dict.user_id
|
||||
or litellm_proxy_admin_name,
|
||||
changed_by_api_key=user_api_key_dict.api_key,
|
||||
table_name=LitellmTableNames.TEAM_TABLE_NAME,
|
||||
object_id=team_id,
|
||||
action="deleted",
|
||||
updated_values="{}",
|
||||
before_value=_team_row,
|
||||
)
|
||||
)
|
||||
)
|
||||
|
||||
# End of Audit logging
|
||||
|
||||
## DELETE ASSOCIATED KEYS
|
||||
await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key")
|
||||
## DELETE TEAMS
|
||||
deleted_teams = await prisma_client.delete_data(
|
||||
team_id_list=data.team_ids, table_name="team"
|
||||
)
|
||||
return deleted_teams
|
||||
|
||||
|
||||
@router.get(
|
||||
"/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def team_info(
|
||||
http_request: Request,
|
||||
team_id: str = fastapi.Query(
|
||||
default=None, description="Team ID in the request parameters"
|
||||
),
|
||||
):
|
||||
"""
|
||||
get info on team + related keys
|
||||
|
||||
```
|
||||
curl --location 'http://localhost:4000/team/info' \
|
||||
--header 'Authorization: Bearer sk-1234' \
|
||||
--header 'Content-Type: application/json' \
|
||||
--data '{
|
||||
"teams": ["<team-id>",..]
|
||||
}'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
create_audit_log_for_update,
|
||||
_duration_in_seconds,
|
||||
)
|
||||
|
||||
try:
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_500_INTERNAL_SERVER_ERROR,
|
||||
detail={
|
||||
"error": f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
},
|
||||
)
|
||||
if team_id is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_422_UNPROCESSABLE_ENTITY,
|
||||
detail={"message": "Malformed request. No team id passed in."},
|
||||
)
|
||||
|
||||
team_info = await prisma_client.get_data(
|
||||
team_id=team_id, table_name="team", query_type="find_unique"
|
||||
)
|
||||
if team_info is None:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"message": f"Team not found, passed team id: {team_id}."},
|
||||
)
|
||||
|
||||
## GET ALL KEYS ##
|
||||
keys = await prisma_client.get_data(
|
||||
team_id=team_id,
|
||||
table_name="key",
|
||||
query_type="find_all",
|
||||
expires=datetime.now(),
|
||||
)
|
||||
|
||||
if team_info is None:
|
||||
## make sure we still return a total spend ##
|
||||
spend = 0
|
||||
for k in keys:
|
||||
spend += getattr(k, "spend", 0)
|
||||
team_info = {"spend": spend}
|
||||
|
||||
## REMOVE HASHED TOKEN INFO before returning ##
|
||||
for key in keys:
|
||||
try:
|
||||
key = key.model_dump() # noqa
|
||||
except:
|
||||
# if using pydantic v1
|
||||
key = key.dict()
|
||||
key.pop("token", None)
|
||||
return {"team_id": team_id, "team_info": team_info, "keys": keys}
|
||||
|
||||
except Exception as e:
|
||||
if isinstance(e, HTTPException):
|
||||
raise ProxyException(
|
||||
message=getattr(e, "detail", f"Authentication Error({str(e)})"),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST),
|
||||
)
|
||||
elif isinstance(e, ProxyException):
|
||||
raise e
|
||||
raise ProxyException(
|
||||
message="Authentication Error, " + str(e),
|
||||
type="auth_error",
|
||||
param=getattr(e, "param", "None"),
|
||||
code=status.HTTP_400_BAD_REQUEST,
|
||||
)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/block", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def block_team(
|
||||
data: BlockTeamRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Blocks all calls from keys with this team id.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
create_audit_log_for_update,
|
||||
_duration_in_seconds,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("No DB Connected.")
|
||||
|
||||
record = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id}, data={"blocked": True} # type: ignore
|
||||
)
|
||||
|
||||
return record
|
||||
|
||||
|
||||
@router.post(
|
||||
"/team/unblock", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def unblock_team(
|
||||
data: BlockTeamRequest,
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
Blocks all calls from keys with this team id.
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
create_audit_log_for_update,
|
||||
_duration_in_seconds,
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise Exception("No DB Connected.")
|
||||
|
||||
record = await prisma_client.db.litellm_teamtable.update(
|
||||
where={"team_id": data.team_id}, data={"blocked": False} # type: ignore
|
||||
)
|
||||
|
||||
return record
|
||||
|
||||
|
||||
@router.get(
|
||||
"/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
@management_endpoint_wrapper
|
||||
async def list_team(
|
||||
http_request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
"""
|
||||
[Admin-only] List all available teams
|
||||
|
||||
```
|
||||
curl --location --request GET 'http://0.0.0.0:4000/team/list' \
|
||||
--header 'Authorization: Bearer sk-1234'
|
||||
```
|
||||
"""
|
||||
from litellm.proxy.proxy_server import (
|
||||
prisma_client,
|
||||
litellm_proxy_admin_name,
|
||||
create_audit_log_for_update,
|
||||
_duration_in_seconds,
|
||||
)
|
||||
|
||||
if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN:
|
||||
raise HTTPException(
|
||||
status_code=401,
|
||||
detail={
|
||||
"error": "Admin-only endpoint. Your user role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
)
|
||||
},
|
||||
)
|
||||
|
||||
if prisma_client is None:
|
||||
raise HTTPException(
|
||||
status_code=400,
|
||||
detail={"error": CommonProxyErrors.db_not_connected_error.value},
|
||||
)
|
||||
|
||||
response = await prisma_client.db.litellm_teamtable.find_many()
|
||||
|
||||
return response
|
|
@ -1,5 +1,11 @@
|
|||
# What is this?
|
||||
## Helper utils for the management endpoints (keys/users/teams)
|
||||
from datetime import datetime
|
||||
from functools import wraps
|
||||
from litellm.proxy._types import UserAPIKeyAuth, ManagementEndpointLoggingPayload
|
||||
from litellm.proxy.common_utils.http_parsing_utils import _read_request_body
|
||||
from litellm._logging import verbose_logger
|
||||
from fastapi import Request
|
||||
|
||||
from litellm.proxy._types import LiteLLM_TeamTable, Member, UserAPIKeyAuth
|
||||
from litellm.proxy.utils import PrismaClient
|
||||
|
@ -61,3 +67,110 @@ async def add_new_member(
|
|||
"budget_id": _budget_id,
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
def management_endpoint_wrapper(func):
|
||||
"""
|
||||
This wrapper does the following:
|
||||
|
||||
1. Log I/O, Exceptions to OTEL
|
||||
2. Create an Audit log for success calls
|
||||
"""
|
||||
|
||||
@wraps(func)
|
||||
async def wrapper(*args, **kwargs):
|
||||
start_time = datetime.now()
|
||||
|
||||
try:
|
||||
result = await func(*args, **kwargs)
|
||||
end_time = datetime.now()
|
||||
try:
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
user_api_key_dict: UserAPIKeyAuth = (
|
||||
kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
|
||||
)
|
||||
_http_request: Request = kwargs.get("http_request")
|
||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||
if parent_otel_span is not None:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None:
|
||||
if _http_request:
|
||||
_route = _http_request.url.path
|
||||
_request_body: dict = await _read_request_body(
|
||||
request=_http_request
|
||||
)
|
||||
_response = dict(result) if result is not None else None
|
||||
|
||||
logging_payload = ManagementEndpointLoggingPayload(
|
||||
route=_route,
|
||||
request_data=_request_body,
|
||||
response=_response,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
)
|
||||
|
||||
await open_telemetry_logger.async_management_endpoint_success_hook(
|
||||
logging_payload=logging_payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
if _http_request:
|
||||
_route = _http_request.url.path
|
||||
# Flush user_api_key cache if this was an update/delete call to /key, /team, or /user
|
||||
if _route in [
|
||||
"/key/update",
|
||||
"/key/delete",
|
||||
"/team/update",
|
||||
"/team/delete",
|
||||
"/user/update",
|
||||
"/user/delete",
|
||||
"/customer/update",
|
||||
"/customer/delete",
|
||||
]:
|
||||
from litellm.proxy.proxy_server import user_api_key_cache
|
||||
|
||||
user_api_key_cache.flush_cache()
|
||||
except Exception as e:
|
||||
# Non-Blocking Exception
|
||||
verbose_logger.debug("Error in management endpoint wrapper: %s", str(e))
|
||||
pass
|
||||
|
||||
return result
|
||||
except Exception as e:
|
||||
end_time = datetime.now()
|
||||
|
||||
if kwargs is None:
|
||||
kwargs = {}
|
||||
user_api_key_dict: UserAPIKeyAuth = (
|
||||
kwargs.get("user_api_key_dict") or UserAPIKeyAuth()
|
||||
)
|
||||
parent_otel_span = user_api_key_dict.parent_otel_span
|
||||
if parent_otel_span is not None:
|
||||
from litellm.proxy.proxy_server import open_telemetry_logger
|
||||
|
||||
if open_telemetry_logger is not None:
|
||||
_http_request: Request = kwargs.get("http_request")
|
||||
if _http_request:
|
||||
_route = _http_request.url.path
|
||||
_request_body: dict = await _read_request_body(
|
||||
request=_http_request
|
||||
)
|
||||
logging_payload = ManagementEndpointLoggingPayload(
|
||||
route=_route,
|
||||
request_data=_request_body,
|
||||
response=None,
|
||||
start_time=start_time,
|
||||
end_time=end_time,
|
||||
exception=e,
|
||||
)
|
||||
|
||||
await open_telemetry_logger.async_management_endpoint_failure_hook(
|
||||
logging_payload=logging_payload,
|
||||
parent_otel_span=parent_otel_span,
|
||||
)
|
||||
|
||||
raise e
|
||||
|
||||
return wrapper
|
||||
|
|
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
|
@ -29,19 +29,22 @@ import pytest, logging, asyncio
|
|||
import litellm, asyncio
|
||||
from litellm.proxy.proxy_server import (
|
||||
new_user,
|
||||
generate_key_fn,
|
||||
user_api_key_auth,
|
||||
user_update,
|
||||
user_info,
|
||||
block_user,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_key_fn,
|
||||
info_key_fn,
|
||||
update_key_fn,
|
||||
generate_key_fn,
|
||||
generate_key_helper_fn,
|
||||
)
|
||||
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
|
||||
spend_user_fn,
|
||||
spend_key_fn,
|
||||
view_spend_logs,
|
||||
user_info,
|
||||
block_user,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
|
|
@ -14,6 +14,7 @@ sys.path.insert(
|
|||
import pytest
|
||||
from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes
|
||||
from litellm.proxy.auth.handle_jwt import JWTHandler
|
||||
from litellm.proxy.management_endpoints.team_endpoints import new_team
|
||||
from litellm.caching import DualCache
|
||||
from datetime import datetime, timedelta
|
||||
from fastapi import Request
|
||||
|
@ -218,7 +219,7 @@ async def test_team_token_output(prisma_client, audience):
|
|||
from cryptography.hazmat.backends import default_backend
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import URL
|
||||
from litellm.proxy.proxy_server import user_api_key_auth, new_team
|
||||
from litellm.proxy.proxy_server import user_api_key_auth
|
||||
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
|
||||
import litellm
|
||||
import uuid
|
||||
|
@ -399,7 +400,6 @@ async def test_user_token_output(
|
|||
from starlette.datastructures import URL
|
||||
from litellm.proxy.proxy_server import (
|
||||
user_api_key_auth,
|
||||
new_team,
|
||||
new_user,
|
||||
user_info,
|
||||
)
|
||||
|
@ -623,7 +623,7 @@ async def test_allowed_routes_admin(prisma_client, audience):
|
|||
from cryptography.hazmat.backends import default_backend
|
||||
from fastapi import Request
|
||||
from starlette.datastructures import URL
|
||||
from litellm.proxy.proxy_server import user_api_key_auth, new_team
|
||||
from litellm.proxy.proxy_server import user_api_key_auth
|
||||
from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth
|
||||
import litellm
|
||||
import uuid
|
||||
|
|
|
@ -38,22 +38,9 @@ import pytest, logging, asyncio
|
|||
import litellm, asyncio
|
||||
from litellm.proxy.proxy_server import (
|
||||
new_user,
|
||||
generate_key_fn,
|
||||
user_api_key_auth,
|
||||
user_update,
|
||||
delete_key_fn,
|
||||
info_key_fn,
|
||||
update_key_fn,
|
||||
generate_key_fn,
|
||||
generate_key_helper_fn,
|
||||
spend_user_fn,
|
||||
spend_key_fn,
|
||||
view_spend_logs,
|
||||
user_info,
|
||||
team_info,
|
||||
info_key_fn,
|
||||
new_team,
|
||||
update_team,
|
||||
chat_completion,
|
||||
completion,
|
||||
embeddings,
|
||||
|
@ -63,6 +50,23 @@ from litellm.proxy.proxy_server import (
|
|||
model_list,
|
||||
LitellmUserRoles,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_key_fn,
|
||||
info_key_fn,
|
||||
update_key_fn,
|
||||
generate_key_fn,
|
||||
generate_key_helper_fn,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.team_endpoints import (
|
||||
team_info,
|
||||
new_team,
|
||||
update_team,
|
||||
)
|
||||
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
|
||||
spend_user_fn,
|
||||
spend_key_fn,
|
||||
view_spend_logs,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
||||
|
|
|
@ -26,7 +26,7 @@ logging.basicConfig(
|
|||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import (
|
||||
router,
|
||||
app,
|
||||
save_worker_config,
|
||||
initialize,
|
||||
) # Replace with the actual module where your FastAPI router is defined
|
||||
|
@ -119,9 +119,6 @@ def client_no_auth(fake_env_vars):
|
|||
config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml"
|
||||
# initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables
|
||||
asyncio.run(initialize(config=config_fp, debug=True))
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
|
||||
return TestClient(app)
|
||||
|
||||
|
||||
|
@ -426,6 +423,10 @@ def test_add_new_model(client_no_auth):
|
|||
def test_health(client_no_auth):
|
||||
global headers
|
||||
import time
|
||||
from litellm._logging import verbose_logger, verbose_proxy_logger
|
||||
import logging
|
||||
|
||||
verbose_proxy_logger.setLevel(logging.DEBUG)
|
||||
|
||||
try:
|
||||
response = client_no_auth.get("/health")
|
||||
|
|
|
@ -26,19 +26,22 @@ import pytest, logging, asyncio
|
|||
import litellm, asyncio
|
||||
from litellm.proxy.proxy_server import (
|
||||
new_user,
|
||||
generate_key_fn,
|
||||
user_api_key_auth,
|
||||
user_update,
|
||||
user_info,
|
||||
block_user,
|
||||
)
|
||||
from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import (
|
||||
spend_user_fn,
|
||||
spend_key_fn,
|
||||
view_spend_logs,
|
||||
)
|
||||
from litellm.proxy.management_endpoints.key_management_endpoints import (
|
||||
delete_key_fn,
|
||||
info_key_fn,
|
||||
update_key_fn,
|
||||
generate_key_fn,
|
||||
generate_key_helper_fn,
|
||||
spend_user_fn,
|
||||
spend_key_fn,
|
||||
view_spend_logs,
|
||||
user_info,
|
||||
block_user,
|
||||
)
|
||||
from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend
|
||||
from litellm._logging import verbose_proxy_logger
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue