From 736c41332f7204eafcf0b39a82884a563a5d5b81 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Sun, 18 May 2025 08:08:44 -0700 Subject: [PATCH] add invalid token test --- tests/unit/server/test_auth.py | 55 ++++++++++++++++++++++++++++++++-- 1 file changed, 52 insertions(+), 3 deletions(-) diff --git a/tests/unit/server/test_auth.py b/tests/unit/server/test_auth.py index 4a60d43c6..f15ca9de4 100644 --- a/tests/unit/server/test_auth.py +++ b/tests/unit/server/test_auth.py @@ -4,6 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import base64 from unittest.mock import AsyncMock, patch import pytest @@ -16,6 +17,7 @@ from llama_stack.distribution.server.auth_providers import ( AuthProviderConfig, AuthProviderType, TokenValidationResult, + get_attributes_from_claims, ) @@ -435,7 +437,7 @@ async def mock_jwks_response(*args, **kwargs): "kty": "oct", "alg": "HS256", "use": "sig", - "k": "MTIzNDU2Nzg5MA", # Base64-encoded "1234567890" + "k": base64.b64encode(b"foobarbaz").decode(), } ] }, @@ -446,7 +448,6 @@ async def mock_jwks_response(*args, **kwargs): def jwt_token_valid(): from jose import jwt - # correctly signed jwt token with "kid" in header return jwt.encode( { "sub": "my-user", @@ -454,7 +455,7 @@ def jwt_token_valid(): "scope": "foo bar", "aud": "llama-stack", }, - key="1234567890", + key="foobarbaz", algorithm="HS256", headers={"kid": "1234567890"}, ) @@ -467,4 +468,52 @@ def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid): assert response.json() == {"message": "Authentication successful"} +@patch("httpx.AsyncClient.get", new=mock_jwks_response) +def test_invalid_oauth2_authentication(oauth2_client, invalid_token): + response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {invalid_token}"}) + assert response.status_code == 401 + assert "Invalid JWT token" in response.json()["error"]["message"] + + +def test_get_attributes_from_claims(): + claims = { + "sub": "my-user", + "groups": ["group1", "group2"], + "scope": "foo bar", + "aud": "llama-stack", + } + attributes = get_attributes_from_claims(claims, {"sub": "roles", "groups": "teams"}) + assert attributes.roles == ["my-user"] + assert attributes.teams == ["group1", "group2"] + + claims = { + "sub": "my-user", + "tenant": "my-tenant", + } + attributes = get_attributes_from_claims(claims, {"sub": "roles", "tenant": "namespaces"}) + assert attributes.roles == ["my-user"] + assert attributes.namespaces == ["my-tenant"] + + claims = { + "sub": "my-user", + "username": "my-username", + "tenant": "my-tenant", + "groups": ["group1", "group2"], + "team": "my-team", + } + attributes = get_attributes_from_claims( + claims, + { + "sub": "roles", + "tenant": "namespaces", + "username": "roles", + "team": "teams", + "groups": "teams", + }, + ) + assert set(attributes.roles) == {"my-user", "my-username"} + assert set(attributes.teams) == {"my-team", "group1", "group2"} + assert attributes.namespaces == ["my-tenant"] + + # TODO: add more tests for oauth2 token provider