mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-31 16:01:46 +00:00
more fixes
This commit is contained in:
parent
b5d5d1fba0
commit
cc77a1b4c8
2 changed files with 152 additions and 36 deletions
|
@ -12,7 +12,7 @@ from urllib.parse import parse_qs
|
|||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
from pydantic import BaseModel, Field, field_validator
|
||||
|
||||
from llama_stack.distribution.datatypes import AccessAttributes
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -77,6 +77,7 @@ class AuthProviderType(str, Enum):
|
|||
|
||||
KUBERNETES = "kubernetes"
|
||||
CUSTOM = "custom"
|
||||
OAUTH2_TOKEN = "oauth2_token"
|
||||
|
||||
|
||||
class AuthProviderConfig(BaseModel):
|
||||
|
@ -177,24 +178,61 @@ class KubernetesAuthProvider(AuthProvider):
|
|||
self._client = None
|
||||
|
||||
|
||||
JWT_AUDIENCE = "llama-stack"
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
|
||||
attributes = AccessAttributes()
|
||||
for claim_key, attribute_key in mapping.items():
|
||||
if claim_key not in claims or not hasattr(attributes, attribute_key):
|
||||
continue
|
||||
claim = claims[claim_key]
|
||||
if isinstance(claim, list):
|
||||
values = claim
|
||||
else:
|
||||
values = claim.split()
|
||||
|
||||
current = getattr(attributes, attribute_key)
|
||||
if current:
|
||||
current.extend(values)
|
||||
else:
|
||||
setattr(attributes, attribute_key, values)
|
||||
return attributes
|
||||
|
||||
|
||||
class JWTAuthProviderConfig(BaseModel):
|
||||
class OAuth2TokenAuthProviderConfig(BaseModel):
|
||||
# The JWKS URI for collecting public keys
|
||||
jwks_uri: str
|
||||
algorithm: str = "RS256"
|
||||
cache_ttl: int = 3600
|
||||
audience: str = "llama-stack"
|
||||
claims_mapping: dict[str, str] = Field(
|
||||
default_factory=lambda: {
|
||||
"sub": "roles",
|
||||
"username": "roles",
|
||||
"groups": "teams",
|
||||
"team": "teams",
|
||||
"project": "projects",
|
||||
"tenant": "namespaces",
|
||||
"namespace": "namespaces",
|
||||
},
|
||||
)
|
||||
|
||||
@classmethod
|
||||
@field_validator("claims_mapping")
|
||||
def validate_claims_mapping(cls, v):
|
||||
for key, value in v.items():
|
||||
if not value:
|
||||
raise ValueError(f"claims_mapping value cannot be empty: {key}")
|
||||
if value not in AccessAttributes.model_fields:
|
||||
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
|
||||
return v
|
||||
|
||||
|
||||
class JWTAuthProvider(AuthProvider):
|
||||
class OAuth2TokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||
|
||||
This should be the standard authentication provider for most use cases.
|
||||
"""
|
||||
|
||||
def __init__(self, config: JWTAuthProviderConfig):
|
||||
def __init__(self, config: OAuth2TokenAuthProviderConfig):
|
||||
self.config = config
|
||||
self._jwks_at: float = 0.0
|
||||
self._jwks: dict[str, str] = {}
|
||||
|
@ -204,15 +242,17 @@ class JWTAuthProvider(AuthProvider):
|
|||
await self._refresh_jwks()
|
||||
|
||||
try:
|
||||
kid = jwt.get_unverified_header(token)["kid"]
|
||||
header = jwt.get_unverified_header(token)
|
||||
kid = header["kid"]
|
||||
if kid not in self._jwks:
|
||||
raise ValueError(f"Unknown key ID: {kid}")
|
||||
key = self._jwks[kid]
|
||||
key_data = self._jwks[kid]
|
||||
algorithm = header.get("alg", "RS256")
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key,
|
||||
algorithms=[self.config.algorithm],
|
||||
audience=JWT_AUDIENCE,
|
||||
key_data,
|
||||
algorithms=[algorithm],
|
||||
audience=self.config.audience,
|
||||
options={"verify_exp": True},
|
||||
)
|
||||
except Exception as exc:
|
||||
|
@ -220,29 +260,11 @@ class JWTAuthProvider(AuthProvider):
|
|||
|
||||
# There are other standard claims, the most relevant of which is `scope`.
|
||||
# We should incorporate these into the access attributes.
|
||||
principal = f"{claims['iss']}:{claims['sub']}"
|
||||
|
||||
teams = claims.get("teams", [])
|
||||
if not teams:
|
||||
if team := claims.get("team", claims.get("team_id")):
|
||||
teams = [team]
|
||||
projects = claims.get("projects", [])
|
||||
if not projects:
|
||||
if project := claims.get("project", claims.get("project_id")):
|
||||
projects = [project]
|
||||
namespaces = claims.get("namespaces", [])
|
||||
if not namespaces:
|
||||
if namespace := claims.get("namespace", claims.get("tenant")):
|
||||
namespaces = [namespace]
|
||||
|
||||
principal = claims["sub"]
|
||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
return TokenValidationResult(
|
||||
principal=principal,
|
||||
access_attributes=AccessAttributes(
|
||||
roles=claims.get("groups", claims.get("roles", [])), # Okta / Auth0
|
||||
teams=teams,
|
||||
projects=projects,
|
||||
namespaces=namespaces,
|
||||
),
|
||||
access_attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
|
@ -250,10 +272,15 @@ class JWTAuthProvider(AuthProvider):
|
|||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
if time.time() - self._jwks_at > self.config.cache_ttl:
|
||||
with httpx.AsyncClient() as client:
|
||||
async with httpx.AsyncClient() as client:
|
||||
res = await client.get(self.config.jwks_uri, timeout=5)
|
||||
res.raise_for_status()
|
||||
self._jwks = {k["kid"]: k for k in res.json()["keys"]}
|
||||
jwks_data = res.json()["keys"]
|
||||
self._jwks = {}
|
||||
for k in jwks_data:
|
||||
kid = k["kid"]
|
||||
# Store the entire key object as it may be needed for different algorithms
|
||||
self._jwks[kid] = k
|
||||
self._jwks_at = time.time()
|
||||
|
||||
|
||||
|
@ -340,8 +367,8 @@ def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
|
|||
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "custom":
|
||||
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "jwt":
|
||||
return JWTAuthProvider(JWTAuthProviderConfig.model_validate(config.config))
|
||||
elif provider_type == "oauth2_token":
|
||||
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))
|
||||
else:
|
||||
supported_providers = ", ".join([t.value for t in AuthProviderType])
|
||||
raise ValueError(f"Unsupported auth provider type: {provider_type}. Supported types are: {supported_providers}")
|
||||
|
|
|
@ -27,6 +27,10 @@ class MockResponse:
|
|||
def json(self):
|
||||
return self._json_data
|
||||
|
||||
def raise_for_status(self):
|
||||
if self.status_code != 200:
|
||||
raise Exception(f"HTTP error: {self.status_code}")
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def mock_auth_endpoint():
|
||||
|
@ -379,3 +383,88 @@ async def test_k8s_middleware_no_attributes(mock_k8s_middleware, mock_scope):
|
|||
assert attributes["roles"] == ["admin"]
|
||||
|
||||
mock_app.assert_called_once_with(mock_scope, mock_receive, mock_send)
|
||||
|
||||
|
||||
# oauth2 token provider tests
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_app():
|
||||
app = FastAPI()
|
||||
auth_config = AuthProviderConfig(
|
||||
provider_type=AuthProviderType.OAUTH2_TOKEN,
|
||||
config={
|
||||
"jwks_uri": "http://mock-authz-service/token/introspect",
|
||||
"cache_ttl": "3600",
|
||||
"audience": "llama-stack",
|
||||
},
|
||||
)
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=auth_config)
|
||||
|
||||
@app.get("/test")
|
||||
def test_endpoint():
|
||||
return {"message": "Authentication successful"}
|
||||
|
||||
return app
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def oauth2_client(oauth2_app):
|
||||
return TestClient(oauth2_app)
|
||||
|
||||
|
||||
def test_missing_auth_header_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test")
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
def test_invalid_auth_header_format_oauth2(oauth2_client):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": "InvalidFormat token123"})
|
||||
assert response.status_code == 401
|
||||
assert "Missing or invalid Authorization header" in response.json()["error"]["message"]
|
||||
|
||||
|
||||
async def mock_jwks_response(*args, **kwargs):
|
||||
return MockResponse(
|
||||
200,
|
||||
{
|
||||
"keys": [
|
||||
{
|
||||
"kid": "1234567890",
|
||||
"kty": "oct",
|
||||
"alg": "HS256",
|
||||
"use": "sig",
|
||||
"k": "MTIzNDU2Nzg5MA", # Base64-encoded "1234567890"
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def jwt_token_valid():
|
||||
from jose import jwt
|
||||
|
||||
# correctly signed jwt token with "kid" in header
|
||||
return jwt.encode(
|
||||
{
|
||||
"sub": "my-user",
|
||||
"groups": ["group1", "group2"],
|
||||
"scope": "foo bar",
|
||||
"aud": "llama-stack",
|
||||
},
|
||||
key="1234567890",
|
||||
algorithm="HS256",
|
||||
headers={"kid": "1234567890"},
|
||||
)
|
||||
|
||||
|
||||
@patch("httpx.AsyncClient.get", new=mock_jwks_response)
|
||||
def test_valid_oauth2_authentication(oauth2_client, jwt_token_valid):
|
||||
response = oauth2_client.get("/test", headers={"Authorization": f"Bearer {jwt_token_valid}"})
|
||||
assert response.status_code == 200
|
||||
assert response.json() == {"message": "Authentication successful"}
|
||||
|
||||
|
||||
# TODO: add more tests for oauth2 token provider
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue