From b20cce5c4389aae26919dde0c4c6073ceb02bc98 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 15 May 2025 17:16:14 -0700 Subject: [PATCH] minor --- .../distribution/server/auth_providers.py | 16 +++++++++++----- 1 file changed, 11 insertions(+), 5 deletions(-) diff --git a/llama_stack/distribution/server/auth_providers.py b/llama_stack/distribution/server/auth_providers.py index 651d4819d..3feb17c28 100644 --- a/llama_stack/distribution/server/auth_providers.py +++ b/llama_stack/distribution/server/auth_providers.py @@ -181,8 +181,6 @@ JWT_AUDIENCE = "llama-stack" class JWTAuthProviderConfig(BaseModel): - """Configuration for JWT token authentication provider.""" - # The JWKS URI for collecting public keys jwks_uri: str algorithm: str = "RS256" @@ -190,7 +188,11 @@ class JWTAuthProviderConfig(BaseModel): class JWTAuthProvider(AuthProvider): - """JWT token authentication provider that validates tokens against the JWT token.""" + """ + JWT token authentication provider that validates a JWT token and extracts access attributes. + + This should be the standard authentication provider for most use cases. + """ def __init__(self, config: JWTAuthProviderConfig): self.config = config @@ -203,7 +205,9 @@ class JWTAuthProvider(AuthProvider): try: kid = jwt.get_unverified_header(token)["kid"] - key = self._jwks[kid] # raises if unknown + if kid not in self._jwks: + raise ValueError(f"Unknown key ID: {kid}") + key = self._jwks[kid] claims = jwt.decode( token, key, @@ -212,8 +216,10 @@ class JWTAuthProvider(AuthProvider): options={"verify_exp": True}, ) except Exception as exc: - raise ValueError(f"invalid token: {token}") from exc + raise ValueError(f"Invalid JWT token: {token}") from exc + # There are other standard claims, the most relevant of which is `scope`. + # We should incorporate these into the access attributes. principal = f"{claims['iss']}:{claims['sub']}" teams = claims.get("teams", [])