forked from phoenix/litellm-mirror
test(tests/): add unit testing for proxy server endpoints
This commit is contained in:
parent
b2b41727ce
commit
f5ced089d6
11 changed files with 870 additions and 111 deletions
|
@ -264,16 +264,6 @@ async def user_api_key_auth(
|
|||
if route.startswith("/config/") and not is_master_key_valid:
|
||||
raise Exception(f"Only admin can modify config")
|
||||
|
||||
if (
|
||||
(route.startswith("/key/") or route.startswith("/user/"))
|
||||
or route.startswith("/model/")
|
||||
and not is_master_key_valid
|
||||
and general_settings.get("allow_user_auth", False) != True
|
||||
):
|
||||
raise Exception(
|
||||
f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users"
|
||||
)
|
||||
|
||||
if (
|
||||
prisma_client is None and custom_db_client is None
|
||||
): # if both master key + user key submitted, and user key != master key, and no db connected, raise an error
|
||||
|
@ -432,6 +422,39 @@ async def user_api_key_auth(
|
|||
db=custom_db_client,
|
||||
)
|
||||
)
|
||||
|
||||
if (
|
||||
(route.startswith("/key/") or route.startswith("/user/"))
|
||||
or route.startswith("/model/")
|
||||
and not is_master_key_valid
|
||||
and general_settings.get("allow_user_auth", False) != True
|
||||
):
|
||||
if route == "/key/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
key = query_params.get("key")
|
||||
if prisma_client.hash_token(token=key) != api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="user not allowed to access this key's info",
|
||||
)
|
||||
elif route == "/user/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
user_id = query_params.get("user_id")
|
||||
if user_id != valid_token.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="user not allowed to access this key's info",
|
||||
)
|
||||
elif route == "/model/info":
|
||||
# /model/info just shows models user has access to
|
||||
pass
|
||||
else:
|
||||
raise Exception(
|
||||
f"If master key is set, only master key can be used to generate, delete, update or get info for new keys/users"
|
||||
)
|
||||
|
||||
return UserAPIKeyAuth(api_key=api_key, **valid_token_dict)
|
||||
else:
|
||||
raise Exception(f"Invalid Key Passed to LiteLLM Proxy")
|
||||
|
@ -2160,7 +2183,7 @@ async def update_key_fn(request: Request, data: UpdateKeyRequest):
|
|||
@router.post(
|
||||
"/key/delete", tags=["key management"], dependencies=[Depends(user_api_key_auth)]
|
||||
)
|
||||
async def delete_key_fn(request: Request, data: DeleteKeyRequest):
|
||||
async def delete_key_fn(data: DeleteKeyRequest):
|
||||
"""
|
||||
Delete a key from the key management system.
|
||||
|
||||
|
@ -2203,6 +2226,9 @@ async def info_key_fn(
|
|||
f"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
key_info = await prisma_client.get_data(token=key)
|
||||
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
||||
key_info = key_info.model_dump()
|
||||
key_info.pop("token")
|
||||
return {"key": key, "info": key_info}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
@ -2338,6 +2364,10 @@ async def user_info(
|
|||
keys = await prisma_client.get_data(
|
||||
user_id=user_id, table_name="key", query_type="find_all"
|
||||
)
|
||||
## REMOVE HASHED TOKEN INFO before returning ##
|
||||
for key in keys:
|
||||
key = key.model_dump()
|
||||
key.pop("token", None)
|
||||
return {"user_id": user_id, "user_info": user_info, "keys": keys}
|
||||
except Exception as e:
|
||||
raise HTTPException(
|
||||
|
@ -2415,13 +2445,19 @@ async def add_new_model(model_params: ModelParams):
|
|||
tags=["model management"],
|
||||
dependencies=[Depends(user_api_key_auth)],
|
||||
)
|
||||
async def model_info_v1(request: Request):
|
||||
async def model_info_v1(
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
):
|
||||
global llm_model_list, general_settings, user_config_file_path, proxy_config
|
||||
|
||||
# Load existing config
|
||||
config = await proxy_config.get_config()
|
||||
|
||||
all_models = config["model_list"]
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
model_names = user_api_key_dict.models
|
||||
all_models = [m for m in config["model_list"] if m in model_names]
|
||||
else:
|
||||
all_models = config["model_list"]
|
||||
for model in all_models:
|
||||
# provided model_info in config.yaml
|
||||
model_info = model.get("model_info", {})
|
||||
|
@ -2750,7 +2786,7 @@ async def test_endpoint(request: Request):
|
|||
|
||||
@router.get("/health", tags=["health"], dependencies=[Depends(user_api_key_auth)])
|
||||
async def health_endpoint(
|
||||
request: Request,
|
||||
user_api_key_dict: UserAPIKeyAuth = Depends(user_api_key_auth),
|
||||
model: Optional[str] = fastapi.Query(
|
||||
None, description="Specify the model name (optional)"
|
||||
),
|
||||
|
@ -2785,6 +2821,11 @@ async def health_endpoint(
|
|||
detail={"error": "Model list not initialized"},
|
||||
)
|
||||
|
||||
### FILTER MODELS FOR ONLY THOSE USER HAS ACCESS TO ###
|
||||
if len(user_api_key_dict.models) > 0:
|
||||
allowed_model_names = user_api_key_dict.models
|
||||
else:
|
||||
allowed_model_names = [] #
|
||||
if use_background_health_checks:
|
||||
return health_check_results
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue