diff --git a/litellm/proxy/_types.py b/litellm/proxy/_types.py index fdbd9e22b..9a2847839 100644 --- a/litellm/proxy/_types.py +++ b/litellm/proxy/_types.py @@ -1609,3 +1609,10 @@ class ProxyException(Exception): "param": self.param, "code": self.code, } + + +class CommonProxyErrors(enum.Enum): + db_not_connected_error = "DB not connected" + no_llm_router = "No models configured on proxy" + not_allowed_access = "Admin-only endpoint. Not allowed to access this." + not_premium_user = "You must be a LiteLLM Enterprise user to use this feature. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat" diff --git a/litellm/proxy/caching_routes.py b/litellm/proxy/caching_routes.py new file mode 100644 index 000000000..bad747793 --- /dev/null +++ b/litellm/proxy/caching_routes.py @@ -0,0 +1,194 @@ +from typing import Optional +from fastapi import Depends, Request, APIRouter +from fastapi import HTTPException +import copy +import litellm +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth + + +router = APIRouter( + prefix="/cache", + tags=["caching"], +) + + +@router.get( + "/ping", + dependencies=[Depends(user_api_key_auth)], +) +async def cache_ping(): + """ + Endpoint for checking if cache can be pinged + """ + try: + litellm_cache_params = {} + specific_cache_params = {} + + if litellm.cache is None: + raise HTTPException( + status_code=503, detail="Cache not initialized. litellm.cache is None" + ) + + for k, v in vars(litellm.cache).items(): + try: + if k == "cache": + continue + litellm_cache_params[k] = str(copy.deepcopy(v)) + except Exception: + litellm_cache_params[k] = "" + for k, v in vars(litellm.cache.cache).items(): + try: + specific_cache_params[k] = str(v) + except Exception: + specific_cache_params[k] = "" + if litellm.cache.type == "redis": + # ping the redis cache + ping_response = await litellm.cache.ping() + verbose_proxy_logger.debug( + "/cache/ping: ping_response: " + str(ping_response) + ) + # making a set cache call + # add cache does not return anything + await litellm.cache.async_add_cache( + result="test_key", + model="test-model", + messages=[{"role": "user", "content": "test from litellm"}], + ) + verbose_proxy_logger.debug("/cache/ping: done with set_cache()") + return { + "status": "healthy", + "cache_type": litellm.cache.type, + "ping_response": True, + "set_cache_response": "success", + "litellm_cache_params": litellm_cache_params, + "redis_cache_params": specific_cache_params, + } + else: + return { + "status": "healthy", + "cache_type": litellm.cache.type, + "litellm_cache_params": litellm_cache_params, + } + except Exception as e: + raise HTTPException( + status_code=503, + detail=f"Service Unhealthy ({str(e)}).Cache parameters: {litellm_cache_params}.specific_cache_params: {specific_cache_params}", + ) + + +@router.post( + "/delete", + tags=["caching"], + dependencies=[Depends(user_api_key_auth)], +) +async def cache_delete(request: Request): + """ + Endpoint for deleting a key from the cache. All responses from litellm proxy have `x-litellm-cache-key` in the headers + + Parameters: + - **keys**: *Optional[List[str]]* - A list of keys to delete from the cache. Example {"keys": ["key1", "key2"]} + + ```shell + curl -X POST "http://0.0.0.0:4000/cache/delete" \ + -H "Authorization: Bearer sk-1234" \ + -d '{"keys": ["key1", "key2"]}' + ``` + + """ + try: + if litellm.cache is None: + raise HTTPException( + status_code=503, detail="Cache not initialized. litellm.cache is None" + ) + + request_data = await request.json() + keys = request_data.get("keys", None) + + if litellm.cache.type == "redis": + await litellm.cache.delete_cache_keys(keys=keys) + return { + "status": "success", + } + else: + raise HTTPException( + status_code=500, + detail=f"Cache type {litellm.cache.type} does not support deleting a key. only `redis` is supported", + ) + except Exception as e: + raise HTTPException( + status_code=500, + detail=f"Cache Delete Failed({str(e)})", + ) + + +@router.get( + "/redis/info", + dependencies=[Depends(user_api_key_auth)], +) +async def cache_redis_info(): + """ + Endpoint for getting /redis/info + """ + try: + if litellm.cache is None: + raise HTTPException( + status_code=503, detail="Cache not initialized. litellm.cache is None" + ) + if litellm.cache.type == "redis": + client_list = litellm.cache.cache.client_list() + redis_info = litellm.cache.cache.info() + num_clients = len(client_list) + return { + "num_clients": num_clients, + "clients": client_list, + "info": redis_info, + } + else: + raise HTTPException( + status_code=500, + detail=f"Cache type {litellm.cache.type} does not support flushing", + ) + except Exception as e: + raise HTTPException( + status_code=503, + detail=f"Service Unhealthy ({str(e)})", + ) + + +@router.post( + "/flushall", + tags=["caching"], + dependencies=[Depends(user_api_key_auth)], +) +async def cache_flushall(): + """ + A function to flush all items from the cache. (All items will be deleted from the cache with this) + Raises HTTPException if the cache is not initialized or if the cache type does not support flushing. + Returns a dictionary with the status of the operation. + + Usage: + ``` + curl -X POST http://0.0.0.0:4000/cache/flushall -H "Authorization: Bearer sk-1234" + ``` + """ + try: + if litellm.cache is None: + raise HTTPException( + status_code=503, detail="Cache not initialized. litellm.cache is None" + ) + if litellm.cache.type == "redis": + litellm.cache.cache.flushall() + return { + "status": "success", + } + else: + raise HTTPException( + status_code=500, + detail=f"Cache type {litellm.cache.type} does not support flushing", + ) + except Exception as e: + raise HTTPException( + status_code=503, + detail=f"Service Unhealthy ({str(e)})", + ) diff --git a/litellm/proxy/common_utils/management_endpoint_utils.py b/litellm/proxy/common_utils/management_endpoint_utils.py deleted file mode 100644 index 2b4465caf..000000000 --- a/litellm/proxy/common_utils/management_endpoint_utils.py +++ /dev/null @@ -1,113 +0,0 @@ -from datetime import datetime -from functools import wraps -from litellm.proxy._types import UserAPIKeyAuth, ManagementEndpointLoggingPayload -from litellm.proxy.common_utils.http_parsing_utils import _read_request_body -from litellm._logging import verbose_logger -from fastapi import Request - - -def management_endpoint_wrapper(func): - """ - This wrapper does the following: - - 1. Log I/O, Exceptions to OTEL - 2. Create an Audit log for success calls - """ - - @wraps(func) - async def wrapper(*args, **kwargs): - start_time = datetime.now() - - try: - result = await func(*args, **kwargs) - end_time = datetime.now() - try: - if kwargs is None: - kwargs = {} - user_api_key_dict: UserAPIKeyAuth = ( - kwargs.get("user_api_key_dict") or UserAPIKeyAuth() - ) - _http_request: Request = kwargs.get("http_request") - parent_otel_span = user_api_key_dict.parent_otel_span - if parent_otel_span is not None: - from litellm.proxy.proxy_server import open_telemetry_logger - - if open_telemetry_logger is not None: - if _http_request: - _route = _http_request.url.path - _request_body: dict = await _read_request_body( - request=_http_request - ) - _response = dict(result) if result is not None else None - - logging_payload = ManagementEndpointLoggingPayload( - route=_route, - request_data=_request_body, - response=_response, - start_time=start_time, - end_time=end_time, - ) - - await open_telemetry_logger.async_management_endpoint_success_hook( - logging_payload=logging_payload, - parent_otel_span=parent_otel_span, - ) - - if _http_request: - _route = _http_request.url.path - # Flush user_api_key cache if this was an update/delete call to /key, /team, or /user - if _route in [ - "/key/update", - "/key/delete", - "/team/update", - "/team/delete", - "/user/update", - "/user/delete", - "/customer/update", - "/customer/delete", - ]: - from litellm.proxy.proxy_server import user_api_key_cache - - user_api_key_cache.flush_cache() - except Exception as e: - # Non-Blocking Exception - verbose_logger.debug("Error in management endpoint wrapper: %s", str(e)) - pass - - return result - except Exception as e: - end_time = datetime.now() - - if kwargs is None: - kwargs = {} - user_api_key_dict: UserAPIKeyAuth = ( - kwargs.get("user_api_key_dict") or UserAPIKeyAuth() - ) - parent_otel_span = user_api_key_dict.parent_otel_span - if parent_otel_span is not None: - from litellm.proxy.proxy_server import open_telemetry_logger - - if open_telemetry_logger is not None: - _http_request: Request = kwargs.get("http_request") - if _http_request: - _route = _http_request.url.path - _request_body: dict = await _read_request_body( - request=_http_request - ) - logging_payload = ManagementEndpointLoggingPayload( - route=_route, - request_data=_request_body, - response=None, - start_time=start_time, - end_time=end_time, - exception=e, - ) - - await open_telemetry_logger.async_management_endpoint_failure_hook( - logging_payload=logging_payload, - parent_otel_span=parent_otel_span, - ) - - raise e - - return wrapper diff --git a/litellm/proxy/health_endpoints/_health_endpoints.py b/litellm/proxy/health_endpoints/_health_endpoints.py new file mode 100644 index 000000000..fbb1c1a5f --- /dev/null +++ b/litellm/proxy/health_endpoints/_health_endpoints.py @@ -0,0 +1,478 @@ +from typing import Optional, Literal +import litellm +import os +import asyncio +import fastapi +import traceback +from datetime import datetime, timedelta +from fastapi import Depends, Request, APIRouter, Header, status +from litellm.proxy.health_check import perform_health_check +from fastapi import HTTPException +import copy +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy._types import ( + UserAPIKeyAuth, + ProxyException, + WebhookEvent, + CallInfo, +) + +#### Health ENDPOINTS #### + +router = APIRouter() + + +@router.get( + "/test", + tags=["health"], + dependencies=[Depends(user_api_key_auth)], +) +async def test_endpoint(request: Request): + """ + [DEPRECATED] use `/health/liveliness` instead. + + A test endpoint that pings the proxy server to check if it's healthy. + + Parameters: + request (Request): The incoming request. + + Returns: + dict: A dictionary containing the route of the request URL. + """ + # ping the proxy server to check if its healthy + return {"route": request.url.path} + + +@router.get( + "/health/services", + tags=["health"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def health_services_endpoint( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + service: Literal[ + "slack_budget_alerts", "langfuse", "slack", "openmeter", "webhook", "email" + ] = fastapi.Query(description="Specify the service being hit."), +): + """ + Hidden endpoint. + + Used by the UI to let user check if slack alerting is working as expected. + """ + try: + from litellm.proxy.proxy_server import ( + proxy_logging_obj, + prisma_client, + general_settings, + ) + + if service is None: + raise HTTPException( + status_code=400, detail={"error": "Service must be specified."} + ) + if service not in [ + "slack_budget_alerts", + "email", + "langfuse", + "slack", + "openmeter", + "webhook", + ]: + raise HTTPException( + status_code=400, + detail={ + "error": f"Service must be in list. Service={service}. List={['slack_budget_alerts']}" + }, + ) + + if service == "openmeter": + _ = await litellm.acompletion( + model="openai/litellm-mock-response-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + user="litellm:/health/services", + mock_response="This is a mock response", + ) + return { + "status": "success", + "message": "Mock LLM request made - check openmeter.", + } + + if service == "langfuse": + from litellm.integrations.langfuse import LangFuseLogger + + langfuse_logger = LangFuseLogger() + langfuse_logger.Langfuse.auth_check() + _ = litellm.completion( + model="openai/litellm-mock-response-model", + messages=[{"role": "user", "content": "Hey, how's it going?"}], + user="litellm:/health/services", + mock_response="This is a mock response", + ) + return { + "status": "success", + "message": "Mock LLM request made - check langfuse.", + } + + if service == "webhook": + user_info = CallInfo( + token=user_api_key_dict.token or "", + spend=1, + max_budget=0, + user_id=user_api_key_dict.user_id, + key_alias=user_api_key_dict.key_alias, + team_id=user_api_key_dict.team_id, + ) + await proxy_logging_obj.budget_alerts( + type="user_budget", + user_info=user_info, + ) + + if service == "slack" or service == "slack_budget_alerts": + if "slack" in general_settings.get("alerting", []): + # test_message = f"""\n🚨 `ProjectedLimitExceededError` šŸ’ø\n\n`Key Alias:` litellm-ui-test-alert \n`Expected Day of Error`: 28th March \n`Current Spend`: $100.00 \n`Projected Spend at end of month`: $1000.00 \n`Soft Limit`: $700""" + # check if user has opted into unique_alert_webhooks + if ( + proxy_logging_obj.slack_alerting_instance.alert_to_webhook_url + is not None + ): + for ( + alert_type + ) in proxy_logging_obj.slack_alerting_instance.alert_to_webhook_url: + """ + "llm_exceptions", + "llm_too_slow", + "llm_requests_hanging", + "budget_alerts", + "db_exceptions", + """ + # only test alert if it's in active alert types + if ( + proxy_logging_obj.slack_alerting_instance.alert_types + is not None + and alert_type + not in proxy_logging_obj.slack_alerting_instance.alert_types + ): + continue + test_message = "default test message" + if alert_type == "llm_exceptions": + test_message = f"LLM Exception test alert" + elif alert_type == "llm_too_slow": + test_message = f"LLM Too Slow test alert" + elif alert_type == "llm_requests_hanging": + test_message = f"LLM Requests Hanging test alert" + elif alert_type == "budget_alerts": + test_message = f"Budget Alert test alert" + elif alert_type == "db_exceptions": + test_message = f"DB Exception test alert" + elif alert_type == "outage_alerts": + test_message = f"Outage Alert Exception test alert" + elif alert_type == "daily_reports": + test_message = f"Daily Reports test alert" + + await proxy_logging_obj.alerting_handler( + message=test_message, level="Low", alert_type=alert_type + ) + else: + await proxy_logging_obj.alerting_handler( + message="This is a test slack alert message", + level="Low", + alert_type="budget_alerts", + ) + + if prisma_client is not None: + asyncio.create_task( + proxy_logging_obj.slack_alerting_instance.send_monthly_spend_report() + ) + asyncio.create_task( + proxy_logging_obj.slack_alerting_instance.send_weekly_spend_report() + ) + + alert_types = ( + proxy_logging_obj.slack_alerting_instance.alert_types or [] + ) + alert_types = list(alert_types) + return { + "status": "success", + "alert_types": alert_types, + "message": "Mock Slack Alert sent, verify Slack Alert Received on your channel", + } + else: + raise HTTPException( + status_code=422, + detail={ + "error": '"{}" not in proxy config: general_settings. Unable to test this.'.format( + service + ) + }, + ) + if service == "email": + webhook_event = WebhookEvent( + event="key_created", + event_group="key", + event_message="Test Email Alert", + token=user_api_key_dict.token or "", + key_alias="Email Test key (This is only a test alert key. DO NOT USE THIS IN PRODUCTION.)", + spend=0, + max_budget=0, + user_id=user_api_key_dict.user_id, + user_email=os.getenv("TEST_EMAIL_ADDRESS"), + team_id=user_api_key_dict.team_id, + ) + + # use create task - this can take 10 seconds. don't keep ui users waiting for notification to check their email + asyncio.create_task( + proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( + webhook_event=webhook_event + ) + ) + + return { + "status": "success", + "message": "Mock Email Alert sent, verify Email Alert Received", + } + + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.health_services_endpoint(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)]) +async def health_endpoint( + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + model: Optional[str] = fastapi.Query( + None, description="Specify the model name (optional)" + ), +): + """ + 🚨 USE `/health/liveliness` to health check the proxy 🚨 + + See more šŸ‘‰ https://docs.litellm.ai/docs/proxy/health + + + Check the health of all the endpoints in config.yaml + + To run health checks in the background, add this to config.yaml: + ``` + general_settings: + # ... other settings + background_health_checks: True + ``` + else, the health checks will be run on models when /health is called. + """ + from litellm.proxy.proxy_server import ( + health_check_results, + use_background_health_checks, + user_model, + llm_model_list, + ) + + try: + if llm_model_list is None: + # if no router set, check if user set a model using litellm --model ollama/llama2 + if user_model is not None: + healthy_endpoints, unhealthy_endpoints = await perform_health_check( + model_list=[], cli_model=user_model + ) + return { + "healthy_endpoints": healthy_endpoints, + "unhealthy_endpoints": unhealthy_endpoints, + "healthy_count": len(healthy_endpoints), + "unhealthy_count": len(unhealthy_endpoints), + } + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "Model list not initialized"}, + ) + _llm_model_list = copy.deepcopy(llm_model_list) + ### FILTER MODELS FOR ONLY THOSE USER HAS ACCESS TO ### + if len(user_api_key_dict.models) > 0: + allowed_model_names = user_api_key_dict.models + else: + allowed_model_names = [] # + if use_background_health_checks: + return health_check_results + else: + healthy_endpoints, unhealthy_endpoints = await perform_health_check( + _llm_model_list, model + ) + + return { + "healthy_endpoints": healthy_endpoints, + "unhealthy_endpoints": unhealthy_endpoints, + "healthy_count": len(healthy_endpoints), + "unhealthy_count": len(unhealthy_endpoints), + } + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.py::health_endpoint(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + raise e + + +db_health_cache = {"status": "unknown", "last_updated": datetime.now()} + + +def _db_health_readiness_check(): + from litellm.proxy.proxy_server import prisma_client + + global db_health_cache + + # Note - Intentionally don't try/except this so it raises an exception when it fails + + # if timedelta is less than 2 minutes return DB Status + time_diff = datetime.now() - db_health_cache["last_updated"] + if db_health_cache["status"] != "unknown" and time_diff < timedelta(minutes=2): + return db_health_cache + prisma_client.health_check() + db_health_cache = {"status": "connected", "last_updated": datetime.now()} + return db_health_cache + + +@router.get( + "/active/callbacks", + tags=["health"], + dependencies=[Depends(user_api_key_auth)], +) +async def active_callbacks(): + """ + Returns a list of active callbacks on litellm.callbacks, litellm.input_callback, litellm.failure_callback, litellm.success_callback + """ + from litellm.proxy.proxy_server import proxy_logging_obj, general_settings + + _alerting = str(general_settings.get("alerting")) + # get success callbacks + + litellm_callbacks = [str(x) for x in litellm.callbacks] + litellm_input_callbacks = [str(x) for x in litellm.input_callback] + litellm_failure_callbacks = [str(x) for x in litellm.failure_callback] + litellm_success_callbacks = [str(x) for x in litellm.success_callback] + litellm_async_success_callbacks = [str(x) for x in litellm._async_success_callback] + litellm_async_failure_callbacks = [str(x) for x in litellm._async_failure_callback] + litellm_async_input_callbacks = [str(x) for x in litellm._async_input_callback] + + all_litellm_callbacks = ( + litellm_callbacks + + litellm_input_callbacks + + litellm_failure_callbacks + + litellm_success_callbacks + + litellm_async_success_callbacks + + litellm_async_failure_callbacks + + litellm_async_input_callbacks + ) + + alerting = proxy_logging_obj.alerting + _num_alerting = 0 + if alerting and isinstance(alerting, list): + _num_alerting = len(alerting) + + return { + "alerting": _alerting, + "litellm.callbacks": litellm_callbacks, + "litellm.input_callback": litellm_input_callbacks, + "litellm.failure_callback": litellm_failure_callbacks, + "litellm.success_callback": litellm_success_callbacks, + "litellm._async_success_callback": litellm_async_success_callbacks, + "litellm._async_failure_callback": litellm_async_failure_callbacks, + "litellm._async_input_callback": litellm_async_input_callbacks, + "all_litellm_callbacks": all_litellm_callbacks, + "num_callbacks": len(all_litellm_callbacks), + "num_alerting": _num_alerting, + } + + +@router.get( + "/health/readiness", + tags=["health"], + dependencies=[Depends(user_api_key_auth)], +) +async def health_readiness(): + """ + Unprotected endpoint for checking if worker can receive requests + """ + from litellm.proxy.proxy_server import proxy_logging_obj, prisma_client, version + + try: + # get success callback + success_callback_names = [] + + try: + # this was returning a JSON of the values in some of the callbacks + # all we need is the callback name, hence we do str(callback) + success_callback_names = [str(x) for x in litellm.success_callback] + except: + # don't let this block the /health/readiness response, if we can't convert to str -> return litellm.success_callback + success_callback_names = litellm.success_callback + + # check Cache + cache_type = None + if litellm.cache is not None: + from litellm.caching import RedisSemanticCache + + cache_type = litellm.cache.type + + if isinstance(litellm.cache.cache, RedisSemanticCache): + # ping the cache + # TODO: @ishaan-jaff - we should probably not ping the cache on every /health/readiness check + try: + index_info = await litellm.cache.cache._index_info() + except Exception as e: + index_info = "index does not exist - error: " + str(e) + cache_type = {"type": cache_type, "index_info": index_info} + + # check DB + if prisma_client is not None: # if db passed in, check if it's connected + db_health_status = _db_health_readiness_check() + return { + "status": "healthy", + "db": "connected", + "cache": cache_type, + "litellm_version": version, + "success_callbacks": success_callback_names, + **db_health_status, + } + else: + return { + "status": "healthy", + "db": "Not connected", + "cache": cache_type, + "litellm_version": version, + "success_callbacks": success_callback_names, + } + except Exception as e: + raise HTTPException(status_code=503, detail=f"Service Unhealthy ({str(e)})") + + +@router.get( + "/health/liveliness", + tags=["health"], + dependencies=[Depends(user_api_key_auth)], +) +async def health_liveliness(): + """ + Unprotected endpoint for checking if worker is alive + """ + return "I'm alive!" diff --git a/litellm/proxy/management_endpoints/key_management_endpoints.py b/litellm/proxy/management_endpoints/key_management_endpoints.py new file mode 100644 index 000000000..94d377f78 --- /dev/null +++ b/litellm/proxy/management_endpoints/key_management_endpoints.py @@ -0,0 +1,926 @@ +""" +KEY MANAGEMENT + +All /key management endpoints + +/key/generate +/key/info +/key/update +/key/delete +""" + +import copy +import json +import uuid +import re +import traceback +import asyncio +import secrets +from typing import Optional, List +import fastapi +from fastapi import Depends, Request, APIRouter, Header, status +from fastapi import HTTPException +import litellm +from datetime import datetime, timedelta, timezone +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy._types import * + +router = APIRouter() + + +@router.post( + "/key/generate", + tags=["key management"], + dependencies=[Depends(user_api_key_auth)], + response_model=GenerateKeyResponse, +) +async def generate_key_fn( + data: GenerateKeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Generate an API key based on the provided data. + + Docs: https://docs.litellm.ai/docs/proxy/virtual_keys + + Parameters: + - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). + - key_alias: Optional[str] - User defined key alias + - team_id: Optional[str] - The team id of the key + - user_id: Optional[str] - The user id of the key + - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) + - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models + - config: Optional[dict] - any key-specific configs, overrides config in config.yaml + - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend + - send_invite_email: Optional[bool] - Whether to send an invite email to the user_id, with the generate key + - max_budget: Optional[float] - Specify max budget for a given key. + - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. + - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + - permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false} + - model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget. + + Examples: + + 1. Allow users to turn on/off pii masking + + ```bash + curl --location 'http://0.0.0.0:8000/key/generate' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "permissions": {"allow_pii_controls": true} + }' + ``` + + Returns: + - key: (str) The generated api key + - expires: (datetime) Datetime object for when key expires. + - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. + """ + try: + from litellm.proxy.proxy_server import ( + user_custom_key_generate, + prisma_client, + litellm_proxy_admin_name, + general_settings, + proxy_logging_obj, + create_audit_log_for_update, + ) + + verbose_proxy_logger.debug("entered /key/generate") + + if user_custom_key_generate is not None: + result = await user_custom_key_generate(data) + decision = result.get("decision", True) + message = result.get("message", "Authentication Failed - Custom Auth Rule") + if not decision: + raise HTTPException( + status_code=status.HTTP_403_FORBIDDEN, detail=message + ) + # check if user set default key/generate params on config.yaml + if litellm.default_key_generate_params is not None: + for elem in data: + key, value = elem + if value is None and key in [ + "max_budget", + "user_id", + "team_id", + "max_parallel_requests", + "tpm_limit", + "rpm_limit", + "budget_duration", + ]: + setattr( + data, key, litellm.default_key_generate_params.get(key, None) + ) + elif key == "models" and value == []: + setattr(data, key, litellm.default_key_generate_params.get(key, [])) + elif key == "metadata" and value == {}: + setattr(data, key, litellm.default_key_generate_params.get(key, {})) + + # check if user set default key/generate params on config.yaml + if litellm.upperbound_key_generate_params is not None: + for elem in data: + # if key in litellm.upperbound_key_generate_params, use the min of value and litellm.upperbound_key_generate_params[key] + key, value = elem + if ( + value is not None + and getattr(litellm.upperbound_key_generate_params, key, None) + is not None + ): + # if value is float/int + if key in [ + "max_budget", + "max_parallel_requests", + "tpm_limit", + "rpm_limit", + ]: + if value > getattr(litellm.upperbound_key_generate_params, key): + raise HTTPException( + status_code=400, + detail={ + "error": f"{key} is over max limit set in config - user_value={value}; max_value={getattr(litellm.upperbound_key_generate_params, key)}" + }, + ) + elif key == "budget_duration": + # budgets are in 1s, 1m, 1h, 1d, 1m (30s, 30m, 30h, 30d, 30m) + # compare the duration in seconds and max duration in seconds + upperbound_budget_duration = _duration_in_seconds( + duration=getattr( + litellm.upperbound_key_generate_params, key + ) + ) + user_set_budget_duration = _duration_in_seconds(duration=value) + if user_set_budget_duration > upperbound_budget_duration: + raise HTTPException( + status_code=400, + detail={ + "error": f"Budget duration is over max limit set in config - user_value={user_set_budget_duration}; max_value={upperbound_budget_duration}" + }, + ) + + # TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable + _budget_id = None + if prisma_client is not None and data.soft_budget is not None: + # create the Budget Row for the LiteLLM Verification Token + budget_row = LiteLLM_BudgetTable( + soft_budget=data.soft_budget, + model_max_budget=data.model_max_budget or {}, + ) + new_budget = prisma_client.jsonify_object( + budget_row.json(exclude_none=True) + ) + + _budget = await prisma_client.db.litellm_budgettable.create( + data={ + **new_budget, # type: ignore + "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, + } + ) + _budget_id = getattr(_budget, "budget_id", None) + data_json = data.json() # type: ignore + # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users + if "max_budget" in data_json: + data_json["key_max_budget"] = data_json.pop("max_budget", None) + if _budget_id is not None: + data_json["budget_id"] = _budget_id + + if "budget_duration" in data_json: + data_json["key_budget_duration"] = data_json.pop("budget_duration", None) + + response = await generate_key_helper_fn( + request_type="key", **data_json, table_name="key" + ) + + response["soft_budget"] = ( + data.soft_budget + ) # include the user-input soft budget in the response + + if data.send_invite_email is True: + if "email" not in general_settings.get("alerting", []): + raise ValueError( + "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" + ) + event = WebhookEvent( + event="key_created", + event_group="key", + event_message=f"API Key Created", + token=response.get("token", ""), + spend=response.get("spend", 0.0), + max_budget=response.get("max_budget", 0.0), + user_id=response.get("user_id", None), + team_id=response.get("team_id", "Default Team"), + key_alias=response.get("key_alias", None), + ) + + # If user configured email alerting - send an Email letting their end-user know the key was created + asyncio.create_task( + proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( + webhook_event=event, + ) + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = json.dumps(response, default=str) + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=response.get("token_id", ""), + action="created", + updated_values=_updated_values, + before_value=None, + ) + ) + ) + + return GenerateKeyResponse(**response) + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.generate_key_fn(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +@router.post( + "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +async def update_key_fn( + request: Request, + data: UpdateKeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Update an existing key + """ + from litellm.proxy.proxy_server import ( + user_custom_key_generate, + prisma_client, + litellm_proxy_admin_name, + general_settings, + proxy_logging_obj, + create_audit_log_for_update, + user_api_key_cache, + ) + + try: + data_json: dict = data.json() + key = data_json.pop("key") + # get the row from db + if prisma_client is None: + raise Exception("Not connected to DB!") + + existing_key_row = await prisma_client.get_data( + token=data.key, table_name="key", query_type="find_unique" + ) + + if existing_key_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + # get non default values for key + non_default_values = {} + for k, v in data_json.items(): + if v is not None and v not in ( + [], + {}, + 0, + ): # models default to [], spend defaults to 0, we should not reset these values + non_default_values[k] = v + + if "duration" in non_default_values: + duration = non_default_values.pop("duration") + duration_s = _duration_in_seconds(duration=duration) + expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + non_default_values["expires"] = expires + + response = await prisma_client.update_data( + token=key, data={**non_default_values, "token": key} + ) + + # Delete - key from cache, since it's been updated! + # key updated - a new model could have been added to this key. it should not block requests after this is done + user_api_key_cache.delete_cache(key) + hashed_token = hash_token(key) + user_api_key_cache.delete_cache(hashed_token) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = json.dumps(data_json, default=str) + + _before_value = existing_key_row.json(exclude_none=True) + _before_value = json.dumps(_before_value, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=data.key, + action="updated", + updated_values=_updated_values, + before_value=_before_value, + ) + ) + ) + + return {"key": key, **response["data"]} + # update based on remaining passed in values + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +@router.post( + "/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +async def delete_key_fn( + data: KeyRequest, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Delete a key from the key management system. + + Parameters:: + - keys (List[str]): A list of keys or hashed keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]} + + Returns: + - deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]} + + + Raises: + HTTPException: If an error occurs during key deletion. + """ + try: + from litellm.proxy.proxy_server import ( + user_custom_key_generate, + prisma_client, + litellm_proxy_admin_name, + general_settings, + proxy_logging_obj, + create_audit_log_for_update, + user_api_key_cache, + ) + + keys = data.keys + if len(keys) == 0: + raise ProxyException( + message=f"No keys provided, passed in: keys={keys}", + type="auth_error", + param="keys", + code=status.HTTP_400_BAD_REQUEST, + ) + + ## only allow user to delete keys they own + user_id = user_api_key_dict.user_id + verbose_proxy_logger.debug( + f"user_api_key_dict.user_role: {user_api_key_dict.user_role}" + ) + if ( + user_api_key_dict.user_role is not None + and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN + ): + user_id = None # unless they're admin + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes + if litellm.store_audit_logs is True: + # make an audit log for each team deleted + for key in data.keys: + key_row = await prisma_client.get_data( # type: ignore + token=key, table_name="key", query_type="find_unique" + ) + + key_row = key_row.json(exclude_none=True) + _key_row = json.dumps(key_row, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.KEY_TABLE_NAME, + object_id=key, + action="deleted", + updated_values="{}", + before_value=_key_row, + ) + ) + ) + + number_deleted_keys = await delete_verification_token( + tokens=keys, user_id=user_id + ) + verbose_proxy_logger.debug( + f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}" + ) + + try: + assert len(keys) == number_deleted_keys["deleted_keys"] + except Exception as e: + raise HTTPException( + status_code=400, + detail={ + "error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={len(keys)}, Deleted keys ={number_deleted_keys['deleted_keys']}" + }, + ) + + for key in keys: + user_api_key_cache.delete_cache(key) + # remove hash token from cache + hashed_token = hash_token(key) + user_api_key_cache.delete_cache(hashed_token) + + verbose_proxy_logger.debug( + f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}" + ) + + return {"deleted_keys": keys} + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +@router.post( + "/v2/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +async def info_key_fn_v2( + data: Optional[KeyRequest] = None, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Retrieve information about a list of keys. + + **New endpoint**. Currently admin only. + Parameters: + keys: Optional[list] = body parameter representing the key(s) in the request + user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key + Returns: + Dict containing the key and its associated information + + Example Curl: + ``` + curl -X GET "http://0.0.0.0:8000/key/info" \ + -H "Authorization: Bearer sk-1234" \ + -d {"keys": ["sk-1", "sk-2", "sk-3"]} + ``` + """ + from litellm.proxy.proxy_server import ( + user_custom_key_generate, + prisma_client, + litellm_proxy_admin_name, + general_settings, + proxy_logging_obj, + create_audit_log_for_update, + ) + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + if data is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={"message": "Malformed request. No keys passed in."}, + ) + + key_info = await prisma_client.get_data( + token=data.keys, table_name="key", query_type="find_all" + ) + filtered_key_info = [] + for k in key_info: + try: + k = k.model_dump() # noqa + except: + # if using pydantic v1 + k = k.dict() + filtered_key_info.append(k) + return {"key": data.keys, "info": filtered_key_info} + + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +@router.get( + "/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)] +) +async def info_key_fn( + key: Optional[str] = fastapi.Query( + default=None, description="Key in the request parameters" + ), + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Retrieve information about a key. + Parameters: + key: Optional[str] = Query parameter representing the key in the request + user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key + Returns: + Dict containing the key and its associated information + + Example Curl: + ``` + curl -X GET "http://0.0.0.0:8000/key/info?key=sk-02Wr4IAlN3NvPXvL5JVvDA" \ +-H "Authorization: Bearer sk-1234" + ``` + + Example Curl - if no key is passed, it will use the Key Passed in Authorization Header + ``` + curl -X GET "http://0.0.0.0:8000/key/info" \ +-H "Authorization: Bearer sk-02Wr4IAlN3NvPXvL5JVvDA" + ``` + """ + from litellm.proxy.proxy_server import ( + user_custom_key_generate, + prisma_client, + litellm_proxy_admin_name, + general_settings, + proxy_logging_obj, + create_audit_log_for_update, + ) + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + if key == None: + key = user_api_key_dict.api_key + key_info = await prisma_client.get_data(token=key) + ## REMOVE HASHED TOKEN INFO BEFORE RETURNING ## + try: + key_info = key_info.model_dump() # noqa + except: + # if using pydantic v1 + key_info = key_info.dict() + key_info.pop("token") + return {"key": key, "info": key_info} + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +def _duration_in_seconds(duration: str): + match = re.match(r"(\d+)([smhd]?)", duration) + if not match: + raise ValueError("Invalid duration format") + + value, unit = match.groups() + value = int(value) + + if unit == "s": + return value + elif unit == "m": + return value * 60 + elif unit == "h": + return value * 3600 + elif unit == "d": + return value * 86400 + else: + raise ValueError("Unsupported duration unit") + + +async def generate_key_helper_fn( + request_type: Literal[ + "user", "key" + ], # identifies if this request is from /user/new or /key/generate + duration: Optional[str], + models: list, + aliases: dict, + config: dict, + spend: float, + key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key + key_budget_duration: Optional[str] = None, + budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable + soft_budget: Optional[ + float + ] = None, # soft_budget is used to set soft Budgets Per user + max_budget: Optional[float] = None, # max_budget is used to Budget Per user + budget_duration: Optional[str] = None, # max_budget is used to Budget Per user + token: Optional[str] = None, + user_id: Optional[str] = None, + team_id: Optional[str] = None, + user_email: Optional[str] = None, + user_role: Optional[str] = None, + max_parallel_requests: Optional[int] = None, + metadata: Optional[dict] = {}, + tpm_limit: Optional[int] = None, + rpm_limit: Optional[int] = None, + query_type: Literal["insert_data", "update_data"] = "insert_data", + update_key_values: Optional[dict] = None, + key_alias: Optional[str] = None, + allowed_cache_controls: Optional[list] = [], + permissions: Optional[dict] = {}, + model_max_budget: Optional[dict] = {}, + teams: Optional[list] = None, + organization_id: Optional[str] = None, + table_name: Optional[Literal["key", "user"]] = None, + send_invite_email: Optional[bool] = None, +): + from litellm.proxy.proxy_server import ( + prisma_client, + custom_db_client, + litellm_proxy_budget_name, + premium_user, + ) + + if prisma_client is None and custom_db_client is None: + raise Exception( + f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " + ) + + if token is None: + token = f"sk-{secrets.token_urlsafe(16)}" + + if duration is None: # allow tokens that never expire + expires = None + else: + duration_s = _duration_in_seconds(duration=duration) + expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + if key_budget_duration is None: # one-time budget + key_reset_at = None + else: + duration_s = _duration_in_seconds(duration=key_budget_duration) + key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + if budget_duration is None: # one-time budget + reset_at = None + else: + duration_s = _duration_in_seconds(duration=budget_duration) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + aliases_json = json.dumps(aliases) + config_json = json.dumps(config) + permissions_json = json.dumps(permissions) + metadata_json = json.dumps(metadata) + model_max_budget_json = json.dumps(model_max_budget) + user_role = user_role + tpm_limit = tpm_limit + rpm_limit = rpm_limit + allowed_cache_controls = allowed_cache_controls + + try: + # Create a new verification token (you may want to enhance this logic based on your needs) + user_data = { + "max_budget": max_budget, + "user_email": user_email, + "user_id": user_id, + "team_id": team_id, + "organization_id": organization_id, + "user_role": user_role, + "spend": spend, + "models": models, + "max_parallel_requests": max_parallel_requests, + "tpm_limit": tpm_limit, + "rpm_limit": rpm_limit, + "budget_duration": budget_duration, + "budget_reset_at": reset_at, + "allowed_cache_controls": allowed_cache_controls, + } + if teams is not None: + user_data["teams"] = teams + key_data = { + "token": token, + "key_alias": key_alias, + "expires": expires, + "models": models, + "aliases": aliases_json, + "config": config_json, + "spend": spend, + "max_budget": key_max_budget, + "user_id": user_id, + "team_id": team_id, + "max_parallel_requests": max_parallel_requests, + "metadata": metadata_json, + "tpm_limit": tpm_limit, + "rpm_limit": rpm_limit, + "budget_duration": key_budget_duration, + "budget_reset_at": key_reset_at, + "allowed_cache_controls": allowed_cache_controls, + "permissions": permissions_json, + "model_max_budget": model_max_budget_json, + "budget_id": budget_id, + } + + if ( + litellm.get_secret("DISABLE_KEY_NAME", False) == True + ): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much) + pass + else: + key_data["key_name"] = f"sk-...{token[-4:]}" + saved_token = copy.deepcopy(key_data) + if isinstance(saved_token["aliases"], str): + saved_token["aliases"] = json.loads(saved_token["aliases"]) + if isinstance(saved_token["config"], str): + saved_token["config"] = json.loads(saved_token["config"]) + if isinstance(saved_token["metadata"], str): + saved_token["metadata"] = json.loads(saved_token["metadata"]) + if isinstance(saved_token["permissions"], str): + if ( + "get_spend_routes" in saved_token["permissions"] + and premium_user != True + ): + raise ValueError( + "get_spend_routes permission is only available for LiteLLM Enterprise users" + ) + + saved_token["permissions"] = json.loads(saved_token["permissions"]) + if isinstance(saved_token["model_max_budget"], str): + saved_token["model_max_budget"] = json.loads( + saved_token["model_max_budget"] + ) + + if saved_token.get("expires", None) is not None and isinstance( + saved_token["expires"], datetime + ): + saved_token["expires"] = saved_token["expires"].isoformat() + if prisma_client is not None: + if ( + table_name is None or table_name == "user" + ): # do not auto-create users for `/key/generate` + ## CREATE USER (If necessary) + if query_type == "insert_data": + user_row = await prisma_client.insert_data( + data=user_data, table_name="user" + ) + ## use default user model list if no key-specific model list provided + if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore + key_data["models"] = user_row.models + elif query_type == "update_data": + user_row = await prisma_client.update_data( + data=user_data, + table_name="user", + update_key_values=update_key_values, + ) + if user_id == litellm_proxy_budget_name or ( + table_name is not None and table_name == "user" + ): + # do not create a key for litellm_proxy_budget_name or if table name is set to just 'user' + # we only need to ensure this exists in the user table + # the LiteLLM_VerificationToken table will increase in size if we don't do this check + return user_data + + ## CREATE KEY + verbose_proxy_logger.debug("prisma_client: Creating Key= %s", key_data) + create_key_response = await prisma_client.insert_data( + data=key_data, table_name="key" + ) + key_data["token_id"] = getattr(create_key_response, "token", None) + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.generate_key_helper_fn(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + if isinstance(e, HTTPException): + raise e + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": "Internal Server Error."}, + ) + + # Add budget related info in key_data - this ensures it's returned + key_data["budget_id"] = budget_id + + if request_type == "user": + # if this is a /user/new request update the key_date with user_data fields + key_data.update(user_data) + return key_data + + +async def delete_verification_token(tokens: List, user_id: Optional[str] = None): + from litellm.proxy.proxy_server import prisma_client, litellm_proxy_admin_name + + try: + if prisma_client: + # Assuming 'db' is your Prisma Client instance + # check if admin making request - don't filter by user-id + if user_id == litellm_proxy_admin_name: + deleted_tokens = await prisma_client.delete_data(tokens=tokens) + # else + else: + deleted_tokens = await prisma_client.delete_data( + tokens=tokens, user_id=user_id + ) + _num_deleted_tokens = deleted_tokens.get("deleted_keys", 0) + if _num_deleted_tokens != len(tokens): + raise Exception( + "Failed to delete all tokens. Tried to delete tokens that don't belong to user: " + + str(user_id) + ) + else: + raise Exception("DB not connected. prisma_client is None") + except Exception as e: + verbose_proxy_logger.error( + "litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format( + str(e) + ) + ) + verbose_proxy_logger.debug(traceback.format_exc()) + raise e + return deleted_tokens diff --git a/litellm/proxy/management_endpoints/team_endpoints.py b/litellm/proxy/management_endpoints/team_endpoints.py new file mode 100644 index 000000000..50244ee23 --- /dev/null +++ b/litellm/proxy/management_endpoints/team_endpoints.py @@ -0,0 +1,899 @@ +from typing import Optional, List +import fastapi +from fastapi import Depends, Request, APIRouter, Header, status +from fastapi import HTTPException +import copy +import json +import uuid +import litellm +import asyncio +from datetime import datetime, timedelta, timezone +from litellm._logging import verbose_proxy_logger +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +from litellm.proxy._types import ( + UserAPIKeyAuth, + LiteLLM_TeamTable, + LiteLLM_ModelTable, + LitellmUserRoles, + NewTeamRequest, + TeamMemberAddRequest, + UpdateTeamRequest, + BlockTeamRequest, + DeleteTeamRequest, + Member, + LitellmTableNames, + LiteLLM_AuditLogs, + TeamMemberDeleteRequest, + ProxyException, + CommonProxyErrors, +) +from litellm.proxy.management_helpers.utils import ( + add_new_member, + management_endpoint_wrapper, +) + +router = APIRouter() + + +#### TEAM MANAGEMENT #### +@router.post( + "/team/new", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], + response_model=LiteLLM_TeamTable, +) +@management_endpoint_wrapper +async def new_team( + data: NewTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Allow users to create a new team. Apply user permissions to their team. + + šŸ‘‰ [Detailed Doc on setting team budgets](https://docs.litellm.ai/docs/proxy/team_budgets) + + + Parameters: + - team_alias: Optional[str] - User defined team alias + - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. + - members_with_roles: List[{"role": "admin" or "user", "user_id": ""}] - A list of users and their roles in the team. Get user_id when making a new user via `/user/new`. + - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"extra_info": "some info"} + - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit + - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit + - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget + - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) + - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. + - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. + + Returns: + - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. + + _deprecated_params: + - admins: list - A list of user_id's for the admin role + - users: list - A list of user_id's for the user role + + Example Request: + ``` + curl --location 'http://0.0.0.0:4000/team/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "team_alias": "my-new-team_2", + "members_with_roles": [{"role": "admin", "user_id": "user-1234"}, + {"role": "user", "user_id": "user-2434"}] + }' + + ``` + + ``` + curl --location 'http://0.0.0.0:4000/team/new' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data '{ + "team_alias": "QA Prod Bot", + "max_budget": 0.000000001, + "budget_duration": "1d" + }' + + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + data.team_id = str(uuid.uuid4()) + else: + # Check if team_id exists already + _existing_team_id = await prisma_client.get_data( + team_id=data.team_id, table_name="team", query_type="find_unique" + ) + if _existing_team_id is not None: + raise HTTPException( + status_code=400, + detail={ + "error": f"Team id = {data.team_id} already exists. Please use a different team id." + }, + ) + + if ( + user_api_key_dict.user_role is None + or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN + ): # don't restrict proxy admin + if ( + data.tpm_limit is not None + and user_api_key_dict.tpm_limit is not None + and data.tpm_limit > user_api_key_dict.tpm_limit + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}" + }, + ) + + if ( + data.rpm_limit is not None + and user_api_key_dict.rpm_limit is not None + and data.rpm_limit > user_api_key_dict.rpm_limit + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}" + }, + ) + + if ( + data.max_budget is not None + and user_api_key_dict.max_budget is not None + and data.max_budget > user_api_key_dict.max_budget + ): + raise HTTPException( + status_code=400, + detail={ + "error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}" + }, + ) + + if data.models is not None and len(user_api_key_dict.models) > 0: + for m in data.models: + if m not in user_api_key_dict.models: + raise HTTPException( + status_code=400, + detail={ + "error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}" + }, + ) + + if user_api_key_dict.user_id is not None: + creating_user_in_list = False + for member in data.members_with_roles: + if member.user_id == user_api_key_dict.user_id: + creating_user_in_list = True + + if creating_user_in_list == False: + data.members_with_roles.append( + Member(role="admin", user_id=user_api_key_dict.user_id) + ) + + ## ADD TO MODEL TABLE + _model_id = None + if data.model_aliases is not None and isinstance(data.model_aliases, dict): + litellm_modeltable = LiteLLM_ModelTable( + model_aliases=json.dumps(data.model_aliases), + created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, + ) + model_dict = await prisma_client.db.litellm_modeltable.create( + {**litellm_modeltable.json(exclude_none=True)} # type: ignore + ) # type: ignore + + _model_id = model_dict.id + + ## ADD TO TEAM TABLE + complete_team_data = LiteLLM_TeamTable( + **data.json(), + model_id=_model_id, + ) + + # If budget_duration is set, set `budget_reset_at` + if complete_team_data.budget_duration is not None: + duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + complete_team_data.budget_reset_at = reset_at + + team_row = await prisma_client.insert_data( + data=complete_team_data.json(exclude_none=True), table_name="team" + ) + + ## ADD TEAM ID TO USER TABLE ## + for user in complete_team_data.members_with_roles: + ## add team id to user row ## + await prisma_client.update_data( + user_id=user.user_id, + data={"user_id": user.user_id, "teams": [team_row.team_id]}, + update_key_values_custom_query={ + "teams": { + "push ": [team_row.team_id], + } + }, + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _updated_values = complete_team_data.json(exclude_none=True) + + _updated_values = json.dumps(_updated_values, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=data.team_id, + action="created", + updated_values=_updated_values, + before_value=None, + ) + ) + ) + + try: + return team_row.model_dump() + except Exception as e: + return team_row.dict() + + +@router.post( + "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def update_team( + data: UpdateTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + Use `/team/member_add` AND `/team/member/delete` to add/remove new team members + + You can now update team budget / rate limits via /team/update + + Parameters: + - team_id: str - The team id of the user. Required param. + - team_alias: Optional[str] - User defined team alias + - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } + - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit + - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit + - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget + - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) + - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. + - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. + + Example - update team TPM Limit + + ``` + curl --location 'http://0.0.0.0:8000/team/update' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data-raw '{ + "team_id": "litellm-test-client-id-new", + "tpm_limit": 100 + }' + ``` + + Example - Update Team `max_budget` budget + ``` + curl --location 'http://0.0.0.0:8000/team/update' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data-raw '{ + "team_id": "litellm-test-client-id-new", + "max_budget": 10 + }' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + verbose_proxy_logger.debug("/team/update - %s", data) + + existing_team_row = await prisma_client.get_data( + team_id=data.team_id, table_name="team", query_type="find_unique" + ) + if existing_team_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={data.team_id}"}, + ) + + updated_kv = data.json(exclude_none=True) + + # Check budget_duration and budget_reset_at + if data.budget_duration is not None: + duration_s = _duration_in_seconds(duration=data.budget_duration) + reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) + + # set the budget_reset_at in DB + updated_kv["budget_reset_at"] = reset_at + + team_row = await prisma_client.update_data( + update_key_values=updated_kv, + data=updated_kv, + table_name="team", + team_id=data.team_id, + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + if litellm.store_audit_logs is True: + _before_value = existing_team_row.json(exclude_none=True) + _before_value = json.dumps(_before_value, default=str) + _after_value: str = json.dumps(updated_kv, default=str) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=data.team_id, + action="updated", + updated_values=_after_value, + before_value=_before_value, + ) + ) + ) + + return team_row + + +@router.post( + "/team/member_add", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def team_member_add( + data: TeamMemberAddRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [BETA] + + Add new members (either via user_email or user_id) to a team + + If user doesn't exist, new user row will also be added to User Table + + ``` + + curl -X POST 'http://0.0.0.0:4000/team/member_add' \ + -H 'Authorization: Bearer sk-1234' \ + -H 'Content-Type: application/json' \ + -d '{"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", "member": {"role": "user", "user_id": "krrish247652@berri.ai"}}' + + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + if data.member is None: + raise HTTPException( + status_code=400, detail={"error": "No member/members passed in"} + ) + + existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + if existing_team_row is None: + raise HTTPException( + status_code=404, + detail={ + "error": f"Team not found for team_id={getattr(data, 'team_id', None)}" + }, + ) + + complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) + + if isinstance(data.member, Member): + # add to team db + new_member = data.member + + complete_team_data.members_with_roles.append(new_member) + + elif isinstance(data.member, List): + # add to team db + new_members = data.member + + complete_team_data.members_with_roles.extend(new_members) + + # ADD MEMBER TO TEAM + _db_team_members = [m.model_dump() for m in complete_team_data.members_with_roles] + updated_team = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, + data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore + ) + + if isinstance(data.member, Member): + await add_new_member( + new_member=data.member, + max_budget_in_team=data.max_budget_in_team, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + team_id=data.team_id, + ) + elif isinstance(data.member, List): + tasks: List = [] + for m in data.member: + await add_new_member( + new_member=m, + max_budget_in_team=data.max_budget_in_team, + prisma_client=prisma_client, + user_api_key_dict=user_api_key_dict, + litellm_proxy_admin_name=litellm_proxy_admin_name, + team_id=data.team_id, + ) + await asyncio.gather(*tasks) + + return updated_team + + +@router.post( + "/team/member_delete", + tags=["team management"], + dependencies=[Depends(user_api_key_auth)], +) +@management_endpoint_wrapper +async def team_member_delete( + data: TeamMemberDeleteRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [BETA] + + delete members (either via user_email or user_id) from a team + + If user doesn't exist, an exception will be raised + ``` + curl -X POST 'http://0.0.0.0:8000/team/update' \ + + -H 'Authorization: Bearer sk-1234' \ + + -H 'Content-Type: application/json' \ + + -D '{ + "team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", + "user_id": "krrish247652@berri.ai" + }' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_id is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + if data.user_id is None and data.user_email is None: + raise HTTPException( + status_code=400, + detail={"error": "Either user_id or user_email needs to be passed in"}, + ) + + _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( + where={"team_id": data.team_id} + ) + + if _existing_team_row is None: + raise HTTPException( + status_code=400, + detail={"error": "Team id={} does not exist in db".format(data.team_id)}, + ) + existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) + + ## DELETE MEMBER FROM TEAM + new_team_members: List[Member] = [] + for m in existing_team_row.members_with_roles: + if ( + data.user_id is not None + and m.user_id is not None + and data.user_id == m.user_id + ): + continue + elif ( + data.user_email is not None + and m.user_email is not None + and data.user_email == m.user_email + ): + continue + new_team_members.append(m) + existing_team_row.members_with_roles = new_team_members + + _db_new_team_members: List[dict] = [m.model_dump() for m in new_team_members] + + _ = await prisma_client.db.litellm_teamtable.update( + where={ + "team_id": data.team_id, + }, + data={"members_with_roles": json.dumps(_db_new_team_members)}, # type: ignore + ) + + ## DELETE TEAM ID from USER ROW, IF EXISTS ## + # get user row + key_val = {} + if data.user_id is not None: + key_val["user_id"] = data.user_id + elif data.user_email is not None: + key_val["user_email"] = data.user_email + existing_user_rows = await prisma_client.db.litellm_usertable.find_many( + where=key_val # type: ignore + ) + + if existing_user_rows is not None and ( + isinstance(existing_user_rows, list) and len(existing_user_rows) > 0 + ): + for existing_user in existing_user_rows: + team_list = [] + if data.team_id in existing_user.teams: + team_list = existing_user.teams + team_list.remove(data.team_id) + await prisma_client.db.litellm_usertable.update( + where={ + "user_id": existing_user.user_id, + }, + data={"teams": {"set": team_list}}, + ) + + return existing_team_row + + +@router.post( + "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def delete_team( + data: DeleteTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), + litellm_changed_by: Optional[str] = Header( + None, + description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", + ), +): + """ + delete team and associated team keys + + ``` + curl --location 'http://0.0.0.0:8000/team/delete' \ + + --header 'Authorization: Bearer sk-1234' \ + + --header 'Content-Type: application/json' \ + + --data-raw '{ + "team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"] + }' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + if data.team_ids is None: + raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) + + # check that all teams passed exist + for team_id in data.team_ids: + team_row = await prisma_client.get_data( # type: ignore + team_id=team_id, table_name="team", query_type="find_unique" + ) + if team_row is None: + raise HTTPException( + status_code=404, + detail={"error": f"Team not found, passed team_id={team_id}"}, + ) + + # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True + # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes + if litellm.store_audit_logs is True: + # make an audit log for each team deleted + for team_id in data.team_ids: + team_row = await prisma_client.get_data( # type: ignore + team_id=team_id, table_name="team", query_type="find_unique" + ) + + _team_row = team_row.json(exclude_none=True) + + asyncio.create_task( + create_audit_log_for_update( + request_data=LiteLLM_AuditLogs( + id=str(uuid.uuid4()), + updated_at=datetime.now(timezone.utc), + changed_by=litellm_changed_by + or user_api_key_dict.user_id + or litellm_proxy_admin_name, + changed_by_api_key=user_api_key_dict.api_key, + table_name=LitellmTableNames.TEAM_TABLE_NAME, + object_id=team_id, + action="deleted", + updated_values="{}", + before_value=_team_row, + ) + ) + ) + + # End of Audit logging + + ## DELETE ASSOCIATED KEYS + await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key") + ## DELETE TEAMS + deleted_teams = await prisma_client.delete_data( + team_id_list=data.team_ids, table_name="team" + ) + return deleted_teams + + +@router.get( + "/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def team_info( + http_request: Request, + team_id: str = fastapi.Query( + default=None, description="Team ID in the request parameters" + ), +): + """ + get info on team + related keys + + ``` + curl --location 'http://localhost:4000/team/info' \ + --header 'Authorization: Bearer sk-1234' \ + --header 'Content-Type: application/json' \ + --data '{ + "teams": ["",..] + }' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + try: + if prisma_client is None: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={ + "error": f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + }, + ) + if team_id is None: + raise HTTPException( + status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, + detail={"message": "Malformed request. No team id passed in."}, + ) + + team_info = await prisma_client.get_data( + team_id=team_id, table_name="team", query_type="find_unique" + ) + if team_info is None: + raise HTTPException( + status_code=status.HTTP_404_NOT_FOUND, + detail={"message": f"Team not found, passed team id: {team_id}."}, + ) + + ## GET ALL KEYS ## + keys = await prisma_client.get_data( + team_id=team_id, + table_name="key", + query_type="find_all", + expires=datetime.now(), + ) + + if team_info is None: + ## make sure we still return a total spend ## + spend = 0 + for k in keys: + spend += getattr(k, "spend", 0) + team_info = {"spend": spend} + + ## REMOVE HASHED TOKEN INFO before returning ## + for key in keys: + try: + key = key.model_dump() # noqa + except: + # if using pydantic v1 + key = key.dict() + key.pop("token", None) + return {"team_id": team_id, "team_info": team_info, "keys": keys} + + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"Authentication Error({str(e)})"), + type="auth_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="Authentication Error, " + str(e), + type="auth_error", + param=getattr(e, "param", "None"), + code=status.HTTP_400_BAD_REQUEST, + ) + + +@router.post( + "/team/block", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def block_team( + data: BlockTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Blocks all calls from keys with this team id. + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise Exception("No DB Connected.") + + record = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"blocked": True} # type: ignore + ) + + return record + + +@router.post( + "/team/unblock", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def unblock_team( + data: BlockTeamRequest, + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + Blocks all calls from keys with this team id. + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if prisma_client is None: + raise Exception("No DB Connected.") + + record = await prisma_client.db.litellm_teamtable.update( + where={"team_id": data.team_id}, data={"blocked": False} # type: ignore + ) + + return record + + +@router.get( + "/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)] +) +@management_endpoint_wrapper +async def list_team( + http_request: Request, + user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), +): + """ + [Admin-only] List all available teams + + ``` + curl --location --request GET 'http://0.0.0.0:4000/team/list' \ + --header 'Authorization: Bearer sk-1234' + ``` + """ + from litellm.proxy.proxy_server import ( + prisma_client, + litellm_proxy_admin_name, + create_audit_log_for_update, + _duration_in_seconds, + ) + + if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: + raise HTTPException( + status_code=401, + detail={ + "error": "Admin-only endpoint. Your user role={}".format( + user_api_key_dict.user_role + ) + }, + ) + + if prisma_client is None: + raise HTTPException( + status_code=400, + detail={"error": CommonProxyErrors.db_not_connected_error.value}, + ) + + response = await prisma_client.db.litellm_teamtable.find_many() + + return response diff --git a/litellm/proxy/management_helpers/utils.py b/litellm/proxy/management_helpers/utils.py index 6c035d3ef..1cf22480b 100644 --- a/litellm/proxy/management_helpers/utils.py +++ b/litellm/proxy/management_helpers/utils.py @@ -1,5 +1,11 @@ # What is this? ## Helper utils for the management endpoints (keys/users/teams) +from datetime import datetime +from functools import wraps +from litellm.proxy._types import UserAPIKeyAuth, ManagementEndpointLoggingPayload +from litellm.proxy.common_utils.http_parsing_utils import _read_request_body +from litellm._logging import verbose_logger +from fastapi import Request from litellm.proxy._types import LiteLLM_TeamTable, Member, UserAPIKeyAuth from litellm.proxy.utils import PrismaClient @@ -61,3 +67,110 @@ async def add_new_member( "budget_id": _budget_id, } ) + + +def management_endpoint_wrapper(func): + """ + This wrapper does the following: + + 1. Log I/O, Exceptions to OTEL + 2. Create an Audit log for success calls + """ + + @wraps(func) + async def wrapper(*args, **kwargs): + start_time = datetime.now() + + try: + result = await func(*args, **kwargs) + end_time = datetime.now() + try: + if kwargs is None: + kwargs = {} + user_api_key_dict: UserAPIKeyAuth = ( + kwargs.get("user_api_key_dict") or UserAPIKeyAuth() + ) + _http_request: Request = kwargs.get("http_request") + parent_otel_span = user_api_key_dict.parent_otel_span + if parent_otel_span is not None: + from litellm.proxy.proxy_server import open_telemetry_logger + + if open_telemetry_logger is not None: + if _http_request: + _route = _http_request.url.path + _request_body: dict = await _read_request_body( + request=_http_request + ) + _response = dict(result) if result is not None else None + + logging_payload = ManagementEndpointLoggingPayload( + route=_route, + request_data=_request_body, + response=_response, + start_time=start_time, + end_time=end_time, + ) + + await open_telemetry_logger.async_management_endpoint_success_hook( + logging_payload=logging_payload, + parent_otel_span=parent_otel_span, + ) + + if _http_request: + _route = _http_request.url.path + # Flush user_api_key cache if this was an update/delete call to /key, /team, or /user + if _route in [ + "/key/update", + "/key/delete", + "/team/update", + "/team/delete", + "/user/update", + "/user/delete", + "/customer/update", + "/customer/delete", + ]: + from litellm.proxy.proxy_server import user_api_key_cache + + user_api_key_cache.flush_cache() + except Exception as e: + # Non-Blocking Exception + verbose_logger.debug("Error in management endpoint wrapper: %s", str(e)) + pass + + return result + except Exception as e: + end_time = datetime.now() + + if kwargs is None: + kwargs = {} + user_api_key_dict: UserAPIKeyAuth = ( + kwargs.get("user_api_key_dict") or UserAPIKeyAuth() + ) + parent_otel_span = user_api_key_dict.parent_otel_span + if parent_otel_span is not None: + from litellm.proxy.proxy_server import open_telemetry_logger + + if open_telemetry_logger is not None: + _http_request: Request = kwargs.get("http_request") + if _http_request: + _route = _http_request.url.path + _request_body: dict = await _read_request_body( + request=_http_request + ) + logging_payload = ManagementEndpointLoggingPayload( + route=_route, + request_data=_request_body, + response=None, + start_time=start_time, + end_time=end_time, + exception=e, + ) + + await open_telemetry_logger.async_management_endpoint_failure_hook( + logging_payload=logging_payload, + parent_otel_span=parent_otel_span, + ) + + raise e + + return wrapper diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 844daff11..64a2cbb2f 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -90,7 +90,6 @@ from litellm.types.llms.openai import ( HttpxBinaryResponseContent, ) from litellm.proxy.litellm_pre_call_utils import add_litellm_data_to_request -from litellm.proxy.management_helpers.utils import add_new_member from litellm.proxy.utils import ( PrismaClient, DBClient, @@ -162,14 +161,25 @@ from litellm.proxy.auth.auth_checks import ( get_actual_routes, log_to_opentelemetry, ) -from litellm.proxy.common_utils.management_endpoint_utils import ( - management_endpoint_wrapper, -) from litellm.llms.custom_httpx.httpx_handler import HTTPHandler from litellm.exceptions import RejectedRequestError from litellm.integrations.slack_alerting import SlackAlertingArgs, SlackAlerting from litellm.scheduler import Scheduler, FlowItem, DefaultPriorities +## Import All Misc routes here ## +from litellm.proxy.caching_routes import router as caching_router +from litellm.proxy.management_endpoints.team_endpoints import router as team_router +from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import ( + router as spend_management_router, +) +from litellm.proxy.management_endpoints.key_management_endpoints import ( + router as key_management_router, + _duration_in_seconds, + generate_key_helper_fn, + delete_verification_token, +) +from litellm.proxy.health_endpoints._health_endpoints import router as health_router + try: from litellm._version import version except: @@ -277,13 +287,6 @@ class UserAPIKeyCacheTTLEnum(enum.Enum): in_memory_cache_ttl = 60 # 1 min ttl ## configure via `general_settings::user_api_key_cache_ttl: ` -class CommonProxyErrors(enum.Enum): - db_not_connected_error = "DB not connected" - no_llm_router = "No models configured on proxy" - not_allowed_access = "Admin-only endpoint. Not allowed to access this." - not_premium_user = "You must be a LiteLLM Enterprise user to use this feature. If you have a license please set `LITELLM_LICENSE` in your env. If you want to obtain a license meet with us here: https://calendly.com/d/4mp-gd3-k5k/litellm-1-1-onboarding-chat" - - @app.exception_handler(ProxyException) async def openai_exception_handler(request: Request, exc: ProxyException): # NOTE: DO NOT MODIFY THIS, its crucial to map to Openai exceptions @@ -2260,263 +2263,6 @@ class ProxyConfig: proxy_config = ProxyConfig() -def _duration_in_seconds(duration: str): - match = re.match(r"(\d+)([smhd]?)", duration) - if not match: - raise ValueError("Invalid duration format") - - value, unit = match.groups() - value = int(value) - - if unit == "s": - return value - elif unit == "m": - return value * 60 - elif unit == "h": - return value * 3600 - elif unit == "d": - return value * 86400 - else: - raise ValueError("Unsupported duration unit") - - -async def generate_key_helper_fn( - request_type: Literal[ - "user", "key" - ], # identifies if this request is from /user/new or /key/generate - duration: Optional[str], - models: list, - aliases: dict, - config: dict, - spend: float, - key_max_budget: Optional[float] = None, # key_max_budget is used to Budget Per key - key_budget_duration: Optional[str] = None, - budget_id: Optional[float] = None, # budget id <-> LiteLLM_BudgetTable - soft_budget: Optional[ - float - ] = None, # soft_budget is used to set soft Budgets Per user - max_budget: Optional[float] = None, # max_budget is used to Budget Per user - budget_duration: Optional[str] = None, # max_budget is used to Budget Per user - token: Optional[str] = None, - user_id: Optional[str] = None, - team_id: Optional[str] = None, - user_email: Optional[str] = None, - user_role: Optional[str] = None, - max_parallel_requests: Optional[int] = None, - metadata: Optional[dict] = {}, - tpm_limit: Optional[int] = None, - rpm_limit: Optional[int] = None, - query_type: Literal["insert_data", "update_data"] = "insert_data", - update_key_values: Optional[dict] = None, - key_alias: Optional[str] = None, - allowed_cache_controls: Optional[list] = [], - permissions: Optional[dict] = {}, - model_max_budget: Optional[dict] = {}, - teams: Optional[list] = None, - organization_id: Optional[str] = None, - table_name: Optional[Literal["key", "user"]] = None, - send_invite_email: Optional[bool] = None, -): - global prisma_client, custom_db_client, user_api_key_cache, litellm_proxy_admin_name, premium_user - - if prisma_client is None and custom_db_client is None: - raise Exception( - f"Connect Proxy to database to generate keys - https://docs.litellm.ai/docs/proxy/virtual_keys " - ) - - if token is None: - token = f"sk-{secrets.token_urlsafe(16)}" - - if duration is None: # allow tokens that never expire - expires = None - else: - duration_s = _duration_in_seconds(duration=duration) - expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - - if key_budget_duration is None: # one-time budget - key_reset_at = None - else: - duration_s = _duration_in_seconds(duration=key_budget_duration) - key_reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - - if budget_duration is None: # one-time budget - reset_at = None - else: - duration_s = _duration_in_seconds(duration=budget_duration) - reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - - aliases_json = json.dumps(aliases) - config_json = json.dumps(config) - permissions_json = json.dumps(permissions) - metadata_json = json.dumps(metadata) - model_max_budget_json = json.dumps(model_max_budget) - user_role = user_role - tpm_limit = tpm_limit - rpm_limit = rpm_limit - allowed_cache_controls = allowed_cache_controls - - try: - # Create a new verification token (you may want to enhance this logic based on your needs) - user_data = { - "max_budget": max_budget, - "user_email": user_email, - "user_id": user_id, - "team_id": team_id, - "organization_id": organization_id, - "user_role": user_role, - "spend": spend, - "models": models, - "max_parallel_requests": max_parallel_requests, - "tpm_limit": tpm_limit, - "rpm_limit": rpm_limit, - "budget_duration": budget_duration, - "budget_reset_at": reset_at, - "allowed_cache_controls": allowed_cache_controls, - } - if teams is not None: - user_data["teams"] = teams - key_data = { - "token": token, - "key_alias": key_alias, - "expires": expires, - "models": models, - "aliases": aliases_json, - "config": config_json, - "spend": spend, - "max_budget": key_max_budget, - "user_id": user_id, - "team_id": team_id, - "max_parallel_requests": max_parallel_requests, - "metadata": metadata_json, - "tpm_limit": tpm_limit, - "rpm_limit": rpm_limit, - "budget_duration": key_budget_duration, - "budget_reset_at": key_reset_at, - "allowed_cache_controls": allowed_cache_controls, - "permissions": permissions_json, - "model_max_budget": model_max_budget_json, - "budget_id": budget_id, - } - - if ( - litellm.get_secret("DISABLE_KEY_NAME", False) == True - ): # allow user to disable storing abbreviated key name (shown in UI, to help figure out which key spent how much) - pass - else: - key_data["key_name"] = f"sk-...{token[-4:]}" - saved_token = copy.deepcopy(key_data) - if isinstance(saved_token["aliases"], str): - saved_token["aliases"] = json.loads(saved_token["aliases"]) - if isinstance(saved_token["config"], str): - saved_token["config"] = json.loads(saved_token["config"]) - if isinstance(saved_token["metadata"], str): - saved_token["metadata"] = json.loads(saved_token["metadata"]) - if isinstance(saved_token["permissions"], str): - if ( - "get_spend_routes" in saved_token["permissions"] - and premium_user != True - ): - raise ValueError( - "get_spend_routes permission is only available for LiteLLM Enterprise users" - ) - - saved_token["permissions"] = json.loads(saved_token["permissions"]) - if isinstance(saved_token["model_max_budget"], str): - saved_token["model_max_budget"] = json.loads( - saved_token["model_max_budget"] - ) - - if saved_token.get("expires", None) is not None and isinstance( - saved_token["expires"], datetime - ): - saved_token["expires"] = saved_token["expires"].isoformat() - if prisma_client is not None: - if ( - table_name is None or table_name == "user" - ): # do not auto-create users for `/key/generate` - ## CREATE USER (If necessary) - if query_type == "insert_data": - user_row = await prisma_client.insert_data( - data=user_data, table_name="user" - ) - ## use default user model list if no key-specific model list provided - if len(user_row.models) > 0 and len(key_data["models"]) == 0: # type: ignore - key_data["models"] = user_row.models - elif query_type == "update_data": - user_row = await prisma_client.update_data( - data=user_data, - table_name="user", - update_key_values=update_key_values, - ) - if user_id == litellm_proxy_budget_name or ( - table_name is not None and table_name == "user" - ): - # do not create a key for litellm_proxy_budget_name or if table name is set to just 'user' - # we only need to ensure this exists in the user table - # the LiteLLM_VerificationToken table will increase in size if we don't do this check - return user_data - - ## CREATE KEY - verbose_proxy_logger.debug("prisma_client: Creating Key= %s", key_data) - create_key_response = await prisma_client.insert_data( - data=key_data, table_name="key" - ) - key_data["token_id"] = getattr(create_key_response, "token", None) - except Exception as e: - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.generate_key_helper_fn(): Exception occured - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - raise e - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={"error": "Internal Server Error."}, - ) - - # Add budget related info in key_data - this ensures it's returned - key_data["budget_id"] = budget_id - - if request_type == "user": - # if this is a /user/new request update the key_date with user_data fields - key_data.update(user_data) - return key_data - - -async def delete_verification_token(tokens: List, user_id: Optional[str] = None): - global prisma_client - try: - if prisma_client: - # Assuming 'db' is your Prisma Client instance - # check if admin making request - don't filter by user-id - if user_id == litellm_proxy_admin_name: - deleted_tokens = await prisma_client.delete_data(tokens=tokens) - # else - else: - deleted_tokens = await prisma_client.delete_data( - tokens=tokens, user_id=user_id - ) - _num_deleted_tokens = deleted_tokens.get("deleted_keys", 0) - if _num_deleted_tokens != len(tokens): - raise Exception( - "Failed to delete all tokens. Tried to delete tokens that don't belong to user: " - + str(user_id) - ) - else: - raise Exception("DB not connected. prisma_client is None") - except Exception as e: - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.delete_verification_token(): Exception occured - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - raise e - return deleted_tokens - - def save_worker_config(**data): import json @@ -4761,231 +4507,6 @@ async def run_thread( ) -###################################################################### - -# /v1/batches Endpoints - - -###################################################################### -@router.post( - "/v1/batches", - dependencies=[Depends(user_api_key_auth)], - tags=["batch"], -) -@router.post( - "/batches", - dependencies=[Depends(user_api_key_auth)], - tags=["batch"], -) -async def create_batch( - request: Request, - fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Create large batches of API requests for asynchronous processing. - This is the equivalent of POST https://api.openai.com/v1/batch - Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch - - Example Curl - ``` - curl http://localhost:4000/v1/batches \ - -H "Authorization: Bearer sk-1234" \ - -H "Content-Type: application/json" \ - -d '{ - "input_file_id": "file-abc123", - "endpoint": "/v1/chat/completions", - "completion_window": "24h" - }' - ``` - """ - global proxy_logging_obj - data: Dict = {} - try: - # Use orjson to parse JSON data, orjson speeds up requests significantly - form_data = await request.form() - data = {key: value for key, value in form_data.items() if key != "file"} - - # Include original request and headers in the data - data = await add_litellm_data_to_request( - data=data, - request=request, - general_settings=general_settings, - user_api_key_dict=user_api_key_dict, - version=version, - proxy_config=proxy_config, - ) - - _create_batch_data = CreateBatchRequest(**data) - - # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch - response = await litellm.acreate_batch( - custom_llm_provider="openai", **_create_batch_data - ) - - ### ALERTING ### - data["litellm_status"] = "success" # used for alerting - - ### RESPONSE HEADERS ### - hidden_params = getattr(response, "_hidden_params", {}) or {} - model_id = hidden_params.get("model_id", None) or "" - cache_key = hidden_params.get("cache_key", None) or "" - api_base = hidden_params.get("api_base", None) or "" - - fastapi_response.headers.update( - get_custom_headers( - user_api_key_dict=user_api_key_dict, - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - model_region=getattr(user_api_key_dict, "allowed_model_region", ""), - ) - ) - - return response - except Exception as e: - data["litellm_status"] = "fail" # used for alerting - await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data - ) - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.create_batch(): Exception occured - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "message", str(e.detail)), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - else: - error_msg = f"{str(e)}" - raise ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) - - -@router.get( - "/v1/batches{batch_id}", - dependencies=[Depends(user_api_key_auth)], - tags=["batch"], -) -@router.get( - "/batches{batch_id}", - dependencies=[Depends(user_api_key_auth)], - tags=["batch"], -) -async def retrieve_batch( - request: Request, - fastapi_response: Response, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - batch_id: str = Path( - title="Batch ID to retrieve", description="The ID of the batch to retrieve" - ), -): - """ - Retrieves a batch. - This is the equivalent of GET https://api.openai.com/v1/batches/{batch_id} - Supports Identical Params as: https://platform.openai.com/docs/api-reference/batch/retrieve - - Example Curl - ``` - curl http://localhost:4000/v1/batches/batch_abc123 \ - -H "Authorization: Bearer sk-1234" \ - -H "Content-Type: application/json" \ - - ``` - """ - global proxy_logging_obj - data: Dict = {} - try: - # Use orjson to parse JSON data, orjson speeds up requests significantly - form_data = await request.form() - data = {key: value for key, value in form_data.items() if key != "file"} - - # Include original request and headers in the data - data = await add_litellm_data_to_request( - data=data, - request=request, - general_settings=general_settings, - user_api_key_dict=user_api_key_dict, - version=version, - proxy_config=proxy_config, - ) - - _retrieve_batch_request = RetrieveBatchRequest( - batch_id=batch_id, - ) - - # for now use custom_llm_provider=="openai" -> this will change as LiteLLM adds more providers for acreate_batch - response = await litellm.aretrieve_batch( - custom_llm_provider="openai", **_retrieve_batch_request - ) - - ### ALERTING ### - data["litellm_status"] = "success" # used for alerting - - ### RESPONSE HEADERS ### - hidden_params = getattr(response, "_hidden_params", {}) or {} - model_id = hidden_params.get("model_id", None) or "" - cache_key = hidden_params.get("cache_key", None) or "" - api_base = hidden_params.get("api_base", None) or "" - - fastapi_response.headers.update( - get_custom_headers( - user_api_key_dict=user_api_key_dict, - model_id=model_id, - cache_key=cache_key, - api_base=api_base, - version=version, - model_region=getattr(user_api_key_dict, "allowed_model_region", ""), - ) - ) - - return response - except Exception as e: - data["litellm_status"] = "fail" # used for alerting - await proxy_logging_obj.post_call_failure_hook( - user_api_key_dict=user_api_key_dict, original_exception=e, request_data=data - ) - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.retrieve_batch(): Exception occured - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "message", str(e.detail)), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - else: - error_traceback = traceback.format_exc() - error_msg = f"{str(e)}" - raise ProxyException( - message=getattr(e, "message", error_msg), - type=getattr(e, "type", "None"), - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", 500), - ) - - -###################################################################### - -# END OF /v1/batches Endpoints Implementation - -###################################################################### - - ###################################################################### # /v1/files Endpoints @@ -5334,2404 +4855,6 @@ async def supported_openai_params(model: str): ) -#### KEY MANAGEMENT #### - - -@router.post( - "/key/generate", - tags=["key management"], - dependencies=[Depends(user_api_key_auth)], - response_model=GenerateKeyResponse, -) -async def generate_key_fn( - data: GenerateKeyRequest, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - litellm_changed_by: Optional[str] = Header( - None, - description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", - ), -): - """ - Generate an API key based on the provided data. - - Docs: https://docs.litellm.ai/docs/proxy/virtual_keys - - Parameters: - - duration: Optional[str] - Specify the length of time the token is valid for. You can set duration as seconds ("30s"), minutes ("30m"), hours ("30h"), days ("30d"). - - key_alias: Optional[str] - User defined key alias - - team_id: Optional[str] - The team id of the key - - user_id: Optional[str] - The user id of the key - - models: Optional[list] - Model_name's a user is allowed to call. (if empty, key is allowed to call all models) - - aliases: Optional[dict] - Any alias mappings, on top of anything in the config.yaml model list. - https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---upgradedowngrade-models - - config: Optional[dict] - any key-specific configs, overrides config in config.yaml - - spend: Optional[int] - Amount spent by key. Default is 0. Will be updated by proxy whenever key is used. https://docs.litellm.ai/docs/proxy/virtual_keys#managing-auth---tracking-spend - - send_invite_email: Optional[bool] - Whether to send an invite email to the user_id, with the generate key - - max_budget: Optional[float] - Specify max budget for a given key. - - max_parallel_requests: Optional[int] - Rate limit a user based on the number of parallel requests. Raises 429 error, if user's parallel requests > x. - - metadata: Optional[dict] - Metadata for key, store information for key. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - - permissions: Optional[dict] - key-specific permissions. Currently just used for turning off pii masking (if connected). Example - {"pii": false} - - model_max_budget: Optional[dict] - key-specific model budget in USD. Example - {"text-davinci-002": 0.5, "gpt-3.5-turbo": 0.5}. IF null or {} then no model specific budget. - - Examples: - - 1. Allow users to turn on/off pii masking - - ```bash - curl --location 'http://0.0.0.0:8000/key/generate' \ - --header 'Authorization: Bearer sk-1234' \ - --header 'Content-Type: application/json' \ - --data '{ - "permissions": {"allow_pii_controls": true} - }' - ``` - - Returns: - - key: (str) The generated api key - - expires: (datetime) Datetime object for when key expires. - - user_id: (str) Unique user id - used for tracking spend across multiple keys for same user id. - """ - try: - global user_custom_key_generate - verbose_proxy_logger.debug("entered /key/generate") - - if user_custom_key_generate is not None: - result = await user_custom_key_generate(data) - decision = result.get("decision", True) - message = result.get("message", "Authentication Failed - Custom Auth Rule") - if not decision: - raise HTTPException( - status_code=status.HTTP_403_FORBIDDEN, detail=message - ) - # check if user set default key/generate params on config.yaml - if litellm.default_key_generate_params is not None: - for elem in data: - key, value = elem - if value is None and key in [ - "max_budget", - "user_id", - "team_id", - "max_parallel_requests", - "tpm_limit", - "rpm_limit", - "budget_duration", - ]: - setattr( - data, key, litellm.default_key_generate_params.get(key, None) - ) - elif key == "models" and value == []: - setattr(data, key, litellm.default_key_generate_params.get(key, [])) - elif key == "metadata" and value == {}: - setattr(data, key, litellm.default_key_generate_params.get(key, {})) - - # check if user set default key/generate params on config.yaml - if litellm.upperbound_key_generate_params is not None: - for elem in data: - # if key in litellm.upperbound_key_generate_params, use the min of value and litellm.upperbound_key_generate_params[key] - key, value = elem - if ( - value is not None - and getattr(litellm.upperbound_key_generate_params, key, None) - is not None - ): - # if value is float/int - if key in [ - "max_budget", - "max_parallel_requests", - "tpm_limit", - "rpm_limit", - ]: - if value > getattr(litellm.upperbound_key_generate_params, key): - raise HTTPException( - status_code=400, - detail={ - "error": f"{key} is over max limit set in config - user_value={value}; max_value={getattr(litellm.upperbound_key_generate_params, key)}" - }, - ) - elif key == "budget_duration": - # budgets are in 1s, 1m, 1h, 1d, 1m (30s, 30m, 30h, 30d, 30m) - # compare the duration in seconds and max duration in seconds - upperbound_budget_duration = _duration_in_seconds( - duration=getattr( - litellm.upperbound_key_generate_params, key - ) - ) - user_set_budget_duration = _duration_in_seconds(duration=value) - if user_set_budget_duration > upperbound_budget_duration: - raise HTTPException( - status_code=400, - detail={ - "error": f"Budget duration is over max limit set in config - user_value={user_set_budget_duration}; max_value={upperbound_budget_duration}" - }, - ) - - # TODO: @ishaan-jaff: Migrate all budget tracking to use LiteLLM_BudgetTable - _budget_id = None - if prisma_client is not None and data.soft_budget is not None: - # create the Budget Row for the LiteLLM Verification Token - budget_row = LiteLLM_BudgetTable( - soft_budget=data.soft_budget, - model_max_budget=data.model_max_budget or {}, - ) - new_budget = prisma_client.jsonify_object( - budget_row.json(exclude_none=True) - ) - - _budget = await prisma_client.db.litellm_budgettable.create( - data={ - **new_budget, # type: ignore - "created_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - "updated_by": user_api_key_dict.user_id or litellm_proxy_admin_name, - } - ) - _budget_id = getattr(_budget, "budget_id", None) - data_json = data.json() # type: ignore - # if we get max_budget passed to /key/generate, then use it as key_max_budget. Since generate_key_helper_fn is used to make new users - if "max_budget" in data_json: - data_json["key_max_budget"] = data_json.pop("max_budget", None) - if _budget_id is not None: - data_json["budget_id"] = _budget_id - - if "budget_duration" in data_json: - data_json["key_budget_duration"] = data_json.pop("budget_duration", None) - - response = await generate_key_helper_fn( - request_type="key", **data_json, table_name="key" - ) - - response["soft_budget"] = ( - data.soft_budget - ) # include the user-input soft budget in the response - - if data.send_invite_email is True: - if "email" not in general_settings.get("alerting", []): - raise ValueError( - "Email alerting not setup on config.yaml. Please set `alerting=['email']. \nDocs: https://docs.litellm.ai/docs/proxy/email`" - ) - event = WebhookEvent( - event="key_created", - event_group="key", - event_message=f"API Key Created", - token=response.get("token", ""), - spend=response.get("spend", 0.0), - max_budget=response.get("max_budget", 0.0), - user_id=response.get("user_id", None), - team_id=response.get("team_id", "Default Team"), - key_alias=response.get("key_alias", None), - ) - - # If user configured email alerting - send an Email letting their end-user know the key was created - asyncio.create_task( - proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( - webhook_event=event, - ) - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = json.dumps(response, default=str) - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=response.get("token_id", ""), - action="created", - updated_values=_updated_values, - before_value=None, - ) - ) - ) - - return GenerateKeyResponse(**response) - except Exception as e: - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.generate_key_fn(): Exception occured - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, - ) - - -@router.post( - "/key/update", tags=["key management"], dependencies=[Depends(user_api_key_auth)] -) -async def update_key_fn( - request: Request, - data: UpdateKeyRequest, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - litellm_changed_by: Optional[str] = Header( - None, - description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", - ), -): - """ - Update an existing key - """ - global prisma_client - try: - data_json: dict = data.json() - key = data_json.pop("key") - # get the row from db - if prisma_client is None: - raise Exception("Not connected to DB!") - - existing_key_row = await prisma_client.get_data( - token=data.key, table_name="key", query_type="find_unique" - ) - - if existing_key_row is None: - raise HTTPException( - status_code=404, - detail={"error": f"Team not found, passed team_id={data.team_id}"}, - ) - - # get non default values for key - non_default_values = {} - for k, v in data_json.items(): - if v is not None and v not in ( - [], - {}, - 0, - ): # models default to [], spend defaults to 0, we should not reset these values - non_default_values[k] = v - - if "duration" in non_default_values: - duration = non_default_values.pop("duration") - duration_s = _duration_in_seconds(duration=duration) - expires = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - non_default_values["expires"] = expires - - response = await prisma_client.update_data( - token=key, data={**non_default_values, "token": key} - ) - - # Delete - key from cache, since it's been updated! - # key updated - a new model could have been added to this key. it should not block requests after this is done - user_api_key_cache.delete_cache(key) - hashed_token = hash_token(key) - user_api_key_cache.delete_cache(hashed_token) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = json.dumps(data_json, default=str) - - _before_value = existing_key_row.json(exclude_none=True) - _before_value = json.dumps(_before_value, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=data.key, - action="updated", - updated_values=_updated_values, - before_value=_before_value, - ) - ) - ) - - return {"key": key, **response["data"]} - # update based on remaining passed in values - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, - ) - - -@router.post( - "/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)] -) -async def delete_key_fn( - data: KeyRequest, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - litellm_changed_by: Optional[str] = Header( - None, - description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", - ), -): - """ - Delete a key from the key management system. - - Parameters:: - - keys (List[str]): A list of keys or hashed keys to delete. Example {"keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]} - - Returns: - - deleted_keys (List[str]): A list of deleted keys. Example {"deleted_keys": ["sk-QWrxEynunsNpV1zT48HIrw", "837e17519f44683334df5291321d97b8bf1098cd490e49e215f6fea935aa28be"]} - - - Raises: - HTTPException: If an error occurs during key deletion. - """ - try: - global user_api_key_cache - keys = data.keys - if len(keys) == 0: - raise ProxyException( - message=f"No keys provided, passed in: keys={keys}", - type="auth_error", - param="keys", - code=status.HTTP_400_BAD_REQUEST, - ) - - ## only allow user to delete keys they own - user_id = user_api_key_dict.user_id - verbose_proxy_logger.debug( - f"user_api_key_dict.user_role: {user_api_key_dict.user_role}" - ) - if ( - user_api_key_dict.user_role is not None - and user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN - ): - user_id = None # unless they're admin - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes - if litellm.store_audit_logs is True: - # make an audit log for each team deleted - for key in data.keys: - key_row = await prisma_client.get_data( # type: ignore - token=key, table_name="key", query_type="find_unique" - ) - - key_row = key_row.json(exclude_none=True) - _key_row = json.dumps(key_row, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.KEY_TABLE_NAME, - object_id=key, - action="deleted", - updated_values="{}", - before_value=_key_row, - ) - ) - ) - - number_deleted_keys = await delete_verification_token( - tokens=keys, user_id=user_id - ) - verbose_proxy_logger.debug( - f"/key/delete - deleted_keys={number_deleted_keys['deleted_keys']}" - ) - - try: - assert len(keys) == number_deleted_keys["deleted_keys"] - except Exception as e: - raise HTTPException( - status_code=400, - detail={ - "error": f"Not all keys passed in were deleted. This probably means you don't have access to delete all the keys passed in. Keys passed in={len(keys)}, Deleted keys ={number_deleted_keys['deleted_keys']}" - }, - ) - - for key in keys: - user_api_key_cache.delete_cache(key) - # remove hash token from cache - hashed_token = hash_token(key) - user_api_key_cache.delete_cache(hashed_token) - - verbose_proxy_logger.debug( - f"/keys/delete - cache after delete: {user_api_key_cache.in_memory_cache.cache_dict}" - ) - - return {"deleted_keys": keys} - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, - ) - - -@router.post( - "/v2/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)] -) -async def info_key_fn_v2( - data: Optional[KeyRequest] = None, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Retrieve information about a list of keys. - - **New endpoint**. Currently admin only. - Parameters: - keys: Optional[list] = body parameter representing the key(s) in the request - user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key - Returns: - Dict containing the key and its associated information - - Example Curl: - ``` - curl -X GET "http://0.0.0.0:8000/key/info" \ - -H "Authorization: Bearer sk-1234" \ - -d {"keys": ["sk-1", "sk-2", "sk-3"]} - ``` - """ - global prisma_client - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - if data is None: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail={"message": "Malformed request. No keys passed in."}, - ) - - key_info = await prisma_client.get_data( - token=data.keys, table_name="key", query_type="find_all" - ) - filtered_key_info = [] - for k in key_info: - try: - k = k.model_dump() # noqa - except: - # if using pydantic v1 - k = k.dict() - filtered_key_info.append(k) - return {"key": data.keys, "info": filtered_key_info} - - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, - ) - - -@router.get( - "/key/info", tags=["key management"], dependencies=[Depends(user_api_key_auth)] -) -async def info_key_fn( - key: Optional[str] = fastapi.Query( - default=None, description="Key in the request parameters" - ), - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Retrieve information about a key. - Parameters: - key: Optional[str] = Query parameter representing the key in the request - user_api_key_dict: UserAPIKeyAuth = Dependency representing the user's API key - Returns: - Dict containing the key and its associated information - - Example Curl: - ``` - curl -X GET "http://0.0.0.0:8000/key/info?key=sk-02Wr4IAlN3NvPXvL5JVvDA" \ --H "Authorization: Bearer sk-1234" - ``` - - Example Curl - if no key is passed, it will use the Key Passed in Authorization Header - ``` - curl -X GET "http://0.0.0.0:8000/key/info" \ --H "Authorization: Bearer sk-02Wr4IAlN3NvPXvL5JVvDA" - ``` - """ - global prisma_client - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - if key == None: - key = user_api_key_dict.api_key - key_info = await prisma_client.get_data(token=key) - ## REMOVE HASHED TOKEN INFO BEFORE RETURNING ## - try: - key_info = key_info.model_dump() # noqa - except: - # if using pydantic v1 - key_info = key_info.dict() - key_info.pop("token") - return {"key": key, "info": key_info} - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, - ) - - -#### SPEND MANAGEMENT ##### - - -@router.get( - "/spend/keys", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def spend_key_fn(): - """ - View all keys created, ordered by spend - - Example Request: - ``` - curl -X GET "http://0.0.0.0:8000/spend/keys" \ --H "Authorization: Bearer sk-1234" - ``` - """ - global prisma_client - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - key_info = await prisma_client.get_data(table_name="key", query_type="find_all") - return key_info - - except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, - ) - - -@router.get( - "/spend/users", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def spend_user_fn( - user_id: Optional[str] = fastapi.Query( - default=None, - description="Get User Table row for user_id", - ), -): - """ - View all users created, ordered by spend - - Example Request: - ``` - curl -X GET "http://0.0.0.0:8000/spend/users" \ --H "Authorization: Bearer sk-1234" - ``` - - View User Table row for user_id - ``` - curl -X GET "http://0.0.0.0:8000/spend/users?user_id=1234" \ --H "Authorization: Bearer sk-1234" - ``` - """ - global prisma_client - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - if user_id is not None: - user_info = await prisma_client.get_data( - table_name="user", query_type="find_unique", user_id=user_id - ) - return [user_info] - else: - user_info = await prisma_client.get_data( - table_name="user", query_type="find_all" - ) - - return user_info - - except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, - ) - - -@router.get( - "/spend/tags", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, -) -async def view_spend_tags( - start_date: Optional[str] = fastapi.Query( - default=None, - description="Time from which to start viewing key spend", - ), - end_date: Optional[str] = fastapi.Query( - default=None, - description="Time till which to view key spend", - ), -): - """ - LiteLLM Enterprise - View Spend Per Request Tag - - Example Request: - ``` - curl -X GET "http://0.0.0.0:8000/spend/tags" \ --H "Authorization: Bearer sk-1234" - ``` - - Spend with Start Date and End Date - ``` - curl -X GET "http://0.0.0.0:8000/spend/tags?start_date=2022-01-01&end_date=2022-02-01" \ --H "Authorization: Bearer sk-1234" - ``` - """ - - from enterprise.utils import get_spend_by_tags - - global prisma_client - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - # run the following SQL query on prisma - """ - SELECT - jsonb_array_elements_text(request_tags) AS individual_request_tag, - COUNT(*) AS log_count, - SUM(spend) AS total_spend - FROM "LiteLLM_SpendLogs" - GROUP BY individual_request_tag; - """ - response = await get_spend_by_tags( - start_date=start_date, end_date=end_date, prisma_client=prisma_client - ) - - return response - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"/spend/tags Error({str(e)})"), - type="internal_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="/spend/tags Error" + str(e), - type="internal_error", - param=getattr(e, "param", "None"), - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - -@router.get( - "/global/activity", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, - include_in_schema=False, -) -async def get_global_activity( - start_date: Optional[str] = fastapi.Query( - default=None, - description="Time from which to start viewing spend", - ), - end_date: Optional[str] = fastapi.Query( - default=None, - description="Time till which to view spend", - ), -): - """ - Get number of API Requests, total tokens through proxy - - { - "daily_data": [ - const chartdata = [ - { - date: 'Jan 22', - api_requests: 10, - total_tokens: 2000 - }, - { - date: 'Jan 23', - api_requests: 10, - total_tokens: 12 - }, - ], - "sum_api_requests": 20, - "sum_total_tokens": 2012 - } - """ - from collections import defaultdict - - if start_date is None or end_date is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": "Please provide start_date and end_date"}, - ) - - start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") - end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") - - global prisma_client, llm_router - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - sql_query = """ - SELECT - date_trunc('day', "startTime") AS date, - COUNT(*) AS api_requests, - SUM(total_tokens) AS total_tokens - FROM "LiteLLM_SpendLogs" - WHERE "startTime" BETWEEN $1::date AND $2::date + interval '1 day' - GROUP BY date_trunc('day', "startTime") - """ - db_response = await prisma_client.db.query_raw( - sql_query, start_date_obj, end_date_obj - ) - - if db_response is None: - return [] - - sum_api_requests = 0 - sum_total_tokens = 0 - daily_data = [] - for row in db_response: - # cast date to datetime - _date_obj = datetime.fromisoformat(row["date"]) - row["date"] = _date_obj.strftime("%b %d") - - daily_data.append(row) - sum_api_requests += row.get("api_requests", 0) - sum_total_tokens += row.get("total_tokens", 0) - - # sort daily_data by date - daily_data = sorted(daily_data, key=lambda x: x["date"]) - - data_to_return = { - "daily_data": daily_data, - "sum_api_requests": sum_api_requests, - "sum_total_tokens": sum_total_tokens, - } - - return data_to_return - - except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, - ) - - -@router.get( - "/global/activity/model", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, - include_in_schema=False, -) -async def get_global_activity_model( - start_date: Optional[str] = fastapi.Query( - default=None, - description="Time from which to start viewing spend", - ), - end_date: Optional[str] = fastapi.Query( - default=None, - description="Time till which to view spend", - ), -): - """ - Get number of API Requests, total tokens through proxy - Grouped by MODEL - - [ - { - "model": "gpt-4", - "daily_data": [ - const chartdata = [ - { - date: 'Jan 22', - api_requests: 10, - total_tokens: 2000 - }, - { - date: 'Jan 23', - api_requests: 10, - total_tokens: 12 - }, - ], - "sum_api_requests": 20, - "sum_total_tokens": 2012 - - }, - { - "model": "azure/gpt-4-turbo", - "daily_data": [ - const chartdata = [ - { - date: 'Jan 22', - api_requests: 10, - total_tokens: 2000 - }, - { - date: 'Jan 23', - api_requests: 10, - total_tokens: 12 - }, - ], - "sum_api_requests": 20, - "sum_total_tokens": 2012 - - }, - ] - """ - from collections import defaultdict - - if start_date is None or end_date is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": "Please provide start_date and end_date"}, - ) - - start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") - end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") - - global prisma_client, llm_router, premium_user - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - sql_query = """ - SELECT - model_group, - date_trunc('day', "startTime") AS date, - COUNT(*) AS api_requests, - SUM(total_tokens) AS total_tokens - FROM "LiteLLM_SpendLogs" - WHERE "startTime" BETWEEN $1::date AND $2::date + interval '1 day' - GROUP BY model_group, date_trunc('day', "startTime") - """ - db_response = await prisma_client.db.query_raw( - sql_query, start_date_obj, end_date_obj - ) - if db_response is None: - return [] - - model_ui_data: dict = ( - {} - ) # {"gpt-4": {"daily_data": [], "sum_api_requests": 0, "sum_total_tokens": 0}} - - for row in db_response: - _model = row["model_group"] - if _model not in model_ui_data: - model_ui_data[_model] = { - "daily_data": [], - "sum_api_requests": 0, - "sum_total_tokens": 0, - } - _date_obj = datetime.fromisoformat(row["date"]) - row["date"] = _date_obj.strftime("%b %d") - - model_ui_data[_model]["daily_data"].append(row) - model_ui_data[_model]["sum_api_requests"] += row.get("api_requests", 0) - model_ui_data[_model]["sum_total_tokens"] += row.get("total_tokens", 0) - - # sort mode ui data by sum_api_requests -> get top 10 models - model_ui_data = dict( - sorted( - model_ui_data.items(), - key=lambda x: x[1]["sum_api_requests"], - reverse=True, - )[:10] - ) - - response = [] - for model, data in model_ui_data.items(): - _sort_daily_data = sorted(data["daily_data"], key=lambda x: x["date"]) - - response.append( - { - "model": model, - "daily_data": _sort_daily_data, - "sum_api_requests": data["sum_api_requests"], - "sum_total_tokens": data["sum_total_tokens"], - } - ) - - return response - - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={"error": str(e)}, - ) - - -@router.get( - "/global/activity/exceptions/deployment", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, - include_in_schema=False, -) -async def get_global_activity_exceptions_per_deployment( - model_group: str = fastapi.Query( - description="Filter by model group", - ), - start_date: Optional[str] = fastapi.Query( - default=None, - description="Time from which to start viewing spend", - ), - end_date: Optional[str] = fastapi.Query( - default=None, - description="Time till which to view spend", - ), -): - """ - Get number of 429 errors - Grouped by deployment - - [ - { - "deployment": "https://azure-us-east-1.openai.azure.com/", - "daily_data": [ - const chartdata = [ - { - date: 'Jan 22', - num_rate_limit_exceptions: 10 - }, - { - date: 'Jan 23', - num_rate_limit_exceptions: 12 - }, - ], - "sum_num_rate_limit_exceptions": 20, - - }, - { - "deployment": "https://azure-us-east-1.openai.azure.com/", - "daily_data": [ - const chartdata = [ - { - date: 'Jan 22', - num_rate_limit_exceptions: 10, - }, - { - date: 'Jan 23', - num_rate_limit_exceptions: 12 - }, - ], - "sum_num_rate_limit_exceptions": 20, - - }, - ] - """ - from collections import defaultdict - - if start_date is None or end_date is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": "Please provide start_date and end_date"}, - ) - - start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") - end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") - - global prisma_client, llm_router, premium_user - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - sql_query = """ - SELECT - api_base, - date_trunc('day', "startTime")::date AS date, - COUNT(*) AS num_rate_limit_exceptions - FROM - "LiteLLM_ErrorLogs" - WHERE - "startTime" >= $1::date - AND "startTime" < ($2::date + INTERVAL '1 day') - AND model_group = $3 - AND status_code = '429' - GROUP BY - api_base, - date_trunc('day', "startTime") - ORDER BY - date; - """ - db_response = await prisma_client.db.query_raw( - sql_query, start_date_obj, end_date_obj, model_group - ) - if db_response is None: - return [] - - model_ui_data: dict = ( - {} - ) # {"gpt-4": {"daily_data": [], "sum_api_requests": 0, "sum_total_tokens": 0}} - - for row in db_response: - _model = row["api_base"] - if _model not in model_ui_data: - model_ui_data[_model] = { - "daily_data": [], - "sum_num_rate_limit_exceptions": 0, - } - _date_obj = datetime.fromisoformat(row["date"]) - row["date"] = _date_obj.strftime("%b %d") - - model_ui_data[_model]["daily_data"].append(row) - model_ui_data[_model]["sum_num_rate_limit_exceptions"] += row.get( - "num_rate_limit_exceptions", 0 - ) - - # sort mode ui data by sum_api_requests -> get top 10 models - model_ui_data = dict( - sorted( - model_ui_data.items(), - key=lambda x: x[1]["sum_num_rate_limit_exceptions"], - reverse=True, - )[:10] - ) - - response = [] - for model, data in model_ui_data.items(): - _sort_daily_data = sorted(data["daily_data"], key=lambda x: x["date"]) - - response.append( - { - "api_base": model, - "daily_data": _sort_daily_data, - "sum_num_rate_limit_exceptions": data[ - "sum_num_rate_limit_exceptions" - ], - } - ) - - return response - - except Exception as e: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={"error": str(e)}, - ) - - -@router.get( - "/global/activity/exceptions", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, - include_in_schema=False, -) -async def get_global_activity_exceptions( - model_group: str = fastapi.Query( - description="Filter by model group", - ), - start_date: Optional[str] = fastapi.Query( - default=None, - description="Time from which to start viewing spend", - ), - end_date: Optional[str] = fastapi.Query( - default=None, - description="Time till which to view spend", - ), -): - """ - Get number of API Requests, total tokens through proxy - - { - "daily_data": [ - const chartdata = [ - { - date: 'Jan 22', - num_rate_limit_exceptions: 10, - }, - { - date: 'Jan 23', - num_rate_limit_exceptions: 10, - }, - ], - "sum_api_exceptions": 20, - } - """ - from collections import defaultdict - - if start_date is None or end_date is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": "Please provide start_date and end_date"}, - ) - - start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") - end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") - - global prisma_client, llm_router - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - sql_query = """ - SELECT - date_trunc('day', "startTime")::date AS date, - COUNT(*) AS num_rate_limit_exceptions - FROM - "LiteLLM_ErrorLogs" - WHERE - "startTime" >= $1::date - AND "startTime" < ($2::date + INTERVAL '1 day') - AND model_group = $3 - AND status_code = '429' - GROUP BY - date_trunc('day', "startTime") - ORDER BY - date; - """ - db_response = await prisma_client.db.query_raw( - sql_query, start_date_obj, end_date_obj, model_group - ) - - if db_response is None: - return [] - - sum_num_rate_limit_exceptions = 0 - daily_data = [] - for row in db_response: - # cast date to datetime - _date_obj = datetime.fromisoformat(row["date"]) - row["date"] = _date_obj.strftime("%b %d") - - daily_data.append(row) - sum_num_rate_limit_exceptions += row.get("num_rate_limit_exceptions", 0) - - # sort daily_data by date - daily_data = sorted(daily_data, key=lambda x: x["date"]) - - data_to_return = { - "daily_data": daily_data, - "sum_num_rate_limit_exceptions": sum_num_rate_limit_exceptions, - } - - return data_to_return - - except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, - ) - - -@router.get( - "/global/spend/provider", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, -) -async def get_global_spend_provider( - start_date: Optional[str] = fastapi.Query( - default=None, - description="Time from which to start viewing spend", - ), - end_date: Optional[str] = fastapi.Query( - default=None, - description="Time till which to view spend", - ), -): - """ - Get breakdown of spend per provider - [ - { - "provider": "Azure OpenAI", - "spend": 20 - }, - { - "provider": "OpenAI", - "spend": 10 - }, - { - "provider": "VertexAI", - "spend": 30 - } - ] - """ - from collections import defaultdict - - if start_date is None or end_date is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": "Please provide start_date and end_date"}, - ) - - start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") - end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") - - global prisma_client, llm_router - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - sql_query = """ - - SELECT - model_id, - SUM(spend) AS spend - FROM "LiteLLM_SpendLogs" - WHERE "startTime" BETWEEN $1::date AND $2::date AND length(model_id) > 0 - GROUP BY model_id - """ - - db_response = await prisma_client.db.query_raw( - sql_query, start_date_obj, end_date_obj - ) - if db_response is None: - return [] - - ################################### - # Convert model_id -> to Provider # - ################################### - - # we use the in memory router for this - ui_response = [] - provider_spend_mapping: defaultdict = defaultdict(int) - for row in db_response: - _model_id = row["model_id"] - _provider = "Unknown" - if llm_router is not None: - _deployment = llm_router.get_deployment(model_id=_model_id) - if _deployment is not None: - try: - _, _provider, _, _ = litellm.get_llm_provider( - model=_deployment.litellm_params.model, - custom_llm_provider=_deployment.litellm_params.custom_llm_provider, - api_base=_deployment.litellm_params.api_base, - litellm_params=_deployment.litellm_params, - ) - provider_spend_mapping[_provider] += row["spend"] - except: - pass - - for provider, spend in provider_spend_mapping.items(): - ui_response.append({"provider": provider, "spend": spend}) - - return ui_response - - except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, - ) - - -@router.get( - "/global/spend/report", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, -) -async def get_global_spend_report( - start_date: Optional[str] = fastapi.Query( - default=None, - description="Time from which to start viewing spend", - ), - end_date: Optional[str] = fastapi.Query( - default=None, - description="Time till which to view spend", - ), - group_by: Optional[Literal["team", "customer"]] = fastapi.Query( - default="team", - description="Group spend by internal team or customer", - ), -): - """ - Get Daily Spend per Team, based on specific startTime and endTime. Per team, view usage by each key, model - [ - { - "group-by-day": "2024-05-10", - "teams": [ - { - "team_name": "team-1" - "spend": 10, - "keys": [ - "key": "1213", - "usage": { - "model-1": { - "cost": 12.50, - "input_tokens": 1000, - "output_tokens": 5000, - "requests": 100 - }, - "audio-modelname1": { - "cost": 25.50, - "seconds": 25, - "requests": 50 - }, - } - } - ] - ] - } - """ - if start_date is None or end_date is None: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": "Please provide start_date and end_date"}, - ) - - start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") - end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") - - global prisma_client - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - if group_by == "team": - # first get data from spend logs -> SpendByModelApiKey - # then read data from "SpendByModelApiKey" to format the response obj - sql_query = """ - - WITH SpendByModelApiKey AS ( - SELECT - date_trunc('day', sl."startTime") AS group_by_day, - COALESCE(tt.team_alias, 'Unassigned Team') AS team_name, - sl.model, - sl.api_key, - SUM(sl.spend) AS model_api_spend, - SUM(sl.total_tokens) AS model_api_tokens - FROM - "LiteLLM_SpendLogs" sl - LEFT JOIN - "LiteLLM_TeamTable" tt - ON - sl.team_id = tt.team_id - WHERE - sl."startTime" BETWEEN $1::date AND $2::date - GROUP BY - date_trunc('day', sl."startTime"), - tt.team_alias, - sl.model, - sl.api_key - ) - SELECT - group_by_day, - jsonb_agg(jsonb_build_object( - 'team_name', team_name, - 'total_spend', total_spend, - 'metadata', metadata - )) AS teams - FROM ( - SELECT - group_by_day, - team_name, - SUM(model_api_spend) AS total_spend, - jsonb_agg(jsonb_build_object( - 'model', model, - 'api_key', api_key, - 'spend', model_api_spend, - 'total_tokens', model_api_tokens - )) AS metadata - FROM - SpendByModelApiKey - GROUP BY - group_by_day, - team_name - ) AS aggregated - GROUP BY - group_by_day - ORDER BY - group_by_day; - """ - - db_response = await prisma_client.db.query_raw( - sql_query, start_date_obj, end_date_obj - ) - if db_response is None: - return [] - - return db_response - - elif group_by == "customer": - sql_query = """ - - WITH SpendByModelApiKey AS ( - SELECT - date_trunc('day', sl."startTime") AS group_by_day, - sl.end_user AS customer, - sl.model, - sl.api_key, - SUM(sl.spend) AS model_api_spend, - SUM(sl.total_tokens) AS model_api_tokens - FROM - "LiteLLM_SpendLogs" sl - WHERE - sl."startTime" BETWEEN $1::date AND $2::date - GROUP BY - date_trunc('day', sl."startTime"), - customer, - sl.model, - sl.api_key - ) - SELECT - group_by_day, - jsonb_agg(jsonb_build_object( - 'customer', customer, - 'total_spend', total_spend, - 'metadata', metadata - )) AS customers - FROM - ( - SELECT - group_by_day, - customer, - SUM(model_api_spend) AS total_spend, - jsonb_agg(jsonb_build_object( - 'model', model, - 'api_key', api_key, - 'spend', model_api_spend, - 'total_tokens', model_api_tokens - )) AS metadata - FROM - SpendByModelApiKey - GROUP BY - group_by_day, - customer - ) AS aggregated - GROUP BY - group_by_day - ORDER BY - group_by_day; - """ - - db_response = await prisma_client.db.query_raw( - sql_query, start_date_obj, end_date_obj - ) - if db_response is None: - return [] - - return db_response - - except Exception as e: - raise HTTPException( - status_code=status.HTTP_400_BAD_REQUEST, - detail={"error": str(e)}, - ) - - -@router.get( - "/global/spend/all_tag_names", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, -) -async def global_get_all_tag_names(): - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - sql_query = """ - SELECT DISTINCT - jsonb_array_elements_text(request_tags) AS individual_request_tag - FROM "LiteLLM_SpendLogs"; - """ - - db_response = await prisma_client.db.query_raw(sql_query) - if db_response is None: - return [] - - _tag_names = [] - for row in db_response: - _tag_names.append(row.get("individual_request_tag")) - - return {"tag_names": _tag_names} - - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"/spend/all_tag_names Error({str(e)})"), - type="internal_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="/spend/all_tag_names Error" + str(e), - type="internal_error", - param=getattr(e, "param", "None"), - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - -@router.get( - "/global/spend/tags", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, -) -async def global_view_spend_tags( - start_date: Optional[str] = fastapi.Query( - default=None, - description="Time from which to start viewing key spend", - ), - end_date: Optional[str] = fastapi.Query( - default=None, - description="Time till which to view key spend", - ), - tags: Optional[str] = fastapi.Query( - default=None, - description="comman separated tags to filter on", - ), -): - """ - LiteLLM Enterprise - View Spend Per Request Tag. Used by LiteLLM UI - - Example Request: - ``` - curl -X GET "http://0.0.0.0:4000/spend/tags" \ --H "Authorization: Bearer sk-1234" - ``` - - Spend with Start Date and End Date - ``` - curl -X GET "http://0.0.0.0:4000/spend/tags?start_date=2022-01-01&end_date=2022-02-01" \ --H "Authorization: Bearer sk-1234" - ``` - """ - - from enterprise.utils import ui_get_spend_by_tags - - global prisma_client - try: - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - - if end_date is None or start_date is None: - raise ProxyException( - message="Please provide start_date and end_date", - type="bad_request", - param=None, - code=status.HTTP_400_BAD_REQUEST, - ) - response = await ui_get_spend_by_tags( - start_date=start_date, - end_date=end_date, - tags_str=tags, - prisma_client=prisma_client, - ) - - return response - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"/spend/tags Error({str(e)})"), - type="internal_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="/spend/tags Error" + str(e), - type="internal_error", - param=getattr(e, "param", "None"), - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - -async def _get_spend_report_for_time_range( - start_date: str, - end_date: str, -): - global prisma_client - if prisma_client is None: - verbose_proxy_logger.error( - f"Database not connected. Connect a database to your proxy for weekly, monthly spend reports" - ) - return None - - try: - sql_query = """ - SELECT - t.team_alias, - SUM(s.spend) AS total_spend - FROM - "LiteLLM_SpendLogs" s - LEFT JOIN - "LiteLLM_TeamTable" t ON s.team_id = t.team_id - WHERE - s."startTime"::DATE >= $1::date AND s."startTime"::DATE <= $2::date - GROUP BY - t.team_alias - ORDER BY - total_spend DESC; - """ - response = await prisma_client.db.query_raw(sql_query, start_date, end_date) - - # get spend per tag for today - sql_query = """ - SELECT - jsonb_array_elements_text(request_tags) AS individual_request_tag, - SUM(spend) AS total_spend - FROM "LiteLLM_SpendLogs" - WHERE "startTime"::DATE >= $1::date AND "startTime"::DATE <= $2::date - GROUP BY individual_request_tag - ORDER BY total_spend DESC; - """ - - spend_per_tag = await prisma_client.db.query_raw( - sql_query, start_date, end_date - ) - - return response, spend_per_tag - except Exception as e: - verbose_proxy_logger.error( - "Exception in _get_daily_spend_reports {}".format(str(e)) - ) # noqa - - -@router.post( - "/spend/calculate", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - responses={ - 200: { - "cost": { - "description": "The calculated cost", - "example": 0.0, - "type": "float", - } - } - }, -) -async def calculate_spend(request: Request): - """ - Accepts all the params of completion_cost. - - Calculate spend **before** making call: - - Note: If you see a spend of $0.0 you need to set custom_pricing for your model: https://docs.litellm.ai/docs/proxy/custom_pricing - - ``` - curl --location 'http://localhost:4000/spend/calculate' - --header 'Authorization: Bearer sk-1234' - --header 'Content-Type: application/json' - --data '{ - "model": "anthropic.claude-v2", - "messages": [{"role": "user", "content": "Hey, how'''s it going?"}] - }' - ``` - - Calculate spend **after** making call: - - ``` - curl --location 'http://localhost:4000/spend/calculate' - --header 'Authorization: Bearer sk-1234' - --header 'Content-Type: application/json' - --data '{ - "completion_response": { - "id": "chatcmpl-123", - "object": "chat.completion", - "created": 1677652288, - "model": "gpt-3.5-turbo-0125", - "system_fingerprint": "fp_44709d6fcb", - "choices": [{ - "index": 0, - "message": { - "role": "assistant", - "content": "Hello there, how may I assist you today?" - }, - "logprobs": null, - "finish_reason": "stop" - }] - "usage": { - "prompt_tokens": 9, - "completion_tokens": 12, - "total_tokens": 21 - } - } - }' - ``` - """ - from litellm import completion_cost - - data = await request.json() - if "completion_response" in data: - data["completion_response"] = litellm.ModelResponse( - **data["completion_response"] - ) - return {"cost": completion_cost(**data)} - - -@router.get( - "/spend/logs", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - responses={ - 200: {"model": List[LiteLLM_SpendLogs]}, - }, -) -async def view_spend_logs( - api_key: Optional[str] = fastapi.Query( - default=None, - description="Get spend logs based on api key", - ), - user_id: Optional[str] = fastapi.Query( - default=None, - description="Get spend logs based on user_id", - ), - request_id: Optional[str] = fastapi.Query( - default=None, - description="request_id to get spend logs for specific request_id. If none passed then pass spend logs for all requests", - ), - start_date: Optional[str] = fastapi.Query( - default=None, - description="Time from which to start viewing key spend", - ), - end_date: Optional[str] = fastapi.Query( - default=None, - description="Time till which to view key spend", - ), -): - """ - View all spend logs, if request_id is provided, only logs for that request_id will be returned - - Example Request for all logs - ``` - curl -X GET "http://0.0.0.0:8000/spend/logs" \ --H "Authorization: Bearer sk-1234" - ``` - - Example Request for specific request_id - ``` - curl -X GET "http://0.0.0.0:8000/spend/logs?request_id=chatcmpl-6dcb2540-d3d7-4e49-bb27-291f863f112e" \ --H "Authorization: Bearer sk-1234" - ``` - - Example Request for specific api_key - ``` - curl -X GET "http://0.0.0.0:8000/spend/logs?api_key=sk-Fn8Ej39NkBQmUagFEoUWPQ" \ --H "Authorization: Bearer sk-1234" - ``` - - Example Request for specific user_id - ``` - curl -X GET "http://0.0.0.0:8000/spend/logs?user_id=ishaan@berri.ai" \ --H "Authorization: Bearer sk-1234" - ``` - """ - global prisma_client - try: - verbose_proxy_logger.debug("inside view_spend_logs") - if prisma_client is None: - raise Exception( - f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - ) - spend_logs = [] - if ( - start_date is not None - and isinstance(start_date, str) - and end_date is not None - and isinstance(end_date, str) - ): - # Convert the date strings to datetime objects - start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") - end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") - - filter_query = { - "startTime": { - "gte": start_date_obj, # Greater than or equal to Start Date - "lte": end_date_obj, # Less than or equal to End Date - } - } - - if api_key is not None and isinstance(api_key, str): - filter_query["api_key"] = api_key # type: ignore - elif request_id is not None and isinstance(request_id, str): - filter_query["request_id"] = request_id # type: ignore - elif user_id is not None and isinstance(user_id, str): - filter_query["user"] = user_id # type: ignore - - # SQL query - response = await prisma_client.db.litellm_spendlogs.group_by( - by=["api_key", "user", "model", "startTime"], - where=filter_query, # type: ignore - sum={ - "spend": True, - }, - ) - - if ( - isinstance(response, list) - and len(response) > 0 - and isinstance(response[0], dict) - ): - result: dict = {} - for record in response: - dt_object = datetime.strptime( - str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" - ) # type: ignore - date = dt_object.date() - if date not in result: - result[date] = {"users": {}, "models": {}} - api_key = record["api_key"] - user_id = record["user"] - model = record["model"] - result[date]["spend"] = ( - result[date].get("spend", 0) + record["_sum"]["spend"] - ) - result[date][api_key] = ( - result[date].get(api_key, 0) + record["_sum"]["spend"] - ) - result[date]["users"][user_id] = ( - result[date]["users"].get(user_id, 0) + record["_sum"]["spend"] - ) - result[date]["models"][model] = ( - result[date]["models"].get(model, 0) + record["_sum"]["spend"] - ) - return_list = [] - final_date = None - for k, v in sorted(result.items()): - return_list.append({**v, "startTime": k}) - final_date = k - - end_date_date = end_date_obj.date() - if final_date is not None and final_date < end_date_date: - current_date = final_date + timedelta(days=1) - while current_date <= end_date_date: - # Represent current_date as string because original response has it this way - return_list.append( - { - "startTime": current_date, - "spend": 0, - "users": {}, - "models": {}, - } - ) # If no data, will stay as zero - current_date += timedelta(days=1) # Move on to the next day - - return return_list - - return response - - elif api_key is not None and isinstance(api_key, str): - if api_key.startswith("sk-"): - hashed_token = prisma_client.hash_token(token=api_key) - else: - hashed_token = api_key - spend_log = await prisma_client.get_data( - table_name="spend", - query_type="find_all", - key_val={"key": "api_key", "value": hashed_token}, - ) - if isinstance(spend_log, list): - return spend_log - else: - return [spend_log] - elif request_id is not None: - spend_log = await prisma_client.get_data( - table_name="spend", - query_type="find_unique", - key_val={"key": "request_id", "value": request_id}, - ) - return [spend_log] - elif user_id is not None: - spend_log = await prisma_client.get_data( - table_name="spend", - query_type="find_all", - key_val={"key": "user", "value": user_id}, - ) - if isinstance(spend_log, list): - return spend_log - else: - return [spend_log] - else: - spend_logs = await prisma_client.get_data( - table_name="spend", query_type="find_all" - ) - - return spend_log - - return None - - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"/spend/logs Error({str(e)})"), - type="internal_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="/spend/logs Error" + str(e), - type="internal_error", - param=getattr(e, "param", "None"), - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - -@router.post( - "/global/spend/reset", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], -) -async def global_spend_reset(): - """ - ADMIN ONLY / MASTER KEY Only Endpoint - - Globally reset spend for All API Keys and Teams, maintain LiteLLM_SpendLogs - - 1. LiteLLM_SpendLogs will maintain the logs on spend, no data gets deleted from there - 2. LiteLLM_VerificationTokens spend will be set = 0 - 3. LiteLLM_TeamTable spend will be set = 0 - - """ - global prisma_client - if prisma_client is None: - raise ProxyException( - message="Prisma Client is not initialized", - type="internal_error", - param="None", - code=status.HTTP_401_UNAUTHORIZED, - ) - - await prisma_client.db.litellm_verificationtoken.update_many( - data={"spend": 0.0}, where={} - ) - await prisma_client.db.litellm_teamtable.update_many(data={"spend": 0.0}, where={}) - - return { - "message": "Spend for all API Keys and Teams reset successfully", - "status": "success", - } - - -@router.get( - "/global/spend/logs", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def global_spend_logs( - api_key: str = fastapi.Query( - default=None, - description="API Key to get global spend (spend per day for last 30d). Admin-only endpoint", - ) -): - """ - [BETA] This is a beta endpoint. It will change. - - Use this to get global spend (spend per day for last 30d). Admin-only endpoint - - More efficient implementation of /spend/logs, by creating a view over the spend logs table. - """ - global prisma_client - if prisma_client is None: - raise ProxyException( - message="Prisma Client is not initialized", - type="internal_error", - param="None", - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - if api_key is None: - sql_query = """SELECT * FROM "MonthlyGlobalSpend" ORDER BY "date";""" - - response = await prisma_client.db.query_raw(query=sql_query) - - return response - else: - sql_query = """ - SELECT * FROM "MonthlyGlobalSpendPerKey" - WHERE "api_key" = $1 - ORDER BY "date"; - """ - - response = await prisma_client.db.query_raw(sql_query, api_key) - - return response - return - - -@router.get( - "/global/spend", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def global_spend(): - """ - [BETA] This is a beta endpoint. It will change. - - View total spend across all proxy keys - """ - global prisma_client - total_spend = 0.0 - total_proxy_budget = 0.0 - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";""" - response = await prisma_client.db.query_raw(query=sql_query) - if response is not None: - if isinstance(response, list) and len(response) > 0: - total_spend = response[0].get("total_spend", 0.0) - - return {"spend": total_spend, "max_budget": litellm.max_budget} - - -@router.get( - "/global/spend/keys", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def global_spend_keys( - limit: int = fastapi.Query( - default=None, - description="Number of keys to get. Will return Top 'n' keys.", - ) -): - """ - [BETA] This is a beta endpoint. It will change. - - Use this to get the top 'n' keys with the highest spend, ordered by spend. - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - sql_query = f"""SELECT * FROM "Last30dKeysBySpend" LIMIT {limit};""" - - response = await prisma_client.db.query_raw(query=sql_query) - - return response - - -@router.get( - "/global/spend/teams", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def global_spend_per_team(): - """ - [BETA] This is a beta endpoint. It will change. - - Use this to get daily spend, grouped by `team_id` and `date` - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - sql_query = """ - SELECT - t.team_alias as team_alias, - DATE(s."startTime") AS spend_date, - SUM(s.spend) AS total_spend - FROM - "LiteLLM_SpendLogs" s - LEFT JOIN - "LiteLLM_TeamTable" t ON s.team_id = t.team_id - WHERE - s."startTime" >= CURRENT_DATE - INTERVAL '30 days' - GROUP BY - t.team_alias, - DATE(s."startTime") - ORDER BY - spend_date; - """ - response = await prisma_client.db.query_raw(query=sql_query) - - # transform the response for the Admin UI - spend_by_date = {} - team_aliases = set() - total_spend_per_team = {} - for row in response: - row_date = row["spend_date"] - if row_date is None: - continue - team_alias = row["team_alias"] - if team_alias is None: - team_alias = "Unassigned" - team_aliases.add(team_alias) - if row_date in spend_by_date: - # get the team_id for this entry - # get the spend for this entry - spend = row["total_spend"] - spend = round(spend, 2) - current_date_entries = spend_by_date[row_date] - current_date_entries[team_alias] = spend - else: - spend = row["total_spend"] - spend = round(spend, 2) - spend_by_date[row_date] = {team_alias: spend} - - if team_alias in total_spend_per_team: - total_spend_per_team[team_alias] += spend - else: - total_spend_per_team[team_alias] = spend - - total_spend_per_team_ui = [] - # order the elements in total_spend_per_team by spend - total_spend_per_team = dict( - sorted(total_spend_per_team.items(), key=lambda item: item[1], reverse=True) - ) - for team_id in total_spend_per_team: - # only add first 10 elements to total_spend_per_team_ui - if len(total_spend_per_team_ui) >= 10: - break - if team_id is None: - team_id = "Unassigned" - total_spend_per_team_ui.append( - {"team_id": team_id, "total_spend": total_spend_per_team[team_id]} - ) - - # sort spend_by_date by it's key (which is a date) - - response_data = [] - for key in spend_by_date: - value = spend_by_date[key] - response_data.append({"date": key, **value}) - - return { - "daily_spend": response_data, - "teams": list(team_aliases), - "total_spend_per_team": total_spend_per_team_ui, - } - - -@router.get( - "/global/all_end_users", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def global_view_all_end_users(): - """ - [BETA] This is a beta endpoint. It will change. - - Use this to just get all the unique `end_users` - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - sql_query = """ - SELECT DISTINCT end_user FROM "LiteLLM_SpendLogs" - """ - - db_response = await prisma_client.db.query_raw(query=sql_query) - if db_response is None: - return [] - - _end_users = [] - for row in db_response: - _end_users.append(row["end_user"]) - - return {"end_users": _end_users} - - -@router.post( - "/global/spend/end_users", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def global_spend_end_users(data: Optional[GlobalEndUsersSpend] = None): - """ - [BETA] This is a beta endpoint. It will change. - - Use this to get the top 'n' keys with the highest spend, ordered by spend. - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - """ - Gets the top 100 end-users for a given api key - """ - startTime = None - endTime = None - selected_api_key = None - if data is not None: - startTime = data.startTime - endTime = data.endTime - selected_api_key = data.api_key - - startTime = startTime or datetime.now() - timedelta(days=30) - endTime = endTime or datetime.now() - - sql_query = """ -SELECT end_user, COUNT(*) AS total_count, SUM(spend) AS total_spend -FROM "LiteLLM_SpendLogs" -WHERE "startTime" >= $1::timestamp - AND "startTime" < $2::timestamp - AND ( - CASE - WHEN $3::TEXT IS NULL THEN TRUE - ELSE api_key = $3 - END - ) -GROUP BY end_user -ORDER BY total_spend DESC -LIMIT 100 - """ - response = await prisma_client.db.query_raw( - sql_query, startTime, endTime, selected_api_key - ) - - return response - - -@router.get( - "/global/spend/models", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def global_spend_models( - limit: int = fastapi.Query( - default=None, - description="Number of models to get. Will return Top 'n' models.", - ) -): - """ - [BETA] This is a beta endpoint. It will change. - - Use this to get the top 'n' keys with the highest spend, ordered by spend. - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - sql_query = f"""SELECT * FROM "Last30dModelsBySpend" LIMIT {limit};""" - - response = await prisma_client.db.query_raw(query=sql_query) - - return response - - -@router.post( - "/global/predict/spend/logs", - tags=["Budget & Spend Tracking"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def global_predict_spend_logs(request: Request): - from enterprise.utils import _forecast_daily_cost - - data = await request.json() - data = data.get("data") - return _forecast_daily_cost(data) - - #### INTERNAL USER MANAGEMENT #### @router.post( "/user/new", @@ -7781,6 +4904,8 @@ async def new_user(data: NewUserRequest): # Admin UI Logic # if team_id passed add this user to the team if data_json.get("team_id", None) is not None: + from litellm.proxy.management_endpoints.team_endpoints import team_member_add + await team_member_add( data=TeamMemberAddRequest( team_id=data_json.get("team_id", None), @@ -8816,235 +5941,6 @@ async def delete_end_user( pass -#### TEAM MANAGEMENT #### - - -@router.post( - "/team/new", - tags=["team management"], - dependencies=[Depends(user_api_key_auth)], - response_model=LiteLLM_TeamTable, -) -@management_endpoint_wrapper -async def new_team( - data: NewTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - litellm_changed_by: Optional[str] = Header( - None, - description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", - ), -): - """ - Allow users to create a new team. Apply user permissions to their team. - - šŸ‘‰ [Detailed Doc on setting team budgets](https://docs.litellm.ai/docs/proxy/team_budgets) - - - Parameters: - - team_alias: Optional[str] - User defined team alias - - team_id: Optional[str] - The team id of the user. If none passed, we'll generate it. - - members_with_roles: List[{"role": "admin" or "user", "user_id": ""}] - A list of users and their roles in the team. Get user_id when making a new user via `/user/new`. - - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"extra_info": "some info"} - - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit - - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit - - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget - - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) - - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. - - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. - - Returns: - - team_id: (str) Unique team id - used for tracking spend across multiple keys for same team id. - - _deprecated_params: - - admins: list - A list of user_id's for the admin role - - users: list - A list of user_id's for the user role - - Example Request: - ``` - curl --location 'http://0.0.0.0:4000/team/new' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data '{ - "team_alias": "my-new-team_2", - "members_with_roles": [{"role": "admin", "user_id": "user-1234"}, - {"role": "user", "user_id": "user-2434"}] - }' - - ``` - - ``` - curl --location 'http://0.0.0.0:4000/team/new' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data '{ - "team_alias": "QA Prod Bot", - "max_budget": 0.000000001, - "budget_duration": "1d" - }' - - ``` - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_id is None: - data.team_id = str(uuid.uuid4()) - else: - # Check if team_id exists already - _existing_team_id = await prisma_client.get_data( - team_id=data.team_id, table_name="team", query_type="find_unique" - ) - if _existing_team_id is not None: - raise HTTPException( - status_code=400, - detail={ - "error": f"Team id = {data.team_id} already exists. Please use a different team id." - }, - ) - - if ( - user_api_key_dict.user_role is None - or user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN - ): # don't restrict proxy admin - if ( - data.tpm_limit is not None - and user_api_key_dict.tpm_limit is not None - and data.tpm_limit > user_api_key_dict.tpm_limit - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"tpm limit higher than user max. User tpm limit={user_api_key_dict.tpm_limit}. User role={user_api_key_dict.user_role}" - }, - ) - - if ( - data.rpm_limit is not None - and user_api_key_dict.rpm_limit is not None - and data.rpm_limit > user_api_key_dict.rpm_limit - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"rpm limit higher than user max. User rpm limit={user_api_key_dict.rpm_limit}. User role={user_api_key_dict.user_role}" - }, - ) - - if ( - data.max_budget is not None - and user_api_key_dict.max_budget is not None - and data.max_budget > user_api_key_dict.max_budget - ): - raise HTTPException( - status_code=400, - detail={ - "error": f"max budget higher than user max. User max budget={user_api_key_dict.max_budget}. User role={user_api_key_dict.user_role}" - }, - ) - - if data.models is not None and len(user_api_key_dict.models) > 0: - for m in data.models: - if m not in user_api_key_dict.models: - raise HTTPException( - status_code=400, - detail={ - "error": f"Model not in allowed user models. User allowed models={user_api_key_dict.models}. User id={user_api_key_dict.user_id}" - }, - ) - - if user_api_key_dict.user_id is not None: - creating_user_in_list = False - for member in data.members_with_roles: - if member.user_id == user_api_key_dict.user_id: - creating_user_in_list = True - - if creating_user_in_list == False: - data.members_with_roles.append( - Member(role="admin", user_id=user_api_key_dict.user_id) - ) - - ## ADD TO MODEL TABLE - _model_id = None - if data.model_aliases is not None and isinstance(data.model_aliases, dict): - litellm_modeltable = LiteLLM_ModelTable( - model_aliases=json.dumps(data.model_aliases), - created_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - updated_by=user_api_key_dict.user_id or litellm_proxy_admin_name, - ) - model_dict = await prisma_client.db.litellm_modeltable.create( - {**litellm_modeltable.json(exclude_none=True)} # type: ignore - ) # type: ignore - - _model_id = model_dict.id - - ## ADD TO TEAM TABLE - complete_team_data = LiteLLM_TeamTable( - **data.json(), - model_id=_model_id, - ) - - # If budget_duration is set, set `budget_reset_at` - if complete_team_data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=complete_team_data.budget_duration) - reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - complete_team_data.budget_reset_at = reset_at - - team_row = await prisma_client.insert_data( - data=complete_team_data.json(exclude_none=True), table_name="team" - ) - - ## ADD TEAM ID TO USER TABLE ## - for user in complete_team_data.members_with_roles: - ## add team id to user row ## - await prisma_client.update_data( - user_id=user.user_id, - data={"user_id": user.user_id, "teams": [team_row.team_id]}, - update_key_values_custom_query={ - "teams": { - "push ": [team_row.team_id], - } - }, - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _updated_values = complete_team_data.json(exclude_none=True) - - _updated_values = json.dumps(_updated_values, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.TEAM_TABLE_NAME, - object_id=data.team_id, - action="created", - updated_values=_updated_values, - before_value=None, - ) - ) - ) - - try: - return team_row.model_dump() - except Exception as e: - return team_row.dict() - - async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): if premium_user is not True: return @@ -9077,593 +5973,6 @@ async def create_audit_log_for_update(request_data: LiteLLM_AuditLogs): return -@router.post( - "/team/update", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def update_team( - data: UpdateTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - litellm_changed_by: Optional[str] = Header( - None, - description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", - ), -): - """ - Use `/team/member_add` AND `/team/member/delete` to add/remove new team members - - You can now update team budget / rate limits via /team/update - - Parameters: - - team_id: str - The team id of the user. Required param. - - team_alias: Optional[str] - User defined team alias - - metadata: Optional[dict] - Metadata for team, store information for team. Example metadata = {"team": "core-infra", "app": "app2", "email": "ishaan@berri.ai" } - - tpm_limit: Optional[int] - The TPM (Tokens Per Minute) limit for this team - all keys with this team_id will have at max this TPM limit - - rpm_limit: Optional[int] - The RPM (Requests Per Minute) limit for this team - all keys associated with this team_id will have at max this RPM limit - - max_budget: Optional[float] - The maximum budget allocated to the team - all keys for this team_id will have at max this max_budget - - budget_duration: Optional[str] - The duration of the budget for the team. Doc [here](https://docs.litellm.ai/docs/proxy/team_budgets) - - models: Optional[list] - A list of models associated with the team - all keys for this team_id will have at most, these models. If empty, assumes all models are allowed. - - blocked: bool - Flag indicating if the team is blocked or not - will stop all calls from keys with this team_id. - - Example - update team TPM Limit - - ``` - curl --location 'http://0.0.0.0:8000/team/update' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data-raw '{ - "team_id": "litellm-test-client-id-new", - "tpm_limit": 100 - }' - ``` - - Example - Update Team `max_budget` budget - ``` - curl --location 'http://0.0.0.0:8000/team/update' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data-raw '{ - "team_id": "litellm-test-client-id-new", - "max_budget": 10 - }' - ``` - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_id is None: - raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) - verbose_proxy_logger.debug("/team/update - %s", data) - - existing_team_row = await prisma_client.get_data( - team_id=data.team_id, table_name="team", query_type="find_unique" - ) - if existing_team_row is None: - raise HTTPException( - status_code=404, - detail={"error": f"Team not found, passed team_id={data.team_id}"}, - ) - - updated_kv = data.json(exclude_none=True) - - # Check budget_duration and budget_reset_at - if data.budget_duration is not None: - duration_s = _duration_in_seconds(duration=data.budget_duration) - reset_at = datetime.now(timezone.utc) + timedelta(seconds=duration_s) - - # set the budget_reset_at in DB - updated_kv["budget_reset_at"] = reset_at - - team_row = await prisma_client.update_data( - update_key_values=updated_kv, - data=updated_kv, - table_name="team", - team_id=data.team_id, - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - if litellm.store_audit_logs is True: - _before_value = existing_team_row.json(exclude_none=True) - _before_value = json.dumps(_before_value, default=str) - _after_value: str = json.dumps(updated_kv, default=str) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.TEAM_TABLE_NAME, - object_id=data.team_id, - action="updated", - updated_values=_after_value, - before_value=_before_value, - ) - ) - ) - - return team_row - - -@router.post( - "/team/member_add", - tags=["team management"], - dependencies=[Depends(user_api_key_auth)], -) -@management_endpoint_wrapper -async def team_member_add( - data: TeamMemberAddRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - [BETA] - - Add new members (either via user_email or user_id) to a team - - If user doesn't exist, new user row will also be added to User Table - - ``` - - curl -X POST 'http://0.0.0.0:4000/team/member_add' \ - -H 'Authorization: Bearer sk-1234' \ - -H 'Content-Type: application/json' \ - -d '{"team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", "member": {"role": "user", "user_id": "krrish247652@berri.ai"}}' - - ``` - """ - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_id is None: - raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) - - if data.member is None: - raise HTTPException( - status_code=400, detail={"error": "No member/members passed in"} - ) - - existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": data.team_id} - ) - if existing_team_row is None: - raise HTTPException( - status_code=404, - detail={ - "error": f"Team not found for team_id={getattr(data, 'team_id', None)}" - }, - ) - - complete_team_data = LiteLLM_TeamTable(**existing_team_row.model_dump()) - - if isinstance(data.member, Member): - # add to team db - new_member = data.member - - complete_team_data.members_with_roles.append(new_member) - - elif isinstance(data.member, List): - # add to team db - new_members = data.member - - complete_team_data.members_with_roles.extend(new_members) - - # ADD MEMBER TO TEAM - _db_team_members = [m.model_dump() for m in complete_team_data.members_with_roles] - updated_team = await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, - data={"members_with_roles": json.dumps(_db_team_members)}, # type: ignore - ) - - if isinstance(data.member, Member): - await add_new_member( - new_member=data.member, - max_budget_in_team=data.max_budget_in_team, - prisma_client=prisma_client, - user_api_key_dict=user_api_key_dict, - litellm_proxy_admin_name=litellm_proxy_admin_name, - team_id=data.team_id, - ) - elif isinstance(data.member, List): - tasks: List = [] - for m in data.member: - await add_new_member( - new_member=m, - max_budget_in_team=data.max_budget_in_team, - prisma_client=prisma_client, - user_api_key_dict=user_api_key_dict, - litellm_proxy_admin_name=litellm_proxy_admin_name, - team_id=data.team_id, - ) - await asyncio.gather(*tasks) - - return updated_team - - -@router.post( - "/team/member_delete", - tags=["team management"], - dependencies=[Depends(user_api_key_auth)], -) -@management_endpoint_wrapper -async def team_member_delete( - data: TeamMemberDeleteRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - [BETA] - - delete members (either via user_email or user_id) from a team - - If user doesn't exist, an exception will be raised - ``` - curl -X POST 'http://0.0.0.0:8000/team/update' \ - - -H 'Authorization: Bearer sk-1234' \ - - -H 'Content-Type: application/json' \ - - -D '{ - "team_id": "45e3e396-ee08-4a61-a88e-16b3ce7e0849", - "user_id": "krrish247652@berri.ai" - }' - ``` - """ - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_id is None: - raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) - - if data.user_id is None and data.user_email is None: - raise HTTPException( - status_code=400, - detail={"error": "Either user_id or user_email needs to be passed in"}, - ) - - _existing_team_row = await prisma_client.db.litellm_teamtable.find_unique( - where={"team_id": data.team_id} - ) - - if _existing_team_row is None: - raise HTTPException( - status_code=400, - detail={"error": "Team id={} does not exist in db".format(data.team_id)}, - ) - existing_team_row = LiteLLM_TeamTable(**_existing_team_row.model_dump()) - - ## DELETE MEMBER FROM TEAM - new_team_members: List[Member] = [] - for m in existing_team_row.members_with_roles: - if ( - data.user_id is not None - and m.user_id is not None - and data.user_id == m.user_id - ): - continue - elif ( - data.user_email is not None - and m.user_email is not None - and data.user_email == m.user_email - ): - continue - new_team_members.append(m) - existing_team_row.members_with_roles = new_team_members - - _db_new_team_members: List[dict] = [m.model_dump() for m in new_team_members] - - _ = await prisma_client.db.litellm_teamtable.update( - where={ - "team_id": data.team_id, - }, - data={"members_with_roles": json.dumps(_db_new_team_members)}, # type: ignore - ) - - ## DELETE TEAM ID from USER ROW, IF EXISTS ## - # get user row - key_val = {} - if data.user_id is not None: - key_val["user_id"] = data.user_id - elif data.user_email is not None: - key_val["user_email"] = data.user_email - existing_user_rows = await prisma_client.db.litellm_usertable.find_many( - where=key_val # type: ignore - ) - - if existing_user_rows is not None and ( - isinstance(existing_user_rows, list) and len(existing_user_rows) > 0 - ): - for existing_user in existing_user_rows: - team_list = [] - if data.team_id in existing_user.teams: - team_list = existing_user.teams - team_list.remove(data.team_id) - await prisma_client.db.litellm_usertable.update( - where={ - "user_id": existing_user.user_id, - }, - data={"teams": {"set": team_list}}, - ) - - return existing_team_row - - -@router.post( - "/team/delete", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def delete_team( - data: DeleteTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - litellm_changed_by: Optional[str] = Header( - None, - description="The litellm-changed-by header enables tracking of actions performed by authorized users on behalf of other users, providing an audit trail for accountability", - ), -): - """ - delete team and associated team keys - - ``` - curl --location 'http://0.0.0.0:8000/team/delete' \ - - --header 'Authorization: Bearer sk-1234' \ - - --header 'Content-Type: application/json' \ - - --data-raw '{ - "team_ids": ["45e3e396-ee08-4a61-a88e-16b3ce7e0849"] - }' - ``` - """ - global prisma_client - - if prisma_client is None: - raise HTTPException(status_code=500, detail={"error": "No db connected"}) - - if data.team_ids is None: - raise HTTPException(status_code=400, detail={"error": "No team id passed in"}) - - # check that all teams passed exist - for team_id in data.team_ids: - team_row = await prisma_client.get_data( # type: ignore - team_id=team_id, table_name="team", query_type="find_unique" - ) - if team_row is None: - raise HTTPException( - status_code=404, - detail={"error": f"Team not found, passed team_id={team_id}"}, - ) - - # Enterprise Feature - Audit Logging. Enable with litellm.store_audit_logs = True - # we do this after the first for loop, since first for loop is for validation. we only want this inserted after validation passes - if litellm.store_audit_logs is True: - # make an audit log for each team deleted - for team_id in data.team_ids: - team_row = await prisma_client.get_data( # type: ignore - team_id=team_id, table_name="team", query_type="find_unique" - ) - - _team_row = team_row.json(exclude_none=True) - - asyncio.create_task( - create_audit_log_for_update( - request_data=LiteLLM_AuditLogs( - id=str(uuid.uuid4()), - updated_at=datetime.now(timezone.utc), - changed_by=litellm_changed_by - or user_api_key_dict.user_id - or litellm_proxy_admin_name, - changed_by_api_key=user_api_key_dict.api_key, - table_name=LitellmTableNames.TEAM_TABLE_NAME, - object_id=team_id, - action="deleted", - updated_values="{}", - before_value=_team_row, - ) - ) - ) - - # End of Audit logging - - ## DELETE ASSOCIATED KEYS - await prisma_client.delete_data(team_id_list=data.team_ids, table_name="key") - ## DELETE TEAMS - deleted_teams = await prisma_client.delete_data( - team_id_list=data.team_ids, table_name="team" - ) - return deleted_teams - - -@router.get( - "/team/info", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def team_info( - http_request: Request, - team_id: str = fastapi.Query( - default=None, description="Team ID in the request parameters" - ), -): - """ - get info on team + related keys - - ``` - curl --location 'http://localhost:4000/team/info' \ - --header 'Authorization: Bearer sk-1234' \ - --header 'Content-Type: application/json' \ - --data '{ - "teams": ["",..] - }' - ``` - """ - global prisma_client - try: - if prisma_client is None: - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={ - "error": f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" - }, - ) - if team_id is None: - raise HTTPException( - status_code=status.HTTP_422_UNPROCESSABLE_ENTITY, - detail={"message": "Malformed request. No team id passed in."}, - ) - - team_info = await prisma_client.get_data( - team_id=team_id, table_name="team", query_type="find_unique" - ) - if team_info is None: - raise HTTPException( - status_code=status.HTTP_404_NOT_FOUND, - detail={"message": f"Team not found, passed team id: {team_id}."}, - ) - - ## GET ALL KEYS ## - keys = await prisma_client.get_data( - team_id=team_id, - table_name="key", - query_type="find_all", - expires=datetime.now(), - ) - - if team_info is None: - ## make sure we still return a total spend ## - spend = 0 - for k in keys: - spend += getattr(k, "spend", 0) - team_info = {"spend": spend} - - ## REMOVE HASHED TOKEN INFO before returning ## - for key in keys: - try: - key = key.model_dump() # noqa - except: - # if using pydantic v1 - key = key.dict() - key.pop("token", None) - return {"team_id": team_id, "team_info": team_info, "keys": keys} - - except Exception as e: - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_400_BAD_REQUEST), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_400_BAD_REQUEST, - ) - - -@router.post( - "/team/block", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def block_team( - data: BlockTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Blocks all calls from keys with this team id. - """ - global prisma_client - - if prisma_client is None: - raise Exception("No DB Connected.") - - record = await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, data={"blocked": True} # type: ignore - ) - - return record - - -@router.post( - "/team/unblock", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def unblock_team( - data: BlockTeamRequest, - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - Blocks all calls from keys with this team id. - """ - global prisma_client - - if prisma_client is None: - raise Exception("No DB Connected.") - - record = await prisma_client.db.litellm_teamtable.update( - where={"team_id": data.team_id}, data={"blocked": False} # type: ignore - ) - - return record - - -@router.get( - "/team/list", tags=["team management"], dependencies=[Depends(user_api_key_auth)] -) -@management_endpoint_wrapper -async def list_team( - http_request: Request, - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), -): - """ - [Admin-only] List all available teams - - ``` - curl --location --request GET 'http://0.0.0.0:4000/team/list' \ - --header 'Authorization: Bearer sk-1234' - ``` - """ - global prisma_client - - if user_api_key_dict.user_role != LitellmUserRoles.PROXY_ADMIN: - raise HTTPException( - status_code=401, - detail={ - "error": "Admin-only endpoint. Your user role={}".format( - user_api_key_dict.user_role - ) - }, - ) - - if prisma_client is None: - raise HTTPException( - status_code=400, - detail={"error": CommonProxyErrors.db_not_connected_error.value}, - ) - - response = await prisma_client.db.litellm_teamtable.find_many() - - return response - - #### ORGANIZATION MANAGEMENT #### @@ -13057,631 +9366,6 @@ async def config_yaml_endpoint(config_info: ConfigYAML): return {"hello": "world"} -#### BASIC ENDPOINTS #### -@router.get( - "/test", - tags=["health"], - dependencies=[Depends(user_api_key_auth)], -) -async def test_endpoint(request: Request): - """ - [DEPRECATED] use `/health/liveliness` instead. - - A test endpoint that pings the proxy server to check if it's healthy. - - Parameters: - request (Request): The incoming request. - - Returns: - dict: A dictionary containing the route of the request URL. - """ - # ping the proxy server to check if its healthy - return {"route": request.url.path} - - -@router.get( - "/health/services", - tags=["health"], - dependencies=[Depends(user_api_key_auth)], - include_in_schema=False, -) -async def health_services_endpoint( - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - service: Literal[ - "slack_budget_alerts", "langfuse", "slack", "openmeter", "webhook", "email" - ] = fastapi.Query(description="Specify the service being hit."), -): - """ - Hidden endpoint. - - Used by the UI to let user check if slack alerting is working as expected. - """ - try: - global general_settings, proxy_logging_obj - - if service is None: - raise HTTPException( - status_code=400, detail={"error": "Service must be specified."} - ) - if service not in [ - "slack_budget_alerts", - "email", - "langfuse", - "slack", - "openmeter", - "webhook", - ]: - raise HTTPException( - status_code=400, - detail={ - "error": f"Service must be in list. Service={service}. List={['slack_budget_alerts']}" - }, - ) - - if service == "openmeter": - _ = await litellm.acompletion( - model="openai/litellm-mock-response-model", - messages=[{"role": "user", "content": "Hey, how's it going?"}], - user="litellm:/health/services", - mock_response="This is a mock response", - ) - return { - "status": "success", - "message": "Mock LLM request made - check openmeter.", - } - - if service == "langfuse": - from litellm.integrations.langfuse import LangFuseLogger - - langfuse_logger = LangFuseLogger() - langfuse_logger.Langfuse.auth_check() - _ = litellm.completion( - model="openai/litellm-mock-response-model", - messages=[{"role": "user", "content": "Hey, how's it going?"}], - user="litellm:/health/services", - mock_response="This is a mock response", - ) - return { - "status": "success", - "message": "Mock LLM request made - check langfuse.", - } - - if service == "webhook": - user_info = CallInfo( - token=user_api_key_dict.token or "", - spend=1, - max_budget=0, - user_id=user_api_key_dict.user_id, - key_alias=user_api_key_dict.key_alias, - team_id=user_api_key_dict.team_id, - ) - await proxy_logging_obj.budget_alerts( - type="user_budget", - user_info=user_info, - ) - - if service == "slack" or service == "slack_budget_alerts": - if "slack" in general_settings.get("alerting", []): - # test_message = f"""\n🚨 `ProjectedLimitExceededError` šŸ’ø\n\n`Key Alias:` litellm-ui-test-alert \n`Expected Day of Error`: 28th March \n`Current Spend`: $100.00 \n`Projected Spend at end of month`: $1000.00 \n`Soft Limit`: $700""" - # check if user has opted into unique_alert_webhooks - if ( - proxy_logging_obj.slack_alerting_instance.alert_to_webhook_url - is not None - ): - for ( - alert_type - ) in proxy_logging_obj.slack_alerting_instance.alert_to_webhook_url: - """ - "llm_exceptions", - "llm_too_slow", - "llm_requests_hanging", - "budget_alerts", - "db_exceptions", - """ - # only test alert if it's in active alert types - if ( - proxy_logging_obj.slack_alerting_instance.alert_types - is not None - and alert_type - not in proxy_logging_obj.slack_alerting_instance.alert_types - ): - continue - test_message = "default test message" - if alert_type == "llm_exceptions": - test_message = f"LLM Exception test alert" - elif alert_type == "llm_too_slow": - test_message = f"LLM Too Slow test alert" - elif alert_type == "llm_requests_hanging": - test_message = f"LLM Requests Hanging test alert" - elif alert_type == "budget_alerts": - test_message = f"Budget Alert test alert" - elif alert_type == "db_exceptions": - test_message = f"DB Exception test alert" - elif alert_type == "outage_alerts": - test_message = f"Outage Alert Exception test alert" - elif alert_type == "daily_reports": - test_message = f"Daily Reports test alert" - - await proxy_logging_obj.alerting_handler( - message=test_message, level="Low", alert_type=alert_type - ) - else: - await proxy_logging_obj.alerting_handler( - message="This is a test slack alert message", - level="Low", - alert_type="budget_alerts", - ) - - if prisma_client is not None: - asyncio.create_task( - proxy_logging_obj.slack_alerting_instance.send_monthly_spend_report() - ) - asyncio.create_task( - proxy_logging_obj.slack_alerting_instance.send_weekly_spend_report() - ) - - alert_types = ( - proxy_logging_obj.slack_alerting_instance.alert_types or [] - ) - alert_types = list(alert_types) - return { - "status": "success", - "alert_types": alert_types, - "message": "Mock Slack Alert sent, verify Slack Alert Received on your channel", - } - else: - raise HTTPException( - status_code=422, - detail={ - "error": '"{}" not in proxy config: general_settings. Unable to test this.'.format( - service - ) - }, - ) - if service == "email": - webhook_event = WebhookEvent( - event="key_created", - event_group="key", - event_message="Test Email Alert", - token=user_api_key_dict.token or "", - key_alias="Email Test key (This is only a test alert key. DO NOT USE THIS IN PRODUCTION.)", - spend=0, - max_budget=0, - user_id=user_api_key_dict.user_id, - user_email=os.getenv("TEST_EMAIL_ADDRESS"), - team_id=user_api_key_dict.team_id, - ) - - # use create task - this can take 10 seconds. don't keep ui users waiting for notification to check their email - asyncio.create_task( - proxy_logging_obj.slack_alerting_instance.send_key_created_or_user_invited_email( - webhook_event=webhook_event - ) - ) - - return { - "status": "success", - "message": "Mock Email Alert sent, verify Email Alert Received", - } - - except Exception as e: - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.health_services_endpoint(): Exception occured - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - if isinstance(e, HTTPException): - raise ProxyException( - message=getattr(e, "detail", f"Authentication Error({str(e)})"), - type="auth_error", - param=getattr(e, "param", "None"), - code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), - ) - elif isinstance(e, ProxyException): - raise e - raise ProxyException( - message="Authentication Error, " + str(e), - type="auth_error", - param=getattr(e, "param", "None"), - code=status.HTTP_500_INTERNAL_SERVER_ERROR, - ) - - -@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)]) -async def health_endpoint( - user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), - model: Optional[str] = fastapi.Query( - None, description="Specify the model name (optional)" - ), -): - """ - 🚨 USE `/health/liveliness` to health check the proxy 🚨 - - See more šŸ‘‰ https://docs.litellm.ai/docs/proxy/health - - - Check the health of all the endpoints in config.yaml - - To run health checks in the background, add this to config.yaml: - ``` - general_settings: - # ... other settings - background_health_checks: True - ``` - else, the health checks will be run on models when /health is called. - """ - global health_check_results, use_background_health_checks, user_model, llm_model_list - try: - if llm_model_list is None: - # if no router set, check if user set a model using litellm --model ollama/llama2 - if user_model is not None: - healthy_endpoints, unhealthy_endpoints = await perform_health_check( - model_list=[], cli_model=user_model - ) - return { - "healthy_endpoints": healthy_endpoints, - "unhealthy_endpoints": unhealthy_endpoints, - "healthy_count": len(healthy_endpoints), - "unhealthy_count": len(unhealthy_endpoints), - } - raise HTTPException( - status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, - detail={"error": "Model list not initialized"}, - ) - _llm_model_list = copy.deepcopy(llm_model_list) - ### FILTER MODELS FOR ONLY THOSE USER HAS ACCESS TO ### - if len(user_api_key_dict.models) > 0: - allowed_model_names = user_api_key_dict.models - else: - allowed_model_names = [] # - if use_background_health_checks: - return health_check_results - else: - healthy_endpoints, unhealthy_endpoints = await perform_health_check( - _llm_model_list, model - ) - - return { - "healthy_endpoints": healthy_endpoints, - "unhealthy_endpoints": unhealthy_endpoints, - "healthy_count": len(healthy_endpoints), - "unhealthy_count": len(unhealthy_endpoints), - } - except Exception as e: - verbose_proxy_logger.error( - "litellm.proxy.proxy_server.py::health_endpoint(): Exception occured - {}".format( - str(e) - ) - ) - verbose_proxy_logger.debug(traceback.format_exc()) - raise e - - -db_health_cache = {"status": "unknown", "last_updated": datetime.now()} - - -def _db_health_readiness_check(): - global db_health_cache, prisma_client - - # Note - Intentionally don't try/except this so it raises an exception when it fails - - # if timedelta is less than 2 minutes return DB Status - time_diff = datetime.now() - db_health_cache["last_updated"] - if db_health_cache["status"] != "unknown" and time_diff < timedelta(minutes=2): - return db_health_cache - prisma_client.health_check() - db_health_cache = {"status": "connected", "last_updated": datetime.now()} - return db_health_cache - - -@router.get( - "/active/callbacks", - tags=["health"], - dependencies=[Depends(user_api_key_auth)], -) -async def active_callbacks(): - """ - Returns a list of active callbacks on litellm.callbacks, litellm.input_callback, litellm.failure_callback, litellm.success_callback - """ - global proxy_logging_obj - _alerting = str(general_settings.get("alerting")) - # get success callbacks - - litellm_callbacks = [str(x) for x in litellm.callbacks] - litellm_input_callbacks = [str(x) for x in litellm.input_callback] - litellm_failure_callbacks = [str(x) for x in litellm.failure_callback] - litellm_success_callbacks = [str(x) for x in litellm.success_callback] - litellm_async_success_callbacks = [str(x) for x in litellm._async_success_callback] - litellm_async_failure_callbacks = [str(x) for x in litellm._async_failure_callback] - litellm_async_input_callbacks = [str(x) for x in litellm._async_input_callback] - - all_litellm_callbacks = ( - litellm_callbacks - + litellm_input_callbacks - + litellm_failure_callbacks - + litellm_success_callbacks - + litellm_async_success_callbacks - + litellm_async_failure_callbacks - + litellm_async_input_callbacks - ) - - alerting = proxy_logging_obj.alerting - _num_alerting = 0 - if alerting and isinstance(alerting, list): - _num_alerting = len(alerting) - - return { - "alerting": _alerting, - "litellm.callbacks": litellm_callbacks, - "litellm.input_callback": litellm_input_callbacks, - "litellm.failure_callback": litellm_failure_callbacks, - "litellm.success_callback": litellm_success_callbacks, - "litellm._async_success_callback": litellm_async_success_callbacks, - "litellm._async_failure_callback": litellm_async_failure_callbacks, - "litellm._async_input_callback": litellm_async_input_callbacks, - "all_litellm_callbacks": all_litellm_callbacks, - "num_callbacks": len(all_litellm_callbacks), - "num_alerting": _num_alerting, - } - - -@router.get( - "/health/readiness", - tags=["health"], - dependencies=[Depends(user_api_key_auth)], -) -async def health_readiness(): - """ - Unprotected endpoint for checking if worker can receive requests - """ - global general_settings - try: - # get success callback - success_callback_names = [] - - try: - # this was returning a JSON of the values in some of the callbacks - # all we need is the callback name, hence we do str(callback) - success_callback_names = [str(x) for x in litellm.success_callback] - except: - # don't let this block the /health/readiness response, if we can't convert to str -> return litellm.success_callback - success_callback_names = litellm.success_callback - - # check Cache - cache_type = None - if litellm.cache is not None: - from litellm.caching import RedisSemanticCache - - cache_type = litellm.cache.type - - if isinstance(litellm.cache.cache, RedisSemanticCache): - # ping the cache - # TODO: @ishaan-jaff - we should probably not ping the cache on every /health/readiness check - try: - index_info = await litellm.cache.cache._index_info() - except Exception as e: - index_info = "index does not exist - error: " + str(e) - cache_type = {"type": cache_type, "index_info": index_info} - - # check DB - if prisma_client is not None: # if db passed in, check if it's connected - db_health_status = _db_health_readiness_check() - return { - "status": "healthy", - "db": "connected", - "cache": cache_type, - "litellm_version": version, - "success_callbacks": success_callback_names, - **db_health_status, - } - else: - return { - "status": "healthy", - "db": "Not connected", - "cache": cache_type, - "litellm_version": version, - "success_callbacks": success_callback_names, - } - except Exception as e: - raise HTTPException(status_code=503, detail=f"Service Unhealthy ({str(e)})") - - -@router.get( - "/health/liveliness", - tags=["health"], - dependencies=[Depends(user_api_key_auth)], -) -async def health_liveliness(): - """ - Unprotected endpoint for checking if worker is alive - """ - return "I'm alive!" - - -@router.get( - "/cache/ping", - tags=["caching"], - dependencies=[Depends(user_api_key_auth)], -) -async def cache_ping(): - """ - Endpoint for checking if cache can be pinged - """ - try: - litellm_cache_params = {} - specific_cache_params = {} - - if litellm.cache is None: - raise HTTPException( - status_code=503, detail="Cache not initialized. litellm.cache is None" - ) - - for k, v in vars(litellm.cache).items(): - try: - if k == "cache": - continue - litellm_cache_params[k] = str(copy.deepcopy(v)) - except Exception: - litellm_cache_params[k] = "" - for k, v in vars(litellm.cache.cache).items(): - try: - specific_cache_params[k] = str(v) - except Exception: - specific_cache_params[k] = "" - if litellm.cache.type == "redis": - # ping the redis cache - ping_response = await litellm.cache.ping() - verbose_proxy_logger.debug( - "/cache/ping: ping_response: " + str(ping_response) - ) - # making a set cache call - # add cache does not return anything - await litellm.cache.async_add_cache( - result="test_key", - model="test-model", - messages=[{"role": "user", "content": "test from litellm"}], - ) - verbose_proxy_logger.debug("/cache/ping: done with set_cache()") - return { - "status": "healthy", - "cache_type": litellm.cache.type, - "ping_response": True, - "set_cache_response": "success", - "litellm_cache_params": litellm_cache_params, - "redis_cache_params": specific_cache_params, - } - else: - return { - "status": "healthy", - "cache_type": litellm.cache.type, - "litellm_cache_params": litellm_cache_params, - } - except Exception as e: - raise HTTPException( - status_code=503, - detail=f"Service Unhealthy ({str(e)}).Cache parameters: {litellm_cache_params}.specific_cache_params: {specific_cache_params}", - ) - - -@router.post( - "/cache/delete", - tags=["caching"], - dependencies=[Depends(user_api_key_auth)], -) -async def cache_delete(request: Request): - """ - Endpoint for deleting a key from the cache. All responses from litellm proxy have `x-litellm-cache-key` in the headers - - Parameters: - - **keys**: *Optional[List[str]]* - A list of keys to delete from the cache. Example {"keys": ["key1", "key2"]} - - ```shell - curl -X POST "http://0.0.0.0:4000/cache/delete" \ - -H "Authorization: Bearer sk-1234" \ - -d '{"keys": ["key1", "key2"]}' - ``` - - """ - try: - if litellm.cache is None: - raise HTTPException( - status_code=503, detail="Cache not initialized. litellm.cache is None" - ) - - request_data = await request.json() - keys = request_data.get("keys", None) - - if litellm.cache.type == "redis": - await litellm.cache.delete_cache_keys(keys=keys) - return { - "status": "success", - } - else: - raise HTTPException( - status_code=500, - detail=f"Cache type {litellm.cache.type} does not support deleting a key. only `redis` is supported", - ) - except Exception as e: - raise HTTPException( - status_code=500, - detail=f"Cache Delete Failed({str(e)})", - ) - - -@router.get( - "/cache/redis/info", - tags=["caching"], - dependencies=[Depends(user_api_key_auth)], -) -async def cache_redis_info(): - """ - Endpoint for getting /redis/info - """ - try: - if litellm.cache is None: - raise HTTPException( - status_code=503, detail="Cache not initialized. litellm.cache is None" - ) - if litellm.cache.type == "redis": - client_list = litellm.cache.cache.client_list() - redis_info = litellm.cache.cache.info() - num_clients = len(client_list) - return { - "num_clients": num_clients, - "clients": client_list, - "info": redis_info, - } - else: - raise HTTPException( - status_code=500, - detail=f"Cache type {litellm.cache.type} does not support flushing", - ) - except Exception as e: - raise HTTPException( - status_code=503, - detail=f"Service Unhealthy ({str(e)})", - ) - - -@router.post( - "/cache/flushall", - tags=["caching"], - dependencies=[Depends(user_api_key_auth)], -) -async def cache_flushall(): - """ - A function to flush all items from the cache. (All items will be deleted from the cache with this) - Raises HTTPException if the cache is not initialized or if the cache type does not support flushing. - Returns a dictionary with the status of the operation. - - Usage: - ``` - curl -X POST http://0.0.0.0:4000/cache/flushall -H "Authorization: Bearer sk-1234" - ``` - """ - try: - if litellm.cache is None: - raise HTTPException( - status_code=503, detail="Cache not initialized. litellm.cache is None" - ) - if litellm.cache.type == "redis": - litellm.cache.cache.flushall() - return { - "status": "success", - } - else: - raise HTTPException( - status_code=500, - detail=f"Cache type {litellm.cache.type} does not support flushing", - ) - except Exception as e: - raise HTTPException( - status_code=503, - detail=f"Service Unhealthy ({str(e)})", - ) - - @router.get( "/get/litellm_model_cost_map", include_in_schema=False, @@ -13726,7 +9410,11 @@ async def get_routes(): #### TEST ENDPOINTS #### -@router.get("/token/generate", dependencies=[Depends(user_api_key_auth)]) +@router.get( + "/token/generate", + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) async def token_generate(): """ Test endpoint. Admin-only access. Meant for generating admin tokens with specific claims and testing if they work for creating keys, etc. @@ -13794,3 +9482,8 @@ def cleanup_router_config_variables(): app.include_router(router) +app.include_router(health_router) +app.include_router(key_management_router) +app.include_router(team_router) +app.include_router(spend_management_router) +app.include_router(caching_router) diff --git a/litellm/proxy/spend_reporting_endpoints/spend_management_endpoints.py b/litellm/proxy/spend_reporting_endpoints/spend_management_endpoints.py new file mode 100644 index 000000000..901a92645 --- /dev/null +++ b/litellm/proxy/spend_reporting_endpoints/spend_management_endpoints.py @@ -0,0 +1,1829 @@ +#### SPEND MANAGEMENT ##### +from typing import Optional, List +import litellm +from litellm._logging import verbose_proxy_logger +from datetime import datetime, timedelta, timezone +from litellm.proxy.auth.user_api_key_auth import user_api_key_auth +import fastapi +from fastapi import Depends, Request, APIRouter, Header, status +from fastapi import HTTPException +from litellm.proxy._types import * + +router = APIRouter() + + +@router.get( + "/spend/keys", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def spend_key_fn(): + """ + View all keys created, ordered by spend + + Example Request: + ``` + curl -X GET "http://0.0.0.0:8000/spend/keys" \ +-H "Authorization: Bearer sk-1234" + ``` + """ + + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + key_info = await prisma_client.get_data(table_name="key", query_type="find_all") + return key_info + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + + +@router.get( + "/spend/users", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def spend_user_fn( + user_id: Optional[str] = fastapi.Query( + default=None, + description="Get User Table row for user_id", + ), +): + """ + View all users created, ordered by spend + + Example Request: + ``` + curl -X GET "http://0.0.0.0:8000/spend/users" \ +-H "Authorization: Bearer sk-1234" + ``` + + View User Table row for user_id + ``` + curl -X GET "http://0.0.0.0:8000/spend/users?user_id=1234" \ +-H "Authorization: Bearer sk-1234" + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + if user_id is not None: + user_info = await prisma_client.get_data( + table_name="user", query_type="find_unique", user_id=user_id + ) + return [user_info] + else: + user_info = await prisma_client.get_data( + table_name="user", query_type="find_all" + ) + + return user_info + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + + +@router.get( + "/spend/tags", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, +) +async def view_spend_tags( + start_date: Optional[str] = fastapi.Query( + default=None, + description="Time from which to start viewing key spend", + ), + end_date: Optional[str] = fastapi.Query( + default=None, + description="Time till which to view key spend", + ), +): + """ + LiteLLM Enterprise - View Spend Per Request Tag + + Example Request: + ``` + curl -X GET "http://0.0.0.0:8000/spend/tags" \ +-H "Authorization: Bearer sk-1234" + ``` + + Spend with Start Date and End Date + ``` + curl -X GET "http://0.0.0.0:8000/spend/tags?start_date=2022-01-01&end_date=2022-02-01" \ +-H "Authorization: Bearer sk-1234" + ``` + """ + + from enterprise.utils import get_spend_by_tags + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + # run the following SQL query on prisma + """ + SELECT + jsonb_array_elements_text(request_tags) AS individual_request_tag, + COUNT(*) AS log_count, + SUM(spend) AS total_spend + FROM "LiteLLM_SpendLogs" + GROUP BY individual_request_tag; + """ + response = await get_spend_by_tags( + start_date=start_date, end_date=end_date, prisma_client=prisma_client + ) + + return response + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"/spend/tags Error({str(e)})"), + type="internal_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="/spend/tags Error" + str(e), + type="internal_error", + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@router.get( + "/global/activity", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, + include_in_schema=False, +) +async def get_global_activity( + start_date: Optional[str] = fastapi.Query( + default=None, + description="Time from which to start viewing spend", + ), + end_date: Optional[str] = fastapi.Query( + default=None, + description="Time till which to view spend", + ), +): + """ + Get number of API Requests, total tokens through proxy + + { + "daily_data": [ + const chartdata = [ + { + date: 'Jan 22', + api_requests: 10, + total_tokens: 2000 + }, + { + date: 'Jan 23', + api_requests: 10, + total_tokens: 12 + }, + ], + "sum_api_requests": 20, + "sum_total_tokens": 2012 + } + """ + from collections import defaultdict + + if start_date is None or end_date is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Please provide start_date and end_date"}, + ) + + start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") + end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") + + from litellm.proxy.proxy_server import prisma_client, llm_router + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + sql_query = """ + SELECT + date_trunc('day', "startTime") AS date, + COUNT(*) AS api_requests, + SUM(total_tokens) AS total_tokens + FROM "LiteLLM_SpendLogs" + WHERE "startTime" BETWEEN $1::date AND $2::date + interval '1 day' + GROUP BY date_trunc('day', "startTime") + """ + db_response = await prisma_client.db.query_raw( + sql_query, start_date_obj, end_date_obj + ) + + if db_response is None: + return [] + + sum_api_requests = 0 + sum_total_tokens = 0 + daily_data = [] + for row in db_response: + # cast date to datetime + _date_obj = datetime.fromisoformat(row["date"]) + row["date"] = _date_obj.strftime("%b %d") + + daily_data.append(row) + sum_api_requests += row.get("api_requests", 0) + sum_total_tokens += row.get("total_tokens", 0) + + # sort daily_data by date + daily_data = sorted(daily_data, key=lambda x: x["date"]) + + data_to_return = { + "daily_data": daily_data, + "sum_api_requests": sum_api_requests, + "sum_total_tokens": sum_total_tokens, + } + + return data_to_return + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + + +@router.get( + "/global/activity/model", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, + include_in_schema=False, +) +async def get_global_activity_model( + start_date: Optional[str] = fastapi.Query( + default=None, + description="Time from which to start viewing spend", + ), + end_date: Optional[str] = fastapi.Query( + default=None, + description="Time till which to view spend", + ), +): + """ + Get number of API Requests, total tokens through proxy - Grouped by MODEL + + [ + { + "model": "gpt-4", + "daily_data": [ + const chartdata = [ + { + date: 'Jan 22', + api_requests: 10, + total_tokens: 2000 + }, + { + date: 'Jan 23', + api_requests: 10, + total_tokens: 12 + }, + ], + "sum_api_requests": 20, + "sum_total_tokens": 2012 + + }, + { + "model": "azure/gpt-4-turbo", + "daily_data": [ + const chartdata = [ + { + date: 'Jan 22', + api_requests: 10, + total_tokens: 2000 + }, + { + date: 'Jan 23', + api_requests: 10, + total_tokens: 12 + }, + ], + "sum_api_requests": 20, + "sum_total_tokens": 2012 + + }, + ] + """ + from collections import defaultdict + + if start_date is None or end_date is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Please provide start_date and end_date"}, + ) + + start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") + end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") + + from litellm.proxy.proxy_server import prisma_client, llm_router, premium_user + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + sql_query = """ + SELECT + model_group, + date_trunc('day', "startTime") AS date, + COUNT(*) AS api_requests, + SUM(total_tokens) AS total_tokens + FROM "LiteLLM_SpendLogs" + WHERE "startTime" BETWEEN $1::date AND $2::date + interval '1 day' + GROUP BY model_group, date_trunc('day', "startTime") + """ + db_response = await prisma_client.db.query_raw( + sql_query, start_date_obj, end_date_obj + ) + if db_response is None: + return [] + + model_ui_data: dict = ( + {} + ) # {"gpt-4": {"daily_data": [], "sum_api_requests": 0, "sum_total_tokens": 0}} + + for row in db_response: + _model = row["model_group"] + if _model not in model_ui_data: + model_ui_data[_model] = { + "daily_data": [], + "sum_api_requests": 0, + "sum_total_tokens": 0, + } + _date_obj = datetime.fromisoformat(row["date"]) + row["date"] = _date_obj.strftime("%b %d") + + model_ui_data[_model]["daily_data"].append(row) + model_ui_data[_model]["sum_api_requests"] += row.get("api_requests", 0) + model_ui_data[_model]["sum_total_tokens"] += row.get("total_tokens", 0) + + # sort mode ui data by sum_api_requests -> get top 10 models + model_ui_data = dict( + sorted( + model_ui_data.items(), + key=lambda x: x[1]["sum_api_requests"], + reverse=True, + )[:10] + ) + + response = [] + for model, data in model_ui_data.items(): + _sort_daily_data = sorted(data["daily_data"], key=lambda x: x["date"]) + + response.append( + { + "model": model, + "daily_data": _sort_daily_data, + "sum_api_requests": data["sum_api_requests"], + "sum_total_tokens": data["sum_total_tokens"], + } + ) + + return response + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": str(e)}, + ) + + +@router.get( + "/global/activity/exceptions/deployment", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, + include_in_schema=False, +) +async def get_global_activity_exceptions_per_deployment( + model_group: str = fastapi.Query( + description="Filter by model group", + ), + start_date: Optional[str] = fastapi.Query( + default=None, + description="Time from which to start viewing spend", + ), + end_date: Optional[str] = fastapi.Query( + default=None, + description="Time till which to view spend", + ), +): + """ + Get number of 429 errors - Grouped by deployment + + [ + { + "deployment": "https://azure-us-east-1.openai.azure.com/", + "daily_data": [ + const chartdata = [ + { + date: 'Jan 22', + num_rate_limit_exceptions: 10 + }, + { + date: 'Jan 23', + num_rate_limit_exceptions: 12 + }, + ], + "sum_num_rate_limit_exceptions": 20, + + }, + { + "deployment": "https://azure-us-east-1.openai.azure.com/", + "daily_data": [ + const chartdata = [ + { + date: 'Jan 22', + num_rate_limit_exceptions: 10, + }, + { + date: 'Jan 23', + num_rate_limit_exceptions: 12 + }, + ], + "sum_num_rate_limit_exceptions": 20, + + }, + ] + """ + from collections import defaultdict + + if start_date is None or end_date is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Please provide start_date and end_date"}, + ) + + start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") + end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") + + from litellm.proxy.proxy_server import prisma_client, llm_router, premium_user + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + sql_query = """ + SELECT + api_base, + date_trunc('day', "startTime")::date AS date, + COUNT(*) AS num_rate_limit_exceptions + FROM + "LiteLLM_ErrorLogs" + WHERE + "startTime" >= $1::date + AND "startTime" < ($2::date + INTERVAL '1 day') + AND model_group = $3 + AND status_code = '429' + GROUP BY + api_base, + date_trunc('day', "startTime") + ORDER BY + date; + """ + db_response = await prisma_client.db.query_raw( + sql_query, start_date_obj, end_date_obj, model_group + ) + if db_response is None: + return [] + + model_ui_data: dict = ( + {} + ) # {"gpt-4": {"daily_data": [], "sum_api_requests": 0, "sum_total_tokens": 0}} + + for row in db_response: + _model = row["api_base"] + if _model not in model_ui_data: + model_ui_data[_model] = { + "daily_data": [], + "sum_num_rate_limit_exceptions": 0, + } + _date_obj = datetime.fromisoformat(row["date"]) + row["date"] = _date_obj.strftime("%b %d") + + model_ui_data[_model]["daily_data"].append(row) + model_ui_data[_model]["sum_num_rate_limit_exceptions"] += row.get( + "num_rate_limit_exceptions", 0 + ) + + # sort mode ui data by sum_api_requests -> get top 10 models + model_ui_data = dict( + sorted( + model_ui_data.items(), + key=lambda x: x[1]["sum_num_rate_limit_exceptions"], + reverse=True, + )[:10] + ) + + response = [] + for model, data in model_ui_data.items(): + _sort_daily_data = sorted(data["daily_data"], key=lambda x: x["date"]) + + response.append( + { + "api_base": model, + "daily_data": _sort_daily_data, + "sum_num_rate_limit_exceptions": data[ + "sum_num_rate_limit_exceptions" + ], + } + ) + + return response + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, + detail={"error": str(e)}, + ) + + +@router.get( + "/global/activity/exceptions", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, + include_in_schema=False, +) +async def get_global_activity_exceptions( + model_group: str = fastapi.Query( + description="Filter by model group", + ), + start_date: Optional[str] = fastapi.Query( + default=None, + description="Time from which to start viewing spend", + ), + end_date: Optional[str] = fastapi.Query( + default=None, + description="Time till which to view spend", + ), +): + """ + Get number of API Requests, total tokens through proxy + + { + "daily_data": [ + const chartdata = [ + { + date: 'Jan 22', + num_rate_limit_exceptions: 10, + }, + { + date: 'Jan 23', + num_rate_limit_exceptions: 10, + }, + ], + "sum_api_exceptions": 20, + } + """ + from collections import defaultdict + + if start_date is None or end_date is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Please provide start_date and end_date"}, + ) + + start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") + end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") + + from litellm.proxy.proxy_server import prisma_client, llm_router + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + sql_query = """ + SELECT + date_trunc('day', "startTime")::date AS date, + COUNT(*) AS num_rate_limit_exceptions + FROM + "LiteLLM_ErrorLogs" + WHERE + "startTime" >= $1::date + AND "startTime" < ($2::date + INTERVAL '1 day') + AND model_group = $3 + AND status_code = '429' + GROUP BY + date_trunc('day', "startTime") + ORDER BY + date; + """ + db_response = await prisma_client.db.query_raw( + sql_query, start_date_obj, end_date_obj, model_group + ) + + if db_response is None: + return [] + + sum_num_rate_limit_exceptions = 0 + daily_data = [] + for row in db_response: + # cast date to datetime + _date_obj = datetime.fromisoformat(row["date"]) + row["date"] = _date_obj.strftime("%b %d") + + daily_data.append(row) + sum_num_rate_limit_exceptions += row.get("num_rate_limit_exceptions", 0) + + # sort daily_data by date + daily_data = sorted(daily_data, key=lambda x: x["date"]) + + data_to_return = { + "daily_data": daily_data, + "sum_num_rate_limit_exceptions": sum_num_rate_limit_exceptions, + } + + return data_to_return + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + + +@router.get( + "/global/spend/provider", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, +) +async def get_global_spend_provider( + start_date: Optional[str] = fastapi.Query( + default=None, + description="Time from which to start viewing spend", + ), + end_date: Optional[str] = fastapi.Query( + default=None, + description="Time till which to view spend", + ), +): + """ + Get breakdown of spend per provider + [ + { + "provider": "Azure OpenAI", + "spend": 20 + }, + { + "provider": "OpenAI", + "spend": 10 + }, + { + "provider": "VertexAI", + "spend": 30 + } + ] + """ + from collections import defaultdict + + if start_date is None or end_date is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Please provide start_date and end_date"}, + ) + + start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") + end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") + + from litellm.proxy.proxy_server import prisma_client, llm_router + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + sql_query = """ + + SELECT + model_id, + SUM(spend) AS spend + FROM "LiteLLM_SpendLogs" + WHERE "startTime" BETWEEN $1::date AND $2::date AND length(model_id) > 0 + GROUP BY model_id + """ + + db_response = await prisma_client.db.query_raw( + sql_query, start_date_obj, end_date_obj + ) + if db_response is None: + return [] + + ################################### + # Convert model_id -> to Provider # + ################################### + + # we use the in memory router for this + ui_response = [] + provider_spend_mapping: defaultdict = defaultdict(int) + for row in db_response: + _model_id = row["model_id"] + _provider = "Unknown" + if llm_router is not None: + _deployment = llm_router.get_deployment(model_id=_model_id) + if _deployment is not None: + try: + _, _provider, _, _ = litellm.get_llm_provider( + model=_deployment.litellm_params.model, + custom_llm_provider=_deployment.litellm_params.custom_llm_provider, + api_base=_deployment.litellm_params.api_base, + litellm_params=_deployment.litellm_params, + ) + provider_spend_mapping[_provider] += row["spend"] + except: + pass + + for provider, spend in provider_spend_mapping.items(): + ui_response.append({"provider": provider, "spend": spend}) + + return ui_response + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + + +@router.get( + "/global/spend/report", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, +) +async def get_global_spend_report( + start_date: Optional[str] = fastapi.Query( + default=None, + description="Time from which to start viewing spend", + ), + end_date: Optional[str] = fastapi.Query( + default=None, + description="Time till which to view spend", + ), + group_by: Optional[Literal["team", "customer"]] = fastapi.Query( + default="team", + description="Group spend by internal team or customer", + ), +): + """ + Get Daily Spend per Team, based on specific startTime and endTime. Per team, view usage by each key, model + [ + { + "group-by-day": "2024-05-10", + "teams": [ + { + "team_name": "team-1" + "spend": 10, + "keys": [ + "key": "1213", + "usage": { + "model-1": { + "cost": 12.50, + "input_tokens": 1000, + "output_tokens": 5000, + "requests": 100 + }, + "audio-modelname1": { + "cost": 25.50, + "seconds": 25, + "requests": 50 + }, + } + } + ] + ] + } + """ + if start_date is None or end_date is None: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": "Please provide start_date and end_date"}, + ) + + start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") + end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") + + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + if group_by == "team": + # first get data from spend logs -> SpendByModelApiKey + # then read data from "SpendByModelApiKey" to format the response obj + sql_query = """ + + WITH SpendByModelApiKey AS ( + SELECT + date_trunc('day', sl."startTime") AS group_by_day, + COALESCE(tt.team_alias, 'Unassigned Team') AS team_name, + sl.model, + sl.api_key, + SUM(sl.spend) AS model_api_spend, + SUM(sl.total_tokens) AS model_api_tokens + FROM + "LiteLLM_SpendLogs" sl + LEFT JOIN + "LiteLLM_TeamTable" tt + ON + sl.team_id = tt.team_id + WHERE + sl."startTime" BETWEEN $1::date AND $2::date + GROUP BY + date_trunc('day', sl."startTime"), + tt.team_alias, + sl.model, + sl.api_key + ) + SELECT + group_by_day, + jsonb_agg(jsonb_build_object( + 'team_name', team_name, + 'total_spend', total_spend, + 'metadata', metadata + )) AS teams + FROM ( + SELECT + group_by_day, + team_name, + SUM(model_api_spend) AS total_spend, + jsonb_agg(jsonb_build_object( + 'model', model, + 'api_key', api_key, + 'spend', model_api_spend, + 'total_tokens', model_api_tokens + )) AS metadata + FROM + SpendByModelApiKey + GROUP BY + group_by_day, + team_name + ) AS aggregated + GROUP BY + group_by_day + ORDER BY + group_by_day; + """ + + db_response = await prisma_client.db.query_raw( + sql_query, start_date_obj, end_date_obj + ) + if db_response is None: + return [] + + return db_response + + elif group_by == "customer": + sql_query = """ + + WITH SpendByModelApiKey AS ( + SELECT + date_trunc('day', sl."startTime") AS group_by_day, + sl.end_user AS customer, + sl.model, + sl.api_key, + SUM(sl.spend) AS model_api_spend, + SUM(sl.total_tokens) AS model_api_tokens + FROM + "LiteLLM_SpendLogs" sl + WHERE + sl."startTime" BETWEEN $1::date AND $2::date + GROUP BY + date_trunc('day', sl."startTime"), + customer, + sl.model, + sl.api_key + ) + SELECT + group_by_day, + jsonb_agg(jsonb_build_object( + 'customer', customer, + 'total_spend', total_spend, + 'metadata', metadata + )) AS customers + FROM + ( + SELECT + group_by_day, + customer, + SUM(model_api_spend) AS total_spend, + jsonb_agg(jsonb_build_object( + 'model', model, + 'api_key', api_key, + 'spend', model_api_spend, + 'total_tokens', model_api_tokens + )) AS metadata + FROM + SpendByModelApiKey + GROUP BY + group_by_day, + customer + ) AS aggregated + GROUP BY + group_by_day + ORDER BY + group_by_day; + """ + + db_response = await prisma_client.db.query_raw( + sql_query, start_date_obj, end_date_obj + ) + if db_response is None: + return [] + + return db_response + + except Exception as e: + raise HTTPException( + status_code=status.HTTP_400_BAD_REQUEST, + detail={"error": str(e)}, + ) + + +@router.get( + "/global/spend/all_tag_names", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, +) +async def global_get_all_tag_names(): + try: + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + sql_query = """ + SELECT DISTINCT + jsonb_array_elements_text(request_tags) AS individual_request_tag + FROM "LiteLLM_SpendLogs"; + """ + + db_response = await prisma_client.db.query_raw(sql_query) + if db_response is None: + return [] + + _tag_names = [] + for row in db_response: + _tag_names.append(row.get("individual_request_tag")) + + return {"tag_names": _tag_names} + + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"/spend/all_tag_names Error({str(e)})"), + type="internal_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="/spend/all_tag_names Error" + str(e), + type="internal_error", + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@router.get( + "/global/spend/tags", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, +) +async def global_view_spend_tags( + start_date: Optional[str] = fastapi.Query( + default=None, + description="Time from which to start viewing key spend", + ), + end_date: Optional[str] = fastapi.Query( + default=None, + description="Time till which to view key spend", + ), + tags: Optional[str] = fastapi.Query( + default=None, + description="comman separated tags to filter on", + ), +): + """ + LiteLLM Enterprise - View Spend Per Request Tag. Used by LiteLLM UI + + Example Request: + ``` + curl -X GET "http://0.0.0.0:4000/spend/tags" \ +-H "Authorization: Bearer sk-1234" + ``` + + Spend with Start Date and End Date + ``` + curl -X GET "http://0.0.0.0:4000/spend/tags?start_date=2022-01-01&end_date=2022-02-01" \ +-H "Authorization: Bearer sk-1234" + ``` + """ + + from enterprise.utils import ui_get_spend_by_tags + + from litellm.proxy.proxy_server import prisma_client + + try: + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + + if end_date is None or start_date is None: + raise ProxyException( + message="Please provide start_date and end_date", + type="bad_request", + param=None, + code=status.HTTP_400_BAD_REQUEST, + ) + response = await ui_get_spend_by_tags( + start_date=start_date, + end_date=end_date, + tags_str=tags, + prisma_client=prisma_client, + ) + + return response + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"/spend/tags Error({str(e)})"), + type="internal_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="/spend/tags Error" + str(e), + type="internal_error", + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +async def _get_spend_report_for_time_range( + start_date: str, + end_date: str, +): + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + verbose_proxy_logger.error( + f"Database not connected. Connect a database to your proxy for weekly, monthly spend reports" + ) + return None + + try: + sql_query = """ + SELECT + t.team_alias, + SUM(s.spend) AS total_spend + FROM + "LiteLLM_SpendLogs" s + LEFT JOIN + "LiteLLM_TeamTable" t ON s.team_id = t.team_id + WHERE + s."startTime"::DATE >= $1::date AND s."startTime"::DATE <= $2::date + GROUP BY + t.team_alias + ORDER BY + total_spend DESC; + """ + response = await prisma_client.db.query_raw(sql_query, start_date, end_date) + + # get spend per tag for today + sql_query = """ + SELECT + jsonb_array_elements_text(request_tags) AS individual_request_tag, + SUM(spend) AS total_spend + FROM "LiteLLM_SpendLogs" + WHERE "startTime"::DATE >= $1::date AND "startTime"::DATE <= $2::date + GROUP BY individual_request_tag + ORDER BY total_spend DESC; + """ + + spend_per_tag = await prisma_client.db.query_raw( + sql_query, start_date, end_date + ) + + return response, spend_per_tag + except Exception as e: + verbose_proxy_logger.error( + "Exception in _get_daily_spend_reports {}".format(str(e)) + ) # noqa + + +@router.post( + "/spend/calculate", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + responses={ + 200: { + "cost": { + "description": "The calculated cost", + "example": 0.0, + "type": "float", + } + } + }, +) +async def calculate_spend(request: Request): + """ + Accepts all the params of completion_cost. + + Calculate spend **before** making call: + + Note: If you see a spend of $0.0 you need to set custom_pricing for your model: https://docs.litellm.ai/docs/proxy/custom_pricing + + ``` + curl --location 'http://localhost:4000/spend/calculate' + --header 'Authorization: Bearer sk-1234' + --header 'Content-Type: application/json' + --data '{ + "model": "anthropic.claude-v2", + "messages": [{"role": "user", "content": "Hey, how'''s it going?"}] + }' + ``` + + Calculate spend **after** making call: + + ``` + curl --location 'http://localhost:4000/spend/calculate' + --header 'Authorization: Bearer sk-1234' + --header 'Content-Type: application/json' + --data '{ + "completion_response": { + "id": "chatcmpl-123", + "object": "chat.completion", + "created": 1677652288, + "model": "gpt-3.5-turbo-0125", + "system_fingerprint": "fp_44709d6fcb", + "choices": [{ + "index": 0, + "message": { + "role": "assistant", + "content": "Hello there, how may I assist you today?" + }, + "logprobs": null, + "finish_reason": "stop" + }] + "usage": { + "prompt_tokens": 9, + "completion_tokens": 12, + "total_tokens": 21 + } + } + }' + ``` + """ + from litellm import completion_cost + + data = await request.json() + if "completion_response" in data: + data["completion_response"] = litellm.ModelResponse( + **data["completion_response"] + ) + return {"cost": completion_cost(**data)} + + +@router.get( + "/spend/logs", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + responses={ + 200: {"model": List[LiteLLM_SpendLogs]}, + }, +) +async def view_spend_logs( + api_key: Optional[str] = fastapi.Query( + default=None, + description="Get spend logs based on api key", + ), + user_id: Optional[str] = fastapi.Query( + default=None, + description="Get spend logs based on user_id", + ), + request_id: Optional[str] = fastapi.Query( + default=None, + description="request_id to get spend logs for specific request_id. If none passed then pass spend logs for all requests", + ), + start_date: Optional[str] = fastapi.Query( + default=None, + description="Time from which to start viewing key spend", + ), + end_date: Optional[str] = fastapi.Query( + default=None, + description="Time till which to view key spend", + ), +): + """ + View all spend logs, if request_id is provided, only logs for that request_id will be returned + + Example Request for all logs + ``` + curl -X GET "http://0.0.0.0:8000/spend/logs" \ +-H "Authorization: Bearer sk-1234" + ``` + + Example Request for specific request_id + ``` + curl -X GET "http://0.0.0.0:8000/spend/logs?request_id=chatcmpl-6dcb2540-d3d7-4e49-bb27-291f863f112e" \ +-H "Authorization: Bearer sk-1234" + ``` + + Example Request for specific api_key + ``` + curl -X GET "http://0.0.0.0:8000/spend/logs?api_key=sk-Fn8Ej39NkBQmUagFEoUWPQ" \ +-H "Authorization: Bearer sk-1234" + ``` + + Example Request for specific user_id + ``` + curl -X GET "http://0.0.0.0:8000/spend/logs?user_id=ishaan@berri.ai" \ +-H "Authorization: Bearer sk-1234" + ``` + """ + from litellm.proxy.proxy_server import prisma_client + + try: + verbose_proxy_logger.debug("inside view_spend_logs") + if prisma_client is None: + raise Exception( + f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" + ) + spend_logs = [] + if ( + start_date is not None + and isinstance(start_date, str) + and end_date is not None + and isinstance(end_date, str) + ): + # Convert the date strings to datetime objects + start_date_obj = datetime.strptime(start_date, "%Y-%m-%d") + end_date_obj = datetime.strptime(end_date, "%Y-%m-%d") + + filter_query = { + "startTime": { + "gte": start_date_obj, # Greater than or equal to Start Date + "lte": end_date_obj, # Less than or equal to End Date + } + } + + if api_key is not None and isinstance(api_key, str): + filter_query["api_key"] = api_key # type: ignore + elif request_id is not None and isinstance(request_id, str): + filter_query["request_id"] = request_id # type: ignore + elif user_id is not None and isinstance(user_id, str): + filter_query["user"] = user_id # type: ignore + + # SQL query + response = await prisma_client.db.litellm_spendlogs.group_by( + by=["api_key", "user", "model", "startTime"], + where=filter_query, # type: ignore + sum={ + "spend": True, + }, + ) + + if ( + isinstance(response, list) + and len(response) > 0 + and isinstance(response[0], dict) + ): + result: dict = {} + for record in response: + dt_object = datetime.strptime( + str(record["startTime"]), "%Y-%m-%dT%H:%M:%S.%fZ" + ) # type: ignore + date = dt_object.date() + if date not in result: + result[date] = {"users": {}, "models": {}} + api_key = record["api_key"] + user_id = record["user"] + model = record["model"] + result[date]["spend"] = ( + result[date].get("spend", 0) + record["_sum"]["spend"] + ) + result[date][api_key] = ( + result[date].get(api_key, 0) + record["_sum"]["spend"] + ) + result[date]["users"][user_id] = ( + result[date]["users"].get(user_id, 0) + record["_sum"]["spend"] + ) + result[date]["models"][model] = ( + result[date]["models"].get(model, 0) + record["_sum"]["spend"] + ) + return_list = [] + final_date = None + for k, v in sorted(result.items()): + return_list.append({**v, "startTime": k}) + final_date = k + + end_date_date = end_date_obj.date() + if final_date is not None and final_date < end_date_date: + current_date = final_date + timedelta(days=1) + while current_date <= end_date_date: + # Represent current_date as string because original response has it this way + return_list.append( + { + "startTime": current_date, + "spend": 0, + "users": {}, + "models": {}, + } + ) # If no data, will stay as zero + current_date += timedelta(days=1) # Move on to the next day + + return return_list + + return response + + elif api_key is not None and isinstance(api_key, str): + if api_key.startswith("sk-"): + hashed_token = prisma_client.hash_token(token=api_key) + else: + hashed_token = api_key + spend_log = await prisma_client.get_data( + table_name="spend", + query_type="find_all", + key_val={"key": "api_key", "value": hashed_token}, + ) + if isinstance(spend_log, list): + return spend_log + else: + return [spend_log] + elif request_id is not None: + spend_log = await prisma_client.get_data( + table_name="spend", + query_type="find_unique", + key_val={"key": "request_id", "value": request_id}, + ) + return [spend_log] + elif user_id is not None: + spend_log = await prisma_client.get_data( + table_name="spend", + query_type="find_all", + key_val={"key": "user", "value": user_id}, + ) + if isinstance(spend_log, list): + return spend_log + else: + return [spend_log] + else: + spend_logs = await prisma_client.get_data( + table_name="spend", query_type="find_all" + ) + + return spend_log + + return None + + except Exception as e: + if isinstance(e, HTTPException): + raise ProxyException( + message=getattr(e, "detail", f"/spend/logs Error({str(e)})"), + type="internal_error", + param=getattr(e, "param", "None"), + code=getattr(e, "status_code", status.HTTP_500_INTERNAL_SERVER_ERROR), + ) + elif isinstance(e, ProxyException): + raise e + raise ProxyException( + message="/spend/logs Error" + str(e), + type="internal_error", + param=getattr(e, "param", "None"), + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + + +@router.post( + "/global/spend/reset", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], +) +async def global_spend_reset(): + """ + ADMIN ONLY / MASTER KEY Only Endpoint + + Globally reset spend for All API Keys and Teams, maintain LiteLLM_SpendLogs + + 1. LiteLLM_SpendLogs will maintain the logs on spend, no data gets deleted from there + 2. LiteLLM_VerificationTokens spend will be set = 0 + 3. LiteLLM_TeamTable spend will be set = 0 + + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise ProxyException( + message="Prisma Client is not initialized", + type="internal_error", + param="None", + code=status.HTTP_401_UNAUTHORIZED, + ) + + await prisma_client.db.litellm_verificationtoken.update_many( + data={"spend": 0.0}, where={} + ) + await prisma_client.db.litellm_teamtable.update_many(data={"spend": 0.0}, where={}) + + return { + "message": "Spend for all API Keys and Teams reset successfully", + "status": "success", + } + + +@router.get( + "/global/spend/logs", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def global_spend_logs( + api_key: str = fastapi.Query( + default=None, + description="API Key to get global spend (spend per day for last 30d). Admin-only endpoint", + ) +): + """ + [BETA] This is a beta endpoint. It will change. + + Use this to get global spend (spend per day for last 30d). Admin-only endpoint + + More efficient implementation of /spend/logs, by creating a view over the spend logs table. + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise ProxyException( + message="Prisma Client is not initialized", + type="internal_error", + param="None", + code=status.HTTP_500_INTERNAL_SERVER_ERROR, + ) + if api_key is None: + sql_query = """SELECT * FROM "MonthlyGlobalSpend" ORDER BY "date";""" + + response = await prisma_client.db.query_raw(query=sql_query) + + return response + else: + sql_query = """ + SELECT * FROM "MonthlyGlobalSpendPerKey" + WHERE "api_key" = $1 + ORDER BY "date"; + """ + + response = await prisma_client.db.query_raw(sql_query, api_key) + + return response + return + + +@router.get( + "/global/spend", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def global_spend(): + """ + [BETA] This is a beta endpoint. It will change. + + View total spend across all proxy keys + """ + from litellm.proxy.proxy_server import prisma_client + + total_spend = 0.0 + total_proxy_budget = 0.0 + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + sql_query = """SELECT SUM(spend) as total_spend FROM "MonthlyGlobalSpend";""" + response = await prisma_client.db.query_raw(query=sql_query) + if response is not None: + if isinstance(response, list) and len(response) > 0: + total_spend = response[0].get("total_spend", 0.0) + + return {"spend": total_spend, "max_budget": litellm.max_budget} + + +@router.get( + "/global/spend/keys", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def global_spend_keys( + limit: int = fastapi.Query( + default=None, + description="Number of keys to get. Will return Top 'n' keys.", + ) +): + """ + [BETA] This is a beta endpoint. It will change. + + Use this to get the top 'n' keys with the highest spend, ordered by spend. + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + sql_query = f"""SELECT * FROM "Last30dKeysBySpend" LIMIT {limit};""" + + response = await prisma_client.db.query_raw(query=sql_query) + + return response + + +@router.get( + "/global/spend/teams", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def global_spend_per_team(): + """ + [BETA] This is a beta endpoint. It will change. + + Use this to get daily spend, grouped by `team_id` and `date` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + sql_query = """ + SELECT + t.team_alias as team_alias, + DATE(s."startTime") AS spend_date, + SUM(s.spend) AS total_spend + FROM + "LiteLLM_SpendLogs" s + LEFT JOIN + "LiteLLM_TeamTable" t ON s.team_id = t.team_id + WHERE + s."startTime" >= CURRENT_DATE - INTERVAL '30 days' + GROUP BY + t.team_alias, + DATE(s."startTime") + ORDER BY + spend_date; + """ + response = await prisma_client.db.query_raw(query=sql_query) + + # transform the response for the Admin UI + spend_by_date = {} + team_aliases = set() + total_spend_per_team = {} + for row in response: + row_date = row["spend_date"] + if row_date is None: + continue + team_alias = row["team_alias"] + if team_alias is None: + team_alias = "Unassigned" + team_aliases.add(team_alias) + if row_date in spend_by_date: + # get the team_id for this entry + # get the spend for this entry + spend = row["total_spend"] + spend = round(spend, 2) + current_date_entries = spend_by_date[row_date] + current_date_entries[team_alias] = spend + else: + spend = row["total_spend"] + spend = round(spend, 2) + spend_by_date[row_date] = {team_alias: spend} + + if team_alias in total_spend_per_team: + total_spend_per_team[team_alias] += spend + else: + total_spend_per_team[team_alias] = spend + + total_spend_per_team_ui = [] + # order the elements in total_spend_per_team by spend + total_spend_per_team = dict( + sorted(total_spend_per_team.items(), key=lambda item: item[1], reverse=True) + ) + for team_id in total_spend_per_team: + # only add first 10 elements to total_spend_per_team_ui + if len(total_spend_per_team_ui) >= 10: + break + if team_id is None: + team_id = "Unassigned" + total_spend_per_team_ui.append( + {"team_id": team_id, "total_spend": total_spend_per_team[team_id]} + ) + + # sort spend_by_date by it's key (which is a date) + + response_data = [] + for key in spend_by_date: + value = spend_by_date[key] + response_data.append({"date": key, **value}) + + return { + "daily_spend": response_data, + "teams": list(team_aliases), + "total_spend_per_team": total_spend_per_team_ui, + } + + +@router.get( + "/global/all_end_users", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def global_view_all_end_users(): + """ + [BETA] This is a beta endpoint. It will change. + + Use this to just get all the unique `end_users` + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + sql_query = """ + SELECT DISTINCT end_user FROM "LiteLLM_SpendLogs" + """ + + db_response = await prisma_client.db.query_raw(query=sql_query) + if db_response is None: + return [] + + _end_users = [] + for row in db_response: + _end_users.append(row["end_user"]) + + return {"end_users": _end_users} + + +@router.post( + "/global/spend/end_users", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def global_spend_end_users(data: Optional[GlobalEndUsersSpend] = None): + """ + [BETA] This is a beta endpoint. It will change. + + Use this to get the top 'n' keys with the highest spend, ordered by spend. + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + """ + Gets the top 100 end-users for a given api key + """ + startTime = None + endTime = None + selected_api_key = None + if data is not None: + startTime = data.startTime + endTime = data.endTime + selected_api_key = data.api_key + + startTime = startTime or datetime.now() - timedelta(days=30) + endTime = endTime or datetime.now() + + sql_query = """ +SELECT end_user, COUNT(*) AS total_count, SUM(spend) AS total_spend +FROM "LiteLLM_SpendLogs" +WHERE "startTime" >= $1::timestamp + AND "startTime" < $2::timestamp + AND ( + CASE + WHEN $3::TEXT IS NULL THEN TRUE + ELSE api_key = $3 + END + ) +GROUP BY end_user +ORDER BY total_spend DESC +LIMIT 100 + """ + response = await prisma_client.db.query_raw( + sql_query, startTime, endTime, selected_api_key + ) + + return response + + +@router.get( + "/global/spend/models", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def global_spend_models( + limit: int = fastapi.Query( + default=None, + description="Number of models to get. Will return Top 'n' models.", + ) +): + """ + [BETA] This is a beta endpoint. It will change. + + Use this to get the top 'n' keys with the highest spend, ordered by spend. + """ + from litellm.proxy.proxy_server import prisma_client + + if prisma_client is None: + raise HTTPException(status_code=500, detail={"error": "No db connected"}) + + sql_query = f"""SELECT * FROM "Last30dModelsBySpend" LIMIT {limit};""" + + response = await prisma_client.db.query_raw(query=sql_query) + + return response + + +@router.post( + "/global/predict/spend/logs", + tags=["Budget & Spend Tracking"], + dependencies=[Depends(user_api_key_auth)], + include_in_schema=False, +) +async def global_predict_spend_logs(request: Request): + from enterprise.utils import _forecast_daily_cost + + data = await request.json() + data = data.get("data") + return _forecast_daily_cost(data) diff --git a/litellm/tests/test_blocked_user_list.py b/litellm/tests/test_blocked_user_list.py index 3c277a2d4..f49084d03 100644 --- a/litellm/tests/test_blocked_user_list.py +++ b/litellm/tests/test_blocked_user_list.py @@ -29,19 +29,22 @@ import pytest, logging, asyncio import litellm, asyncio from litellm.proxy.proxy_server import ( new_user, - generate_key_fn, user_api_key_auth, user_update, + user_info, + block_user, +) +from litellm.proxy.management_endpoints.key_management_endpoints import ( delete_key_fn, info_key_fn, update_key_fn, generate_key_fn, generate_key_helper_fn, +) +from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import ( spend_user_fn, spend_key_fn, view_spend_logs, - user_info, - block_user, ) from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token from litellm._logging import verbose_proxy_logger diff --git a/litellm/tests/test_jwt.py b/litellm/tests/test_jwt.py index 17de57291..a2edbcc2f 100644 --- a/litellm/tests/test_jwt.py +++ b/litellm/tests/test_jwt.py @@ -14,6 +14,7 @@ sys.path.insert( import pytest from litellm.proxy._types import LiteLLM_JWTAuth, LiteLLMRoutes from litellm.proxy.auth.handle_jwt import JWTHandler +from litellm.proxy.management_endpoints.team_endpoints import new_team from litellm.caching import DualCache from datetime import datetime, timedelta from fastapi import Request @@ -218,7 +219,7 @@ async def test_team_token_output(prisma_client, audience): from cryptography.hazmat.backends import default_backend from fastapi import Request from starlette.datastructures import URL - from litellm.proxy.proxy_server import user_api_key_auth, new_team + from litellm.proxy.proxy_server import user_api_key_auth from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth import litellm import uuid @@ -399,7 +400,6 @@ async def test_user_token_output( from starlette.datastructures import URL from litellm.proxy.proxy_server import ( user_api_key_auth, - new_team, new_user, user_info, ) @@ -623,7 +623,7 @@ async def test_allowed_routes_admin(prisma_client, audience): from cryptography.hazmat.backends import default_backend from fastapi import Request from starlette.datastructures import URL - from litellm.proxy.proxy_server import user_api_key_auth, new_team + from litellm.proxy.proxy_server import user_api_key_auth from litellm.proxy._types import NewTeamRequest, UserAPIKeyAuth import litellm import uuid diff --git a/litellm/tests/test_key_generate_prisma.py b/litellm/tests/test_key_generate_prisma.py index f53e241ce..cf93cfae9 100644 --- a/litellm/tests/test_key_generate_prisma.py +++ b/litellm/tests/test_key_generate_prisma.py @@ -38,22 +38,9 @@ import pytest, logging, asyncio import litellm, asyncio from litellm.proxy.proxy_server import ( new_user, - generate_key_fn, user_api_key_auth, user_update, - delete_key_fn, - info_key_fn, - update_key_fn, - generate_key_fn, - generate_key_helper_fn, - spend_user_fn, - spend_key_fn, - view_spend_logs, user_info, - team_info, - info_key_fn, - new_team, - update_team, chat_completion, completion, embeddings, @@ -63,6 +50,23 @@ from litellm.proxy.proxy_server import ( model_list, LitellmUserRoles, ) +from litellm.proxy.management_endpoints.key_management_endpoints import ( + delete_key_fn, + info_key_fn, + update_key_fn, + generate_key_fn, + generate_key_helper_fn, +) +from litellm.proxy.management_endpoints.team_endpoints import ( + team_info, + new_team, + update_team, +) +from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import ( + spend_user_fn, + spend_key_fn, + view_spend_logs, +) from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend from litellm._logging import verbose_proxy_logger diff --git a/litellm/tests/test_proxy_server.py b/litellm/tests/test_proxy_server.py index 114b96872..d8bfb5229 100644 --- a/litellm/tests/test_proxy_server.py +++ b/litellm/tests/test_proxy_server.py @@ -26,7 +26,7 @@ logging.basicConfig( from fastapi.testclient import TestClient from fastapi import FastAPI from litellm.proxy.proxy_server import ( - router, + app, save_worker_config, initialize, ) # Replace with the actual module where your FastAPI router is defined @@ -119,9 +119,6 @@ def client_no_auth(fake_env_vars): config_fp = f"{filepath}/test_configs/test_config_no_auth.yaml" # initialize can get run in parallel, it sets specific variables for the fast api app, sinc eit gets run in parallel different tests use the wrong variables asyncio.run(initialize(config=config_fp, debug=True)) - app = FastAPI() - app.include_router(router) # Include your router in the test app - return TestClient(app) @@ -426,6 +423,10 @@ def test_add_new_model(client_no_auth): def test_health(client_no_auth): global headers import time + from litellm._logging import verbose_logger, verbose_proxy_logger + import logging + + verbose_proxy_logger.setLevel(logging.DEBUG) try: response = client_no_auth.get("/health") diff --git a/litellm/tests/test_update_spend.py b/litellm/tests/test_update_spend.py index 529e90e3c..d4f5662e1 100644 --- a/litellm/tests/test_update_spend.py +++ b/litellm/tests/test_update_spend.py @@ -26,19 +26,22 @@ import pytest, logging, asyncio import litellm, asyncio from litellm.proxy.proxy_server import ( new_user, - generate_key_fn, user_api_key_auth, user_update, + user_info, + block_user, +) +from litellm.proxy.spend_reporting_endpoints.spend_management_endpoints import ( + spend_user_fn, + spend_key_fn, + view_spend_logs, +) +from litellm.proxy.management_endpoints.key_management_endpoints import ( delete_key_fn, info_key_fn, update_key_fn, generate_key_fn, generate_key_helper_fn, - spend_user_fn, - spend_key_fn, - view_spend_logs, - user_info, - block_user, ) from litellm.proxy.utils import PrismaClient, ProxyLogging, hash_token, update_spend from litellm._logging import verbose_proxy_logger