forked from phoenix/litellm-mirror
fix(simple_proxy.md): enable setting a master key to protect proxy endpoints
This commit is contained in:
parent
0779dfbbd6
commit
7ef8611952
4 changed files with 42 additions and 16 deletions
|
@ -2,10 +2,6 @@ import Image from '@theme/IdealImage';
|
||||||
|
|
||||||
# Reliability - Fallbacks, Azure Deployments, etc.
|
# Reliability - Fallbacks, Azure Deployments, etc.
|
||||||
|
|
||||||
Prevent failed calls and slow response times with multiple deployments for API calls (E.g. multiple azure-openai deployments).
|
|
||||||
|
|
||||||
<Image img={require('../img/multiple_deployments.png')} alt="HF_Dashboard" style={{ maxWidth: '100%', height: 'auto' }}/>
|
|
||||||
|
|
||||||
## Manage Multiple Deployments
|
## Manage Multiple Deployments
|
||||||
|
|
||||||
Use this if you're trying to load-balance across multiple deployments (e.g. Azure/OpenAI).
|
Use this if you're trying to load-balance across multiple deployments (e.g. Azure/OpenAI).
|
||||||
|
|
|
@ -483,6 +483,7 @@ The Config allows you to set the following params
|
||||||
|----------------------|---------------------------------------------------------------|
|
|----------------------|---------------------------------------------------------------|
|
||||||
| `model_list` | List of supported models on the server, with model-specific configs |
|
| `model_list` | List of supported models on the server, with model-specific configs |
|
||||||
| `litellm_settings` | litellm Module settings, example `litellm.drop_params=True`, `litellm.set_verbose=True`, `litellm.api_base` |
|
| `litellm_settings` | litellm Module settings, example `litellm.drop_params=True`, `litellm.set_verbose=True`, `litellm.api_base` |
|
||||||
|
| `general_settings` | Server settings, example setting `master_key: sk-my_special_key` |
|
||||||
|
|
||||||
### Example Config
|
### Example Config
|
||||||
```yaml
|
```yaml
|
||||||
|
@ -499,6 +500,9 @@ model_list:
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
drop_params: True
|
drop_params: True
|
||||||
set_verbose: True
|
set_verbose: True
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-1234 # [OPTIONAL] Only use this if you to require all calls to contain this key (Authorization: Bearer sk-1234)
|
||||||
```
|
```
|
||||||
|
|
||||||
### Quick Start - Config
|
### Quick Start - Config
|
||||||
|
|
|
@ -15,3 +15,6 @@ model_list:
|
||||||
|
|
||||||
litellm_settings:
|
litellm_settings:
|
||||||
drop_params: True
|
drop_params: True
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
master_key: sk-12345
|
||||||
|
|
|
@ -85,11 +85,12 @@ def generate_feedback_box():
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
litellm.suppress_debug_info = True
|
litellm.suppress_debug_info = True
|
||||||
from fastapi import FastAPI, Request
|
from fastapi import FastAPI, Request, HTTPException, status, Depends
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import StreamingResponse, FileResponse
|
from fastapi.responses import StreamingResponse, FileResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
|
|
||||||
|
@ -121,6 +122,7 @@ config_dir = appdirs.user_config_dir("litellm")
|
||||||
user_config_path = os.getenv(
|
user_config_path = os.getenv(
|
||||||
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
|
"LITELLM_CONFIG_PATH", os.path.join(config_dir, config_filename)
|
||||||
)
|
)
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
experimental = False
|
experimental = False
|
||||||
#### GLOBAL VARIABLES ####
|
#### GLOBAL VARIABLES ####
|
||||||
llm_router: Optional[litellm.Router] = None
|
llm_router: Optional[litellm.Router] = None
|
||||||
|
@ -128,7 +130,7 @@ llm_model_list: Optional[list] = None
|
||||||
server_settings: dict = {}
|
server_settings: dict = {}
|
||||||
log_file = "api_log.json"
|
log_file = "api_log.json"
|
||||||
worker_config = None
|
worker_config = None
|
||||||
|
master_key = None
|
||||||
#### HELPER FUNCTIONS ####
|
#### HELPER FUNCTIONS ####
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
global user_debug
|
global user_debug
|
||||||
|
@ -145,6 +147,20 @@ def usage_telemetry(
|
||||||
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
|
async def user_api_key_auth(request: Request):
|
||||||
|
global master_key
|
||||||
|
if master_key is None:
|
||||||
|
return
|
||||||
|
try:
|
||||||
|
api_key = await oauth2_scheme(request=request)
|
||||||
|
if api_key == master_key:
|
||||||
|
return
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
raise HTTPException(
|
||||||
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
|
detail={"error": "invalid user key"},
|
||||||
|
)
|
||||||
|
|
||||||
def add_keys_to_config(key, value):
|
def add_keys_to_config(key, value):
|
||||||
# Check if file exists
|
# Check if file exists
|
||||||
|
@ -198,6 +214,7 @@ def save_params_to_config(data: dict):
|
||||||
|
|
||||||
|
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
|
global master_key
|
||||||
config = {}
|
config = {}
|
||||||
server_settings: dict = {}
|
server_settings: dict = {}
|
||||||
try:
|
try:
|
||||||
|
@ -210,6 +227,12 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
raise Exception(f"Exception while reading Config: {e}")
|
raise Exception(f"Exception while reading Config: {e}")
|
||||||
|
|
||||||
print_verbose(f"Configs passed in, loaded config YAML\n{config}")
|
print_verbose(f"Configs passed in, loaded config YAML\n{config}")
|
||||||
|
|
||||||
|
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
||||||
|
general_settings = config.get("general_settings", None)
|
||||||
|
if general_settings:
|
||||||
|
master_key = general_settings.get("master_key", None)
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get('litellm_settings', None)
|
litellm_settings = config.get('litellm_settings', None)
|
||||||
if litellm_settings:
|
if litellm_settings:
|
||||||
|
@ -460,8 +483,8 @@ def startup_event():
|
||||||
# print(f"\033[32mWorker Initialized\033[0m\n")
|
# print(f"\033[32mWorker Initialized\033[0m\n")
|
||||||
|
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.get("/v1/models")
|
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.get("/models") # if project requires model list
|
@router.get("/models", dependencies=[Depends(user_api_key_auth)]) # if project requires model list
|
||||||
def model_list():
|
def model_list():
|
||||||
global llm_model_list, server_settings
|
global llm_model_list, server_settings
|
||||||
all_models = []
|
all_models = []
|
||||||
|
@ -492,9 +515,9 @@ def model_list():
|
||||||
object="list",
|
object="list",
|
||||||
)
|
)
|
||||||
|
|
||||||
@router.post("/v1/completions")
|
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/completions")
|
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/engines/{model:path}/completions")
|
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def completion(request: Request, model: Optional[str] = None):
|
async def completion(request: Request, model: Optional[str] = None):
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -523,9 +546,9 @@ async def completion(request: Request, model: Optional[str] = None):
|
||||||
return {"error": error_msg}
|
return {"error": error_msg}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/v1/chat/completions")
|
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/chat/completions")
|
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/openai/deployments/{model:path}/chat/completions") # azure compatible endpoint
|
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)]) # azure compatible endpoint
|
||||||
async def chat_completion(request: Request, model: Optional[str] = None):
|
async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
global server_settings
|
global server_settings
|
||||||
try:
|
try:
|
||||||
|
@ -552,7 +575,7 @@ async def chat_completion(request: Request, model: Optional[str] = None):
|
||||||
return {"error": error_msg}
|
return {"error": error_msg}
|
||||||
|
|
||||||
|
|
||||||
@router.post("/router/chat/completions")
|
@router.post("/router/chat/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def router_completion(request: Request):
|
async def router_completion(request: Request):
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
|
@ -568,7 +591,7 @@ async def router_completion(request: Request):
|
||||||
error_msg = f"{str(e)}\n\n{error_traceback}"
|
error_msg = f"{str(e)}\n\n{error_traceback}"
|
||||||
return {"error": error_msg}
|
return {"error": error_msg}
|
||||||
|
|
||||||
@router.get("/ollama_logs")
|
@router.get("/ollama_logs", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def retrieve_server_log(request: Request):
|
async def retrieve_server_log(request: Request):
|
||||||
filepath = os.path.expanduser("~/.ollama/logs/server.log")
|
filepath = os.path.expanduser("~/.ollama/logs/server.log")
|
||||||
return FileResponse(filepath)
|
return FileResponse(filepath)
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue