Merge branch 'main' into feat/sambanova-safety

This commit is contained in:
Jorge Piedrahita Ortiz 2025-05-21 11:32:42 -05:00 committed by GitHub
commit e12df4293b
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
26 changed files with 1094 additions and 494 deletions

View file

@ -0,0 +1,56 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class StackListBuilds(Subcommand):
"""List built stacks in .llama/distributions directory"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"list",
prog="llama stack list",
description="list the build stacks",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._list_stack_command)
def _get_distribution_dirs(self) -> dict[str, Path]:
"""Return a dictionary of distribution names and their paths"""
distributions = {}
dist_dir = Path.home() / ".llama" / "distributions"
if dist_dir.exists():
for stack_dir in dist_dir.iterdir():
if stack_dir.is_dir():
distributions[stack_dir.name] = stack_dir
return distributions
def _list_stack_command(self, args: argparse.Namespace) -> None:
distributions = self._get_distribution_dirs()
if not distributions:
print("No stacks found in ~/.llama/distributions")
return
headers = ["Stack Name", "Path"]
headers.extend(["Build Config", "Run Config"])
rows = []
for name, path in distributions.items():
row = [name, str(path)]
# Check for build and run config files
build_config = "Yes" if (path / f"{name}-build.yaml").exists() else "No"
run_config = "Yes" if (path / f"{name}-run.yaml").exists() else "No"
row.extend([build_config, run_config])
rows.append(row)
print_table(rows, headers, separate_rows=True)

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import shutil
import sys
from pathlib import Path
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
class StackRemove(Subcommand):
"""Remove the build stack"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"rm",
prog="llama stack rm",
description="Remove the build stack",
formatter_class=argparse.ArgumentDefaultsHelpFormatter,
)
self._add_arguments()
self.parser.set_defaults(func=self._remove_stack_build_command)
def _add_arguments(self) -> None:
self.parser.add_argument(
"name",
type=str,
nargs="?",
help="Name of the stack to delete",
)
self.parser.add_argument(
"--all",
"-a",
action="store_true",
help="Delete all stacks (use with caution)",
)
def _get_distribution_dirs(self) -> dict[str, Path]:
"""Return a dictionary of distribution names and their paths"""
distributions = {}
dist_dir = Path.home() / ".llama" / "distributions"
if dist_dir.exists():
for stack_dir in dist_dir.iterdir():
if stack_dir.is_dir():
distributions[stack_dir.name] = stack_dir
return distributions
def _list_stacks(self) -> None:
"""Display available stacks in a table"""
distributions = self._get_distribution_dirs()
if not distributions:
print("No stacks found in ~/.llama/distributions")
return
headers = ["Stack Name", "Path"]
rows = [[name, str(path)] for name, path in distributions.items()]
print_table(rows, headers, separate_rows=True)
def _remove_stack_build_command(self, args: argparse.Namespace) -> None:
distributions = self._get_distribution_dirs()
if args.all:
confirm = input("Are you sure you want to delete ALL stacks? [yes-i-really-want/N] ").lower()
if confirm != "yes-i-really-want":
print("Deletion cancelled.")
return
for name, path in distributions.items():
try:
shutil.rmtree(path)
print(f"Deleted stack: {name}")
except Exception as e:
cprint(
f"Failed to delete stack {name}: {e}",
color="red",
)
sys.exit(2)
if not args.name:
self._list_stacks()
if not args.name:
return
if args.name not in distributions:
self._list_stacks()
cprint(
f"Stack not found: {args.name}",
color="red",
)
return
stack_path = distributions[args.name]
confirm = input(f"Are you sure you want to delete stack '{args.name}'? [y/N] ").lower()
if confirm != "y":
print("Deletion cancelled.")
return
try:
shutil.rmtree(stack_path)
print(f"Successfully deleted stack: {args.name}")
except Exception as e:
cprint(
f"Failed to delete stack {args.name}: {e}",
color="red",
)
sys.exit(2)

View file

@ -7,12 +7,14 @@
import argparse
from importlib.metadata import version
from llama_stack.cli.stack.list_stacks import StackListBuilds
from llama_stack.cli.stack.utils import print_subcommand_description
from llama_stack.cli.subcommand import Subcommand
from .build import StackBuild
from .list_apis import StackListApis
from .list_providers import StackListProviders
from .remove import StackRemove
from .run import StackRun
@ -41,5 +43,6 @@ class StackParser(Subcommand):
StackListApis.create(subparsers)
StackListProviders.create(subparsers)
StackRun.create(subparsers)
StackRemove.create(subparsers)
StackListBuilds.create(subparsers)
print_subcommand_description(self.parser, subparsers)

View file

@ -25,7 +25,7 @@ from llama_stack.apis.tools import Tool, ToolGroup, ToolGroupInput, ToolRuntime
from llama_stack.apis.vector_dbs import VectorDB, VectorDBInput
from llama_stack.apis.vector_io import VectorIO
from llama_stack.providers.datatypes import Api, ProviderSpec
from llama_stack.providers.utils.kvstore.config import KVStoreConfig
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
@ -220,21 +220,34 @@ class LoggingConfig(BaseModel):
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
OAUTH2_TOKEN = "oauth2_token"
CUSTOM = "custom"
class AuthenticationConfig(BaseModel):
provider_type: AuthProviderType = Field(
...,
description="Type of authentication provider (e.g., 'kubernetes', 'custom')",
description="Type of authentication provider",
)
config: dict[str, str] = Field(
config: dict[str, Any] = Field(
...,
description="Provider-specific configuration",
)
class QuotaPeriod(str, Enum):
DAY = "day"
class QuotaConfig(BaseModel):
kvstore: SqliteKVStoreConfig = Field(description="Config for KV store backend (SQLite only for now)")
anonymous_max_requests: int = Field(default=100, description="Max requests for unauthenticated clients per period")
authenticated_max_requests: int = Field(
default=1000, description="Max requests for authenticated clients per period"
)
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -262,6 +275,10 @@ class ServerConfig(BaseModel):
default=None,
description="The host the server should listen on",
)
quota: QuotaConfig | None = Field(
default=None,
description="Per client quota request configuration",
)
class StackRunConfig(BaseModel):

View file

@ -8,7 +8,8 @@ import json
import httpx
from llama_stack.distribution.server.auth_providers import AuthProviderConfig, create_auth_provider
from llama_stack.distribution.datatypes import AuthenticationConfig
from llama_stack.distribution.server.auth_providers import create_auth_provider
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -77,7 +78,7 @@ class AuthenticationMiddleware:
access resources that don't have access_attributes defined.
"""
def __init__(self, app, auth_config: AuthProviderConfig):
def __init__(self, app, auth_config: AuthenticationConfig):
self.app = app
self.auth_provider = create_auth_provider(auth_config)
@ -113,6 +114,10 @@ class AuthenticationMiddleware:
"roles": [token],
}
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
# can identify the requester and enforce per-client rate limits.
scope["authenticated_client_id"] = token
# Store attributes in request scope
scope["user_attributes"] = user_attributes
scope["principal"] = validation_result.principal

View file

@ -4,18 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import ssl
import time
from abc import ABC, abstractmethod
from asyncio import Lock
from enum import Enum
from pathlib import Path
from urllib.parse import parse_qs
import httpx
from jose import jwt
from pydantic import BaseModel, Field, field_validator
from pydantic import BaseModel, Field, field_validator, model_validator
from typing_extensions import Self
from llama_stack.distribution.datatypes import AccessAttributes
from llama_stack.distribution.datatypes import AccessAttributes, AuthenticationConfig, AuthProviderType
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
@ -73,21 +74,6 @@ class AuthRequest(BaseModel):
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
class AuthProviderType(str, Enum):
"""Supported authentication provider types."""
KUBERNETES = "kubernetes"
CUSTOM = "custom"
OAUTH2_TOKEN = "oauth2_token"
class AuthProviderConfig(BaseModel):
"""Base configuration for authentication providers."""
provider_type: AuthProviderType = Field(..., description="Type of authentication provider")
config: dict[str, str] = Field(..., description="Provider-specific configuration")
class AuthProvider(ABC):
"""Abstract base class for authentication providers."""
@ -102,83 +88,6 @@ class AuthProvider(ABC):
pass
class KubernetesAuthProviderConfig(BaseModel):
api_server_url: str
ca_cert_path: str | None = None
class KubernetesAuthProvider(AuthProvider):
"""Kubernetes authentication provider that validates tokens against the Kubernetes API server."""
def __init__(self, config: KubernetesAuthProviderConfig):
self.config = config
self._client = None
async def _get_client(self):
"""Get or create a Kubernetes client."""
if self._client is None:
# kubernetes-client has not async support, see:
# https://github.com/kubernetes-client/python/issues/323
from kubernetes import client
from kubernetes.client import ApiClient
# Configure the client
configuration = client.Configuration()
configuration.host = self.config.api_server_url
if self.config.ca_cert_path:
configuration.ssl_ca_cert = self.config.ca_cert_path
configuration.verify_ssl = bool(self.config.ca_cert_path)
# Create API client
self._client = ApiClient(configuration)
return self._client
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a Kubernetes token and return access attributes."""
try:
client = await self._get_client()
# Set the token in the client
client.set_default_header("Authorization", f"Bearer {token}")
# Make a request to validate the token
# We use the /api endpoint which requires authentication
from kubernetes.client import CoreV1Api
api = CoreV1Api(client)
api.get_api_resources(_request_timeout=3.0) # Set timeout for this specific request
# If we get here, the token is valid
# Extract user info from the token claims
import base64
# Decode the token (without verification since we've already validated it)
token_parts = token.split(".")
payload = json.loads(base64.b64decode(token_parts[1] + "=" * (-len(token_parts[1]) % 4)))
# Extract user information from the token
username = payload.get("sub", "")
groups = payload.get("groups", [])
return TokenValidationResult(
principal=username,
access_attributes=AccessAttributes(
roles=[username], # Use username as a role
teams=groups, # Use Kubernetes groups as teams
),
)
except Exception as e:
logger.exception("Failed to validate Kubernetes token")
raise ValueError("Invalid or expired token") from e
async def close(self):
"""Close the HTTP client."""
if self._client:
self._client.close()
self._client = None
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> AccessAttributes:
attributes = AccessAttributes()
for claim_key, attribute_key in mapping.items():
@ -198,11 +107,24 @@ def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str])
return attributes
class OAuth2TokenAuthProviderConfig(BaseModel):
class OAuth2JWKSConfig(BaseModel):
# The JWKS URI for collecting public keys
jwks_uri: str
cache_ttl: int = 3600
uri: str
key_recheck_period: int = Field(default=3600, description="The period to recheck the JWKS URI for key updates")
class OAuth2IntrospectionConfig(BaseModel):
url: str
client_id: str
client_secret: str
send_secret_in_body: bool = False
class OAuth2TokenAuthProviderConfig(BaseModel):
audience: str = "llama-stack"
verify_tls: bool = True
tls_cafile: Path | None = None
issuer: str | None = Field(default=None, description="The OIDC issuer URL.")
claims_mapping: dict[str, str] = Field(
default_factory=lambda: {
"sub": "roles",
@ -214,6 +136,8 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
"namespace": "namespaces",
},
)
jwks: OAuth2JWKSConfig | None
introspection: OAuth2IntrospectionConfig | None = None
@classmethod
@field_validator("claims_mapping")
@ -225,6 +149,14 @@ class OAuth2TokenAuthProviderConfig(BaseModel):
raise ValueError(f"claims_mapping value is not a valid attribute: {value}")
return v
@model_validator(mode="after")
def validate_mode(self) -> Self:
if not self.jwks and not self.introspection:
raise ValueError("One of jwks or introspection must be configured")
if self.jwks and self.introspection:
raise ValueError("At present only one of jwks or introspection should be configured")
return self
class OAuth2TokenAuthProvider(AuthProvider):
"""
@ -240,6 +172,13 @@ class OAuth2TokenAuthProvider(AuthProvider):
self._jwks_lock = Lock()
async def validate_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
if self.config.jwks:
return await self.validate_jwt_token(token, scope)
if self.config.introspection:
return await self.introspect_token(token, scope)
raise ValueError("One of jwks or introspection must be configured")
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using the JWT token."""
await self._refresh_jwks()
@ -255,7 +194,7 @@ class OAuth2TokenAuthProvider(AuthProvider):
key_data,
algorithms=[algorithm],
audience=self.config.audience,
options={"verify_exp": True},
issuer=self.config.issuer,
)
except Exception as exc:
raise ValueError(f"Invalid JWT token: {token}") from exc
@ -269,14 +208,75 @@ class OAuth2TokenAuthProvider(AuthProvider):
access_attributes=access_attributes,
)
async def introspect_token(self, token: str, scope: dict | None = None) -> TokenValidationResult:
"""Validate a token using token introspection as defined by RFC 7662."""
form = {
"token": token,
}
if self.config.introspection is None:
raise ValueError("Introspection is not configured")
if self.config.introspection.send_secret_in_body:
form["client_id"] = self.config.introspection.client_id
form["client_secret"] = self.config.introspection.client_secret
auth = None
else:
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
ssl_ctxt = None
if self.config.tls_cafile:
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
try:
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
response = await client.post(
self.config.introspection.url,
data=form,
auth=auth,
timeout=10.0, # Add a reasonable timeout
)
if response.status_code != 200:
logger.warning(f"Token introspection failed with status code: {response.status_code}")
raise ValueError(f"Token introspection failed: {response.status_code}")
fields = response.json()
if not fields["active"]:
raise ValueError("Token not active")
principal = fields["sub"] or fields["username"]
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
return TokenValidationResult(
principal=principal,
access_attributes=access_attributes,
)
except httpx.TimeoutException:
logger.exception("Token introspection request timed out")
raise
except ValueError:
# Re-raise ValueError exceptions to preserve their message
raise
except Exception as e:
logger.exception("Error during token introspection")
raise ValueError("Token introspection error") from e
async def close(self):
"""Close the HTTP client."""
pass
async def _refresh_jwks(self) -> None:
"""
Refresh the JWKS cache.
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
If the cache is expired, we refresh the JWKS from the JWKS URI.
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
* It doesn't have user authentication flows
* It doesn't have refresh tokens
"""
async with self._jwks_lock:
if time.time() - self._jwks_at > self.config.cache_ttl:
async with httpx.AsyncClient() as client:
res = await client.get(self.config.jwks_uri, timeout=5)
if self.config.jwks is None:
raise ValueError("JWKS is not configured")
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
async with httpx.AsyncClient(verify=verify) as client:
res = await client.get(self.config.jwks.uri, timeout=5)
res.raise_for_status()
jwks_data = res.json()["keys"]
updated = {}
@ -363,13 +363,11 @@ class CustomAuthProvider(AuthProvider):
self._client = None
def create_auth_provider(config: AuthProviderConfig) -> AuthProvider:
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
"""Factory function to create the appropriate auth provider."""
provider_type = config.provider_type.lower()
if provider_type == "kubernetes":
return KubernetesAuthProvider(KubernetesAuthProviderConfig.model_validate(config.config))
elif provider_type == "custom":
if provider_type == "custom":
return CustomAuthProvider(CustomAuthProviderConfig.model_validate(config.config))
elif provider_type == "oauth2_token":
return OAuth2TokenAuthProvider(OAuth2TokenAuthProviderConfig.model_validate(config.config))

View file

@ -0,0 +1,110 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
import time
from datetime import datetime, timedelta, timezone
from starlette.types import ASGIApp, Receive, Scope, Send
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
class QuotaMiddleware:
"""
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
within a configurable time window.
- For authenticated requests, it reads the client ID from the
`Authorization: Bearer <client_id>` header.
- For anonymous requests, it falls back to the IP address of the client.
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
once a client exceeds its quota.
"""
def __init__(
self,
app: ASGIApp,
kv_config: KVStoreConfig,
anonymous_max_requests: int,
authenticated_max_requests: int,
window_seconds: int = 86400,
):
self.app = app
self.kv_config = kv_config
self.kv: KVStore | None = None
self.anonymous_max_requests = anonymous_max_requests
self.authenticated_max_requests = authenticated_max_requests
self.window_seconds = window_seconds
if isinstance(self.kv_config, SqliteKVStoreConfig):
logger.warning(
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
f"window_seconds={self.window_seconds}"
)
async def _get_kv(self) -> KVStore:
if self.kv is None:
self.kv = await kvstore_impl(self.kv_config)
return self.kv
async def __call__(self, scope: Scope, receive: Receive, send: Send):
if scope["type"] == "http":
# pick key & limit based on auth
auth_id = scope.get("authenticated_client_id")
if auth_id:
key_id = auth_id
limit = self.authenticated_max_requests
else:
# fallback to IP
client = scope.get("client")
key_id = client[0] if client else "anonymous"
limit = self.anonymous_max_requests
current_window = int(time.time() // self.window_seconds)
key = f"quota:{key_id}:{current_window}"
try:
kv = await self._get_kv()
prev = await kv.get(key) or "0"
count = int(prev) + 1
if int(prev) == 0:
# Set with expiration datetime when it is the first request in the window.
expiration = datetime.now(timezone.utc) + timedelta(seconds=self.window_seconds)
await kv.set(key, str(count), expiration=expiration)
else:
await kv.set(key, str(count))
except Exception:
logger.exception("Failed to access KV store for quota")
return await self._send_error(send, 500, "Quota service error")
if count > limit:
logger.warning(
"Quota exceeded for client %s: %d/%d",
key_id,
count,
limit,
)
return await self._send_error(send, 429, "Quota exceeded")
return await self.app(scope, receive, send)
async def _send_error(self, send: Send, status: int, message: str):
await send(
{
"type": "http.response.start",
"status": status,
"headers": [[b"content-type", b"application/json"]],
}
)
body = json.dumps({"error": {"message": message}}).encode()
await send({"type": "http.response.body", "body": body})

View file

@ -60,6 +60,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
from .auth import AuthenticationMiddleware
from .endpoints import get_all_api_endpoints
from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
@ -434,6 +435,35 @@ def main(args: argparse.Namespace | None = None):
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_type.value}")
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth)
else:
if config.server.quota:
quota = config.server.quota
logger.warning(
"Configured authenticated_max_requests (%d) but no auth is enabled; "
"falling back to anonymous_max_requests (%d) for all the requests",
quota.authenticated_max_requests,
quota.anonymous_max_requests,
)
if config.server.quota:
logger.info("Enabling quota middleware for authenticated and anonymous clients")
quota = config.server.quota
anonymous_max_requests = quota.anonymous_max_requests
# if auth is disabled, use the anonymous max requests
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
kv_config = quota.kvstore
window_map = {"day": 86400}
window_seconds = window_map[quota.period.value]
app.add_middleware(
QuotaMiddleware,
kv_config=kv_config,
anonymous_max_requests=anonymous_max_requests,
authenticated_max_requests=authenticated_max_requests,
window_seconds=window_seconds,
)
try:
impls = asyncio.run(construct_stack(config))