more fixes

This commit is contained in:
Ashwin Bharambe 2025-05-18 07:40:11 -07:00
parent b5d5d1fba0
commit cc77a1b4c8
2 changed files with 152 additions and 36 deletions

View file

@ -12,7 +12,7 @@ from urllib.parse import parse_qs
import httpx
from jose import jwt
from pydantic import BaseModel, Field
from pydantic import BaseModel, Field, field_validator
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.log import get_logger
@ -77,6 +77,7 @@ class AuthProviderType(str, Enum):
KUBERNETES = "kubernetes"
CUSTOM = "custom"
OAUTH2_TOKEN = "oauth2_token"
class AuthProviderConfig(BaseModel):
@ -177,24 +178,61 @@ class KubernetesAuthProvider(AuthProvider):
self._client = None
JWT_AUDIENCE = "llama-stack"
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
for claim_key, attribute_key in mapping.items():
if claim_key not in claims or not hasattr(attributes, attribute_key):
continue
claim = claims[claim_key]
if isinstance(claim, list):
values = claim
else:
values = claim.split()
current = getattr(attributes, attribute_key)
if current:
current.extend(values)
else:
setattr(attributes, attribute_key, values)
return attributes
class JWTAuthProviderConfig(BaseModel):
class OAuth2TokenAuthProviderConfig(BaseModel):
# The JWKS URI for collecting public keys
jwks_uri: str
algorithm: str = "RS256"
cache_ttl: int = 3600
audience: str = "llama-stack"
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
"username": "roles",
"groups": "teams",
"team": "teams",
"project": "projects",
"tenant": "namespaces",
"namespace": "namespaces",
},
)
@classmethod
@field_validator("claims_mapping")
def validate_claims_mapping(cls, v):
for key, value in v.items():
if not value:
raise ValueError(f"claims_mapping value cannot be empty: {key}")
if value not in AccessAttributes.model_fields:
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v
class JWTAuthProvider(AuthProvider):
class OAuth2TokenAuthProvider(AuthProvider):
"""
JWT token authentication provider that validates a JWT token and extracts access attributes.
This should be the standard authentication provider for most use cases.
"""
def __init__(self, config: JWTAuthProviderConfig):
def __init__(self, config: OAuth2TokenAuthProviderConfig):
self.config = config
self._jwks_at: float = 0.0
self._jwks: dict[str, str] = {}
@ -204,15 +242,17 @@ class JWTAuthProvider(AuthProvider):
await self._refresh_jwks()
try:
kid = jwt.get_unverified_header(token)["kid"]
header = jwt.get_unverified_header(token)
kid = header["kid"]
if kid not in self._jwks:
raise ValueError(f"Unknown key ID: {kid}")
key = self._jwks[kid]
key_data = self._jwks[kid]
algorithm = header.get("alg", "RS256")
claims = jwt.decode(
token,
key,
algorithms=[self.config.algorithm],
audience=JWT_AUDIENCE,
key_data,
algorithms=[algorithm],
audience=self.config.audience,
options={"verify_exp": True},
)
except Exception as exc:
@ -220,29 +260,11 @@ class JWTAuthProvider(AuthProvider):
# There are other standard claims, the most relevant of which is `scope`.
# We should incorporate these into the access attributes.
principal = f"{claims['iss']}:{claims['sub']}"
teams = claims.get("teams", [])
if not teams:
if team := claims.get("team", claims.get("team_id")):
teams = [team]
projects = claims.get("projects", [])
if not projects:
if project := claims.get("project", claims.get("project_id")):
projects = [project]
namespaces = claims.get("namespaces", [])
if not namespaces:
if namespace := claims.get("namespace", claims.get("tenant")):
namespaces = [namespace]
principal = claims["sub"]
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
return TokenValidationResult(
principal=principal,
access_attributes=AccessAttributes(
roles=claims.get("groups", claims.get("roles", [])), # Okta / Auth0
teams=teams,
projects=projects,
namespaces=namespaces,
),
access_attributes=access_attributes,
)
async def close(self):
@ -250,10 +272,15 @@ class JWTAuthProvider(AuthProvider):
async def _refresh_jwks(self) -> None:
if time.time() - self._jwks_at > self.config.cache_ttl:
with httpx.AsyncClient() as client:
async with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5)
res.raise_for_status()
self._jwks = {k["kid"]: k for k in res.json()["keys"]}
jwks_data = res.json()["keys"]
self._jwks = {}
for k in jwks_data:
kid = k["kid"]
# Store the entire key object as it may be needed for different algorithms
self._jwks[kid] = k
self._jwks_at = time.time()
@ -340,8 +367,8 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
elif provider_type == "custom":
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "jwt":
return JWTAuthProvider(JWTAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
else:
supported_providers = ", ".join([t.value for t in AuthProviderType])
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")

View file

@ -27,6 +27,10 @@ class MockResponse:
def json(self):
return self._json_data
def raise_for_status(self):
if self.status_code != 200:
raise Exception(f"HTTP error: {self.status_code}")
@pytest.fixture
def mock_auth_endpoint():
@ -379,3 +383,88 @@ async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope):
assert attributes["roles"] == ["admin"]
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
# oauth2 token provider tests
@pytest.fixture
def oauth2_app():
app = FastAPI()
auth_config = AuthProviderConfig(
provider_type=AuthProviderType.OAUTH2_TOKEN,
config={
"jwks_uri": "http://mock-authz-service/token/introspect",
"cache_ttl": "3600",
"audience": "llama-stack",
},
)
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
@app.get("/test")
def test_endpoint():
return {"message": "Authentication successful"}
return app
@pytest.fixture
def oauth2_client(oauth2_app):
return TestClient(oauth2_app)
def test_missing_auth_header_oauth2(oauth2_client):
response = oauth2_client.get("/test")
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
def test_invalid_auth_header_format_oauth2(oauth2_client):
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
assert response.status_code == 401
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
async def mock_jwks_response(*args, **kwargs):
return MockResponse(
200,
{
"keys": [
{
"kid": "1234567890",
"kty": "oct",
"alg": "HS256",
"use": "sig",
"k": "MTIzNDU2Nzg5MA", # Base64-encoded "1234567890"
}
]
},
)
@pytest.fixture
def jwt_token_valid():
from jose import jwt
# correctly signed jwt token with "kid" in header
return jwt.encode(
{
"sub": "my-user",
"groups": ["group1", "group2"],
"scope": "foo bar",
"aud": "llama-stack",
},
key="1234567890",
algorithm="HS256",
headers={"kid": "1234567890"},
)
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
assert response.status_code == 200
assert response.json() == {"message": "Authentication successful"}
# TODO: add more tests for oauth2 token provider