forked from phoenix/litellm-mirror
fix raise correct error 404 when /key/info is called on non-existent key (#6653)
* fix raise correct error on /key/info * add not_found_error error * fix key not found in DB error * use 1 helper for checking token hash * fix error code on key info * fix test key gen prisma * test_generate_and_call_key_info * test fix test_call_with_valid_model_using_all_models * fix key info tests
This commit is contained in:
parent
25bae4cc23
commit
de2f9aed3a
8 changed files with 3593 additions and 57 deletions
|
@ -1894,6 +1894,7 @@ class ProxyErrorTypes(str, enum.Enum):
|
|||
auth_error = "auth_error"
|
||||
internal_server_error = "internal_server_error"
|
||||
bad_request_error = "bad_request_error"
|
||||
not_found_error = "not_found_error"
|
||||
|
||||
|
||||
class SSOUserDefinedValues(TypedDict):
|
||||
|
|
|
@ -44,14 +44,8 @@ class RouteChecks:
|
|||
route in LiteLLMRoutes.info_routes.value
|
||||
): # check if user allowed to call an info route
|
||||
if route == "/key/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
key = query_params.get("key")
|
||||
if key is not None and hash_token(token=key) != api_key:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="user not allowed to access this key's info",
|
||||
)
|
||||
# handled by function itself
|
||||
pass
|
||||
elif route == "/user/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
|
|
|
@ -32,7 +32,7 @@ from litellm.proxy.auth.auth_checks import (
|
|||
)
|
||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||
from litellm.proxy.utils import _duration_in_seconds
|
||||
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
|
||||
from litellm.secret_managers.main import get_secret
|
||||
|
||||
router = APIRouter()
|
||||
|
@ -734,13 +734,37 @@ async def info_key_fn(
|
|||
raise Exception(
|
||||
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||
)
|
||||
if key is None:
|
||||
key = user_api_key_dict.api_key
|
||||
key_info = await prisma_client.get_data(token=key)
|
||||
|
||||
# default to using Auth token if no key is passed in
|
||||
key = key or user_api_key_dict.api_key
|
||||
hashed_key: Optional[str] = key
|
||||
if key is not None:
|
||||
hashed_key = _hash_token_if_needed(token=key)
|
||||
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||
where={"token": hashed_key}, # type: ignore
|
||||
include={"litellm_budget_table": True},
|
||||
)
|
||||
if key_info is None:
|
||||
raise ProxyException(
|
||||
message="Key not found in database",
|
||||
type=ProxyErrorTypes.not_found_error,
|
||||
param="key",
|
||||
code=status.HTTP_404_NOT_FOUND,
|
||||
)
|
||||
|
||||
if (
|
||||
_can_user_query_key_info(
|
||||
user_api_key_dict=user_api_key_dict,
|
||||
key=key,
|
||||
key_info=key_info,
|
||||
)
|
||||
is not True
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_404_NOT_FOUND,
|
||||
detail={"message": "No keys found"},
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="You are not allowed to access this key's info. Your role={}".format(
|
||||
user_api_key_dict.user_role
|
||||
),
|
||||
)
|
||||
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
||||
try:
|
||||
|
@ -1540,6 +1564,27 @@ async def key_health(
|
|||
)
|
||||
|
||||
|
||||
def _can_user_query_key_info(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
key: Optional[str],
|
||||
key_info: LiteLLM_VerificationToken,
|
||||
) -> bool:
|
||||
"""
|
||||
Helper to check if the user has access to the key's info
|
||||
"""
|
||||
if (
|
||||
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
||||
):
|
||||
return True
|
||||
elif user_api_key_dict.api_key == key:
|
||||
return True
|
||||
# user can query their own key info
|
||||
elif key_info.user_id == user_api_key_dict.user_id:
|
||||
return True
|
||||
return False
|
||||
|
||||
|
||||
async def test_key_logging(
|
||||
user_api_key_dict: UserAPIKeyAuth,
|
||||
request: Request,
|
||||
|
|
|
@ -1424,9 +1424,7 @@ class PrismaClient:
|
|||
# check if plain text or hash
|
||||
if token is not None:
|
||||
if isinstance(token, str):
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
hashed_token = _hash_token_if_needed(token=token)
|
||||
verbose_proxy_logger.debug(
|
||||
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||
)
|
||||
|
@ -1493,8 +1491,7 @@ class PrismaClient:
|
|||
if token is not None:
|
||||
where_filter["token"] = {}
|
||||
if isinstance(token, str):
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
token = _hash_token_if_needed(token=token)
|
||||
where_filter["token"]["in"] = [token]
|
||||
elif isinstance(token, list):
|
||||
hashed_tokens = []
|
||||
|
@ -1630,9 +1627,7 @@ class PrismaClient:
|
|||
# check if plain text or hash
|
||||
if token is not None:
|
||||
if isinstance(token, str):
|
||||
hashed_token = token
|
||||
if token.startswith("sk-"):
|
||||
hashed_token = self.hash_token(token=token)
|
||||
hashed_token = _hash_token_if_needed(token=token)
|
||||
verbose_proxy_logger.debug(
|
||||
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||
)
|
||||
|
@ -1912,8 +1907,7 @@ class PrismaClient:
|
|||
if token is not None:
|
||||
print_verbose(f"token: {token}")
|
||||
# check if plain text or hash
|
||||
if token.startswith("sk-"):
|
||||
token = self.hash_token(token=token)
|
||||
token = _hash_token_if_needed(token=token)
|
||||
db_data["token"] = token
|
||||
response = await self.db.litellm_verificationtoken.update(
|
||||
where={"token": token}, # type: ignore
|
||||
|
@ -2424,6 +2418,18 @@ def hash_token(token: str):
|
|||
return hashed_token
|
||||
|
||||
|
||||
def _hash_token_if_needed(token: str) -> str:
|
||||
"""
|
||||
Hash the token if it's a string and starts with "sk-"
|
||||
|
||||
Else return the token as is
|
||||
"""
|
||||
if token.startswith("sk-"):
|
||||
return hash_token(token=token)
|
||||
else:
|
||||
return token
|
||||
|
||||
|
||||
def _extract_from_regex(duration: str) -> Tuple[int, str]:
|
||||
match = re.match(r"(\d+)(mo|[smhd]?)", duration)
|
||||
|
||||
|
|
3469
tests/local_testing/test_key_generate_prisma.py
Normal file
3469
tests/local_testing/test_key_generate_prisma.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -147,23 +147,6 @@ def test_key_info_route_allowed(route_checks):
|
|||
)
|
||||
|
||||
|
||||
def test_key_info_route_forbidden(route_checks):
|
||||
"""
|
||||
Internal User is not allowed to access /key/info route for a key they're not using in Authenticated API Key
|
||||
"""
|
||||
with pytest.raises(HTTPException) as exc_info:
|
||||
route_checks.non_proxy_admin_allowed_routes_check(
|
||||
user_obj=None,
|
||||
_user_role=LitellmUserRoles.INTERNAL_USER.value,
|
||||
route="/key/info",
|
||||
request=MockRequest(query_params={"key": "wrong_key"}),
|
||||
valid_token=UserAPIKeyAuth(api_key="test_key"),
|
||||
api_key="test_key",
|
||||
request_data={},
|
||||
)
|
||||
assert exc_info.value.status_code == 403
|
||||
|
||||
|
||||
def test_user_info_route_allowed(route_checks):
|
||||
"""
|
||||
Internal User is allowed to access /user/info route for their own user_id
|
||||
|
|
|
@ -456,7 +456,10 @@ async def test_call_with_valid_model_using_all_models(prisma_client):
|
|||
print("result from user auth with new key", result)
|
||||
|
||||
# call /key/info for key - models == "all-proxy-models"
|
||||
key_info = await info_key_fn(key=generated_key)
|
||||
key_info = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
print("key_info", key_info)
|
||||
models = key_info["info"]["models"]
|
||||
assert models == ["all-team-models"]
|
||||
|
@ -1179,7 +1182,12 @@ def test_generate_and_call_key_info(prisma_client):
|
|||
generated_key = key.key
|
||||
|
||||
# use generated key to auth in
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert result["key"] == generated_key
|
||||
print("\n info for key=", result["info"])
|
||||
|
@ -1271,7 +1279,12 @@ def test_generate_and_update_key(prisma_client):
|
|||
generated_key = key.key
|
||||
|
||||
# use generated key to auth in
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert result["key"] == generated_key
|
||||
print("\n info for key=", result["info"])
|
||||
|
@ -1303,7 +1316,12 @@ def test_generate_and_update_key(prisma_client):
|
|||
print("response2=", response2)
|
||||
|
||||
# get info on key after update
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(
|
||||
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||
),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert result["key"] == generated_key
|
||||
print("\n info for key=", result["info"])
|
||||
|
@ -1989,7 +2007,10 @@ async def test_key_name_null(prisma_client):
|
|||
key = await generate_key_fn(request)
|
||||
print("generated key=", key)
|
||||
generated_key = key.key
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert result["info"]["key_name"] is None
|
||||
except Exception as e:
|
||||
|
@ -2014,7 +2035,10 @@ async def test_key_name_set(prisma_client):
|
|||
request = GenerateKeyRequest()
|
||||
key = await generate_key_fn(request)
|
||||
generated_key = key.key
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert isinstance(result["info"]["key_name"], str)
|
||||
except Exception as e:
|
||||
|
@ -2038,7 +2062,10 @@ async def test_default_key_params(prisma_client):
|
|||
request = GenerateKeyRequest()
|
||||
key = await generate_key_fn(request)
|
||||
generated_key = key.key
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert result["info"]["max_budget"] == 0.000122
|
||||
except Exception as e:
|
||||
|
@ -2804,7 +2831,10 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
|||
generated_key = key.key
|
||||
|
||||
# use generated key to auth in
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert result["key"] == generated_key
|
||||
print("\n info for key=", result["info"])
|
||||
|
@ -2825,7 +2855,10 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
|||
_request._url = URL(url="/update/key")
|
||||
|
||||
await update_key_fn(data=request, request=_request)
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert result["key"] == generated_key
|
||||
print("\n info for key=", result["info"])
|
||||
|
@ -2863,7 +2896,10 @@ async def test_generate_key_with_guardrails(prisma_client):
|
|||
generated_key = key.key
|
||||
|
||||
# use generated key to auth in
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert result["key"] == generated_key
|
||||
print("\n info for key=", result["info"])
|
||||
|
@ -2882,7 +2918,10 @@ async def test_generate_key_with_guardrails(prisma_client):
|
|||
_request._url = URL(url="/update/key")
|
||||
|
||||
await update_key_fn(data=request, request=_request)
|
||||
result = await info_key_fn(key=generated_key)
|
||||
result = await info_key_fn(
|
||||
key=generated_key,
|
||||
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||
)
|
||||
print("result from info_key_fn", result)
|
||||
assert result["key"] == generated_key
|
||||
print("\n info for key=", result["info"])
|
||||
|
|
|
@ -412,7 +412,7 @@ async def test_key_info():
|
|||
Get key info
|
||||
- as admin -> 200
|
||||
- as key itself -> 200
|
||||
- as random key -> 403
|
||||
- as non existent key -> 404
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
key_gen = await generate_key(session=session, i=0)
|
||||
|
@ -425,10 +425,9 @@ async def test_key_info():
|
|||
# as key itself, use the auth param, and no query key needed
|
||||
await get_key_info(session=session, call_key=key)
|
||||
# as random key #
|
||||
key_gen = await generate_key(session=session, i=0)
|
||||
random_key = key_gen["key"]
|
||||
status = await get_key_info(session=session, get_key=key, call_key=random_key)
|
||||
assert status == 403
|
||||
random_key = f"sk-{uuid.uuid4()}"
|
||||
status = await get_key_info(session=session, get_key=random_key, call_key=key)
|
||||
assert status == 404
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue