fix(simple_proxy.md): enable setting a master key to protect proxy endpoints

This commit is contained in:
Krrish Dholakia 2023-11-14 12:44:26 -08:00
parent 0779dfbbd6
commit 7ef8611952
4 changed files with 42 additions and 16 deletions

View file

@ -2,10 +2,6 @@ import Image from '@theme/IdealImage';
# Reliability - Fallbacks, Azure Deployments, etc. # Reliability - Fallbacks, Azure Deployments, etc.
Prevent failed calls and slow response times with multiple deployments for API calls (E.g. multiple azure-openai deployments).
<Image img={require('../img/multiple_deployments.png')} alt="HF_Dashboard" style={{ maxWidth: '100%', height: 'auto' }}/>
## Manage Multiple Deployments ## Manage Multiple Deployments
Use this if you're trying to load-balance across multiple deployments (e.g. Azure/OpenAI). Use this if you're trying to load-balance across multiple deployments (e.g. Azure/OpenAI).

View file

@ -483,6 +483,7 @@ The Config allows you to set the following params
|----------------------|---------------------------------------------------------------| |----------------------|---------------------------------------------------------------|
| `model_list` | List of supported models on the server, with model-specific configs | | `model_list` | List of supported models on the server, with model-specific configs |
| `litellm_settings` | litellm Module settings, example `litellm.drop_params=True`, `litellm.set_verbose=True`, `litellm.api_base` | | `litellm_settings` | litellm Module settings, example `litellm.drop_params=True`, `litellm.set_verbose=True`, `litellm.api_base` |
| `general_settings` | Server settings, example setting `master_key: sk-my_special_key` |
### Example Config ### Example Config
```yaml ```yaml
@ -499,6 +500,9 @@ model_list:
litellm_settings: litellm_settings:
drop_params: True drop_params: True
set_verbose: True set_verbose: True
general_settings:
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
``` ```
### Quick Start - Config ### Quick Start - Config

View file

@ -15,3 +15,6 @@ model_list:
litellm_settings: litellm_settings:
drop_params: True drop_params: True
general_settings:
master_key: sk-12345

View file

@ -85,11 +85,12 @@ def generate_feedback_box():
import litellm import litellm
litellm.suppress_debug_info = True litellm.suppress_debug_info = True
from fastapi import FastAPI, Request from fastapi import FastAPI, Request, HTTPException, status, Depends
from fastapi.routing import APIRouter from fastapi.routing import APIRouter
from fastapi.encoders import jsonable_encoder from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse, FileResponse from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer
import json import json
import logging import logging
@ -121,6 +122,7 @@ config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.getenv( user_config_path = os.getenv(
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename) "LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
) )
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
experimental = False experimental = False
#### GLOBAL VARIABLES #### #### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None llm_router: Optional[litellm.Router] = None
@ -128,7 +130,7 @@ llm_model_list: Optional[list] = None
server_settings: dict = {} server_settings: dict = {}
log_file = "api_log.json" log_file = "api_log.json"
worker_config = None worker_config = None
master_key = None
#### HELPER FUNCTIONS #### #### HELPER FUNCTIONS ####
def print_verbose(print_statement): def print_verbose(print_statement):
global user_debug global user_debug
@ -145,6 +147,20 @@ def usage_telemetry(
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
).start() ).start()
async def user_api_key_auth(request: Request):
global master_key
if master_key is None:
return
try:
api_key = await oauth2_scheme(request=request)
if api_key == master_key:
return
except:
pass
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"},
)
def add_keys_to_config(key, value): def add_keys_to_config(key, value):
# Check if file exists # Check if file exists
@ -198,6 +214,7 @@ def save_params_to_config(data: dict):
def load_router_config(router: Optional[litellm.Router], config_file_path: str): def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key
config = {} config = {}
server_settings: dict = {} server_settings: dict = {}
try: try:
@ -210,6 +227,12 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
raise Exception(f"Exception while reading Config: {e}") raise Exception(f"Exception while reading Config: {e}")
print_verbose(f"Configs passed in, loaded config YAML\n{config}") print_verbose(f"Configs passed in, loaded config YAML\n{config}")
## GENERAL SERVER SETTINGS (e.g. master key,..)
general_settings = config.get("general_settings", None)
if general_settings:
master_key = general_settings.get("master_key", None)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..) ## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None) litellm_settings = config.get('litellm_settings', None)
if litellm_settings: if litellm_settings:
@ -460,8 +483,8 @@ def startup_event():
# print(f"\033[32mWorker Initialized\033[0m\n") # print(f"\033[32mWorker Initialized\033[0m\n")
#### API ENDPOINTS #### #### API ENDPOINTS ####
@router.get("/v1/models") @router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@router.get("/models") # if project requires model list @router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
def model_list(): def model_list():
global llm_model_list, server_settings global llm_model_list, server_settings
all_models = [] all_models = []
@ -492,9 +515,9 @@ def model_list():
object="list", object="list",
) )
@router.post("/v1/completions") @router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions") @router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions") @router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request, model: Optional[str] = None): async def completion(request: Request, model: Optional[str] = None):
try: try:
body = await request.body() body = await request.body()
@ -523,9 +546,9 @@ async def completion(request: Request, model: Optional[str] = None):
return {"error": error_msg} return {"error": error_msg}
@router.post("/v1/chat/completions") @router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/chat/completions") @router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint @router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None): async def chat_completion(request: Request, model: Optional[str] = None):
global server_settings global server_settings
try: try:
@ -552,7 +575,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
return {"error": error_msg} return {"error": error_msg}
@router.post("/router/chat/completions") @router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)])
async def router_completion(request: Request): async def router_completion(request: Request):
try: try:
body = await request.body() body = await request.body()
@ -568,7 +591,7 @@ async def router_completion(request: Request):
error_msg = f"{str(e)}\n\n{error_traceback}" error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg} return {"error": error_msg}
@router.get("/ollama_logs") @router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
async def retrieve_server_log(request: Request): async def retrieve_server_log(request: Request):
filepath = os.path.expanduser("~/.ollama/logs/server.log") filepath = os.path.expanduser("~/.ollama/logs/server.log")
return FileResponse(filepath) return FileResponse(filepath)