mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-01 16:24:44 +00:00
minor
This commit is contained in:
parent
529b12dc5e
commit
b20cce5c43
1 changed files with 11 additions and 5 deletions
|
@ -181,8 +181,6 @@ JWT_AUDIENCE = "llama-stack"
|
||||||
|
|
||||||
|
|
||||||
class JWTAuthProviderConfig(BaseModel):
|
class JWTAuthProviderConfig(BaseModel):
|
||||||
"""Configuration for JWT token authentication provider."""
|
|
||||||
|
|
||||||
# The JWKS URI for collecting public keys
|
# The JWKS URI for collecting public keys
|
||||||
jwks_uri: str
|
jwks_uri: str
|
||||||
algorithm: str = "RS256"
|
algorithm: str = "RS256"
|
||||||
|
@ -190,7 +188,11 @@ class JWTAuthProviderConfig(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class JWTAuthProvider(AuthProvider):
|
class JWTAuthProvider(AuthProvider):
|
||||||
"""JWT token authentication provider that validates tokens against the JWT token."""
|
"""
|
||||||
|
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||||
|
|
||||||
|
This should be the standard authentication provider for most use cases.
|
||||||
|
"""
|
||||||
|
|
||||||
def __init__(self, config: JWTAuthProviderConfig):
|
def __init__(self, config: JWTAuthProviderConfig):
|
||||||
self.config = config
|
self.config = config
|
||||||
|
@ -203,7 +205,9 @@ class JWTAuthProvider(AuthProvider):
|
||||||
|
|
||||||
try:
|
try:
|
||||||
kid = jwt.get_unverified_header(token)["kid"]
|
kid = jwt.get_unverified_header(token)["kid"]
|
||||||
key = self._jwks[kid] # raises if unknown
|
if kid not in self._jwks:
|
||||||
|
raise ValueError(f"Unknown key ID: {kid}")
|
||||||
|
key = self._jwks[kid]
|
||||||
claims = jwt.decode(
|
claims = jwt.decode(
|
||||||
token,
|
token,
|
||||||
key,
|
key,
|
||||||
|
@ -212,8 +216,10 @@ class JWTAuthProvider(AuthProvider):
|
||||||
options={"verify_exp": True},
|
options={"verify_exp": True},
|
||||||
)
|
)
|
||||||
except Exception as exc:
|
except Exception as exc:
|
||||||
raise ValueError(f"invalid token: {token}") from exc
|
raise ValueError(f"Invalid JWT token: {token}") from exc
|
||||||
|
|
||||||
|
# There are other standard claims, the most relevant of which is `scope`.
|
||||||
|
# We should incorporate these into the access attributes.
|
||||||
principal = f"{claims['iss']}:{claims['sub']}"
|
principal = f"{claims['iss']}:{claims['sub']}"
|
||||||
|
|
||||||
teams = claims.get("teams", [])
|
teams = claims.get("teams", [])
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue