forked from phoenix/litellm-mirror
fix(proxy_server.py): fix /key/generate post endpoint
This commit is contained in:
parent
d7d8c5f6e6
commit
63e55f1865
6 changed files with 115 additions and 27 deletions
|
@ -213,11 +213,11 @@ class GenerateKeyRequest(BaseModel):
|
|||
aliases: dict = {}
|
||||
config: dict = {}
|
||||
spend: int = 0
|
||||
user_id: Optional[str]
|
||||
user_id: Optional[str] = None
|
||||
|
||||
class GenerateKeyResponse(BaseModel):
|
||||
key: str
|
||||
expires: str
|
||||
expires: datetime
|
||||
user_id: str
|
||||
|
||||
class _DeleteKeyObject(BaseModel):
|
||||
|
@ -277,6 +277,8 @@ async def user_api_key_auth(request: Request, api_key: str = fastapi.Security(ap
|
|||
"api_key": None
|
||||
}
|
||||
try:
|
||||
if api_key is None:
|
||||
raise Exception("No api key passed in.")
|
||||
route = request.url.path
|
||||
|
||||
# note: never string compare api keys, this is vulenerable to a time attack. Use secrets.compare_digest instead
|
||||
|
@ -491,8 +493,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
## PRINT YAML FOR CONFIRMING IT WORKS
|
||||
printed_yaml = copy.deepcopy(config)
|
||||
printed_yaml.pop("environment_variables", None)
|
||||
for model in printed_yaml["model_list"]:
|
||||
model["litellm_params"].pop("api_key", None)
|
||||
if "model_list" in printed_yaml:
|
||||
for model in printed_yaml["model_list"]:
|
||||
model["litellm_params"].pop("api_key", None)
|
||||
|
||||
print(f"Loaded config YAML (api_key and environment_variables are not shown):\n{json.dumps(printed_yaml, indent=2)}")
|
||||
|
||||
|
@ -507,22 +510,24 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
if general_settings is None:
|
||||
general_settings = {}
|
||||
if general_settings:
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get("master_key", None)
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key_env_name = master_key.replace("os.environ/", "")
|
||||
master_key = os.getenv(master_key_env_name)
|
||||
### LOAD FROM AZURE KEY VAULT ###
|
||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||
|
||||
### CONNECT TO DATABASE ###
|
||||
database_url = general_settings.get("database_url", None)
|
||||
if database_url and database_url.startswith("os.environ/"):
|
||||
database_url = litellm.get_secret(database_url)
|
||||
prisma_setup(database_url=database_url)
|
||||
## COST TRACKING ##
|
||||
cost_tracking()
|
||||
### START REDIS QUEUE ###
|
||||
use_queue = general_settings.get("use_queue", False)
|
||||
celery_setup(use_queue=use_queue)
|
||||
### LOAD FROM AZURE KEY VAULT ###
|
||||
use_azure_key_vault = general_settings.get("use_azure_key_vault", False)
|
||||
load_from_azure_key_vault(use_azure_key_vault=use_azure_key_vault)
|
||||
### MASTER KEY ###
|
||||
master_key = general_settings.get("master_key", None)
|
||||
if master_key and master_key.startswith("os.environ/"):
|
||||
master_key = litellm.get_secret(master_key)
|
||||
|
||||
#### OpenTelemetry Logging (OTEL) ########
|
||||
otel_logging = general_settings.get("otel", False)
|
||||
|
@ -540,9 +545,9 @@ def load_router_config(router: Optional[litellm.Router], config_file_path: str):
|
|||
print(f"{blue_color_code}\nSetting Cache on Proxy")
|
||||
from litellm.caching import Cache
|
||||
cache_type = value["type"]
|
||||
cache_host = os.environ.get("REDIS_HOST")
|
||||
cache_port = os.environ.get("REDIS_PORT")
|
||||
cache_password = os.environ.get("REDIS_PASSWORD")
|
||||
cache_host = litellm.get_secret("REDIS_HOST")
|
||||
cache_port = litellm.get_secret("REDIS_PORT")
|
||||
cache_password = litellm.get_secret("REDIS_PASSWORD")
|
||||
|
||||
# Assuming cache_type, cache_host, cache_port, and cache_password are strings
|
||||
print(f"{blue_color_code}Cache Type:{reset_color_code} {cache_type}")
|
||||
|
@ -794,12 +799,14 @@ def litellm_completion(*args, **kwargs):
|
|||
return StreamingResponse(data_generator(response), media_type='text/event-stream')
|
||||
return response
|
||||
|
||||
@app.on_event("startup")
|
||||
@router.on_event("startup")
|
||||
async def startup_event():
|
||||
global prisma_client, master_key
|
||||
import json
|
||||
worker_config = json.loads(os.getenv("WORKER_CONFIG"))
|
||||
print(f"worker_config: {worker_config}")
|
||||
initialize(**worker_config)
|
||||
print(f"prisma client - {prisma_client}")
|
||||
if prisma_client:
|
||||
await prisma_client.connect()
|
||||
|
||||
|
@ -807,7 +814,7 @@ async def startup_event():
|
|||
# add master key to db
|
||||
await generate_key_helper_fn(duration_str=None, models=[], aliases={}, config={}, spend=0, token=master_key)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
@router.on_event("shutdown")
|
||||
async def shutdown_event():
|
||||
global prisma_client
|
||||
if prisma_client:
|
||||
|
@ -1022,8 +1029,7 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
|||
- key: The generated api key
|
||||
- expires: Datetime object for when key expires.
|
||||
"""
|
||||
data = await request.json()
|
||||
|
||||
# data = await request.json()
|
||||
duration_str = data.duration # Default to 1 hour if duration is not provided
|
||||
models = data.models # Default to an empty list (meaning allow token to call all models)
|
||||
aliases = data.aliases # Default to an empty dict (no alias mappings, on top of anything in the config.yaml model_list)
|
||||
|
@ -1042,8 +1048,6 @@ async def generate_key_fn(request: Request, data: GenerateKeyRequest):
|
|||
@router.post("/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||
try:
|
||||
data = await request.json()
|
||||
|
||||
keys = data.keys
|
||||
|
||||
deleted_keys = await delete_verification_token(tokens=keys)
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue