mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(proxy_server.py): allow user to override api key auth
This commit is contained in:
parent
51cddf1e97
commit
030bd22078
9 changed files with 274 additions and 118 deletions
|
@ -92,12 +92,16 @@ def generate_feedback_box():
|
|||
|
||||
import litellm
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient
|
||||
PrismaClient,
|
||||
get_instance_fn
|
||||
)
|
||||
import pydantic
|
||||
from litellm.proxy.types import *
|
||||
from litellm.caching import DualCache
|
||||
litellm.suppress_debug_info = True
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
@ -163,70 +167,8 @@ def log_input_output(request, response, custom_logger=None):
|
|||
return True
|
||||
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
######### Request Class Definition ######
|
||||
class ProxyChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
response_format: Optional[Dict[str, str]] = None
|
||||
seed: Optional[int] = None
|
||||
tools: Optional[List[str]] = None
|
||||
tool_choice: Optional[str] = None
|
||||
functions: Optional[List[str]] = None # soon to be deprecated
|
||||
function_call: Optional[str] = None # soon to be deprecated
|
||||
|
||||
# Optional LiteLLM params
|
||||
caching: Optional[bool] = None
|
||||
api_base: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
num_retries: Optional[int] = None
|
||||
context_window_fallback_dict: Optional[Dict[str, str]] = None
|
||||
fallbacks: Optional[List[str]] = None
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
deployment_id: Optional[str] = None
|
||||
request_timeout: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
|
||||
class ModelParams(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: Optional[dict]
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
class GenerateKeyRequest(BaseModel):
|
||||
duration: str = "1h"
|
||||
models: list = []
|
||||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: int = 0
|
||||
user_id: Optional[str] = None
|
||||
|
||||
class GenerateKeyResponse(BaseModel):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(BaseModel):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(BaseModel):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
user_api_base = None
|
||||
user_model = None
|
||||
user_debug = False
|
||||
|
@ -249,6 +191,7 @@ master_key = None
|
|||
otel_logging = False
|
||||
prisma_client: Optional[PrismaClient] = None
|
||||
user_api_key_cache = DualCache()
|
||||
user_custom_auth = None
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
celery_app_conn = None
|
||||
|
@ -268,21 +211,21 @@ def usage_telemetry(
|
|||
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
||||
).start()
|
||||
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
||||
global master_key, prisma_client, llm_model_list
|
||||
print(f"master_key - {master_key}; api_key - {api_key}")
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
return {
|
||||
"api_key": api_key.replace("Bearer ", "")
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"api_key": api_key
|
||||
}
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth:
|
||||
global master_key, prisma_client, llm_model_list, user_custom_auth
|
||||
try:
|
||||
### USER-DEFINED AUTH FUNCTION ###
|
||||
if user_custom_auth:
|
||||
response = await user_custom_auth(request=request, api_key=api_key)
|
||||
return UserAPIKeyAuth.model_validate(response)
|
||||
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", ""))
|
||||
else:
|
||||
return UserAPIKeyAuth()
|
||||
if api_key is None:
|
||||
raise Exception("No api key passed in.")
|
||||
route = request.url.path
|
||||
|
@ -290,9 +233,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
|
||||
if is_master_key_valid:
|
||||
return {
|
||||
"api_key": master_key
|
||||
}
|
||||
return UserAPIKeyAuth(api_key=master_key)
|
||||
|
||||
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
||||
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
||||
|
@ -318,7 +259,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return return_dict
|
||||
return UserAPIKeyAuth(**return_dict)
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
|
@ -329,14 +270,14 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return return_dict
|
||||
return UserAPIKeyAuth(**return_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
print(f"An exception occurred - {traceback.format_exc()}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={"error": "invalid user key"},
|
||||
detail="invalid user key",
|
||||
)
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
|
@ -464,7 +405,7 @@ def run_ollama_serve():
|
|||
""")
|
||||
|
||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key, user_config_file_path, otel_logging
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth
|
||||
config = {}
|
||||
try:
|
||||
if os.path.exists(config_file_path):
|
||||
|
@ -499,7 +440,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
### LOAD FROM AZURE KEY VAULT ###
|
||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||
|
||||
### CONNECT TO DATABASE ###
|
||||
database_url = general_settings.get("database_url", None)
|
||||
if database_url and database_url.startswith("os.environ/"):
|
||||
|
@ -514,12 +454,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
master_key = general_settings.get("master_key", None)
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
|
||||
#### OpenTelemetry Logging (OTEL) ########
|
||||
otel_logging = general_settings.get("otel", False)
|
||||
if otel_logging == True:
|
||||
print("\nOpenTelemetry Logging Activated")
|
||||
|
||||
### CUSTOM API KEY AUTH ###
|
||||
custom_auth = general_settings.get("custom_auth", None)
|
||||
if custom_auth:
|
||||
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get('litellm_settings', None)
|
||||
if litellm_settings:
|
||||
|
@ -549,23 +491,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
password=cache_password
|
||||
)
|
||||
elif key == "callbacks":
|
||||
print(f"{blue_color_code}\nSetting custom callbacks on Proxy")
|
||||
passed_module, instance_name = value.split(".")
|
||||
|
||||
# Dynamically import the module
|
||||
module = importlib.import_module(passed_module)
|
||||
# Get the instance from the module
|
||||
instance = getattr(module, instance_name)
|
||||
|
||||
methods = [method for method in dir(instance) if callable(getattr(instance, method))]
|
||||
# Print the methods
|
||||
print("Methods in the custom callbacks instance:")
|
||||
for method in methods:
|
||||
print(method)
|
||||
|
||||
litellm.callbacks = [instance]
|
||||
print()
|
||||
|
||||
litellm.callbacks = [get_instance_fn(value=value)]
|
||||
else:
|
||||
setattr(litellm, key, value)
|
||||
|
||||
|
@ -844,7 +770,7 @@ def model_list():
|
|||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -853,7 +779,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
except:
|
||||
data = json.loads(body_str)
|
||||
|
||||
data["user"] = user_api_key_dict.get("user_id", None)
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
@ -864,9 +790,9 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
data["model"] = user_model
|
||||
data["call_type"] = "text_completion"
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
|
||||
return litellm_completion(
|
||||
**data
|
||||
|
@ -888,7 +814,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
||||
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global general_settings, user_debug
|
||||
try:
|
||||
data = {}
|
||||
|
@ -905,13 +831,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||
if data.get("user", None) is None:
|
||||
# if users are using user_api_key_auth, set `user` in `data`
|
||||
data["user"] = user_api_key_dict.get("user_id", None)
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["headers"] = request.headers
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
data["metadata"]["headers"] = request.headers
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
# override with user settings, these are params passed via cli
|
||||
|
@ -962,14 +888,14 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
try:
|
||||
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
data["user"] = user_api_key_dict.get("user_id", None)
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
data["model"] = (
|
||||
general_settings.get("embedding_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
@ -978,9 +904,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
|
|||
if user_model:
|
||||
data["model"] = user_model
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue