feat(proxy_server.py): allow user to override api key auth

This commit is contained in:
Krrish Dholakia 2023-12-04 18:32:47 -08:00
parent 51cddf1e97
commit 030bd22078
9 changed files with 274 additions and 118 deletions

View file

@ -0,0 +1,14 @@
from litellm.proxy.types import UserAPIKeyAuth
from fastapi import Request
from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
if api_key == modified_master_key:
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception

View file

@ -92,12 +92,16 @@ def generate_feedback_box():
import litellm
from litellm.proxy.utils import (
PrismaClient
PrismaClient,
get_instance_fn
)
import pydantic
from litellm.proxy.types import *
from litellm.caching import DualCache
litellm.suppress_debug_info = True
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
from fastapi.routing import APIRouter
from fastapi.security import OAuth2PasswordBearer
from fastapi.encoders import jsonable_encoder
from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse
from fastapi.middleware.cors import CORSMiddleware
@ -163,70 +167,8 @@ def log_input_output(request, response, custom_logger=None):
return True
from typing import Dict
from pydantic import BaseModel
######### Request Class Definition ######
class ProxyChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
response_format: Optional[Dict[str, str]] = None
seed: Optional[int] = None
tools: Optional[List[str]] = None
tool_choice: Optional[str] = None
functions: Optional[List[str]] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated
# Optional LiteLLM params
caching: Optional[bool] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
num_retries: Optional[int] = None
context_window_fallback_dict: Optional[Dict[str, str]] = None
fallbacks: Optional[List[str]] = None
metadata: Optional[Dict[str, str]] = {}
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None
class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelParams(BaseModel):
model_name: str
litellm_params: dict
model_info: Optional[dict]
class Config:
protected_namespaces = ()
class GenerateKeyRequest(BaseModel):
duration: str = "1h"
models: list = []
aliases: dict = {}
config: dict = {}
spend: int = 0
user_id: Optional[str] = None
class GenerateKeyResponse(BaseModel):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(BaseModel):
key: str
class DeleteKeyRequest(BaseModel):
keys: List[_DeleteKeyObject]
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
user_api_base = None
user_model = None
user_debug = False
@ -249,6 +191,7 @@ master_key = None
otel_logging = False
prisma_client: Optional[PrismaClient] = None
user_api_key_cache = DualCache()
user_custom_auth = None
### REDIS QUEUE ###
async_result = None
celery_app_conn = None
@ -268,21 +211,21 @@ def usage_telemetry(
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
).start()
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
global master_key, prisma_client, llm_model_list
print(f"master_key - {master_key}; api_key - {api_key}")
async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth:
global master_key, prisma_client, llm_model_list, user_custom_auth
try:
### USER-DEFINED AUTH FUNCTION ###
if user_custom_auth:
response = await user_custom_auth(request=request, api_key=api_key)
return UserAPIKeyAuth.model_validate(response)
if master_key is None:
if isinstance(api_key, str):
return {
"api_key": api_key.replace("Bearer ", "")
}
return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", ""))
else:
return {
"api_key": api_key
}
try:
return UserAPIKeyAuth()
if api_key is None:
raise Exception("No api key passed in.")
route = request.url.path
@ -290,9 +233,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
if is_master_key_valid:
return {
"api_key": master_key
}
return UserAPIKeyAuth(api_key=master_key)
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
@ -318,7 +259,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
return_dict = {"api_key": valid_token.token}
if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id
return return_dict
return UserAPIKeyAuth(**return_dict)
else:
data = await request.json()
model = data.get("model", None)
@ -329,14 +270,14 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
return_dict = {"api_key": valid_token.token}
if valid_token.user_id:
return_dict["user_id"] = valid_token.user_id
return return_dict
return UserAPIKeyAuth(**return_dict)
else:
raise Exception(f"Invalid token")
except Exception as e:
print(f"An exception occurred - {traceback.format_exc()}")
raise HTTPException(
status_code=status.HTTP_401_UNAUTHORIZED,
detail={"error": "invalid user key"},
detail="invalid user key",
)
def prisma_setup(database_url: Optional[str]):
@ -464,7 +405,7 @@ def run_ollama_serve():
""")
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
global master_key, user_config_file_path, otel_logging
global master_key, user_config_file_path, otel_logging, user_custom_auth
config = {}
try:
if os.path.exists(config_file_path):
@ -499,7 +440,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
### LOAD FROM AZURE KEY VAULT ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
@ -514,12 +454,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
master_key = general_settings.get("master_key", None)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
#### OpenTelemetry Logging (OTEL) ########
otel_logging = general_settings.get("otel", False)
if otel_logging == True:
print("\nOpenTelemetry Logging Activated")
### CUSTOM API KEY AUTH ###
custom_auth = general_settings.get("custom_auth", None)
if custom_auth:
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
litellm_settings = config.get('litellm_settings', None)
if litellm_settings:
@ -549,23 +491,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
password=cache_password
)
elif key == "callbacks":
print(f"{blue_color_code}\nSetting custom callbacks on Proxy")
passed_module, instance_name = value.split(".")
# Dynamically import the module
module = importlib.import_module(passed_module)
# Get the instance from the module
instance = getattr(module, instance_name)
methods = [method for method in dir(instance) if callable(getattr(instance, method))]
# Print the methods
print("Methods in the custom callbacks instance:")
for method in methods:
print(method)
litellm.callbacks = [instance]
print()
litellm.callbacks = [get_instance_fn(value=value)]
else:
setattr(litellm, key, value)
@ -844,7 +770,7 @@ def model_list():
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
try:
body = await request.body()
body_str = body.decode()
@ -853,7 +779,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
except:
data = json.loads(body_str)
data["user"] = user_api_key_dict.get("user_id", None)
data["user"] = user_api_key_dict.user_id
data["model"] = (
general_settings.get("completion_model", None) # server default
or user_model # model name passed via cli args
@ -864,9 +790,9 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
data["model"] = user_model
data["call_type"] = "text_completion"
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
else:
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
return litellm_completion(
**data
@ -888,7 +814,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
global general_settings, user_debug
try:
data = {}
@ -905,13 +831,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
# users can pass in 'user' param to /chat/completions. Don't override it
if data.get("user", None) is None:
# if users are using user_api_key_auth, set `user` in `data`
data["user"] = user_api_key_dict.get("user_id", None)
data["user"] = user_api_key_dict.user_id
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
data["metadata"]["headers"] = request.headers
else:
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
data["metadata"]["headers"] = request.headers
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
# override with user settings, these are params passed via cli
@ -962,14 +888,14 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
try:
# Use orjson to parse JSON data, orjson speeds up requests significantly
body = await request.body()
data = orjson.loads(body)
data["user"] = user_api_key_dict.get("user_id", None)
data["user"] = user_api_key_dict.user_id
data["model"] = (
general_settings.get("embedding_model", None) # server default
or user_model # model name passed via cli args
@ -978,9 +904,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
if user_model:
data["model"] = user_model
if "metadata" in data:
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
else:
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
## ROUTE TO CORRECT ENDPOINT ##
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []

70
litellm/proxy/types.py Normal file
View file

@ -0,0 +1,70 @@
from pydantic import BaseModel
from typing import Optional, List, Union, Dict
from datetime import datetime
######### Request Class Definition ######
class ProxyChatCompletionRequest(BaseModel):
model: str
messages: List[Dict[str, str]]
temperature: Optional[float] = None
top_p: Optional[float] = None
n: Optional[int] = None
stream: Optional[bool] = None
stop: Optional[List[str]] = None
max_tokens: Optional[int] = None
presence_penalty: Optional[float] = None
frequency_penalty: Optional[float] = None
logit_bias: Optional[Dict[str, float]] = None
user: Optional[str] = None
response_format: Optional[Dict[str, str]] = None
seed: Optional[int] = None
tools: Optional[List[str]] = None
tool_choice: Optional[str] = None
functions: Optional[List[str]] = None # soon to be deprecated
function_call: Optional[str] = None # soon to be deprecated
# Optional LiteLLM params
caching: Optional[bool] = None
api_base: Optional[str] = None
api_version: Optional[str] = None
api_key: Optional[str] = None
num_retries: Optional[int] = None
context_window_fallback_dict: Optional[Dict[str, str]] = None
fallbacks: Optional[List[str]] = None
metadata: Optional[Dict[str, str]] = {}
deployment_id: Optional[str] = None
request_timeout: Optional[int] = None
class Config:
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
class ModelParams(BaseModel):
model_name: str
litellm_params: dict
model_info: Optional[dict]
class Config:
protected_namespaces = ()
class GenerateKeyRequest(BaseModel):
duration: str = "1h"
models: list = []
aliases: dict = {}
config: dict = {}
spend: int = 0
user_id: Optional[str] = None
class GenerateKeyResponse(BaseModel):
key: str
expires: datetime
user_id: str
class _DeleteKeyObject(BaseModel):
key: str
class DeleteKeyRequest(BaseModel):
keys: List[_DeleteKeyObject]
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
api_key: Optional[str] = None
user_id: Optional[str] = None

View file

@ -1,6 +1,7 @@
from typing import Optional, List, Any
import os, subprocess, hashlib
import os, subprocess, hashlib, importlib
### DB CONNECTOR ###
class PrismaClient:
def __init__(self, database_url: str):
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
@ -95,3 +96,60 @@ class PrismaClient:
async def disconnect(self):
await self.db.disconnect()
# ### CUSTOM FILE ###
# def get_instance_fn(value: str, config_file_path: Optional[str]=None):
# try:
# # Split the path by dots to separate module from instance
# parts = value.split(".")
# # The module path is all but the last part, and the instance is the last part
# module_path = ".".join(parts[:-1])
# instance_name = parts[-1]
# if config_file_path is not None:
# directory = os.path.dirname(config_file_path)
# module_path = os.path.join(directory, module_path)
# # Dynamically import the module
# module = importlib.import_module(module_path)
# # Get the instance from the module
# instance = getattr(module, instance_name)
# return instance
# except ImportError as e:
# print(e)
# raise ImportError(f"Could not import file at {value}")
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
try:
print(f"value: {value}")
# Split the path by dots to separate module from instance
parts = value.split(".")
# The module path is all but the last part, and the instance_name is the last part
module_name = ".".join(parts[:-1])
instance_name = parts[-1]
# If config_file_path is provided, use it to determine the module spec and load the module
if config_file_path is not None:
directory = os.path.dirname(config_file_path)
module_file_path = os.path.join(directory, *module_name.split('.'))
module_file_path += '.py'
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
if spec is None:
raise ImportError(f"Could not find a module specification for {module_file_path}")
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
else:
# Dynamically import the module
module = importlib.import_module(module_name)
# Get the instance from the module
instance = getattr(module, instance_name)
return instance
except ImportError as e:
# Print the error message for easier debugging
print(e)
# Re-raise the exception with a user-friendly message
raise ImportError(f"Could not import {instance_name} from {module_name}") from e

View file

@ -0,0 +1,14 @@
from litellm.proxy.types import UserAPIKeyAuth
from fastapi import Request
from dotenv import load_dotenv
import os
load_dotenv()
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
try:
print(f"api_key: {api_key}")
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
return UserAPIKeyAuth(api_key=api_key)
raise Exception
except:
raise Exception

View file

@ -0,0 +1,11 @@
model_list:
- model_name: "openai-model"
litellm_params:
model: "gpt-3.5-turbo"
litellm_settings:
drop_params: True
set_verbose: True
general_settings:
custom_auth: custom_auth.user_api_key_auth

View file

@ -0,0 +1,63 @@
import sys, os
import traceback
from dotenv import load_dotenv
load_dotenv()
import os, io
# this file is to test litellm/proxy
sys.path.insert(
0, os.path.abspath("../..")
) # Adds the parent directory to the system path
import pytest
import litellm
from litellm import embedding, completion, completion_cost, Timeout
from litellm import RateLimitError
# test /chat/completion request to the proxy
from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI()
app.include_router(router) # Include your router in the test app
@app.on_event("startup")
async def wrapper_startup_event():
await startup_event()
# Here you create a fixture that will be used by your tests
# Make sure the fixture returns TestClient(app)
@pytest.fixture(autouse=True)
def client():
with TestClient(app) as client:
yield client
def test_custom_auth(client):
try:
# Your test data
test_data = {
"model": "openai-model",
"messages": [
{
"role": "user",
"content": "hi"
},
],
"max_tokens": 10,
}
# Your bearer token
token = os.getenv("PROXY_MASTER_KEY")
headers = {
"Authorization": f"Bearer {token}"
}
response = client.post("/chat/completions", json=test_data, headers=headers)
print(f"response: {response.text}")
assert response.status_code == 401
result = response.json()
print(f"Received response: {result}")
except Exception as e:
pytest.fail("LiteLLM Proxy test failed. Exception", e)

View file

@ -25,7 +25,7 @@ from fastapi.testclient import TestClient
from fastapi import FastAPI
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
filepath = os.path.dirname(os.path.abspath(__file__))
config_fp = f"{filepath}/test_config.yaml"
config_fp = f"{filepath}/test_configs/test_config.yaml"
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
app = FastAPI()
app.include_router(router) # Include your router in the test app