This commit is contained in:
Ashwin Bharambe 2025-05-15 17:16:14 -07:00
parent 529b12dc5e
commit b20cce5c43

View file

@ -181,8 +181,6 @@ JWT_AUDIENCE = "llama-stack"
class JWTAuthProviderConfig(BaseModel): class JWTAuthProviderConfig(BaseModel):
"""Configuration for JWT token authentication provider."""
# The JWKS URI for collecting public keys # The JWKS URI for collecting public keys
jwks_uri: str jwks_uri: str
algorithm: str = "RS256" algorithm: str = "RS256"
@ -190,7 +188,11 @@ class JWTAuthProviderConfig(BaseModel):
class JWTAuthProvider(AuthProvider): class JWTAuthProvider(AuthProvider):
"""JWT token authentication provider that validates tokens against the JWT token.""" """
JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: JWTAuthProviderConfig): def __init__(self, config: JWTAuthProviderConfig):
self.config = config self.config = config
@ -203,7 +205,9 @@ class JWTAuthProvider(AuthProvider):
try: try:
kid = jwt.get_unverified_header(token)["kid"] kid = jwt.get_unverified_header(token)["kid"]
key = self._jwks[kid] # raises if unknown if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key = self._jwks[kid]
claims = jwt.decode( claims = jwt.decode(
token, token,
key, key,
@ -212,8 +216,10 @@ class JWTAuthProvider(AuthProvider):
options={"verify_exp": True}, options={"verify_exp": True},
) )
except Exception as exc: except Exception as exc:
raise ValueError(f"invalid token: {token}") from exc raise ValueError(f"Invalid JWT token: {token}") from exc
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = f"{claims['iss']}:{claims['sub']}" principal = f"{claims['iss']}:{claims['sub']}"
teams = claims.get("teams", []) teams = claims.get("teams", [])