fix(proxy_server.py): fix /key/generate post endpoint

This commit is contained in:
Krrish Dholakia 2023-12-04 10:43:42 -08:00
parent 89972239a6
commit 813bb15a00
6 changed files with 115 additions and 27 deletions

View file

@ -213,11 +213,11 @@ class GenerateKeyRequest(BaseModel):
aliases: dict = {}
config: dict = {}
spend: int = 0
user_id: Optional[str]
user_id: Optional[str] = None
class GenerateKeyResponse(BaseModel):
key: str
expires: str
expires: datetime
user_id: str
class _DeleteKeyObject(BaseModel):
@ -277,6 +277,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
"api_key": None
}
try:
if api_key is None:
raise Exception("No api key passed in.")
route = request.url.path
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
@ -491,8 +493,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
## PRINT YAML FOR CONFIRMING IT WORKS
printed_yaml = copy.deepcopy(config)
printed_yaml.pop("environment_variables", None)
for model in printed_yaml["model_list"]:
model["litellm_params"].pop("api_key", None)
if "model_list" in printed_yaml:
for model in printed_yaml["model_list"]:
model["litellm_params"].pop("api_key", None)
print(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
@ -507,22 +510,24 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
if general_settings is None:
general_settings = {}
if general_settings:
### MASTER KEY ###
master_key = general_settings.get("master_key", None)
if master_key and master_key.startswith("os.environ/"):
master_key_env_name = master_key.replace("os.environ/", "")
master_key = os.getenv(master_key_env_name)
### LOAD FROM AZURE KEY VAULT ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### CONNECT TO DATABASE ###
database_url = general_settings.get("database_url", None)
if database_url and database_url.startswith("os.environ/"):
database_url = litellm.get_secret(database_url)
prisma_setup(database_url=database_url)
## COST TRACKING ##
cost_tracking()
### START REDIS QUEUE ###
use_queue = general_settings.get("use_queue", False)
celery_setup(use_queue=use_queue)
### LOAD FROM AZURE KEY VAULT ###
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
### MASTER KEY ###
master_key = general_settings.get("master_key", None)
if master_key and master_key.startswith("os.environ/"):
master_key = litellm.get_secret(master_key)
#### OpenTelemetry Logging (OTEL) ########
otel_logging = general_settings.get("otel", False)
@ -540,9 +545,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
print(f"{blue_color_code}\nSetting Cache on Proxy")
from litellm.caching import Cache
cache_type = value["type"]
cache_host = os.environ.get("REDIS_HOST")
cache_port = os.environ.get("REDIS_PORT")
cache_password = os.environ.get("REDIS_PASSWORD")
cache_host = litellm.get_secret("REDIS_HOST")
cache_port = litellm.get_secret("REDIS_PORT")
cache_password = litellm.get_secret("REDIS_PASSWORD")
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
@ -794,12 +799,14 @@ def litellm_completion(*args, **kwargs):
return StreamingResponse(data_generator(response), media_type='text/event-stream')
return response
@app.on_event("startup")
@router.on_event("startup")
async def startup_event():
global prisma_client, master_key
import json
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
print(f"worker_config: {worker_config}")
initialize(**worker_config)
print(f"prisma client - {prisma_client}")
if prisma_client:
await prisma_client.connect()
@ -807,7 +814,7 @@ async def startup_event():
# add master key to db
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
@app.on_event("shutdown")
@router.on_event("shutdown")
async def shutdown_event():
global prisma_client
if prisma_client:
@ -1022,8 +1029,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
- key: The generated api key
- expires: Datetime object for when key expires.
"""
data = await request.json()
# data = await request.json()
duration_str = data.duration # Default to 1 hour if duration is not provided
models = data.models # Default to an empty list (meaning allow token to call all models)
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
@ -1042,8 +1048,6 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
try:
data = await request.json()
keys = data.keys
deleted_keys = await delete_verification_token(tokens=keys)