mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-26 03:04:13 +00:00
fix(simple_proxy.md): enable setting a master key to protect proxy endpoints
This commit is contained in:
parent
0779dfbbd6
commit
7ef8611952
4 changed files with 42 additions and 16 deletions
|
@ -85,11 +85,12 @@ def generate_feedback_box():
|
|||
|
||||
import litellm
|
||||
litellm.suppress_debug_info = True
|
||||
from fastapi import FastAPI, Request
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import StreamingResponse, FileResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
import json
|
||||
import logging
|
||||
|
||||
|
@ -121,6 +122,7 @@ config_dir = appdirs.user_config_dir("litellm")
|
|||
user_config_path = os.getenv(
|
||||
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
|
||||
)
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
experimental = False
|
||||
#### GLOBAL VARIABLES ####
|
||||
llm_router: Optional[litellm.Router] = None
|
||||
|
@ -128,7 +130,7 @@ llm_model_list: Optional[list] = None
|
|||
server_settings: dict = {}
|
||||
log_file = "api_log.json"
|
||||
worker_config = None
|
||||
|
||||
master_key = None
|
||||
#### HELPER FUNCTIONS ####
|
||||
def print_verbose(print_statement):
|
||||
global user_debug
|
||||
|
@ -145,6 +147,20 @@ def usage_telemetry(
|
|||
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
||||
).start()
|
||||
|
||||
async def user_api_key_auth(request: Request):
|
||||
global master_key
|
||||
if master_key is None:
|
||||
return
|
||||
try:
|
||||
api_key = await oauth2_scheme(request=request)
|
||||
if api_key == master_key:
|
||||
return
|
||||
except:
|
||||
pass
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={"error": "invalid user key"},
|
||||
)
|
||||
|
||||
def add_keys_to_config(key, value):
|
||||
# Check if file exists
|
||||
|
@ -198,6 +214,7 @@ def save_params_to_config(data: dict):
|
|||
|
||||
|
||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key
|
||||
config = {}
|
||||
server_settings: dict = {}
|
||||
try:
|
||||
|
@ -210,6 +227,12 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
raise Exception(f"Exception while reading Config: {e}")
|
||||
|
||||
print_verbose(f"Configs passed in, loaded config YAML\n{config}")
|
||||
|
||||
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
||||
general_settings = config.get("general_settings", None)
|
||||
if general_settings:
|
||||
master_key = general_settings.get("master_key", None)
|
||||
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get('litellm_settings', None)
|
||||
if litellm_settings:
|
||||
|
@ -460,8 +483,8 @@ def startup_event():
|
|||
# print(f"\033[32mWorker Initialized\033[0m\n")
|
||||
|
||||
#### API ENDPOINTS ####
|
||||
@router.get("/v1/models")
|
||||
@router.get("/models") # if project requires model list
|
||||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
||||
def model_list():
|
||||
global llm_model_list, server_settings
|
||||
all_models = []
|
||||
|
@ -492,9 +515,9 @@ def model_list():
|
|||
object="list",
|
||||
)
|
||||
|
||||
@router.post("/v1/completions")
|
||||
@router.post("/completions")
|
||||
@router.post("/engines/{model:path}/completions")
|
||||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
async def completion(request: Request, model: Optional[str] = None):
|
||||
try:
|
||||
body = await request.body()
|
||||
|
@ -523,9 +546,9 @@ async def completion(request: Request, model: Optional[str] = None):
|
|||
return {"error": error_msg}
|
||||
|
||||
|
||||
@router.post("/v1/chat/completions")
|
||||
@router.post("/chat/completions")
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
|
||||
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
|
||||
async def chat_completion(request: Request, model: Optional[str] = None):
|
||||
global server_settings
|
||||
try:
|
||||
|
@ -552,7 +575,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
|||
return {"error": error_msg}
|
||||
|
||||
|
||||
@router.post("/router/chat/completions")
|
||||
@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
async def router_completion(request: Request):
|
||||
try:
|
||||
body = await request.body()
|
||||
|
@ -568,7 +591,7 @@ async def router_completion(request: Request):
|
|||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||
return {"error": error_msg}
|
||||
|
||||
@router.get("/ollama_logs")
|
||||
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
||||
async def retrieve_server_log(request: Request):
|
||||
filepath = os.path.expanduser("~/.ollama/logs/server.log")
|
||||
return FileResponse(filepath)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue