forked from phoenix/litellm-mirror
fix(proxy_server.py): fix key create logic + add unit tests
This commit is contained in:
parent
a0d230e3a2
commit
70716b3373
2 changed files with 75 additions and 40 deletions
|
@ -1010,48 +1010,50 @@ async def user_api_key_auth(
|
|||
db=custom_db_client,
|
||||
)
|
||||
)
|
||||
if route in LiteLLMRoutes.info_routes.value and (
|
||||
not _is_user_proxy_admin(user_id_information)
|
||||
): # check if user allowed to call an info route
|
||||
if route == "/key/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
key = query_params.get("key")
|
||||
if (
|
||||
key is not None
|
||||
and prisma_client.hash_token(token=key) != api_key
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="user not allowed to access this key's info",
|
||||
)
|
||||
elif route == "/user/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
user_id = query_params.get("user_id")
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
|
||||
)
|
||||
if user_id != valid_token.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="key not allowed to access this user's info",
|
||||
)
|
||||
elif route == "/model/info":
|
||||
# /model/info just shows models user has access to
|
||||
pass
|
||||
elif route == "/team/info":
|
||||
# check if key can access this team's info
|
||||
query_params = request.query_params
|
||||
team_id = query_params.get("team_id")
|
||||
if team_id != valid_token.team_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="key not allowed to access this team's info",
|
||||
|
||||
if not _is_user_proxy_admin(user_id_information): # if non-admin
|
||||
if (
|
||||
route in LiteLLMRoutes.info_routes.value
|
||||
): # check if user allowed to call an info route
|
||||
if route == "/key/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
key = query_params.get("key")
|
||||
if (
|
||||
key is not None
|
||||
and prisma_client.hash_token(token=key) != api_key
|
||||
):
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="user not allowed to access this key's info",
|
||||
)
|
||||
elif route == "/user/info":
|
||||
# check if user can access this route
|
||||
query_params = request.query_params
|
||||
user_id = query_params.get("user_id")
|
||||
verbose_proxy_logger.debug(
|
||||
f"user_id: {user_id} & valid_token.user_id: {valid_token.user_id}"
|
||||
)
|
||||
if user_id != valid_token.user_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="key not allowed to access this user's info",
|
||||
)
|
||||
elif route == "/model/info":
|
||||
# /model/info just shows models user has access to
|
||||
pass
|
||||
elif route == "/team/info":
|
||||
# check if key can access this team's info
|
||||
query_params = request.query_params
|
||||
team_id = query_params.get("team_id")
|
||||
if team_id != valid_token.team_id:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
detail="key not allowed to access this team's info",
|
||||
)
|
||||
else:
|
||||
raise Exception(
|
||||
f"Only master key can be used to generate, delete, update info for new keys/users."
|
||||
f"Only master key can be used to generate, delete, update info for new keys/users/teams."
|
||||
)
|
||||
|
||||
# check if token is from litellm-ui, litellm ui makes keys to allow users to login with sso. These keys can only be used for LiteLLM UI functions
|
||||
|
|
|
@ -44,9 +44,13 @@ async def generate_key(
|
|||
models=["azure-models", "gpt-4", "dall-e-3"],
|
||||
max_parallel_requests: Optional[int] = None,
|
||||
user_id: Optional[str] = None,
|
||||
calling_key="sk-1234",
|
||||
):
|
||||
url = "http://0.0.0.0:4000/key/generate"
|
||||
headers = {"Authorization": "Bearer sk-1234", "Content-Type": "application/json"}
|
||||
headers = {
|
||||
"Authorization": f"Bearer {calling_key}",
|
||||
"Content-Type": "application/json",
|
||||
}
|
||||
data = {
|
||||
"models": models,
|
||||
"aliases": {"mistral-7b": "gpt-3.5-turbo"},
|
||||
|
@ -80,6 +84,35 @@ async def test_key_gen():
|
|||
await asyncio.gather(*tasks)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_key_gen_bad_key():
|
||||
"""
|
||||
Test if you can create a key with a non-admin key, even with UI setup
|
||||
"""
|
||||
async with aiohttp.ClientSession() as session:
|
||||
## LOGIN TO UI
|
||||
form_data = {"username": "admin", "password": "sk-1234"}
|
||||
async with session.post(
|
||||
"http://0.0.0.0:4000/login", data=form_data
|
||||
) as response:
|
||||
assert (
|
||||
response.status == 200
|
||||
) # Assuming the endpoint returns a 500 status code for error handling
|
||||
text = await response.text()
|
||||
print(text)
|
||||
## create user key with admin key -> expect to work
|
||||
key_data = await generate_key(session=session, i=0, user_id="user-1234")
|
||||
key = key_data["key"]
|
||||
## create new key with user key -> expect to fail
|
||||
try:
|
||||
await generate_key(
|
||||
session=session, i=0, user_id="user-1234", calling_key=key
|
||||
)
|
||||
pytest.fail("Expected to fail")
|
||||
except Exception as e:
|
||||
pass
|
||||
|
||||
|
||||
async def update_key(session, get_key):
|
||||
"""
|
||||
Make sure only models user has access to are returned
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue