mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(proxy_server.py): allow user to override api key auth
This commit is contained in:
parent
51cddf1e97
commit
030bd22078
9 changed files with 274 additions and 118 deletions
14
litellm/proxy/custom_auth.py
Normal file
14
litellm/proxy/custom_auth.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
from litellm.proxy.types import UserAPIKeyAuth
|
||||||
|
from fastapi import Request
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||||
|
try:
|
||||||
|
modified_master_key = f"{os.getenv('PROXY_MASTER_KEY')}-1234"
|
||||||
|
if api_key == modified_master_key:
|
||||||
|
return UserAPIKeyAuth(api_key=api_key)
|
||||||
|
raise Exception
|
||||||
|
except:
|
||||||
|
raise Exception
|
|
@ -92,12 +92,16 @@ def generate_feedback_box():
|
||||||
|
|
||||||
import litellm
|
import litellm
|
||||||
from litellm.proxy.utils import (
|
from litellm.proxy.utils import (
|
||||||
PrismaClient
|
PrismaClient,
|
||||||
|
get_instance_fn
|
||||||
)
|
)
|
||||||
|
import pydantic
|
||||||
|
from litellm.proxy.types import *
|
||||||
from litellm.caching import DualCache
|
from litellm.caching import DualCache
|
||||||
litellm.suppress_debug_info = True
|
litellm.suppress_debug_info = True
|
||||||
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
from fastapi import FastAPI, Request, HTTPException, status, Depends, BackgroundTasks
|
||||||
from fastapi.routing import APIRouter
|
from fastapi.routing import APIRouter
|
||||||
|
from fastapi.security import OAuth2PasswordBearer
|
||||||
from fastapi.encoders import jsonable_encoder
|
from fastapi.encoders import jsonable_encoder
|
||||||
from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse
|
from fastapi.responses import StreamingResponse, FileResponse, ORJSONResponse
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
@ -163,70 +167,8 @@ def log_input_output(request, response, custom_logger=None):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
from typing import Dict
|
from typing import Dict
|
||||||
from pydantic import BaseModel
|
|
||||||
######### Request Class Definition ######
|
|
||||||
class ProxyChatCompletionRequest(BaseModel):
|
|
||||||
model: str
|
|
||||||
messages: List[Dict[str, str]]
|
|
||||||
temperature: Optional[float] = None
|
|
||||||
top_p: Optional[float] = None
|
|
||||||
n: Optional[int] = None
|
|
||||||
stream: Optional[bool] = None
|
|
||||||
stop: Optional[List[str]] = None
|
|
||||||
max_tokens: Optional[int] = None
|
|
||||||
presence_penalty: Optional[float] = None
|
|
||||||
frequency_penalty: Optional[float] = None
|
|
||||||
logit_bias: Optional[Dict[str, float]] = None
|
|
||||||
user: Optional[str] = None
|
|
||||||
response_format: Optional[Dict[str, str]] = None
|
|
||||||
seed: Optional[int] = None
|
|
||||||
tools: Optional[List[str]] = None
|
|
||||||
tool_choice: Optional[str] = None
|
|
||||||
functions: Optional[List[str]] = None # soon to be deprecated
|
|
||||||
function_call: Optional[str] = None # soon to be deprecated
|
|
||||||
|
|
||||||
# Optional LiteLLM params
|
|
||||||
caching: Optional[bool] = None
|
|
||||||
api_base: Optional[str] = None
|
|
||||||
api_version: Optional[str] = None
|
|
||||||
api_key: Optional[str] = None
|
|
||||||
num_retries: Optional[int] = None
|
|
||||||
context_window_fallback_dict: Optional[Dict[str, str]] = None
|
|
||||||
fallbacks: Optional[List[str]] = None
|
|
||||||
metadata: Optional[Dict[str, str]] = {}
|
|
||||||
deployment_id: Optional[str] = None
|
|
||||||
request_timeout: Optional[int] = None
|
|
||||||
|
|
||||||
class Config:
|
|
||||||
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
|
||||||
|
|
||||||
class ModelParams(BaseModel):
|
|
||||||
model_name: str
|
|
||||||
litellm_params: dict
|
|
||||||
model_info: Optional[dict]
|
|
||||||
class Config:
|
|
||||||
protected_namespaces = ()
|
|
||||||
|
|
||||||
class GenerateKeyRequest(BaseModel):
|
|
||||||
duration: str = "1h"
|
|
||||||
models: list = []
|
|
||||||
aliases: dict = {}
|
|
||||||
config: dict = {}
|
|
||||||
spend: int = 0
|
|
||||||
user_id: Optional[str] = None
|
|
||||||
|
|
||||||
class GenerateKeyResponse(BaseModel):
|
|
||||||
key: str
|
|
||||||
expires: datetime
|
|
||||||
user_id: str
|
|
||||||
|
|
||||||
class _DeleteKeyObject(BaseModel):
|
|
||||||
key: str
|
|
||||||
|
|
||||||
class DeleteKeyRequest(BaseModel):
|
|
||||||
keys: List[_DeleteKeyObject]
|
|
||||||
|
|
||||||
|
|
||||||
|
oauth2_scheme = OAuth2PasswordBearer(tokenUrl="token")
|
||||||
user_api_base = None
|
user_api_base = None
|
||||||
user_model = None
|
user_model = None
|
||||||
user_debug = False
|
user_debug = False
|
||||||
|
@ -249,6 +191,7 @@ master_key = None
|
||||||
otel_logging = False
|
otel_logging = False
|
||||||
prisma_client: Optional[PrismaClient] = None
|
prisma_client: Optional[PrismaClient] = None
|
||||||
user_api_key_cache = DualCache()
|
user_api_key_cache = DualCache()
|
||||||
|
user_custom_auth = None
|
||||||
### REDIS QUEUE ###
|
### REDIS QUEUE ###
|
||||||
async_result = None
|
async_result = None
|
||||||
celery_app_conn = None
|
celery_app_conn = None
|
||||||
|
@ -268,21 +211,21 @@ def usage_telemetry(
|
||||||
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
target=litellm.utils.litellm_telemetry, args=(data,), daemon=True
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
api_key_header = APIKeyHeader(name="Authorization", auto_error=False)
|
|
||||||
|
|
||||||
async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(api_key_header)):
|
|
||||||
global master_key, prisma_client, llm_model_list
|
async def user_api_key_auth(request: Request, api_key: str = Depends(oauth2_scheme)) -> UserAPIKeyAuth:
|
||||||
print(f"master_key - {master_key}; api_key - {api_key}")
|
global master_key, prisma_client, llm_model_list, user_custom_auth
|
||||||
if master_key is None:
|
|
||||||
if isinstance(api_key, str):
|
|
||||||
return {
|
|
||||||
"api_key": api_key.replace("Bearer ", "")
|
|
||||||
}
|
|
||||||
else:
|
|
||||||
return {
|
|
||||||
"api_key": api_key
|
|
||||||
}
|
|
||||||
try:
|
try:
|
||||||
|
### USER-DEFINED AUTH FUNCTION ###
|
||||||
|
if user_custom_auth:
|
||||||
|
response = await user_custom_auth(request=request, api_key=api_key)
|
||||||
|
return UserAPIKeyAuth.model_validate(response)
|
||||||
|
|
||||||
|
if master_key is None:
|
||||||
|
if isinstance(api_key, str):
|
||||||
|
return UserAPIKeyAuth(api_key=api_key.replace("Bearer ", ""))
|
||||||
|
else:
|
||||||
|
return UserAPIKeyAuth()
|
||||||
if api_key is None:
|
if api_key is None:
|
||||||
raise Exception("No api key passed in.")
|
raise Exception("No api key passed in.")
|
||||||
route = request.url.path
|
route = request.url.path
|
||||||
|
@ -290,9 +233,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||||
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
|
is_master_key_valid = secrets.compare_digest(api_key, master_key) or secrets.compare_digest(api_key, "Bearer " + master_key)
|
||||||
if is_master_key_valid:
|
if is_master_key_valid:
|
||||||
return {
|
return UserAPIKeyAuth(api_key=master_key)
|
||||||
"api_key": master_key
|
|
||||||
}
|
|
||||||
|
|
||||||
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
if (route == "/key/generate" or route == "/key/delete" or route == "/key/info") and not is_master_key_valid:
|
||||||
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
raise Exception(f"If master key is set, only master key can be used to generate, delete or get info for new keys")
|
||||||
|
@ -318,7 +259,7 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
return_dict = {"api_key": valid_token.token}
|
return_dict = {"api_key": valid_token.token}
|
||||||
if valid_token.user_id:
|
if valid_token.user_id:
|
||||||
return_dict["user_id"] = valid_token.user_id
|
return_dict["user_id"] = valid_token.user_id
|
||||||
return return_dict
|
return UserAPIKeyAuth(**return_dict)
|
||||||
else:
|
else:
|
||||||
data = await request.json()
|
data = await request.json()
|
||||||
model = data.get("model", None)
|
model = data.get("model", None)
|
||||||
|
@ -329,14 +270,14 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
||||||
return_dict = {"api_key": valid_token.token}
|
return_dict = {"api_key": valid_token.token}
|
||||||
if valid_token.user_id:
|
if valid_token.user_id:
|
||||||
return_dict["user_id"] = valid_token.user_id
|
return_dict["user_id"] = valid_token.user_id
|
||||||
return return_dict
|
return UserAPIKeyAuth(**return_dict)
|
||||||
else:
|
else:
|
||||||
raise Exception(f"Invalid token")
|
raise Exception(f"Invalid token")
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(f"An exception occurred - {traceback.format_exc()}")
|
print(f"An exception occurred - {traceback.format_exc()}")
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail={"error": "invalid user key"},
|
detail="invalid user key",
|
||||||
)
|
)
|
||||||
|
|
||||||
def prisma_setup(database_url: Optional[str]):
|
def prisma_setup(database_url: Optional[str]):
|
||||||
|
@ -464,7 +405,7 @@ def run_ollama_serve():
|
||||||
""")
|
""")
|
||||||
|
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
global master_key, user_config_file_path, otel_logging
|
global master_key, user_config_file_path, otel_logging, user_custom_auth
|
||||||
config = {}
|
config = {}
|
||||||
try:
|
try:
|
||||||
if os.path.exists(config_file_path):
|
if os.path.exists(config_file_path):
|
||||||
|
@ -499,7 +440,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
### LOAD FROM AZURE KEY VAULT ###
|
### LOAD FROM AZURE KEY VAULT ###
|
||||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||||
|
|
||||||
### CONNECT TO DATABASE ###
|
### CONNECT TO DATABASE ###
|
||||||
database_url = general_settings.get("database_url", None)
|
database_url = general_settings.get("database_url", None)
|
||||||
if database_url and database_url.startswith("os.environ/"):
|
if database_url and database_url.startswith("os.environ/"):
|
||||||
|
@ -514,12 +454,14 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
master_key = general_settings.get("master_key", None)
|
master_key = general_settings.get("master_key", None)
|
||||||
if master_key and master_key.startswith("os.environ/"):
|
if master_key and master_key.startswith("os.environ/"):
|
||||||
master_key = litellm.get_secret(master_key)
|
master_key = litellm.get_secret(master_key)
|
||||||
|
|
||||||
#### OpenTelemetry Logging (OTEL) ########
|
#### OpenTelemetry Logging (OTEL) ########
|
||||||
otel_logging = general_settings.get("otel", False)
|
otel_logging = general_settings.get("otel", False)
|
||||||
if otel_logging == True:
|
if otel_logging == True:
|
||||||
print("\nOpenTelemetry Logging Activated")
|
print("\nOpenTelemetry Logging Activated")
|
||||||
|
### CUSTOM API KEY AUTH ###
|
||||||
|
custom_auth = general_settings.get("custom_auth", None)
|
||||||
|
if custom_auth:
|
||||||
|
user_custom_auth = get_instance_fn(value=custom_auth, config_file_path=config_file_path)
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
litellm_settings = config.get('litellm_settings', None)
|
litellm_settings = config.get('litellm_settings', None)
|
||||||
if litellm_settings:
|
if litellm_settings:
|
||||||
|
@ -549,23 +491,7 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
password=cache_password
|
password=cache_password
|
||||||
)
|
)
|
||||||
elif key == "callbacks":
|
elif key == "callbacks":
|
||||||
print(f"{blue_color_code}\nSetting custom callbacks on Proxy")
|
litellm.callbacks = [get_instance_fn(value=value)]
|
||||||
passed_module, instance_name = value.split(".")
|
|
||||||
|
|
||||||
# Dynamically import the module
|
|
||||||
module = importlib.import_module(passed_module)
|
|
||||||
# Get the instance from the module
|
|
||||||
instance = getattr(module, instance_name)
|
|
||||||
|
|
||||||
methods = [method for method in dir(instance) if callable(getattr(instance, method))]
|
|
||||||
# Print the methods
|
|
||||||
print("Methods in the custom callbacks instance:")
|
|
||||||
for method in methods:
|
|
||||||
print(method)
|
|
||||||
|
|
||||||
litellm.callbacks = [instance]
|
|
||||||
print()
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
setattr(litellm, key, value)
|
setattr(litellm, key, value)
|
||||||
|
|
||||||
|
@ -844,7 +770,7 @@ def model_list():
|
||||||
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/v1/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
@router.post("/engines/{model:path}/completions", dependencies=[Depends(user_api_key_auth)])
|
||||||
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth)):
|
async def completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth)):
|
||||||
try:
|
try:
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
body_str = body.decode()
|
body_str = body.decode()
|
||||||
|
@ -853,7 +779,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
except:
|
except:
|
||||||
data = json.loads(body_str)
|
data = json.loads(body_str)
|
||||||
|
|
||||||
data["user"] = user_api_key_dict.get("user_id", None)
|
data["user"] = user_api_key_dict.user_id
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
general_settings.get("completion_model", None) # server default
|
general_settings.get("completion_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
|
@ -864,9 +790,9 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
data["call_type"] = "text_completion"
|
data["call_type"] = "text_completion"
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
|
|
||||||
return litellm_completion(
|
return litellm_completion(
|
||||||
**data
|
**data
|
||||||
|
@ -888,7 +814,7 @@ async def completion(request: Request, model: Optional[str] = None, user_api_key
|
||||||
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
@router.post("/v1/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||||
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
@router.post("/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"])
|
||||||
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
@router.post("/openai/deployments/{model:path}/chat/completions", dependencies=[Depends(user_api_key_auth)], tags=["chat/completions"]) # azure compatible endpoint
|
||||||
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
async def chat_completion(request: Request, model: Optional[str] = None, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||||
global general_settings, user_debug
|
global general_settings, user_debug
|
||||||
try:
|
try:
|
||||||
data = {}
|
data = {}
|
||||||
|
@ -905,13 +831,13 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
# users can pass in 'user' param to /chat/completions. Don't override it
|
# users can pass in 'user' param to /chat/completions. Don't override it
|
||||||
if data.get("user", None) is None:
|
if data.get("user", None) is None:
|
||||||
# if users are using user_api_key_auth, set `user` in `data`
|
# if users are using user_api_key_auth, set `user` in `data`
|
||||||
data["user"] = user_api_key_dict.get("user_id", None)
|
data["user"] = user_api_key_dict.user_id
|
||||||
|
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
data["metadata"]["headers"] = request.headers
|
data["metadata"]["headers"] = request.headers
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
data["metadata"]["headers"] = request.headers
|
data["metadata"]["headers"] = request.headers
|
||||||
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
global user_temperature, user_request_timeout, user_max_tokens, user_api_base
|
||||||
# override with user settings, these are params passed via cli
|
# override with user settings, these are params passed via cli
|
||||||
|
@ -962,14 +888,14 @@ async def chat_completion(request: Request, model: Optional[str] = None, user_ap
|
||||||
|
|
||||||
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
@router.post("/v1/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||||
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
@router.post("/embeddings", dependencies=[Depends(user_api_key_auth)], response_class=ORJSONResponse)
|
||||||
async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
async def embeddings(request: Request, user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth), background_tasks: BackgroundTasks = BackgroundTasks()):
|
||||||
try:
|
try:
|
||||||
|
|
||||||
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
# Use orjson to parse JSON data, orjson speeds up requests significantly
|
||||||
body = await request.body()
|
body = await request.body()
|
||||||
data = orjson.loads(body)
|
data = orjson.loads(body)
|
||||||
|
|
||||||
data["user"] = user_api_key_dict.get("user_id", None)
|
data["user"] = user_api_key_dict.user_id
|
||||||
data["model"] = (
|
data["model"] = (
|
||||||
general_settings.get("embedding_model", None) # server default
|
general_settings.get("embedding_model", None) # server default
|
||||||
or user_model # model name passed via cli args
|
or user_model # model name passed via cli args
|
||||||
|
@ -978,9 +904,9 @@ async def embeddings(request: Request, user_api_key_dict: dict = Depends(user_ap
|
||||||
if user_model:
|
if user_model:
|
||||||
data["model"] = user_model
|
data["model"] = user_model
|
||||||
if "metadata" in data:
|
if "metadata" in data:
|
||||||
data["metadata"]["user_api_key"] = user_api_key_dict["api_key"]
|
data["metadata"]["user_api_key"] = user_api_key_dict.api_key
|
||||||
else:
|
else:
|
||||||
data["metadata"] = {"user_api_key": user_api_key_dict["api_key"]}
|
data["metadata"] = {"user_api_key": user_api_key_dict.api_key}
|
||||||
|
|
||||||
## ROUTE TO CORRECT ENDPOINT ##
|
## ROUTE TO CORRECT ENDPOINT ##
|
||||||
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
router_model_names = [m["model_name"] for m in llm_model_list] if llm_model_list is not None else []
|
||||||
|
|
70
litellm/proxy/types.py
Normal file
70
litellm/proxy/types.py
Normal file
|
@ -0,0 +1,70 @@
|
||||||
|
from pydantic import BaseModel
|
||||||
|
from typing import Optional, List, Union, Dict
|
||||||
|
from datetime import datetime
|
||||||
|
|
||||||
|
######### Request Class Definition ######
|
||||||
|
class ProxyChatCompletionRequest(BaseModel):
|
||||||
|
model: str
|
||||||
|
messages: List[Dict[str, str]]
|
||||||
|
temperature: Optional[float] = None
|
||||||
|
top_p: Optional[float] = None
|
||||||
|
n: Optional[int] = None
|
||||||
|
stream: Optional[bool] = None
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
max_tokens: Optional[int] = None
|
||||||
|
presence_penalty: Optional[float] = None
|
||||||
|
frequency_penalty: Optional[float] = None
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None
|
||||||
|
user: Optional[str] = None
|
||||||
|
response_format: Optional[Dict[str, str]] = None
|
||||||
|
seed: Optional[int] = None
|
||||||
|
tools: Optional[List[str]] = None
|
||||||
|
tool_choice: Optional[str] = None
|
||||||
|
functions: Optional[List[str]] = None # soon to be deprecated
|
||||||
|
function_call: Optional[str] = None # soon to be deprecated
|
||||||
|
|
||||||
|
# Optional LiteLLM params
|
||||||
|
caching: Optional[bool] = None
|
||||||
|
api_base: Optional[str] = None
|
||||||
|
api_version: Optional[str] = None
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
num_retries: Optional[int] = None
|
||||||
|
context_window_fallback_dict: Optional[Dict[str, str]] = None
|
||||||
|
fallbacks: Optional[List[str]] = None
|
||||||
|
metadata: Optional[Dict[str, str]] = {}
|
||||||
|
deployment_id: Optional[str] = None
|
||||||
|
request_timeout: Optional[int] = None
|
||||||
|
|
||||||
|
class Config:
|
||||||
|
extra='allow' # allow params not defined here, these fall in litellm.completion(**kwargs)
|
||||||
|
|
||||||
|
class ModelParams(BaseModel):
|
||||||
|
model_name: str
|
||||||
|
litellm_params: dict
|
||||||
|
model_info: Optional[dict]
|
||||||
|
class Config:
|
||||||
|
protected_namespaces = ()
|
||||||
|
|
||||||
|
class GenerateKeyRequest(BaseModel):
|
||||||
|
duration: str = "1h"
|
||||||
|
models: list = []
|
||||||
|
aliases: dict = {}
|
||||||
|
config: dict = {}
|
||||||
|
spend: int = 0
|
||||||
|
user_id: Optional[str] = None
|
||||||
|
|
||||||
|
class GenerateKeyResponse(BaseModel):
|
||||||
|
key: str
|
||||||
|
expires: datetime
|
||||||
|
user_id: str
|
||||||
|
|
||||||
|
class _DeleteKeyObject(BaseModel):
|
||||||
|
key: str
|
||||||
|
|
||||||
|
class DeleteKeyRequest(BaseModel):
|
||||||
|
keys: List[_DeleteKeyObject]
|
||||||
|
|
||||||
|
|
||||||
|
class UserAPIKeyAuth(BaseModel): # the expected response object for user api key auth
|
||||||
|
api_key: Optional[str] = None
|
||||||
|
user_id: Optional[str] = None
|
|
@ -1,6 +1,7 @@
|
||||||
from typing import Optional, List, Any
|
from typing import Optional, List, Any
|
||||||
import os, subprocess, hashlib
|
import os, subprocess, hashlib, importlib
|
||||||
|
|
||||||
|
### DB CONNECTOR ###
|
||||||
class PrismaClient:
|
class PrismaClient:
|
||||||
def __init__(self, database_url: str):
|
def __init__(self, database_url: str):
|
||||||
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
print("LiteLLM: DATABASE_URL Set in config, trying to 'pip install prisma'")
|
||||||
|
@ -95,3 +96,60 @@ class PrismaClient:
|
||||||
async def disconnect(self):
|
async def disconnect(self):
|
||||||
await self.db.disconnect()
|
await self.db.disconnect()
|
||||||
|
|
||||||
|
# ### CUSTOM FILE ###
|
||||||
|
# def get_instance_fn(value: str, config_file_path: Optional[str]=None):
|
||||||
|
# try:
|
||||||
|
# # Split the path by dots to separate module from instance
|
||||||
|
# parts = value.split(".")
|
||||||
|
# # The module path is all but the last part, and the instance is the last part
|
||||||
|
# module_path = ".".join(parts[:-1])
|
||||||
|
# instance_name = parts[-1]
|
||||||
|
|
||||||
|
# if config_file_path is not None:
|
||||||
|
# directory = os.path.dirname(config_file_path)
|
||||||
|
# module_path = os.path.join(directory, module_path)
|
||||||
|
# # Dynamically import the module
|
||||||
|
# module = importlib.import_module(module_path)
|
||||||
|
|
||||||
|
# # Get the instance from the module
|
||||||
|
# instance = getattr(module, instance_name)
|
||||||
|
|
||||||
|
# return instance
|
||||||
|
# except ImportError as e:
|
||||||
|
# print(e)
|
||||||
|
# raise ImportError(f"Could not import file at {value}")
|
||||||
|
|
||||||
|
def get_instance_fn(value: str, config_file_path: Optional[str] = None) -> Any:
|
||||||
|
try:
|
||||||
|
print(f"value: {value}")
|
||||||
|
# Split the path by dots to separate module from instance
|
||||||
|
parts = value.split(".")
|
||||||
|
|
||||||
|
# The module path is all but the last part, and the instance_name is the last part
|
||||||
|
module_name = ".".join(parts[:-1])
|
||||||
|
instance_name = parts[-1]
|
||||||
|
|
||||||
|
# If config_file_path is provided, use it to determine the module spec and load the module
|
||||||
|
if config_file_path is not None:
|
||||||
|
directory = os.path.dirname(config_file_path)
|
||||||
|
module_file_path = os.path.join(directory, *module_name.split('.'))
|
||||||
|
module_file_path += '.py'
|
||||||
|
|
||||||
|
spec = importlib.util.spec_from_file_location(module_name, module_file_path)
|
||||||
|
if spec is None:
|
||||||
|
raise ImportError(f"Could not find a module specification for {module_file_path}")
|
||||||
|
module = importlib.util.module_from_spec(spec)
|
||||||
|
spec.loader.exec_module(module)
|
||||||
|
else:
|
||||||
|
# Dynamically import the module
|
||||||
|
module = importlib.import_module(module_name)
|
||||||
|
|
||||||
|
# Get the instance from the module
|
||||||
|
instance = getattr(module, instance_name)
|
||||||
|
|
||||||
|
return instance
|
||||||
|
except ImportError as e:
|
||||||
|
# Print the error message for easier debugging
|
||||||
|
print(e)
|
||||||
|
# Re-raise the exception with a user-friendly message
|
||||||
|
raise ImportError(f"Could not import {instance_name} from {module_name}") from e
|
14
litellm/tests/test_configs/custom_auth.py
Normal file
14
litellm/tests/test_configs/custom_auth.py
Normal file
|
@ -0,0 +1,14 @@
|
||||||
|
from litellm.proxy.types import UserAPIKeyAuth
|
||||||
|
from fastapi import Request
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
import os
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
async def user_api_key_auth(request: Request, api_key: str) -> UserAPIKeyAuth:
|
||||||
|
try:
|
||||||
|
print(f"api_key: {api_key}")
|
||||||
|
if api_key == f"{os.getenv('PROXY_MASTER_KEY')}-1234":
|
||||||
|
return UserAPIKeyAuth(api_key=api_key)
|
||||||
|
raise Exception
|
||||||
|
except:
|
||||||
|
raise Exception
|
11
litellm/tests/test_configs/test_config_custom_auth.yaml
Normal file
11
litellm/tests/test_configs/test_config_custom_auth.yaml
Normal file
|
@ -0,0 +1,11 @@
|
||||||
|
model_list:
|
||||||
|
- model_name: "openai-model"
|
||||||
|
litellm_params:
|
||||||
|
model: "gpt-3.5-turbo"
|
||||||
|
|
||||||
|
litellm_settings:
|
||||||
|
drop_params: True
|
||||||
|
set_verbose: True
|
||||||
|
|
||||||
|
general_settings:
|
||||||
|
custom_auth: custom_auth.user_api_key_auth
|
63
litellm/tests/test_proxy_custom_auth.py
Normal file
63
litellm/tests/test_proxy_custom_auth.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
import sys, os
|
||||||
|
import traceback
|
||||||
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
|
load_dotenv()
|
||||||
|
import os, io
|
||||||
|
|
||||||
|
# this file is to test litellm/proxy
|
||||||
|
|
||||||
|
sys.path.insert(
|
||||||
|
0, os.path.abspath("../..")
|
||||||
|
) # Adds the parent directory to the system path
|
||||||
|
import pytest
|
||||||
|
import litellm
|
||||||
|
from litellm import embedding, completion, completion_cost, Timeout
|
||||||
|
from litellm import RateLimitError
|
||||||
|
|
||||||
|
# test /chat/completion request to the proxy
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
||||||
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
config_fp = f"{filepath}/test_configs/test_config_custom_auth.yaml"
|
||||||
|
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||||
|
app = FastAPI()
|
||||||
|
app.include_router(router) # Include your router in the test app
|
||||||
|
@app.on_event("startup")
|
||||||
|
async def wrapper_startup_event():
|
||||||
|
await startup_event()
|
||||||
|
|
||||||
|
# Here you create a fixture that will be used by your tests
|
||||||
|
# Make sure the fixture returns TestClient(app)
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def client():
|
||||||
|
with TestClient(app) as client:
|
||||||
|
yield client
|
||||||
|
|
||||||
|
def test_custom_auth(client):
|
||||||
|
try:
|
||||||
|
# Your test data
|
||||||
|
test_data = {
|
||||||
|
"model": "openai-model",
|
||||||
|
"messages": [
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "hi"
|
||||||
|
},
|
||||||
|
],
|
||||||
|
"max_tokens": 10,
|
||||||
|
}
|
||||||
|
# Your bearer token
|
||||||
|
token = os.getenv("PROXY_MASTER_KEY")
|
||||||
|
|
||||||
|
headers = {
|
||||||
|
"Authorization": f"Bearer {token}"
|
||||||
|
}
|
||||||
|
response = client.post("/chat/completions", json=test_data, headers=headers)
|
||||||
|
print(f"response: {response.text}")
|
||||||
|
assert response.status_code == 401
|
||||||
|
result = response.json()
|
||||||
|
print(f"Received response: {result}")
|
||||||
|
except Exception as e:
|
||||||
|
pytest.fail("LiteLLM Proxy test failed. Exception", e)
|
|
@ -25,7 +25,7 @@ from fastapi.testclient import TestClient
|
||||||
from fastapi import FastAPI
|
from fastapi import FastAPI
|
||||||
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
from litellm.proxy.proxy_server import router, save_worker_config, startup_event # Replace with the actual module where your FastAPI router is defined
|
||||||
filepath = os.path.dirname(os.path.abspath(__file__))
|
filepath = os.path.dirname(os.path.abspath(__file__))
|
||||||
config_fp = f"{filepath}/test_config.yaml"
|
config_fp = f"{filepath}/test_configs/test_config.yaml"
|
||||||
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
save_worker_config(config=config_fp, model=None, alias=None, api_base=None, api_version=None, debug=False, temperature=None, max_tokens=None, request_timeout=600, max_budget=None, telemetry=False, drop_params=True, add_function_to_prompt=False, headers=None, save=False, use_queue=False)
|
||||||
app = FastAPI()
|
app = FastAPI()
|
||||||
app.include_router(router) # Include your router in the test app
|
app.include_router(router) # Include your router in the test app
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue