forked from phoenix/litellm-mirror
feat(proxy_server.py): allow user to override api key auth
This commit is contained in:
parent
51cddf1e97
commit
030bd22078
9 changed files with 274 additions and 118 deletions
14
litellm/proxy/custom_auth.py
Normal file
14
litellm/proxy/custom_auth.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from litellm.proxy.types import UserAPIKeyAuth
|
||||
from fastapi import Request
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
|
||||
if api_key == modified_master_key:
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
|
@ -92,12 +92,16 @@ def generate_feedback_box():
|
|||
|
||||
import litellm
|
||||
from litellm.proxy.utils import (
|
||||
PrismaClient
|
||||
PrismaClient,
|
||||
get_instance_fn
|
||||
)
|
||||
import pydantic
|
||||
from litellm.proxy.types import *
|
||||
from litellm.caching import DualCache
|
||||
litellm.suppress_debug_info = True
|
||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
||||
from fastapi.routing import APIRouter
|
||||
from fastapi.security import OAuth2PasswordBearer
|
||||
from fastapi.encoders import jsonable_encoder
|
||||
from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
|
@ -163,70 +167,8 @@ def log_input_output(request, response, custom_logger=None):
|
|||
return True
|
||||
|
||||
from typing import Dict
|
||||
from pydantic import BaseModel
|
||||
######### Request Class Definition ######
|
||||
class ProxyChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
response_format: Optional[Dict[str, str]] = None
|
||||
seed: Optional[int] = None
|
||||
tools: Optional[List[str]] = None
|
||||
tool_choice: Optional[str] = None
|
||||
functions: Optional[List[str]] = None # soon to be deprecated
|
||||
function_call: Optional[str] = None # soon to be deprecated
|
||||
|
||||
# Optional LiteLLM params
|
||||
caching: Optional[bool] = None
|
||||
api_base: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
num_retries: Optional[int] = None
|
||||
context_window_fallback_dict: Optional[Dict[str, str]] = None
|
||||
fallbacks: Optional[List[str]] = None
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
deployment_id: Optional[str] = None
|
||||
request_timeout: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
|
||||
class ModelParams(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: Optional[dict]
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
class GenerateKeyRequest(BaseModel):
|
||||
duration: str = "1h"
|
||||
models: list = []
|
||||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: int = 0
|
||||
user_id: Optional[str] = None
|
||||
|
||||
class GenerateKeyResponse(BaseModel):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(BaseModel):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(BaseModel):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||
user_api_base = None
|
||||
user_model = None
|
||||
user_debug = False
|
||||
|
@ -249,6 +191,7 @@ master_key = None
|
|||
otel_logging = False
|
||||
prisma_client: Optional[PrismaClient] = None
|
||||
user_api_key_cache = DualCache()
|
||||
user_custom_auth = None
|
||||
### REDIS QUEUE ###
|
||||
async_result = None
|
||||
celery_app_conn = None
|
||||
|
@ -268,21 +211,21 @@ def usage_telemetry(
|
|||
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
||||
).start()
|
||||
|
||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
||||
global master_key, prisma_client, llm_model_list
|
||||
print(f"master_key - {master_key}; api_key - {api_key}")
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
return {
|
||||
"api_key": api_key.replace("Bearer ", "")
|
||||
}
|
||||
else:
|
||||
return {
|
||||
"api_key": api_key
|
||||
}
|
||||
|
||||
async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth:
|
||||
global master_key, prisma_client, llm_model_list, user_custom_auth
|
||||
try:
|
||||
### USER-DEFINED AUTH FUNCTION ###
|
||||
if user_custom_auth:
|
||||
response = await user_custom_auth(request=request, api_key=api_key)
|
||||
return UserAPIKeyAuth.model_validate(response)
|
||||
|
||||
if master_key is None:
|
||||
if isinstance(api_key, str):
|
||||
return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", ""))
|
||||
else:
|
||||
return UserAPIKeyAuth()
|
||||
if api_key is None:
|
||||
raise Exception("No api key passed in.")
|
||||
route = request.url.path
|
||||
|
@ -290,9 +233,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
|
||||
if is_master_key_valid:
|
||||
return {
|
||||
"api_key": master_key
|
||||
}
|
||||
return UserAPIKeyAuth(api_key=master_key)
|
||||
|
||||
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
||||
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
||||
|
@ -318,7 +259,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return return_dict
|
||||
return UserAPIKeyAuth(**return_dict)
|
||||
else:
|
||||
data = await request.json()
|
||||
model = data.get("model", None)
|
||||
|
@ -329,14 +270,14 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
return_dict = {"api_key": valid_token.token}
|
||||
if valid_token.user_id:
|
||||
return_dict["user_id"] = valid_token.user_id
|
||||
return return_dict
|
||||
return UserAPIKeyAuth(**return_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid token")
|
||||
except Exception as e:
|
||||
print(f"An exception occurred - {traceback.format_exc()}")
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||
detail={"error": "invalid user key"},
|
||||
detail="invalid user key",
|
||||
)
|
||||
|
||||
def prisma_setup(database_url: Optional[str]):
|
||||
|
@ -464,7 +405,7 @@ def run_ollama_serve():
|
|||
""")
|
||||
|
||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||
global master_key, user_config_file_path, otel_logging
|
||||
global master_key, user_config_file_path, otel_logging, user_custom_auth
|
||||
config = {}
|
||||
try:
|
||||
if os.path.exists(config_file_path):
|
||||
|
@ -499,7 +440,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
### LOAD FROM AZURE KEY VAULT ###
|
||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||
|
||||
### CONNECT TO DATABASE ###
|
||||
database_url = general_settings.get("database_url", None)
|
||||
if database_url and database_url.startswith("os.environ/"):
|
||||
|
@ -514,12 +454,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
master_key = general_settings.get("master_key", None)
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
|
||||
#### OpenTelemetry Logging (OTEL) ########
|
||||
otel_logging = general_settings.get("otel", False)
|
||||
if otel_logging == True:
|
||||
print("\nOpenTelemetry Logging Activated")
|
||||
|
||||
### CUSTOM API KEY AUTH ###
|
||||
custom_auth = general_settings.get("custom_auth", None)
|
||||
if custom_auth:
|
||||
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
|
||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||
litellm_settings = config.get('litellm_settings', None)
|
||||
if litellm_settings:
|
||||
|
@ -549,23 +491,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
password=cache_password
|
||||
)
|
||||
elif key == "callbacks":
|
||||
print(f"{blue_color_code}\nSetting custom callbacks on Proxy")
|
||||
passed_module, instance_name = value.split(".")
|
||||
|
||||
# Dynamically import the module
|
||||
module = importlib.import_module(passed_module)
|
||||
# Get the instance from the module
|
||||
instance = getattr(module, instance_name)
|
||||
|
||||
methods = [method for method in dir(instance) if callable(getattr(instance, method))]
|
||||
# Print the methods
|
||||
print("Methods in the custom callbacks instance:")
|
||||
for method in methods:
|
||||
print(method)
|
||||
|
||||
litellm.callbacks = [instance]
|
||||
print()
|
||||
|
||||
litellm.callbacks = [get_instance_fn(value=value)]
|
||||
else:
|
||||
setattr(litellm, key, value)
|
||||
|
||||
|
@ -844,7 +770,7 @@ def model_list():
|
|||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
|
||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
||||
try:
|
||||
body = await request.body()
|
||||
body_str = body.decode()
|
||||
|
@ -853,7 +779,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
except:
|
||||
data = json.loads(body_str)
|
||||
|
||||
data["user"] = user_api_key_dict.get("user_id", None)
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
data["model"] = (
|
||||
general_settings.get("completion_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
@ -864,9 +790,9 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
data["model"] = user_model
|
||||
data["call_type"] = "text_completion"
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
|
||||
return litellm_completion(
|
||||
**data
|
||||
|
@ -888,7 +814,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
|||
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
||||
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
global general_settings, user_debug
|
||||
try:
|
||||
data = {}
|
||||
|
@ -905,13 +831,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||
if data.get("user", None) is None:
|
||||
# if users are using user_api_key_auth, set `user` in `data`
|
||||
data["user"] = user_api_key_dict.get("user_id", None)
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
data["metadata"]["headers"] = request.headers
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
data["metadata"]["headers"] = request.headers
|
||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||
# override with user settings, these are params passed via cli
|
||||
|
@ -962,14 +888,14 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
|||
|
||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||
async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||
try:
|
||||
|
||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||
body = await request.body()
|
||||
data = orjson.loads(body)
|
||||
|
||||
data["user"] = user_api_key_dict.get("user_id", None)
|
||||
data["user"] = user_api_key_dict.user_id
|
||||
data["model"] = (
|
||||
general_settings.get("embedding_model", None) # server default
|
||||
or user_model # model name passed via cli args
|
||||
|
@ -978,9 +904,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
|
|||
if user_model:
|
||||
data["model"] = user_model
|
||||
if "metadata" in data:
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
||||
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||
else:
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
||||
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||
|
||||
## ROUTE TO CORRECT ENDPOINT ##
|
||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||
|
|
70
litellm/proxy/types.py
Normal file
70
litellm/proxy/types.py
Normal file
|
@ -0,0 +1,70 @@
|
|||
from pydantic import BaseModel
|
||||
from typing import Optional, List, Union, Dict
|
||||
from datetime import datetime
|
||||
|
||||
######### Request Class Definition ######
|
||||
class ProxyChatCompletionRequest(BaseModel):
|
||||
model: str
|
||||
messages: List[Dict[str, str]]
|
||||
temperature: Optional[float] = None
|
||||
top_p: Optional[float] = None
|
||||
n: Optional[int] = None
|
||||
stream: Optional[bool] = None
|
||||
stop: Optional[List[str]] = None
|
||||
max_tokens: Optional[int] = None
|
||||
presence_penalty: Optional[float] = None
|
||||
frequency_penalty: Optional[float] = None
|
||||
logit_bias: Optional[Dict[str, float]] = None
|
||||
user: Optional[str] = None
|
||||
response_format: Optional[Dict[str, str]] = None
|
||||
seed: Optional[int] = None
|
||||
tools: Optional[List[str]] = None
|
||||
tool_choice: Optional[str] = None
|
||||
functions: Optional[List[str]] = None # soon to be deprecated
|
||||
function_call: Optional[str] = None # soon to be deprecated
|
||||
|
||||
# Optional LiteLLM params
|
||||
caching: Optional[bool] = None
|
||||
api_base: Optional[str] = None
|
||||
api_version: Optional[str] = None
|
||||
api_key: Optional[str] = None
|
||||
num_retries: Optional[int] = None
|
||||
context_window_fallback_dict: Optional[Dict[str, str]] = None
|
||||
fallbacks: Optional[List[str]] = None
|
||||
metadata: Optional[Dict[str, str]] = {}
|
||||
deployment_id: Optional[str] = None
|
||||
request_timeout: Optional[int] = None
|
||||
|
||||
class Config:
|
||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||
|
||||
class ModelParams(BaseModel):
|
||||
model_name: str
|
||||
litellm_params: dict
|
||||
model_info: Optional[dict]
|
||||
class Config:
|
||||
protected_namespaces = ()
|
||||
|
||||
class GenerateKeyRequest(BaseModel):
|
||||
duration: str = "1h"
|
||||
models: list = []
|
||||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: int = 0
|
||||
user_id: Optional[str] = None
|
||||
|
||||
class GenerateKeyResponse(BaseModel):
|
||||
key: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(BaseModel):
|
||||
key: str
|
||||
|
||||
class DeleteKeyRequest(BaseModel):
|
||||
keys: List[_DeleteKeyObject]
|
||||
|
||||
|
||||
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
|
||||
api_key: Optional[str] = None
|
||||
user_id: Optional[str] = None
|
|
@ -1,6 +1,7 @@
|
|||
from typing import Optional, List, Any
|
||||
import os, subprocess, hashlib
|
||||
import os, subprocess, hashlib, importlib
|
||||
|
||||
### DB CONNECTOR ###
|
||||
class PrismaClient:
|
||||
def __init__(self, database_url: str):
|
||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||
|
@ -95,3 +96,60 @@ class PrismaClient:
|
|||
async def disconnect(self):
|
||||
await self.db.disconnect()
|
||||
|
||||
# ### CUSTOM FILE ###
|
||||
# def get_instance_fn(value: str, config_file_path: Optional[str]=None):
|
||||
# try:
|
||||
# # Split the path by dots to separate module from instance
|
||||
# parts = value.split(".")
|
||||
# # The module path is all but the last part, and the instance is the last part
|
||||
# module_path = ".".join(parts[:-1])
|
||||
# instance_name = parts[-1]
|
||||
|
||||
# if config_file_path is not None:
|
||||
# directory = os.path.dirname(config_file_path)
|
||||
# module_path = os.path.join(directory, module_path)
|
||||
# # Dynamically import the module
|
||||
# module = importlib.import_module(module_path)
|
||||
|
||||
# # Get the instance from the module
|
||||
# instance = getattr(module, instance_name)
|
||||
|
||||
# return instance
|
||||
# except ImportError as e:
|
||||
# print(e)
|
||||
# raise ImportError(f"Could not import file at {value}")
|
||||
|
||||
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||
try:
|
||||
print(f"value: {value}")
|
||||
# Split the path by dots to separate module from instance
|
||||
parts = value.split(".")
|
||||
|
||||
# The module path is all but the last part, and the instance_name is the last part
|
||||
module_name = ".".join(parts[:-1])
|
||||
instance_name = parts[-1]
|
||||
|
||||
# If config_file_path is provided, use it to determine the module spec and load the module
|
||||
if config_file_path is not None:
|
||||
directory = os.path.dirname(config_file_path)
|
||||
module_file_path = os.path.join(directory, *module_name.split('.'))
|
||||
module_file_path += '.py'
|
||||
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
|
||||
if spec is None:
|
||||
raise ImportError(f"Could not find a module specification for {module_file_path}")
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
else:
|
||||
# Dynamically import the module
|
||||
module = importlib.import_module(module_name)
|
||||
|
||||
# Get the instance from the module
|
||||
instance = getattr(module, instance_name)
|
||||
|
||||
return instance
|
||||
except ImportError as e:
|
||||
# Print the error message for easier debugging
|
||||
print(e)
|
||||
# Re-raise the exception with a user-friendly message
|
||||
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
|
14
litellm/tests/test_configs/custom_auth.py
Normal file
14
litellm/tests/test_configs/custom_auth.py
Normal file
|
@ -0,0 +1,14 @@
|
|||
from litellm.proxy.types import UserAPIKeyAuth
|
||||
from fastapi import Request
|
||||
from dotenv import load_dotenv
|
||||
import os
|
||||
|
||||
load_dotenv()
|
||||
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||
try:
|
||||
print(f"api_key: {api_key}")
|
||||
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
|
||||
return UserAPIKeyAuth(api_key=api_key)
|
||||
raise Exception
|
||||
except:
|
||||
raise Exception
|
11
litellm/tests/test_configs/test_config_custom_auth.yaml
Normal file
11
litellm/tests/test_configs/test_config_custom_auth.yaml
Normal file
|
@ -0,0 +1,11 @@
|
|||
model_list:
|
||||
- model_name: "openai-model"
|
||||
litellm_params:
|
||||
model: "gpt-3.5-turbo"
|
||||
|
||||
litellm_settings:
|
||||
drop_params: True
|
||||
set_verbose: True
|
||||
|
||||
general_settings:
|
||||
custom_auth: custom_auth.user_api_key_auth
|
63
litellm/tests/test_proxy_custom_auth.py
Normal file
63
litellm/tests/test_proxy_custom_auth.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
import sys, os
|
||||
import traceback
|
||||
from dotenv import load_dotenv
|
||||
|
||||
load_dotenv()
|
||||
import os, io
|
||||
|
||||
# this file is to test litellm/proxy
|
||||
|
||||
sys.path.insert(
|
||||
0, os.path.abspath("../..")
|
||||
) # Adds the parent directory to the system path
|
||||
import pytest
|
||||
import litellm
|
||||
from litellm import embedding, completion, completion_cost, Timeout
|
||||
from litellm import RateLimitError
|
||||
|
||||
# test /chat/completion request to the proxy
|
||||
from fastapi.testclient import TestClient
|
||||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
@app.on_event("startup")
|
||||
async def wrapper_startup_event():
|
||||
await startup_event()
|
||||
|
||||
# Here you create a fixture that will be used by your tests
|
||||
# Make sure the fixture returns TestClient(app)
|
||||
@pytest.fixture(autouse=True)
|
||||
def client():
|
||||
with TestClient(app) as client:
|
||||
yield client
|
||||
|
||||
def test_custom_auth(client):
|
||||
try:
|
||||
# Your test data
|
||||
test_data = {
|
||||
"model": "openai-model",
|
||||
"messages": [
|
||||
{
|
||||
"role": "user",
|
||||
"content": "hi"
|
||||
},
|
||||
],
|
||||
"max_tokens": 10,
|
||||
}
|
||||
# Your bearer token
|
||||
token = os.getenv("PROXY_MASTER_KEY")
|
||||
|
||||
headers = {
|
||||
"Authorization": f"Bearer {token}"
|
||||
}
|
||||
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||
print(f"response: {response.text}")
|
||||
assert response.status_code == 401
|
||||
result = response.json()
|
||||
print(f"Received response: {result}")
|
||||
except Exception as e:
|
||||
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
|
@ -25,7 +25,7 @@ from fastapi.testclient import TestClient
|
|||
from fastapi import FastAPI
|
||||
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||
config_fp = f"{filepath}/test_config.yaml"
|
||||
config_fp = f"{filepath}/test_configs/test_config.yaml"
|
||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||
app = FastAPI()
|
||||
app.include_router(router) # Include your router in the test app
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue