This commit is contained in:
Ashwin Bharambe 2025-05-15 17:16:14 -07:00
parent 529b12dc5e
commit b20cce5c43

View file

@ -181,8 +181,6 @@ JWT_AUDIENCE = "llama-stack"
class JWTAuthProviderConfig(BaseModel):
"""Configuration for JWT token authentication provider."""
# The JWKS URI for collecting public keys
jwks_uri: str
algorithm: str = "RS256"
@ -190,7 +188,11 @@ class JWTAuthProviderConfig(BaseModel):
class JWTAuthProvider(AuthProvider):
"""JWT token authentication provider that validates tokens against the JWT token."""
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: JWTAuthProviderConfig):
self.config = config
@ -203,7 +205,9 @@ class JWTAuthProvider(AuthProvider):
try:
kid = jwt.get_unverified_header(token)["kid"]
key = self._jwks[kid] # raises if unknown
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key = self._jwks[kid]
claims = jwt.decode(
token,
key,
@ -212,8 +216,10 @@ class JWTAuthProvider(AuthProvider):
options={"verify_exp": True},
)
except Exception as exc:
raise ValueError(f"invalid token: {token}") from exc
raise ValueError(f"Invalid JWT token: {token}") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = f"{claims['iss']}:{claims['sub']}"
teams = claims.get("teams", [])