fix(simple_proxy.md): enable setting a master key to protect proxy endpoints

This commit is contained in:
Krrish Dholakia 2023-11-14 12:44:26 -08:00
parent 1bb99af134
commit 1283f345dc
4 changed files with 42 additions and 16 deletions

View file

@ -85,11 +85,12 @@ def generate_feedback_box():
import litellm
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request
from fastapi import FastAPI, Request, HTTPException, status, Depends
from fastapi.routing import APIRouter
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse, FileResponse
from fastapi.middleware.cors import CORSMiddleware
from fastapi.security import OAuth2PasswordBearer
import json
import logging
@ -121,6 +122,7 @@ config_dir = appdirs.user_config_dir("litellm")
user_config_path = os.getenv(
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
)
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
experimental = False
#### GLOBAL VARIABLES ####
llm_router: Optional[litellm.Router] = None
@ -128,7 +130,7 @@ llm_model_list: Optional[list] = None
server_settings: dict = {}
log_file = "api_log.json"
worker_config = None
master_key = None
#### HELPER FUNCTIONS ####
def print_verbose(print_statement):
global user_debug
@ -145,6 +147,20 @@ def usage_telemetry(
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
).start()
async def user_api_key_auth(request: Request):
global master_key
if master_key is None:
return
try:
api_key = await oauth2_scheme(request=request)
if api_key == master_key:
return
except:
pass
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"},
)
def add_keys_to_config(key, value):
# Check if file exists
@ -198,6 +214,7 @@ def save_params_to_config(data: dict):
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key
config = {}
server_settings: dict = {}
try:
@ -210,6 +227,12 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
raise Exception(f"Exception while reading Config: {e}")
print_verbose(f"Configs passed in, loaded config YAML\n{config}")
## GENERAL SERVER SETTINGS (e.g. master key,..)
general_settings = config.get("general_settings", None)
if general_settings:
master_key = general_settings.get("master_key", None)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None)
if litellm_settings:
@ -460,8 +483,8 @@ def startup_event():
# print(f"\033[32mWorker Initialized\033[0m\n")
#### API ENDPOINTS ####
@router.get("/v1/models")
@router.get("/models") # if project requires model list
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
def model_list():
global llm_model_list, server_settings
all_models = []
@ -492,9 +515,9 @@ def model_list():
object="list",
)
@router.post("/v1/completions")
@router.post("/completions")
@router.post("/engines/{model:path}/completions")
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request, model: Optional[str] = None):
try:
body = await request.body()
@ -523,9 +546,9 @@ async def completion(request: Request, model: Optional[str] = None):
return {"error": error_msg}
@router.post("/v1/chat/completions")
@router.post("/chat/completions")
@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None):
global server_settings
try:
@ -552,7 +575,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
return {"error": error_msg}
@router.post("/router/chat/completions")
@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)])
async def router_completion(request: Request):
try:
body = await request.body()
@ -568,7 +591,7 @@ async def router_completion(request: Request):
error_msg = f"{str(e)}\n\n{error_traceback}"
return {"error": error_msg}
@router.get("/ollama_logs")
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
async def retrieve_server_log(request: Request):
filepath = os.path.expanduser("~/.ollama/logs/server.log")
return FileResponse(filepath)