diff --git a/docs/my-website/docs/routing.md b/docs/my-website/docs/routing.md index ac5dcbe29e..20dfdcdb82 100644 --- a/docs/my-website/docs/routing.md +++ b/docs/my-website/docs/routing.md @@ -2,10 +2,6 @@ import Image from '@theme/IdealImage'; # Reliability - Fallbacks, Azure Deployments, etc. -Prevent failed calls and slow response times with multiple deployments for API calls (E.g. multiple azure-openai deployments). - -HF_Dashboard - ## Manage Multiple Deployments Use this if you're trying to load-balance across multiple deployments (e.g. Azure/OpenAI). diff --git a/docs/my-website/docs/simple_proxy.md b/docs/my-website/docs/simple_proxy.md index 77977cd755..a4ae6267d2 100644 --- a/docs/my-website/docs/simple_proxy.md +++ b/docs/my-website/docs/simple_proxy.md @@ -483,6 +483,7 @@ The Config allows you to set the following params |----------------------|---------------------------------------------------------------| | `model_list` | List of supported models on the server, with model-specific configs | | `litellm_settings` | litellm Module settings, example `litellm.drop_params=True`, `litellm.set_verbose=True`, `litellm.api_base` | +| `general_settings` | Server settings, example setting `master_key: sk-my_special_key` | ### Example Config ```yaml @@ -499,6 +500,9 @@ model_list: litellm_settings: drop_params: True set_verbose: True + +general_settings: + master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234) ``` ### Quick Start - Config diff --git a/litellm/proxy/config.yaml b/litellm/proxy/config.yaml index 06215ef1d0..333c59719c 100644 --- a/litellm/proxy/config.yaml +++ b/litellm/proxy/config.yaml @@ -15,3 +15,6 @@ model_list: litellm_settings: drop_params: True + +general_settings: + master_key: sk-12345 diff --git a/litellm/proxy/proxy_server.py b/litellm/proxy/proxy_server.py index 7ea0da58dd..61c2473939 100644 --- a/litellm/proxy/proxy_server.py +++ b/litellm/proxy/proxy_server.py @@ -85,11 +85,12 @@ def generate_feedback_box(): import litellm litellm.suppress_debug_info = True -from fastapi import FastAPI, Request +from fastapi import FastAPI, Request, HTTPException, status, Depends from fastapi.routing import APIRouter from fastapi.encoders import jsonable_encoder from fastapi.responses import StreamingResponse, FileResponse from fastapi.middleware.cors import CORSMiddleware +from fastapi.security import OAuth2PasswordBearer import json import logging @@ -121,6 +122,7 @@ config_dir = appdirs.user_config_dir("litellm") user_config_path = os.getenv( "LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename) ) +oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token") experimental = False #### GLOBAL VARIABLES #### llm_router: Optional[litellm.Router] = None @@ -128,7 +130,7 @@ llm_model_list: Optional[list] = None server_settings: dict = {} log_file = "api_log.json" worker_config = None - +master_key = None #### HELPER FUNCTIONS #### def print_verbose(print_statement): global user_debug @@ -145,6 +147,20 @@ def usage_telemetry( target=litellm.utils.litellm_telemetry, args=(data,), daemon=True ).start() +async def user_api_key_auth(request: Request): + global master_key + if master_key is None: + return + try: + api_key = await oauth2_scheme(request=request) + if api_key == master_key: + return + except: + pass + raise HTTPException( + status_code=status.HTTP_401_UNAUTHORIZED, + detail={"error": "invalid user key"}, + ) def add_keys_to_config(key, value): # Check if file exists @@ -198,6 +214,7 @@ def save_params_to_config(data: dict): def load_router_config(router: Optional[litellm.Router], config_file_path: str): + global master_key config = {} server_settings: dict = {} try: @@ -210,6 +227,12 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str): raise Exception(f"Exception while reading Config: {e}") print_verbose(f"Configs passed in, loaded config YAML\n{config}") + + ## GENERAL SERVER SETTINGS (e.g. master key,..) + general_settings = config.get("general_settings", None) + if general_settings: + master_key = general_settings.get("master_key", None) + ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) litellm_settings = config.get('litellm_settings', None) if litellm_settings: @@ -460,8 +483,8 @@ def startup_event(): # print(f"\033[32mWorker Initialized\033[0m\n") #### API ENDPOINTS #### -@router.get("/v1/models") -@router.get("/models") # if project requires model list +@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)]) +@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list def model_list(): global llm_model_list, server_settings all_models = [] @@ -492,9 +515,9 @@ def model_list(): object="list", ) -@router.post("/v1/completions") -@router.post("/completions") -@router.post("/engines/{model:path}/completions") +@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)]) +@router.post("/completions", dependencies=[Depends(user_api_key_auth)]) +@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)]) async def completion(request: Request, model: Optional[str] = None): try: body = await request.body() @@ -523,9 +546,9 @@ async def completion(request: Request, model: Optional[str] = None): return {"error": error_msg} -@router.post("/v1/chat/completions") -@router.post("/chat/completions") -@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint +@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)]) +@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)]) +@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint async def chat_completion(request: Request, model: Optional[str] = None): global server_settings try: @@ -552,7 +575,7 @@ async def chat_completion(request: Request, model: Optional[str] = None): return {"error": error_msg} -@router.post("/router/chat/completions") +@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)]) async def router_completion(request: Request): try: body = await request.body() @@ -568,7 +591,7 @@ async def router_completion(request: Request): error_msg = f"{str(e)}\n\n{error_traceback}" return {"error": error_msg} -@router.get("/ollama_logs") +@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)]) async def retrieve_server_log(request: Request): filepath = os.path.expanduser("~/.ollama/logs/server.log") return FileResponse(filepath)