mirror of
https://github.com/BerriAI/litellm.git
synced 2025-04-25 18:54:30 +00:00
feat(proxy_server.py): enable token based authentication for server endpoints
This commit is contained in:
parent
8ae855e008
commit
c02794d3ff
1 changed files with 39 additions and 16 deletions
|
@ -134,6 +134,7 @@ server_settings: dict = {}
|
||||||
log_file = "api_log.json"
|
log_file = "api_log.json"
|
||||||
worker_config = None
|
worker_config = None
|
||||||
master_key = None
|
master_key = None
|
||||||
|
prisma_client = None
|
||||||
#### HELPER FUNCTIONS ####
|
#### HELPER FUNCTIONS ####
|
||||||
def print_verbose(print_statement):
|
def print_verbose(print_statement):
|
||||||
global user_debug
|
global user_debug
|
||||||
|
@ -151,21 +152,36 @@ def usage_telemetry(
|
||||||
).start()
|
).start()
|
||||||
|
|
||||||
async def user_api_key_auth(request: Request):
|
async def user_api_key_auth(request: Request):
|
||||||
global master_key
|
global master_key, prisma_client
|
||||||
if master_key is None:
|
if master_key is None:
|
||||||
return
|
return
|
||||||
try:
|
try:
|
||||||
api_key = await oauth2_scheme(request=request)
|
api_key = await oauth2_scheme(request=request)
|
||||||
if api_key == master_key:
|
if api_key == master_key:
|
||||||
return
|
return
|
||||||
except:
|
print(f"prisma_client: {prisma_client}")
|
||||||
pass
|
if prisma_client:
|
||||||
|
await prisma_client.connect()
|
||||||
|
valid_token = await prisma_client.litellm_verificationtoken.find_first(
|
||||||
|
where={
|
||||||
|
"token": api_key,
|
||||||
|
"expires": {"gte": datetime.utcnow()} # Check if the token is not expired
|
||||||
|
}
|
||||||
|
)
|
||||||
|
print(f"valid_token: {valid_token}")
|
||||||
|
if valid_token:
|
||||||
|
return
|
||||||
|
else:
|
||||||
|
raise Exception
|
||||||
|
except Exception as e:
|
||||||
|
print(e)
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_401_UNAUTHORIZED,
|
status_code=status.HTTP_401_UNAUTHORIZED,
|
||||||
detail={"error": "invalid user key"},
|
detail={"error": "invalid user key"},
|
||||||
)
|
)
|
||||||
|
|
||||||
def add_keys_to_config(key, value):
|
def add_keys_to_config(key, value):
|
||||||
|
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
|
||||||
# Check if file exists
|
# Check if file exists
|
||||||
if os.path.exists(user_config_path):
|
if os.path.exists(user_config_path):
|
||||||
# Load existing file
|
# Load existing file
|
||||||
|
@ -184,6 +200,7 @@ def add_keys_to_config(key, value):
|
||||||
|
|
||||||
|
|
||||||
def save_params_to_config(data: dict):
|
def save_params_to_config(data: dict):
|
||||||
|
#### DEPRECATED #### - this uses the older .toml config approach, which has been deprecated for config.yaml
|
||||||
# Check if file exists
|
# Check if file exists
|
||||||
if os.path.exists(user_config_path):
|
if os.path.exists(user_config_path):
|
||||||
# Load existing file
|
# Load existing file
|
||||||
|
@ -215,13 +232,14 @@ def save_params_to_config(data: dict):
|
||||||
with open(user_config_path, "wb") as f:
|
with open(user_config_path, "wb") as f:
|
||||||
tomli_w.dump(config, f)
|
tomli_w.dump(config, f)
|
||||||
|
|
||||||
def prisma_setup(database_url: Optional[str]):
|
def prisma_setup():
|
||||||
if database_url:
|
global prisma_client
|
||||||
subprocess.run(['pip', 'install', 'prisma'])
|
subprocess.run(['pip', 'install', 'prisma'])
|
||||||
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
|
subprocess.run(['python3', '-m', 'pip', 'install', 'prisma'])
|
||||||
subprocess.run(['prisma', 'db', 'push'])
|
subprocess.run(['prisma', 'db', 'push'])
|
||||||
# Now you can import the Prisma Client
|
# Now you can import the Prisma Client
|
||||||
from prisma import Client
|
from prisma import Client
|
||||||
|
prisma_client = Client()
|
||||||
|
|
||||||
|
|
||||||
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
|
@ -244,6 +262,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
if environment_variables:
|
if environment_variables:
|
||||||
for key, value in environment_variables.items():
|
for key, value in environment_variables.items():
|
||||||
os.environ[key] = value
|
os.environ[key] = value
|
||||||
|
### CONNECT TO DATABASE ###
|
||||||
|
if key == "DATABASE_URL":
|
||||||
|
prisma_setup()
|
||||||
|
|
||||||
|
|
||||||
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
## GENERAL SERVER SETTINGS (e.g. master key,..)
|
||||||
|
@ -251,9 +272,6 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
||||||
if general_settings:
|
if general_settings:
|
||||||
### MASTER KEY ###
|
### MASTER KEY ###
|
||||||
master_key = general_settings.get("master_key", None)
|
master_key = general_settings.get("master_key", None)
|
||||||
### CONNECT TO DATABASE ###
|
|
||||||
database_url = general_settings.get("database_url", None)
|
|
||||||
prisma_setup(database_url=database_url)
|
|
||||||
|
|
||||||
|
|
||||||
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
## LITELLM MODULE SETTINGS (e.g. litellm.drop_params=True,..)
|
||||||
|
@ -493,7 +511,13 @@ def startup_event():
|
||||||
import json
|
import json
|
||||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||||
initialize(**worker_config)
|
initialize(**worker_config)
|
||||||
# print(f"\033[32mWorker Initialized\033[0m\n")
|
|
||||||
|
@app.on_event("shutdown")
|
||||||
|
async def shutdown_event():
|
||||||
|
global prisma_client
|
||||||
|
if prisma_client:
|
||||||
|
print("Disconnecting from Prisma")
|
||||||
|
await prisma_client.disconnect()
|
||||||
|
|
||||||
#### API ENDPOINTS ####
|
#### API ENDPOINTS ####
|
||||||
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
@router.get("/v1/models", dependencies=[Depends(user_api_key_auth)])
|
||||||
|
@ -629,8 +653,7 @@ async def generate_key(request: Request):
|
||||||
duration = _duration_in_seconds(duration=duration_str)
|
duration = _duration_in_seconds(duration=duration_str)
|
||||||
expires = datetime.utcnow() + timedelta(seconds=duration)
|
expires = datetime.utcnow() + timedelta(seconds=duration)
|
||||||
try:
|
try:
|
||||||
from prisma import Client
|
db = prisma_client
|
||||||
db = Client()
|
|
||||||
await db.connect()
|
await db.connect()
|
||||||
# Create a new verification token (you may want to enhance this logic based on your needs)
|
# Create a new verification token (you may want to enhance this logic based on your needs)
|
||||||
print(dir(db))
|
print(dir(db))
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue