mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-28 01:21:59 +00:00
more fixes
This commit is contained in:
parent
b5d5d1fba0
commit
cc77a1b4c8
2 changed files with 152 additions and 36 deletions
|
|
@ -27,6 +27,10 @@ class MockResponse:
|
|||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code != 200:
|
||||
raise Exception(f"HTTP error: {self.status_code}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_endpoint():
|
||||
|
|
@ -379,3 +383,88 @@ async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope):
|
|||
assert attributes["roles"] == ["admin"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
# oauth2 token provider tests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks_uri": "http://mock-authz-service/token/introspect",
|
||||
"cache_ttl": "3600",
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_client(oauth2_app):
|
||||
return TestClient(oauth2_app)
|
||||
|
||||
|
||||
def test_missing_auth_header_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_jwks_response(*args, **kwargs):
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "1234567890",
|
||||
"kty": "oct",
|
||||
"alg": "HS256",
|
||||
"use": "sig",
|
||||
"k": "MTIzNDU2Nzg5MA", # Base64-encoded "1234567890"
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_token_valid():
|
||||
from jose import jwt
|
||||
|
||||
# correctly signed jwt token with "kid" in header
|
||||
return jwt.encode(
|
||||
{
|
||||
"sub": "my-user",
|
||||
"groups": ["group1", "group2"],
|
||||
"scope": "foo bar",
|
||||
"aud": "llama-stack",
|
||||
},
|
||||
key="1234567890",
|
||||
algorithm="HS256",
|
||||
headers={"kid": "1234567890"},
|
||||
)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
# TODO: add more tests for oauth2 token provider
|
||||
|
|
|
|||
Loading…
Add table
Add a link
Reference in a new issue