fix raise correct error 404 when /key/info is called on non-existent key (#6653)

* fix raise correct error on /key/info

* add not_found_error error

* fix key not found in DB error

* use 1 helper for checking token hash

* fix error code on key info

* fix test key gen prisma

* test_generate_and_call_key_info

* test fix test_call_with_valid_model_using_all_models

* fix key info tests
This commit is contained in:
Ishaan Jaff 2024-11-11 21:00:39 -08:00 committed by GitHub
parent 25bae4cc23
commit de2f9aed3a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 3593 additions and 57 deletions

View file

@ -1894,6 +1894,7 @@ class ProxyErrorTypes(str, enum.Enum):
auth_error = "auth_error" auth_error = "auth_error"
internal_server_error = "internal_server_error" internal_server_error = "internal_server_error"
bad_request_error = "bad_request_error" bad_request_error = "bad_request_error"
not_found_error = "not_found_error"
class SSOUserDefinedValues(TypedDict): class SSOUserDefinedValues(TypedDict):

View file

@ -44,14 +44,8 @@ class RouteChecks:
route in LiteLLMRoutes.info_routes.value route in LiteLLMRoutes.info_routes.value
): # check if user allowed to call an info route ): # check if user allowed to call an info route
if route == "/key/info": if route == "/key/info":
# check if user can access this route # handled by function itself
query_params = request.query_params pass
key = query_params.get("key")
if key is not None and hash_token(token=key) != api_key:
raise HTTPException(
status_code=status.HTTP_403_FORBIDDEN,
detail="user not allowed to access this key's info",
)
elif route == "/user/info": elif route == "/user/info":
# check if user can access this route # check if user can access this route
query_params = request.query_params query_params = request.query_params

View file

@ -32,7 +32,7 @@ from litellm.proxy.auth.auth_checks import (
) )
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
from litellm.proxy.utils import _duration_in_seconds from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
from litellm.secret_managers.main import get_secret from litellm.secret_managers.main import get_secret
router = APIRouter() router = APIRouter()
@ -734,13 +734,37 @@ async def info_key_fn(
raise Exception( raise Exception(
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys" "Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
) )
if key is None:
key = user_api_key_dict.api_key # default to using Auth token if no key is passed in
key_info = await prisma_client.get_data(token=key) key = key or user_api_key_dict.api_key
hashed_key: Optional[str] = key
if key is not None:
hashed_key = _hash_token_if_needed(token=key)
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
where={"token": hashed_key}, # type: ignore
include={"litellm_budget_table": True},
)
if key_info is None: if key_info is None:
raise ProxyException(
message="Key not found in database",
type=ProxyErrorTypes.not_found_error,
param="key",
code=status.HTTP_404_NOT_FOUND,
)
if (
_can_user_query_key_info(
user_api_key_dict=user_api_key_dict,
key=key,
key_info=key_info,
)
is not True
):
raise HTTPException( raise HTTPException(
status_code=status.HTTP_404_NOT_FOUND, status_code=status.HTTP_403_FORBIDDEN,
detail={"message": "No keys found"}, detail="You are not allowed to access this key's info. Your role={}".format(
user_api_key_dict.user_role
),
) )
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ## ## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
try: try:
@ -1540,6 +1564,27 @@ async def key_health(
) )
def _can_user_query_key_info(
user_api_key_dict: UserAPIKeyAuth,
key: Optional[str],
key_info: LiteLLM_VerificationToken,
) -> bool:
"""
Helper to check if the user has access to the key's info
"""
if (
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
):
return True
elif user_api_key_dict.api_key == key:
return True
# user can query their own key info
elif key_info.user_id == user_api_key_dict.user_id:
return True
return False
async def test_key_logging( async def test_key_logging(
user_api_key_dict: UserAPIKeyAuth, user_api_key_dict: UserAPIKeyAuth,
request: Request, request: Request,

View file

@ -1424,9 +1424,7 @@ class PrismaClient:
# check if plain text or hash # check if plain text or hash
if token is not None: if token is not None:
if isinstance(token, str): if isinstance(token, str):
hashed_token = token hashed_token = _hash_token_if_needed(token=token)
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}" f"PrismaClient: find_unique for token: {hashed_token}"
) )
@ -1493,8 +1491,7 @@ class PrismaClient:
if token is not None: if token is not None:
where_filter["token"] = {} where_filter["token"] = {}
if isinstance(token, str): if isinstance(token, str):
if token.startswith("sk-"): token = _hash_token_if_needed(token=token)
token = self.hash_token(token=token)
where_filter["token"]["in"] = [token] where_filter["token"]["in"] = [token]
elif isinstance(token, list): elif isinstance(token, list):
hashed_tokens = [] hashed_tokens = []
@ -1630,9 +1627,7 @@ class PrismaClient:
# check if plain text or hash # check if plain text or hash
if token is not None: if token is not None:
if isinstance(token, str): if isinstance(token, str):
hashed_token = token hashed_token = _hash_token_if_needed(token=token)
if token.startswith("sk-"):
hashed_token = self.hash_token(token=token)
verbose_proxy_logger.debug( verbose_proxy_logger.debug(
f"PrismaClient: find_unique for token: {hashed_token}" f"PrismaClient: find_unique for token: {hashed_token}"
) )
@ -1912,8 +1907,7 @@ class PrismaClient:
if token is not None: if token is not None:
print_verbose(f"token: {token}") print_verbose(f"token: {token}")
# check if plain text or hash # check if plain text or hash
if token.startswith("sk-"): token = _hash_token_if_needed(token=token)
token = self.hash_token(token=token)
db_data["token"] = token db_data["token"] = token
response = await self.db.litellm_verificationtoken.update( response = await self.db.litellm_verificationtoken.update(
where={"token": token}, # type: ignore where={"token": token}, # type: ignore
@ -2424,6 +2418,18 @@ def hash_token(token: str):
return hashed_token return hashed_token
def _hash_token_if_needed(token: str) -> str:
"""
Hash the token if it's a string and starts with "sk-"
Else return the token as is
"""
if token.startswith("sk-"):
return hash_token(token=token)
else:
return token
def _extract_from_regex(duration: str) -> Tuple[int, str]: def _extract_from_regex(duration: str) -> Tuple[int, str]:
match = re.match(r"(\d+)(mo|[smhd]?)", duration) match = re.match(r"(\d+)(mo|[smhd]?)", duration)

File diff suppressed because it is too large Load diff

View file

@ -147,23 +147,6 @@ def test_key_info_route_allowed(route_checks):
) )
def test_key_info_route_forbidden(route_checks):
"""
Internal User is not allowed to access /key/info route for a key they're not using in Authenticated API Key
"""
with pytest.raises(HTTPException) as exc_info:
route_checks.non_proxy_admin_allowed_routes_check(
user_obj=None,
_user_role=LitellmUserRoles.INTERNAL_USER.value,
route="/key/info",
request=MockRequest(query_params={"key": "wrong_key"}),
valid_token=UserAPIKeyAuth(api_key="test_key"),
api_key="test_key",
request_data={},
)
assert exc_info.value.status_code == 403
def test_user_info_route_allowed(route_checks): def test_user_info_route_allowed(route_checks):
""" """
Internal User is allowed to access /user/info route for their own user_id Internal User is allowed to access /user/info route for their own user_id

View file

@ -456,7 +456,10 @@ async def test_call_with_valid_model_using_all_models(prisma_client):
print("result from user auth with new key", result) print("result from user auth with new key", result)
# call /key/info for key - models == "all-proxy-models" # call /key/info for key - models == "all-proxy-models"
key_info = await info_key_fn(key=generated_key) key_info = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
print("key_info", key_info) print("key_info", key_info)
models = key_info["info"]["models"] models = key_info["info"]["models"]
assert models == ["all-team-models"] assert models == ["all-team-models"]
@ -1179,7 +1182,12 @@ def test_generate_and_call_key_info(prisma_client):
generated_key = key.key generated_key = key.key
# use generated key to auth in # use generated key to auth in
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert result["key"] == generated_key assert result["key"] == generated_key
print("\n info for key=", result["info"]) print("\n info for key=", result["info"])
@ -1271,7 +1279,12 @@ def test_generate_and_update_key(prisma_client):
generated_key = key.key generated_key = key.key
# use generated key to auth in # use generated key to auth in
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert result["key"] == generated_key assert result["key"] == generated_key
print("\n info for key=", result["info"]) print("\n info for key=", result["info"])
@ -1303,7 +1316,12 @@ def test_generate_and_update_key(prisma_client):
print("response2=", response2) print("response2=", response2)
# get info on key after update # get info on key after update
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(
user_role=LitellmUserRoles.PROXY_ADMIN,
),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert result["key"] == generated_key assert result["key"] == generated_key
print("\n info for key=", result["info"]) print("\n info for key=", result["info"])
@ -1989,7 +2007,10 @@ async def test_key_name_null(prisma_client):
key = await generate_key_fn(request) key = await generate_key_fn(request)
print("generated key=", key) print("generated key=", key)
generated_key = key.key generated_key = key.key
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert result["info"]["key_name"] is None assert result["info"]["key_name"] is None
except Exception as e: except Exception as e:
@ -2014,7 +2035,10 @@ async def test_key_name_set(prisma_client):
request = GenerateKeyRequest() request = GenerateKeyRequest()
key = await generate_key_fn(request) key = await generate_key_fn(request)
generated_key = key.key generated_key = key.key
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert isinstance(result["info"]["key_name"], str) assert isinstance(result["info"]["key_name"], str)
except Exception as e: except Exception as e:
@ -2038,7 +2062,10 @@ async def test_default_key_params(prisma_client):
request = GenerateKeyRequest() request = GenerateKeyRequest()
key = await generate_key_fn(request) key = await generate_key_fn(request)
generated_key = key.key generated_key = key.key
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert result["info"]["max_budget"] == 0.000122 assert result["info"]["max_budget"] == 0.000122
except Exception as e: except Exception as e:
@ -2804,7 +2831,10 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
generated_key = key.key generated_key = key.key
# use generated key to auth in # use generated key to auth in
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert result["key"] == generated_key assert result["key"] == generated_key
print("\n info for key=", result["info"]) print("\n info for key=", result["info"])
@ -2825,7 +2855,10 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
_request._url = URL(url="/update/key") _request._url = URL(url="/update/key")
await update_key_fn(data=request, request=_request) await update_key_fn(data=request, request=_request)
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert result["key"] == generated_key assert result["key"] == generated_key
print("\n info for key=", result["info"]) print("\n info for key=", result["info"])
@ -2863,7 +2896,10 @@ async def test_generate_key_with_guardrails(prisma_client):
generated_key = key.key generated_key = key.key
# use generated key to auth in # use generated key to auth in
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert result["key"] == generated_key assert result["key"] == generated_key
print("\n info for key=", result["info"]) print("\n info for key=", result["info"])
@ -2882,7 +2918,10 @@ async def test_generate_key_with_guardrails(prisma_client):
_request._url = URL(url="/update/key") _request._url = URL(url="/update/key")
await update_key_fn(data=request, request=_request) await update_key_fn(data=request, request=_request)
result = await info_key_fn(key=generated_key) result = await info_key_fn(
key=generated_key,
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
)
print("result from info_key_fn", result) print("result from info_key_fn", result)
assert result["key"] == generated_key assert result["key"] == generated_key
print("\n info for key=", result["info"]) print("\n info for key=", result["info"])

View file

@ -412,7 +412,7 @@ async def test_key_info():
Get key info Get key info
- as admin -> 200 - as admin -> 200
- as key itself -> 200 - as key itself -> 200
- as random key -> 403 - as non existent key -> 404
""" """
async with aiohttp.ClientSession() as session: async with aiohttp.ClientSession() as session:
key_gen = await generate_key(session=session, i=0) key_gen = await generate_key(session=session, i=0)
@ -425,10 +425,9 @@ async def test_key_info():
# as key itself, use the auth param, and no query key needed # as key itself, use the auth param, and no query key needed
await get_key_info(session=session, call_key=key) await get_key_info(session=session, call_key=key)
# as random key # # as random key #
key_gen = await generate_key(session=session, i=0) random_key = f"sk-{uuid.uuid4()}"
random_key = key_gen["key"] status = await get_key_info(session=session, get_key=random_key, call_key=key)
status = await get_key_info(session=session, get_key=key, call_key=random_key) assert status == 404
assert status == 403
@pytest.mark.asyncio @pytest.mark.asyncio