mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-13 08:36:09 +00:00
feat(auth): allow token to be provided for use against jwks endpoint
Though the jwks endpoint does not usually require authentication, it does by default in a kubernetes cluster. The cluster can be configured to allow anonymous access to that endpoint, but by allowing a token to be presented that cluster configuration is not necessary.
This commit is contained in:
parent
c8c742ba45
commit
74d891db72
3 changed files with 57 additions and 24 deletions
25
.github/workflows/integration-auth-tests.yml
vendored
25
.github/workflows/integration-auth-tests.yml
vendored
|
@ -54,28 +54,6 @@ jobs:
|
|||
kubectl create serviceaccount llama-stack-auth -n llama-stack
|
||||
kubectl create rolebinding llama-stack-auth-rolebinding --clusterrole=admin --serviceaccount=llama-stack:llama-stack-auth -n llama-stack
|
||||
kubectl create token llama-stack-auth -n llama-stack > llama-stack-auth-token
|
||||
cat <<EOF | kubectl apply -f -
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: ClusterRole
|
||||
metadata:
|
||||
name: allow-anonymous-openid
|
||||
rules:
|
||||
- nonResourceURLs: ["/openid/v1/jwks"]
|
||||
verbs: ["get"]
|
||||
---
|
||||
apiVersion: rbac.authorization.k8s.io/v1
|
||||
kind: ClusterRoleBinding
|
||||
metadata:
|
||||
name: allow-anonymous-openid
|
||||
roleRef:
|
||||
apiGroup: rbac.authorization.k8s.io
|
||||
kind: ClusterRole
|
||||
name: allow-anonymous-openid
|
||||
subjects:
|
||||
- kind: User
|
||||
name: system:anonymous
|
||||
apiGroup: rbac.authorization.k8s.io
|
||||
EOF
|
||||
|
||||
- name: Set Kubernetes Config
|
||||
if: ${{ matrix.auth-provider == 'oauth2_token' }}
|
||||
|
@ -84,6 +62,7 @@ jobs:
|
|||
echo "KUBERNETES_CA_CERT_PATH=$(kubectl config view --minify -o jsonpath='{.clusters[0].cluster.certificate-authority}')" >> $GITHUB_ENV
|
||||
echo "KUBERNETES_ISSUER=$(kubectl get --raw /.well-known/openid-configuration| jq -r .issuer)" >> $GITHUB_ENV
|
||||
echo "KUBERNETES_AUDIENCE=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h | cut -d. -f2 | base64 -d | jq -r '.aud[0]')" >> $GITHUB_ENV
|
||||
echo "TOKEN=$(kubectl create token llama-stack-auth -n llama-stack --duration=1h)" >> $GITHUB_ENV
|
||||
|
||||
- name: Set Kube Auth Config and run server
|
||||
env:
|
||||
|
@ -101,7 +80,7 @@ jobs:
|
|||
EOF
|
||||
yq eval '.server.auth = {"provider_type": "${{ matrix.auth-provider }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config = {"tls_cafile": "${{ env.KUBERNETES_CA_CERT_PATH }}", "issuer": "${{ env.KUBERNETES_ISSUER }}", "audience": "${{ env.KUBERNETES_AUDIENCE }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}"}' -i $run_dir/run.yaml
|
||||
yq eval '.server.auth.config.jwks = {"uri": "${{ env.KUBERNETES_API_SERVER_URL }}", "token": "${{ env.TOKEN }}"}' -i $run_dir/run.yaml
|
||||
cat $run_dir/run.yaml
|
||||
|
||||
nohup uv run llama stack run $run_dir/run.yaml --image-type venv > server.log 2>&1 &
|
||||
|
|
|
@ -84,6 +84,7 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
|
|||
class OAuth2JWKSConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
uri: str
|
||||
token: str | None = Field(default=None, description="token to authorise access to jwks")
|
||||
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
|
||||
|
||||
|
||||
|
@ -246,9 +247,12 @@ class OAuth2TokenAuthProvider(AuthProvider):
|
|||
if self.config.jwks is None:
|
||||
raise ValueError("JWKS is not configured")
|
||||
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
||||
headers = {}
|
||||
if self.config.jwks.token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
||||
async with httpx.AsyncClient(verify=verify) as client:
|
||||
res = await client.get(self.config.jwks.uri, timeout=5)
|
||||
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
updated = {}
|
||||
|
|
|
@ -345,6 +345,56 @@ def test_invalid_oauth2_authentication(oauth2_client, invalid_token):
|
|||
assert "Invalid JWT token" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_auth_jwks_response(*args, **kwargs):
|
||||
if "headers" not in kwargs or "Authorization" not in kwargs["headers"]:
|
||||
return MockResponse(401, {})
|
||||
authz = kwargs["headers"]["Authorization"]
|
||||
if authz != "Bearer abcdefg":
|
||||
return MockResponse(401, {})
|
||||
return await mock_jwks_response(args, kwargs)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_app_with_jwks_token():
|
||||
app = FastAPI()
|
||||
auth_config = AuthenticationConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks": {
|
||||
"uri": "http://mock-authz-service/token/introspect",
|
||||
"key_recheck_period": "3600",
|
||||
"token": "abcdefg",
|
||||
},
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_client_with_jwks_token(oauth2_app_with_jwks_token):
|
||||
return TestClient(oauth2_app_with_jwks_token)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_expected(oauth2_client, jwt_token_valid):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 401
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_auth_jwks_response)
|
||||
def test_oauth2_with_jwks_token_configured(oauth2_client_with_jwks_token, jwt_token_valid):
|
||||
response = oauth2_client_with_jwks_token.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
def test_get_attributes_from_claims():
|
||||
claims = {
|
||||
"sub": "my-user",
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue