forked from phoenix/litellm-mirror
fix raise correct error 404 when /key/info is called on non-existent key (#6653)
* fix raise correct error on /key/info * add not_found_error error * fix key not found in DB error * use 1 helper for checking token hash * fix error code on key info * fix test key gen prisma * test_generate_and_call_key_info * test fix test_call_with_valid_model_using_all_models * fix key info tests
This commit is contained in:
parent
25bae4cc23
commit
de2f9aed3a
8 changed files with 3593 additions and 57 deletions
|
@ -1894,6 +1894,7 @@ class ProxyErrorTypes(str, enum.Enum):
|
||||||
auth_error = "auth_error"
|
auth_error = "auth_error"
|
||||||
internal_server_error = "internal_server_error"
|
internal_server_error = "internal_server_error"
|
||||||
bad_request_error = "bad_request_error"
|
bad_request_error = "bad_request_error"
|
||||||
|
not_found_error = "not_found_error"
|
||||||
|
|
||||||
|
|
||||||
class SSOUserDefinedValues(TypedDict):
|
class SSOUserDefinedValues(TypedDict):
|
||||||
|
|
|
@ -44,14 +44,8 @@ class RouteChecks:
|
||||||
route in LiteLLMRoutes.info_routes.value
|
route in LiteLLMRoutes.info_routes.value
|
||||||
): # check if user allowed to call an info route
|
): # check if user allowed to call an info route
|
||||||
if route == "/key/info":
|
if route == "/key/info":
|
||||||
# check if user can access this route
|
# handled by function itself
|
||||||
query_params = request.query_params
|
pass
|
||||||
key = query_params.get("key")
|
|
||||||
if key is not None and hash_token(token=key) != api_key:
|
|
||||||
raise HTTPException(
|
|
||||||
status_code=status.HTTP_403_FORBIDDEN,
|
|
||||||
detail="user not allowed to access this key's info",
|
|
||||||
)
|
|
||||||
elif route == "/user/info":
|
elif route == "/user/info":
|
||||||
# check if user can access this route
|
# check if user can access this route
|
||||||
query_params = request.query_params
|
query_params = request.query_params
|
||||||
|
|
|
@ -32,7 +32,7 @@ from litellm.proxy.auth.auth_checks import (
|
||||||
)
|
)
|
||||||
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
from litellm.proxy.auth.user_api_key_auth import user_api_key_auth
|
||||||
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
from litellm.proxy.management_helpers.utils import management_endpoint_wrapper
|
||||||
from litellm.proxy.utils import _duration_in_seconds
|
from litellm.proxy.utils import _duration_in_seconds, _hash_token_if_needed
|
||||||
from litellm.secret_managers.main import get_secret
|
from litellm.secret_managers.main import get_secret
|
||||||
|
|
||||||
router = APIRouter()
|
router = APIRouter()
|
||||||
|
@ -734,13 +734,37 @@ async def info_key_fn(
|
||||||
raise Exception(
|
raise Exception(
|
||||||
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
"Database not connected. Connect a database to your proxy - https://docs.litellm.ai/docs/simple_proxy#managing-auth---virtual-keys"
|
||||||
)
|
)
|
||||||
if key is None:
|
|
||||||
key = user_api_key_dict.api_key
|
# default to using Auth token if no key is passed in
|
||||||
key_info = await prisma_client.get_data(token=key)
|
key = key or user_api_key_dict.api_key
|
||||||
|
hashed_key: Optional[str] = key
|
||||||
|
if key is not None:
|
||||||
|
hashed_key = _hash_token_if_needed(token=key)
|
||||||
|
key_info = await prisma_client.db.litellm_verificationtoken.find_unique(
|
||||||
|
where={"token": hashed_key}, # type: ignore
|
||||||
|
include={"litellm_budget_table": True},
|
||||||
|
)
|
||||||
if key_info is None:
|
if key_info is None:
|
||||||
|
raise ProxyException(
|
||||||
|
message="Key not found in database",
|
||||||
|
type=ProxyErrorTypes.not_found_error,
|
||||||
|
param="key",
|
||||||
|
code=status.HTTP_404_NOT_FOUND,
|
||||||
|
)
|
||||||
|
|
||||||
|
if (
|
||||||
|
_can_user_query_key_info(
|
||||||
|
user_api_key_dict=user_api_key_dict,
|
||||||
|
key=key,
|
||||||
|
key_info=key_info,
|
||||||
|
)
|
||||||
|
is not True
|
||||||
|
):
|
||||||
raise HTTPException(
|
raise HTTPException(
|
||||||
status_code=status.HTTP_404_NOT_FOUND,
|
status_code=status.HTTP_403_FORBIDDEN,
|
||||||
detail={"message": "No keys found"},
|
detail="You are not allowed to access this key's info. Your role={}".format(
|
||||||
|
user_api_key_dict.user_role
|
||||||
|
),
|
||||||
)
|
)
|
||||||
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
## REMOVE HASHED TOKEN INFO BEFORE RETURNING ##
|
||||||
try:
|
try:
|
||||||
|
@ -1540,6 +1564,27 @@ async def key_health(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def _can_user_query_key_info(
|
||||||
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
|
key: Optional[str],
|
||||||
|
key_info: LiteLLM_VerificationToken,
|
||||||
|
) -> bool:
|
||||||
|
"""
|
||||||
|
Helper to check if the user has access to the key's info
|
||||||
|
"""
|
||||||
|
if (
|
||||||
|
user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN.value
|
||||||
|
or user_api_key_dict.user_role == LitellmUserRoles.PROXY_ADMIN_VIEW_ONLY.value
|
||||||
|
):
|
||||||
|
return True
|
||||||
|
elif user_api_key_dict.api_key == key:
|
||||||
|
return True
|
||||||
|
# user can query their own key info
|
||||||
|
elif key_info.user_id == user_api_key_dict.user_id:
|
||||||
|
return True
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
async def test_key_logging(
|
async def test_key_logging(
|
||||||
user_api_key_dict: UserAPIKeyAuth,
|
user_api_key_dict: UserAPIKeyAuth,
|
||||||
request: Request,
|
request: Request,
|
||||||
|
|
|
@ -1424,9 +1424,7 @@ class PrismaClient:
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
if token is not None:
|
if token is not None:
|
||||||
if isinstance(token, str):
|
if isinstance(token, str):
|
||||||
hashed_token = token
|
hashed_token = _hash_token_if_needed(token=token)
|
||||||
if token.startswith("sk-"):
|
|
||||||
hashed_token = self.hash_token(token=token)
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"PrismaClient: find_unique for token: {hashed_token}"
|
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||||
)
|
)
|
||||||
|
@ -1493,8 +1491,7 @@ class PrismaClient:
|
||||||
if token is not None:
|
if token is not None:
|
||||||
where_filter["token"] = {}
|
where_filter["token"] = {}
|
||||||
if isinstance(token, str):
|
if isinstance(token, str):
|
||||||
if token.startswith("sk-"):
|
token = _hash_token_if_needed(token=token)
|
||||||
token = self.hash_token(token=token)
|
|
||||||
where_filter["token"]["in"] = [token]
|
where_filter["token"]["in"] = [token]
|
||||||
elif isinstance(token, list):
|
elif isinstance(token, list):
|
||||||
hashed_tokens = []
|
hashed_tokens = []
|
||||||
|
@ -1630,9 +1627,7 @@ class PrismaClient:
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
if token is not None:
|
if token is not None:
|
||||||
if isinstance(token, str):
|
if isinstance(token, str):
|
||||||
hashed_token = token
|
hashed_token = _hash_token_if_needed(token=token)
|
||||||
if token.startswith("sk-"):
|
|
||||||
hashed_token = self.hash_token(token=token)
|
|
||||||
verbose_proxy_logger.debug(
|
verbose_proxy_logger.debug(
|
||||||
f"PrismaClient: find_unique for token: {hashed_token}"
|
f"PrismaClient: find_unique for token: {hashed_token}"
|
||||||
)
|
)
|
||||||
|
@ -1912,8 +1907,7 @@ class PrismaClient:
|
||||||
if token is not None:
|
if token is not None:
|
||||||
print_verbose(f"token: {token}")
|
print_verbose(f"token: {token}")
|
||||||
# check if plain text or hash
|
# check if plain text or hash
|
||||||
if token.startswith("sk-"):
|
token = _hash_token_if_needed(token=token)
|
||||||
token = self.hash_token(token=token)
|
|
||||||
db_data["token"] = token
|
db_data["token"] = token
|
||||||
response = await self.db.litellm_verificationtoken.update(
|
response = await self.db.litellm_verificationtoken.update(
|
||||||
where={"token": token}, # type: ignore
|
where={"token": token}, # type: ignore
|
||||||
|
@ -2424,6 +2418,18 @@ def hash_token(token: str):
|
||||||
return hashed_token
|
return hashed_token
|
||||||
|
|
||||||
|
|
||||||
|
def _hash_token_if_needed(token: str) -> str:
|
||||||
|
"""
|
||||||
|
Hash the token if it's a string and starts with "sk-"
|
||||||
|
|
||||||
|
Else return the token as is
|
||||||
|
"""
|
||||||
|
if token.startswith("sk-"):
|
||||||
|
return hash_token(token=token)
|
||||||
|
else:
|
||||||
|
return token
|
||||||
|
|
||||||
|
|
||||||
def _extract_from_regex(duration: str) -> Tuple[int, str]:
|
def _extract_from_regex(duration: str) -> Tuple[int, str]:
|
||||||
match = re.match(r"(\d+)(mo|[smhd]?)", duration)
|
match = re.match(r"(\d+)(mo|[smhd]?)", duration)
|
||||||
|
|
||||||
|
|
3469
tests/local_testing/test_key_generate_prisma.py
Normal file
3469
tests/local_testing/test_key_generate_prisma.py
Normal file
File diff suppressed because it is too large
Load diff
|
@ -147,23 +147,6 @@ def test_key_info_route_allowed(route_checks):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def test_key_info_route_forbidden(route_checks):
|
|
||||||
"""
|
|
||||||
Internal User is not allowed to access /key/info route for a key they're not using in Authenticated API Key
|
|
||||||
"""
|
|
||||||
with pytest.raises(HTTPException) as exc_info:
|
|
||||||
route_checks.non_proxy_admin_allowed_routes_check(
|
|
||||||
user_obj=None,
|
|
||||||
_user_role=LitellmUserRoles.INTERNAL_USER.value,
|
|
||||||
route="/key/info",
|
|
||||||
request=MockRequest(query_params={"key": "wrong_key"}),
|
|
||||||
valid_token=UserAPIKeyAuth(api_key="test_key"),
|
|
||||||
api_key="test_key",
|
|
||||||
request_data={},
|
|
||||||
)
|
|
||||||
assert exc_info.value.status_code == 403
|
|
||||||
|
|
||||||
|
|
||||||
def test_user_info_route_allowed(route_checks):
|
def test_user_info_route_allowed(route_checks):
|
||||||
"""
|
"""
|
||||||
Internal User is allowed to access /user/info route for their own user_id
|
Internal User is allowed to access /user/info route for their own user_id
|
||||||
|
|
|
@ -456,7 +456,10 @@ async def test_call_with_valid_model_using_all_models(prisma_client):
|
||||||
print("result from user auth with new key", result)
|
print("result from user auth with new key", result)
|
||||||
|
|
||||||
# call /key/info for key - models == "all-proxy-models"
|
# call /key/info for key - models == "all-proxy-models"
|
||||||
key_info = await info_key_fn(key=generated_key)
|
key_info = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
print("key_info", key_info)
|
print("key_info", key_info)
|
||||||
models = key_info["info"]["models"]
|
models = key_info["info"]["models"]
|
||||||
assert models == ["all-team-models"]
|
assert models == ["all-team-models"]
|
||||||
|
@ -1179,7 +1182,12 @@ def test_generate_and_call_key_info(prisma_client):
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert result["key"] == generated_key
|
assert result["key"] == generated_key
|
||||||
print("\n info for key=", result["info"])
|
print("\n info for key=", result["info"])
|
||||||
|
@ -1271,7 +1279,12 @@ def test_generate_and_update_key(prisma_client):
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert result["key"] == generated_key
|
assert result["key"] == generated_key
|
||||||
print("\n info for key=", result["info"])
|
print("\n info for key=", result["info"])
|
||||||
|
@ -1303,7 +1316,12 @@ def test_generate_and_update_key(prisma_client):
|
||||||
print("response2=", response2)
|
print("response2=", response2)
|
||||||
|
|
||||||
# get info on key after update
|
# get info on key after update
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(
|
||||||
|
user_role=LitellmUserRoles.PROXY_ADMIN,
|
||||||
|
),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert result["key"] == generated_key
|
assert result["key"] == generated_key
|
||||||
print("\n info for key=", result["info"])
|
print("\n info for key=", result["info"])
|
||||||
|
@ -1989,7 +2007,10 @@ async def test_key_name_null(prisma_client):
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(request)
|
||||||
print("generated key=", key)
|
print("generated key=", key)
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert result["info"]["key_name"] is None
|
assert result["info"]["key_name"] is None
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2014,7 +2035,10 @@ async def test_key_name_set(prisma_client):
|
||||||
request = GenerateKeyRequest()
|
request = GenerateKeyRequest()
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(request)
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert isinstance(result["info"]["key_name"], str)
|
assert isinstance(result["info"]["key_name"], str)
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2038,7 +2062,10 @@ async def test_default_key_params(prisma_client):
|
||||||
request = GenerateKeyRequest()
|
request = GenerateKeyRequest()
|
||||||
key = await generate_key_fn(request)
|
key = await generate_key_fn(request)
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert result["info"]["max_budget"] == 0.000122
|
assert result["info"]["max_budget"] == 0.000122
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
@ -2804,7 +2831,10 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert result["key"] == generated_key
|
assert result["key"] == generated_key
|
||||||
print("\n info for key=", result["info"])
|
print("\n info for key=", result["info"])
|
||||||
|
@ -2825,7 +2855,10 @@ async def test_generate_key_with_model_tpm_limit(prisma_client):
|
||||||
_request._url = URL(url="/update/key")
|
_request._url = URL(url="/update/key")
|
||||||
|
|
||||||
await update_key_fn(data=request, request=_request)
|
await update_key_fn(data=request, request=_request)
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert result["key"] == generated_key
|
assert result["key"] == generated_key
|
||||||
print("\n info for key=", result["info"])
|
print("\n info for key=", result["info"])
|
||||||
|
@ -2863,7 +2896,10 @@ async def test_generate_key_with_guardrails(prisma_client):
|
||||||
generated_key = key.key
|
generated_key = key.key
|
||||||
|
|
||||||
# use generated key to auth in
|
# use generated key to auth in
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert result["key"] == generated_key
|
assert result["key"] == generated_key
|
||||||
print("\n info for key=", result["info"])
|
print("\n info for key=", result["info"])
|
||||||
|
@ -2882,7 +2918,10 @@ async def test_generate_key_with_guardrails(prisma_client):
|
||||||
_request._url = URL(url="/update/key")
|
_request._url = URL(url="/update/key")
|
||||||
|
|
||||||
await update_key_fn(data=request, request=_request)
|
await update_key_fn(data=request, request=_request)
|
||||||
result = await info_key_fn(key=generated_key)
|
result = await info_key_fn(
|
||||||
|
key=generated_key,
|
||||||
|
user_api_key_dict=UserAPIKeyAuth(user_role=LitellmUserRoles.PROXY_ADMIN),
|
||||||
|
)
|
||||||
print("result from info_key_fn", result)
|
print("result from info_key_fn", result)
|
||||||
assert result["key"] == generated_key
|
assert result["key"] == generated_key
|
||||||
print("\n info for key=", result["info"])
|
print("\n info for key=", result["info"])
|
||||||
|
|
|
@ -412,7 +412,7 @@ async def test_key_info():
|
||||||
Get key info
|
Get key info
|
||||||
- as admin -> 200
|
- as admin -> 200
|
||||||
- as key itself -> 200
|
- as key itself -> 200
|
||||||
- as random key -> 403
|
- as non existent key -> 404
|
||||||
"""
|
"""
|
||||||
async with aiohttp.ClientSession() as session:
|
async with aiohttp.ClientSession() as session:
|
||||||
key_gen = await generate_key(session=session, i=0)
|
key_gen = await generate_key(session=session, i=0)
|
||||||
|
@ -425,10 +425,9 @@ async def test_key_info():
|
||||||
# as key itself, use the auth param, and no query key needed
|
# as key itself, use the auth param, and no query key needed
|
||||||
await get_key_info(session=session, call_key=key)
|
await get_key_info(session=session, call_key=key)
|
||||||
# as random key #
|
# as random key #
|
||||||
key_gen = await generate_key(session=session, i=0)
|
random_key = f"sk-{uuid.uuid4()}"
|
||||||
random_key = key_gen["key"]
|
status = await get_key_info(session=session, get_key=random_key, call_key=key)
|
||||||
status = await get_key_info(session=session, get_key=key, call_key=random_key)
|
assert status == 404
|
||||||
assert status == 403
|
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue