mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-12 13:00:39 +00:00
chore(rename): move llama_stack.distribution to llama_stack.core (#2975)
We would like to rename the term `template` to `distribution`. To prepare for that, this is a precursor. cc @leseb
This commit is contained in:
parent
f3d5459647
commit
2665f00102
211 changed files with 351 additions and 348 deletions
5
llama_stack/core/server/__init__.py
Normal file
5
llama_stack/core/server/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
175
llama_stack/core/server/auth.py
Normal file
175
llama_stack/core/server/auth.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
|
||||
import httpx
|
||||
from aiohttp import hdrs
|
||||
|
||||
from llama_stack.core.datatypes import AuthenticationConfig, User
|
||||
from llama_stack.core.request_headers import user_from_scope
|
||||
from llama_stack.core.server.auth_providers import create_auth_provider
|
||||
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthenticationMiddleware:
|
||||
"""Middleware that authenticates requests using configured authentication provider.
|
||||
|
||||
This middleware:
|
||||
1. Extracts the Bearer token from the Authorization header
|
||||
2. Uses the configured auth provider to validate the token
|
||||
3. Extracts user attributes from the provider's response
|
||||
4. Makes these attributes available to the route handlers for access control
|
||||
|
||||
The middleware supports multiple authentication providers through the AuthProvider interface:
|
||||
- Kubernetes: Validates tokens against the Kubernetes API server
|
||||
- Custom: Validates tokens against a custom endpoint
|
||||
|
||||
Authentication Request Format for Custom Auth Provider:
|
||||
```json
|
||||
{
|
||||
"api_key": "the-api-key-extracted-from-auth-header",
|
||||
"request": {
|
||||
"path": "/models/list",
|
||||
"headers": {
|
||||
"content-type": "application/json",
|
||||
"user-agent": "..."
|
||||
// All headers except Authorization
|
||||
},
|
||||
"params": {
|
||||
"limit": ["100"],
|
||||
"offset": ["0"]
|
||||
// Query parameters as key -> list of values
|
||||
}
|
||||
}
|
||||
}
|
||||
```
|
||||
|
||||
Expected Auth Endpoint Response Format:
|
||||
```json
|
||||
{
|
||||
"access_attributes": { // Structured attribute format
|
||||
"roles": ["admin", "user"],
|
||||
"teams": ["ml-team", "nlp-team"],
|
||||
"projects": ["llama-3", "project-x"],
|
||||
"namespaces": ["research"]
|
||||
},
|
||||
"message": "Optional message about auth result"
|
||||
}
|
||||
```
|
||||
|
||||
Token Validation:
|
||||
Each provider implements its own token validation logic:
|
||||
- Kubernetes: Uses TokenReview API to validate service account tokens
|
||||
- Custom: Sends token to custom endpoint for validation
|
||||
|
||||
Attribute-Based Access Control:
|
||||
The attributes returned by the auth provider are used to determine which
|
||||
resources the user can access. Resources can specify required attributes
|
||||
using the access_attributes field. For a user to access a resource:
|
||||
|
||||
1. All attribute categories specified in the resource must be present in the user's attributes
|
||||
2. For each category, the user must have at least one matching value
|
||||
|
||||
If the auth provider doesn't return any attributes, the user will only be able to
|
||||
access resources that don't have access_attributes defined.
|
||||
"""
|
||||
|
||||
def __init__(self, app, auth_config: AuthenticationConfig, impls):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.auth_provider = create_auth_provider(auth_config)
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
# First, handle authentication
|
||||
headers = dict(scope.get("headers", []))
|
||||
auth_header = headers.get(b"authorization", b"").decode()
|
||||
|
||||
if not auth_header:
|
||||
error_msg = self.auth_provider.get_auth_error_message(scope)
|
||||
return await self._send_auth_error(send, error_msg)
|
||||
|
||||
if not auth_header.startswith("Bearer "):
|
||||
return await self._send_auth_error(send, "Invalid Authorization header format")
|
||||
|
||||
token = auth_header.split("Bearer ", 1)[1]
|
||||
|
||||
# Validate token and get access attributes
|
||||
try:
|
||||
validation_result = await self.auth_provider.validate_token(token, scope)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
return await self._send_auth_error(send, "Authentication service timeout")
|
||||
except ValueError as e:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, str(e))
|
||||
except Exception:
|
||||
logger.exception("Error during authentication")
|
||||
return await self._send_auth_error(send, "Authentication service error")
|
||||
|
||||
# Store the client ID in the request scope so that downstream middleware (like QuotaMiddleware)
|
||||
# can identify the requester and enforce per-client rate limits.
|
||||
scope["authenticated_client_id"] = token
|
||||
|
||||
# Store attributes in request scope
|
||||
scope["principal"] = validation_result.principal
|
||||
if validation_result.attributes:
|
||||
scope["user_attributes"] = validation_result.attributes
|
||||
logger.debug(
|
||||
f"Authentication successful: {validation_result.principal} with {len(validation_result.attributes)} attributes"
|
||||
)
|
||||
|
||||
# Scope-based API access control
|
||||
path = scope.get("path", "")
|
||||
method = scope.get("method", hdrs.METH_GET)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls)
|
||||
|
||||
try:
|
||||
_, _, _, webmethod = find_matching_route(method, path, self.route_impls)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if webmethod.required_scope:
|
||||
user = user_from_scope(scope)
|
||||
if not _has_required_scope(webmethod.required_scope, user):
|
||||
return await self._send_auth_error(
|
||||
send,
|
||||
f"Access denied: user does not have required scope: {webmethod.required_scope}",
|
||||
status=403,
|
||||
)
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_auth_error(self, send, message, status=401):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_key = "message" if status == 401 else "detail"
|
||||
error_msg = json.dumps({"error": {error_key: message}}).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
||||
|
||||
def _has_required_scope(required_scope: str, user: User | None) -> bool:
|
||||
# if no user, assume auth is not enabled
|
||||
if not user:
|
||||
return True
|
||||
|
||||
if not user.attributes:
|
||||
return False
|
||||
|
||||
user_scopes = user.attributes.get("scopes", [])
|
||||
return required_scope in user_scopes
|
388
llama_stack/core/server/auth_providers.py
Normal file
388
llama_stack/core/server/auth_providers.py
Normal file
|
@ -0,0 +1,388 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import ssl
|
||||
import time
|
||||
from abc import ABC, abstractmethod
|
||||
from asyncio import Lock
|
||||
from urllib.parse import parse_qs, urlparse
|
||||
|
||||
import httpx
|
||||
from jose import jwt
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationConfig,
|
||||
CustomAuthConfig,
|
||||
GitHubTokenAuthConfig,
|
||||
OAuth2TokenAuthConfig,
|
||||
User,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
|
||||
logger = get_logger(name=__name__, category="auth")
|
||||
|
||||
|
||||
class AuthResponse(BaseModel):
|
||||
"""The format of the authentication response from the auth endpoint."""
|
||||
|
||||
principal: str
|
||||
# further attributes that may be used for access control decisions
|
||||
attributes: dict[str, list[str]] | None = None
|
||||
message: str | None = Field(
|
||||
default=None, description="Optional message providing additional context about the authentication result."
|
||||
)
|
||||
|
||||
|
||||
class AuthRequestContext(BaseModel):
|
||||
path: str = Field(description="The path of the request being authenticated")
|
||||
|
||||
headers: dict[str, str] = Field(description="HTTP headers from the original request (excluding Authorization)")
|
||||
|
||||
params: dict[str, list[str]] = Field(default_factory=dict, description="Query parameters from the original request")
|
||||
|
||||
|
||||
class AuthRequest(BaseModel):
|
||||
api_key: str = Field(description="The API key extracted from the Authorization header")
|
||||
|
||||
request: AuthRequestContext = Field(description="Context information about the request being authenticated")
|
||||
|
||||
|
||||
class AuthProvider(ABC):
|
||||
"""Abstract base class for authentication providers."""
|
||||
|
||||
@abstractmethod
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token and return access attributes."""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
async def close(self):
|
||||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return provider-specific authentication error message."""
|
||||
return "Authentication required"
|
||||
|
||||
|
||||
def get_attributes_from_claims(claims: dict[str, str], mapping: dict[str, str]) -> dict[str, list[str]]:
|
||||
attributes: dict[str, list[str]] = {}
|
||||
for claim_key, attribute_key in mapping.items():
|
||||
if claim_key not in claims:
|
||||
continue
|
||||
claim = claims[claim_key]
|
||||
if isinstance(claim, list):
|
||||
values = claim
|
||||
else:
|
||||
values = claim.split()
|
||||
|
||||
if attribute_key in attributes:
|
||||
attributes[attribute_key].extend(values)
|
||||
else:
|
||||
attributes[attribute_key] = values
|
||||
return attributes
|
||||
|
||||
|
||||
class OAuth2TokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
JWT token authentication provider that validates a JWT token and extracts access attributes.
|
||||
|
||||
This should be the standard authentication provider for most use cases.
|
||||
"""
|
||||
|
||||
def __init__(self, config: OAuth2TokenAuthConfig):
|
||||
self.config = config
|
||||
self._jwks_at: float = 0.0
|
||||
self._jwks: dict[str, str] = {}
|
||||
self._jwks_lock = Lock()
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
if self.config.jwks:
|
||||
return await self.validate_jwt_token(token, scope)
|
||||
if self.config.introspection:
|
||||
return await self.introspect_token(token, scope)
|
||||
raise ValueError("One of jwks or introspection must be configured")
|
||||
|
||||
async def validate_jwt_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using the JWT token."""
|
||||
await self._refresh_jwks()
|
||||
|
||||
try:
|
||||
header = jwt.get_unverified_header(token)
|
||||
kid = header["kid"]
|
||||
if kid not in self._jwks:
|
||||
raise ValueError(f"Unknown key ID: {kid}")
|
||||
key_data = self._jwks[kid]
|
||||
algorithm = header.get("alg", "RS256")
|
||||
claims = jwt.decode(
|
||||
token,
|
||||
key_data,
|
||||
algorithms=[algorithm],
|
||||
audience=self.config.audience,
|
||||
issuer=self.config.issuer,
|
||||
)
|
||||
except Exception as exc:
|
||||
raise ValueError("Invalid JWT token") from exc
|
||||
|
||||
# There are other standard claims, the most relevant of which is `scope`.
|
||||
# We should incorporate these into the access attributes.
|
||||
principal = claims["sub"]
|
||||
access_attributes = get_attributes_from_claims(claims, self.config.claims_mapping)
|
||||
return User(
|
||||
principal=principal,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def introspect_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using token introspection as defined by RFC 7662."""
|
||||
form = {
|
||||
"token": token,
|
||||
}
|
||||
if self.config.introspection is None:
|
||||
raise ValueError("Introspection is not configured")
|
||||
|
||||
if self.config.introspection.send_secret_in_body:
|
||||
form["client_id"] = self.config.introspection.client_id
|
||||
form["client_secret"] = self.config.introspection.client_secret
|
||||
auth = None
|
||||
else:
|
||||
auth = (self.config.introspection.client_id, self.config.introspection.client_secret)
|
||||
ssl_ctxt = None
|
||||
if self.config.tls_cafile:
|
||||
ssl_ctxt = ssl.create_default_context(cafile=self.config.tls_cafile.as_posix())
|
||||
try:
|
||||
async with httpx.AsyncClient(verify=ssl_ctxt) as client:
|
||||
response = await client.post(
|
||||
self.config.introspection.url,
|
||||
data=form,
|
||||
auth=auth,
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Token introspection failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Token introspection failed: {response.status_code}")
|
||||
|
||||
fields = response.json()
|
||||
if not fields["active"]:
|
||||
raise ValueError("Token not active")
|
||||
principal = fields["sub"] or fields["username"]
|
||||
access_attributes = get_attributes_from_claims(fields, self.config.claims_mapping)
|
||||
return User(
|
||||
principal=principal,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Token introspection request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during token introspection")
|
||||
raise ValueError("Token introspection error") from e
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return OAuth2-specific authentication error message."""
|
||||
if self.config.issuer:
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token from {self.config.issuer}"
|
||||
elif self.config.introspection:
|
||||
# Extract domain from introspection URL for a cleaner message
|
||||
domain = urlparse(self.config.introspection.url).netloc
|
||||
return f"Authentication required. Please provide a valid OAuth2 Bearer token validated by {domain}"
|
||||
else:
|
||||
return "Authentication required. Please provide a valid OAuth2 Bearer token in the Authorization header"
|
||||
|
||||
async def _refresh_jwks(self) -> None:
|
||||
"""
|
||||
Refresh the JWKS cache.
|
||||
|
||||
This is a simple cache that expires after a certain amount of time (defined by `key_recheck_period`).
|
||||
If the cache is expired, we refresh the JWKS from the JWKS URI.
|
||||
|
||||
Notes: for Kubernetes which doesn't fully implement the OIDC protocol:
|
||||
* It doesn't have user authentication flows
|
||||
* It doesn't have refresh tokens
|
||||
"""
|
||||
async with self._jwks_lock:
|
||||
if self.config.jwks is None:
|
||||
raise ValueError("JWKS is not configured")
|
||||
if time.time() - self._jwks_at > self.config.jwks.key_recheck_period:
|
||||
headers = {}
|
||||
if self.config.jwks.token:
|
||||
headers["Authorization"] = f"Bearer {self.config.jwks.token}"
|
||||
verify = self.config.tls_cafile.as_posix() if self.config.tls_cafile else self.config.verify_tls
|
||||
async with httpx.AsyncClient(verify=verify) as client:
|
||||
res = await client.get(self.config.jwks.uri, timeout=5, headers=headers)
|
||||
res.raise_for_status()
|
||||
jwks_data = res.json()["keys"]
|
||||
updated = {}
|
||||
for k in jwks_data:
|
||||
kid = k["kid"]
|
||||
# Store the entire key object as it may be needed for different algorithms
|
||||
updated[kid] = k
|
||||
self._jwks = updated
|
||||
self._jwks_at = time.time()
|
||||
|
||||
|
||||
class CustomAuthProvider(AuthProvider):
|
||||
"""Custom authentication provider that uses an external endpoint."""
|
||||
|
||||
def __init__(self, config: CustomAuthConfig):
|
||||
self.config = config
|
||||
self._client = None
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a token using the custom authentication endpoint."""
|
||||
if scope is None:
|
||||
scope = {}
|
||||
|
||||
headers = dict(scope.get("headers", []))
|
||||
path = scope.get("path", "")
|
||||
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||
|
||||
# Remove sensitive headers
|
||||
if "authorization" in request_headers:
|
||||
del request_headers["authorization"]
|
||||
|
||||
query_string = scope.get("query_string", b"").decode()
|
||||
params = parse_qs(query_string)
|
||||
|
||||
# Build the auth request model
|
||||
auth_request = AuthRequest(
|
||||
api_key=token,
|
||||
request=AuthRequestContext(
|
||||
path=path,
|
||||
headers=request_headers,
|
||||
params=params,
|
||||
),
|
||||
)
|
||||
|
||||
# Validate with authentication endpoint
|
||||
try:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
self.config.endpoint,
|
||||
json=auth_request.model_dump(),
|
||||
timeout=10.0, # Add a reasonable timeout
|
||||
)
|
||||
if response.status_code != 200:
|
||||
logger.warning(f"Authentication failed with status code: {response.status_code}")
|
||||
raise ValueError(f"Authentication failed: {response.status_code}")
|
||||
|
||||
# Parse and validate the auth response
|
||||
try:
|
||||
response_data = response.json()
|
||||
auth_response = AuthResponse(**response_data)
|
||||
return User(principal=auth_response.principal, attributes=auth_response.attributes)
|
||||
except Exception as e:
|
||||
logger.exception("Error parsing authentication response")
|
||||
raise ValueError("Invalid authentication response format") from e
|
||||
|
||||
except httpx.TimeoutException:
|
||||
logger.exception("Authentication request timed out")
|
||||
raise
|
||||
except ValueError:
|
||||
# Re-raise ValueError exceptions to preserve their message
|
||||
raise
|
||||
except Exception as e:
|
||||
logger.exception("Error during authentication")
|
||||
raise ValueError("Authentication service error") from e
|
||||
|
||||
async def close(self):
|
||||
"""Close the HTTP client."""
|
||||
if self._client:
|
||||
await self._client.aclose()
|
||||
self._client = None
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return custom auth provider-specific authentication error message."""
|
||||
domain = urlparse(self.config.endpoint).netloc
|
||||
if domain:
|
||||
return f"Authentication required. Please provide your API key as a Bearer token (validated by {domain})"
|
||||
else:
|
||||
return "Authentication required. Please provide your API key as a Bearer token in the Authorization header"
|
||||
|
||||
|
||||
class GitHubTokenAuthProvider(AuthProvider):
|
||||
"""
|
||||
GitHub token authentication provider that validates GitHub access tokens directly.
|
||||
|
||||
This provider accepts GitHub personal access tokens or OAuth tokens and verifies
|
||||
them against the GitHub API to get user information.
|
||||
"""
|
||||
|
||||
def __init__(self, config: GitHubTokenAuthConfig):
|
||||
self.config = config
|
||||
|
||||
async def validate_token(self, token: str, scope: dict | None = None) -> User:
|
||||
"""Validate a GitHub token by calling the GitHub API.
|
||||
|
||||
This validates tokens issued by GitHub (personal access tokens or OAuth tokens).
|
||||
"""
|
||||
try:
|
||||
user_info = await _get_github_user_info(token, self.config.github_api_base_url)
|
||||
except httpx.HTTPStatusError as e:
|
||||
logger.warning(f"GitHub token validation failed: {e}")
|
||||
raise ValueError("GitHub token validation failed. Please check your token and try again.") from e
|
||||
|
||||
principal = user_info["user"]["login"]
|
||||
|
||||
github_data = {
|
||||
"login": user_info["user"]["login"],
|
||||
"id": str(user_info["user"]["id"]),
|
||||
"organizations": user_info.get("organizations", []),
|
||||
}
|
||||
|
||||
access_attributes = get_attributes_from_claims(github_data, self.config.claims_mapping)
|
||||
|
||||
return User(
|
||||
principal=principal,
|
||||
attributes=access_attributes,
|
||||
)
|
||||
|
||||
async def close(self):
|
||||
"""Clean up any resources."""
|
||||
pass
|
||||
|
||||
def get_auth_error_message(self, scope: dict | None = None) -> str:
|
||||
"""Return GitHub-specific authentication error message."""
|
||||
return "Authentication required. Please provide a valid GitHub access token (https://docs.github.com/en/authentication/keeping-your-account-and-data-secure/managing-your-personal-access-tokens) in the Authorization header (Bearer <token>)"
|
||||
|
||||
|
||||
async def _get_github_user_info(access_token: str, github_api_base_url: str) -> dict:
|
||||
"""Fetch user info and organizations from GitHub API."""
|
||||
headers = {
|
||||
"Authorization": f"Bearer {access_token}",
|
||||
"Accept": "application/vnd.github.v3+json",
|
||||
"User-Agent": "llama-stack",
|
||||
}
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
user_response = await client.get(f"{github_api_base_url}/user", headers=headers, timeout=10.0)
|
||||
user_response.raise_for_status()
|
||||
user_data = user_response.json()
|
||||
|
||||
return {
|
||||
"user": user_data,
|
||||
}
|
||||
|
||||
|
||||
def create_auth_provider(config: AuthenticationConfig) -> AuthProvider:
|
||||
"""Factory function to create the appropriate auth provider."""
|
||||
provider_config = config.provider_config
|
||||
|
||||
if isinstance(provider_config, CustomAuthConfig):
|
||||
return CustomAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, OAuth2TokenAuthConfig):
|
||||
return OAuth2TokenAuthProvider(provider_config)
|
||||
elif isinstance(provider_config, GitHubTokenAuthConfig):
|
||||
return GitHubTokenAuthProvider(provider_config)
|
||||
else:
|
||||
raise ValueError(f"Unknown authentication provider config type: {type(provider_config)}")
|
110
llama_stack/core/server/quota.py
Normal file
110
llama_stack/core/server/quota.py
Normal file
|
@ -0,0 +1,110 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import json
|
||||
import time
|
||||
from datetime import UTC, datetime, timedelta
|
||||
|
||||
from starlette.types import ASGIApp, Receive, Scope, Send
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.utils.kvstore.api import KVStore
|
||||
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
|
||||
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
|
||||
|
||||
logger = get_logger(name=__name__, category="quota")
|
||||
|
||||
|
||||
class QuotaMiddleware:
|
||||
"""
|
||||
ASGI middleware that enforces separate quotas for authenticated and anonymous clients
|
||||
within a configurable time window.
|
||||
|
||||
- For authenticated requests, it reads the client ID from the
|
||||
`Authorization: Bearer <client_id>` header.
|
||||
- For anonymous requests, it falls back to the IP address of the client.
|
||||
Requests are counted in a KV store (e.g., SQLite), and HTTP 429 is returned
|
||||
once a client exceeds its quota.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
app: ASGIApp,
|
||||
kv_config: KVStoreConfig,
|
||||
anonymous_max_requests: int,
|
||||
authenticated_max_requests: int,
|
||||
window_seconds: int = 86400,
|
||||
):
|
||||
self.app = app
|
||||
self.kv_config = kv_config
|
||||
self.kv: KVStore | None = None
|
||||
self.anonymous_max_requests = anonymous_max_requests
|
||||
self.authenticated_max_requests = authenticated_max_requests
|
||||
self.window_seconds = window_seconds
|
||||
|
||||
if isinstance(self.kv_config, SqliteKVStoreConfig):
|
||||
logger.warning(
|
||||
"QuotaMiddleware: Using SQLite backend. Expiry/TTL is not enforced; cleanup is manual. "
|
||||
f"window_seconds={self.window_seconds}"
|
||||
)
|
||||
|
||||
async def _get_kv(self) -> KVStore:
|
||||
if self.kv is None:
|
||||
self.kv = await kvstore_impl(self.kv_config)
|
||||
return self.kv
|
||||
|
||||
async def __call__(self, scope: Scope, receive: Receive, send: Send):
|
||||
if scope["type"] == "http":
|
||||
# pick key & limit based on auth
|
||||
auth_id = scope.get("authenticated_client_id")
|
||||
if auth_id:
|
||||
key_id = auth_id
|
||||
limit = self.authenticated_max_requests
|
||||
else:
|
||||
# fallback to IP
|
||||
client = scope.get("client")
|
||||
key_id = client[0] if client else "anonymous"
|
||||
limit = self.anonymous_max_requests
|
||||
|
||||
current_window = int(time.time() // self.window_seconds)
|
||||
key = f"quota:{key_id}:{current_window}"
|
||||
|
||||
try:
|
||||
kv = await self._get_kv()
|
||||
prev = await kv.get(key) or "0"
|
||||
count = int(prev) + 1
|
||||
|
||||
if int(prev) == 0:
|
||||
# Set with expiration datetime when it is the first request in the window.
|
||||
expiration = datetime.now(UTC) + timedelta(seconds=self.window_seconds)
|
||||
await kv.set(key, str(count), expiration=expiration)
|
||||
else:
|
||||
await kv.set(key, str(count))
|
||||
except Exception:
|
||||
logger.exception("Failed to access KV store for quota")
|
||||
return await self._send_error(send, 500, "Quota service error")
|
||||
|
||||
if count > limit:
|
||||
logger.warning(
|
||||
"Quota exceeded for client %s: %d/%d",
|
||||
key_id,
|
||||
count,
|
||||
limit,
|
||||
)
|
||||
return await self._send_error(send, 429, "Quota exceeded")
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
async def _send_error(self, send: Send, status: int, message: str):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": status,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
body = json.dumps({"error": {"message": message}}).encode()
|
||||
await send({"type": "http.response.body", "body": body})
|
141
llama_stack/core/server/routes.py
Normal file
141
llama_stack/core/server/routes.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import inspect
|
||||
import re
|
||||
from collections.abc import Callable
|
||||
from typing import Any
|
||||
|
||||
from aiohttp import hdrs
|
||||
from starlette.routing import Route
|
||||
|
||||
from llama_stack.apis.datatypes import Api, ExternalApiSpec
|
||||
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
|
||||
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
|
||||
from llama_stack.core.resolver import api_protocol_map
|
||||
from llama_stack.schema_utils import WebMethod
|
||||
|
||||
EndpointFunc = Callable[..., Any]
|
||||
PathParams = dict[str, str]
|
||||
RouteInfo = tuple[EndpointFunc, str, WebMethod]
|
||||
PathImpl = dict[str, RouteInfo]
|
||||
RouteImpls = dict[str, PathImpl]
|
||||
RouteMatch = tuple[EndpointFunc, PathParams, str, WebMethod]
|
||||
|
||||
|
||||
def toolgroup_protocol_map():
|
||||
return {
|
||||
SpecialToolGroup.rag_tool: RAGToolRuntime,
|
||||
}
|
||||
|
||||
|
||||
def get_all_api_routes(
|
||||
external_apis: dict[Api, ExternalApiSpec] | None = None,
|
||||
) -> dict[Api, list[tuple[Route, WebMethod]]]:
|
||||
apis = {}
|
||||
|
||||
protocols = api_protocol_map(external_apis)
|
||||
toolgroup_protocols = toolgroup_protocol_map()
|
||||
for api, protocol in protocols.items():
|
||||
routes = []
|
||||
protocol_methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
# HACK ALERT
|
||||
if api == Api.tool_runtime:
|
||||
for tool_group in SpecialToolGroup:
|
||||
sub_protocol = toolgroup_protocols[tool_group]
|
||||
sub_protocol_methods = inspect.getmembers(sub_protocol, predicate=inspect.isfunction)
|
||||
for name, method in sub_protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
protocol_methods.append((f"{tool_group.value}.{name}", method))
|
||||
|
||||
for name, method in protocol_methods:
|
||||
if not hasattr(method, "__webmethod__"):
|
||||
continue
|
||||
|
||||
# The __webmethod__ attribute is dynamically added by the @webmethod decorator
|
||||
# mypy doesn't know about this dynamic attribute, so we ignore the attr-defined error
|
||||
webmethod = method.__webmethod__ # type: ignore[attr-defined]
|
||||
path = f"/{LLAMA_STACK_API_VERSION}/{webmethod.route.lstrip('/')}"
|
||||
if webmethod.method == hdrs.METH_GET:
|
||||
http_method = hdrs.METH_GET
|
||||
elif webmethod.method == hdrs.METH_DELETE:
|
||||
http_method = hdrs.METH_DELETE
|
||||
else:
|
||||
http_method = hdrs.METH_POST
|
||||
routes.append(
|
||||
(Route(path=path, methods=[http_method], name=name, endpoint=None), webmethod)
|
||||
) # setting endpoint to None since don't use a Router object
|
||||
|
||||
apis[api] = routes
|
||||
|
||||
return apis
|
||||
|
||||
|
||||
def initialize_route_impls(impls, external_apis: dict[Api, ExternalApiSpec] | None = None) -> RouteImpls:
|
||||
api_to_routes = get_all_api_routes(external_apis)
|
||||
route_impls: RouteImpls = {}
|
||||
|
||||
def _convert_path_to_regex(path: str) -> str:
|
||||
# Convert {param} to named capture groups
|
||||
# handle {param:path} as well which allows for forward slashes in the param value
|
||||
pattern = re.sub(
|
||||
r"{(\w+)(?::path)?}",
|
||||
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
|
||||
path,
|
||||
)
|
||||
|
||||
return f"^{pattern}$"
|
||||
|
||||
for api, api_routes in api_to_routes.items():
|
||||
if api not in impls:
|
||||
continue
|
||||
for route, webmethod in api_routes:
|
||||
impl = impls[api]
|
||||
func = getattr(impl, route.name)
|
||||
# Get the first (and typically only) method from the set, filtering out HEAD
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if not available_methods:
|
||||
continue # Skip if only HEAD method is available
|
||||
method = available_methods[0].lower()
|
||||
if method not in route_impls:
|
||||
route_impls[method] = {}
|
||||
route_impls[method][_convert_path_to_regex(route.path)] = (
|
||||
func,
|
||||
route.path,
|
||||
webmethod,
|
||||
)
|
||||
|
||||
return route_impls
|
||||
|
||||
|
||||
def find_matching_route(method: str, path: str, route_impls: RouteImpls) -> RouteMatch:
|
||||
"""Find the matching endpoint implementation for a given method and path.
|
||||
|
||||
Args:
|
||||
method: HTTP method (GET, POST, etc.)
|
||||
path: URL path to match against
|
||||
route_impls: A dictionary of endpoint implementations
|
||||
|
||||
Returns:
|
||||
A tuple of (endpoint_function, path_params, route_path, webmethod_metadata)
|
||||
|
||||
Raises:
|
||||
ValueError: If no matching endpoint is found
|
||||
"""
|
||||
impls = route_impls.get(method.lower())
|
||||
if not impls:
|
||||
raise ValueError(f"No endpoint found for {path}")
|
||||
|
||||
for regex, (func, route_path, webmethod) in impls.items():
|
||||
match = re.match(regex, path)
|
||||
if match:
|
||||
# Extract named groups from the regex match
|
||||
path_params = match.groupdict()
|
||||
return func, path_params, route_path, webmethod
|
||||
|
||||
raise ValueError(f"No endpoint found for {path}")
|
625
llama_stack/core/server/server.py
Normal file
625
llama_stack/core/server/server.py
Normal file
|
@ -0,0 +1,625 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import argparse
|
||||
import asyncio
|
||||
import functools
|
||||
import inspect
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import ssl
|
||||
import sys
|
||||
import traceback
|
||||
import warnings
|
||||
from collections.abc import Callable
|
||||
from contextlib import asynccontextmanager
|
||||
from importlib.metadata import version as parse_version
|
||||
from pathlib import Path
|
||||
from typing import Annotated, Any, get_origin
|
||||
|
||||
import rich.pretty
|
||||
import yaml
|
||||
from aiohttp import hdrs
|
||||
from fastapi import Body, FastAPI, HTTPException, Request
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
||||
from llama_stack.apis.common.responses import PaginatedResponse
|
||||
from llama_stack.cli.utils import add_config_template_args, get_config_from_args
|
||||
from llama_stack.core.access_control.access_control import AccessDeniedError
|
||||
from llama_stack.core.datatypes import (
|
||||
AuthenticationRequiredError,
|
||||
LoggingConfig,
|
||||
StackRunConfig,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||
from llama_stack.core.request_headers import (
|
||||
PROVIDER_DATA_VAR,
|
||||
request_provider_data_context,
|
||||
user_from_scope,
|
||||
)
|
||||
from llama_stack.core.resolver import InvalidProviderError
|
||||
from llama_stack.core.server.routes import (
|
||||
find_matching_route,
|
||||
get_all_api_routes,
|
||||
initialize_route_impls,
|
||||
)
|
||||
from llama_stack.core.stack import (
|
||||
cast_image_name_to_string,
|
||||
construct_stack,
|
||||
replace_env_vars,
|
||||
shutdown_stack,
|
||||
validate_env_pair,
|
||||
)
|
||||
from llama_stack.core.utils.config import redact_sensitive_fields
|
||||
from llama_stack.core.utils.config_resolution import Mode, resolve_config_or_template
|
||||
from llama_stack.core.utils.context import preserve_contexts_async_generator
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import Api
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.config import TelemetryConfig
|
||||
from llama_stack.providers.inline.telemetry.meta_reference.telemetry import (
|
||||
TelemetryAdapter,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.tracing import (
|
||||
CURRENT_TRACE_CONTEXT,
|
||||
end_trace,
|
||||
setup_logger,
|
||||
start_trace,
|
||||
)
|
||||
|
||||
from .auth import AuthenticationMiddleware
|
||||
from .quota import QuotaMiddleware
|
||||
|
||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||
|
||||
logger = get_logger(name=__name__, category="server")
|
||||
|
||||
|
||||
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
|
||||
log = file if hasattr(file, "write") else sys.stderr
|
||||
traceback.print_stack(file=log)
|
||||
log.write(warnings.formatwarning(message, category, filename, lineno, line))
|
||||
|
||||
|
||||
if os.environ.get("LLAMA_STACK_TRACE_WARNINGS"):
|
||||
warnings.showwarning = warn_with_traceback
|
||||
|
||||
|
||||
def create_sse_event(data: Any) -> str:
|
||||
if isinstance(data, BaseModel):
|
||||
data = data.model_dump_json()
|
||||
else:
|
||||
data = json.dumps(data)
|
||||
|
||||
return f"data: {data}\n\n"
|
||||
|
||||
|
||||
async def global_exception_handler(request: Request, exc: Exception):
|
||||
traceback.print_exception(exc)
|
||||
http_exc = translate_exception(exc)
|
||||
|
||||
return JSONResponse(status_code=http_exc.status_code, content={"error": {"detail": http_exc.detail}})
|
||||
|
||||
|
||||
def translate_exception(exc: Exception) -> HTTPException | RequestValidationError:
|
||||
if isinstance(exc, ValidationError):
|
||||
exc = RequestValidationError(exc.errors())
|
||||
|
||||
if isinstance(exc, RequestValidationError):
|
||||
return HTTPException(
|
||||
status_code=400,
|
||||
detail={
|
||||
"errors": [
|
||||
{
|
||||
"loc": list(error["loc"]),
|
||||
"msg": error["msg"],
|
||||
"type": error["type"],
|
||||
}
|
||||
for error in exc.errors()
|
||||
]
|
||||
},
|
||||
)
|
||||
elif isinstance(exc, ValueError):
|
||||
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
||||
elif isinstance(exc, BadRequestError):
|
||||
return HTTPException(status_code=400, detail=str(exc))
|
||||
elif isinstance(exc, PermissionError | AccessDeniedError):
|
||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||
elif isinstance(exc, asyncio.TimeoutError | TimeoutError):
|
||||
return HTTPException(status_code=504, detail=f"Operation timed out: {str(exc)}")
|
||||
elif isinstance(exc, NotImplementedError):
|
||||
return HTTPException(status_code=501, detail=f"Not implemented: {str(exc)}")
|
||||
elif isinstance(exc, AuthenticationRequiredError):
|
||||
return HTTPException(status_code=401, detail=f"Authentication required: {str(exc)}")
|
||||
else:
|
||||
return HTTPException(
|
||||
status_code=500,
|
||||
detail="Internal server error: An unexpected error occurred.",
|
||||
)
|
||||
|
||||
|
||||
async def shutdown(app):
|
||||
"""Initiate a graceful shutdown of the application.
|
||||
|
||||
Handled by the lifespan context manager. The shutdown process involves
|
||||
shutting down all implementations registered in the application.
|
||||
"""
|
||||
await shutdown_stack(app.__llama_stack_impls__)
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def lifespan(app: FastAPI):
|
||||
logger.info("Starting up")
|
||||
yield
|
||||
logger.info("Shutting down")
|
||||
await shutdown(app)
|
||||
|
||||
|
||||
def is_streaming_request(func_name: str, request: Request, **kwargs):
|
||||
# TODO: pass the api method and punt it to the Protocol definition directly
|
||||
return kwargs.get("stream", False)
|
||||
|
||||
|
||||
async def maybe_await(value):
|
||||
if inspect.iscoroutine(value):
|
||||
return await value
|
||||
return value
|
||||
|
||||
|
||||
async def sse_generator(event_gen_coroutine):
|
||||
event_gen = None
|
||||
try:
|
||||
event_gen = await event_gen_coroutine
|
||||
async for item in event_gen:
|
||||
yield create_sse_event(item)
|
||||
await asyncio.sleep(0.01)
|
||||
except asyncio.CancelledError:
|
||||
logger.info("Generator cancelled")
|
||||
if event_gen:
|
||||
await event_gen.aclose()
|
||||
except Exception as e:
|
||||
logger.exception("Error in sse_generator")
|
||||
yield create_sse_event(
|
||||
{
|
||||
"error": {
|
||||
"message": str(translate_exception(e)),
|
||||
},
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
async def log_request_pre_validation(request: Request):
|
||||
if request.method in ("POST", "PUT", "PATCH"):
|
||||
try:
|
||||
body_bytes = await request.body()
|
||||
if body_bytes:
|
||||
try:
|
||||
parsed_body = json.loads(body_bytes.decode())
|
||||
log_output = rich.pretty.pretty_repr(parsed_body)
|
||||
except (json.JSONDecodeError, UnicodeDecodeError):
|
||||
log_output = repr(body_bytes)
|
||||
logger.debug(f"Incoming raw request body for {request.method} {request.url.path}:\n{log_output}")
|
||||
else:
|
||||
logger.debug(f"Incoming {request.method} {request.url.path} request with empty body.")
|
||||
except Exception as e:
|
||||
logger.warning(f"Could not read or log request body for {request.method} {request.url.path}: {e}")
|
||||
|
||||
|
||||
def create_dynamic_typed_route(func: Any, method: str, route: str) -> Callable:
|
||||
@functools.wraps(func)
|
||||
async def route_handler(request: Request, **kwargs):
|
||||
# Get auth attributes from the request scope
|
||||
user = user_from_scope(request.scope)
|
||||
|
||||
await log_request_pre_validation(request)
|
||||
|
||||
# Use context manager with both provider data and auth attributes
|
||||
with request_provider_data_context(request.headers, user):
|
||||
is_streaming = is_streaming_request(func.__name__, request, **kwargs)
|
||||
|
||||
try:
|
||||
if is_streaming:
|
||||
gen = preserve_contexts_async_generator(
|
||||
sse_generator(func(**kwargs)), [CURRENT_TRACE_CONTEXT, PROVIDER_DATA_VAR]
|
||||
)
|
||||
return StreamingResponse(gen, media_type="text/event-stream")
|
||||
else:
|
||||
value = func(**kwargs)
|
||||
result = await maybe_await(value)
|
||||
if isinstance(result, PaginatedResponse) and result.url is None:
|
||||
result.url = route
|
||||
return result
|
||||
except Exception as e:
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
logger.exception(f"Error executing endpoint {route=} {method=}")
|
||||
else:
|
||||
logger.error(f"Error executing endpoint {route=} {method=}: {str(e)}")
|
||||
raise translate_exception(e) from e
|
||||
|
||||
sig = inspect.signature(func)
|
||||
|
||||
new_params = [inspect.Parameter("request", inspect.Parameter.POSITIONAL_OR_KEYWORD, annotation=Request)]
|
||||
new_params.extend(sig.parameters.values())
|
||||
|
||||
path_params = extract_path_params(route)
|
||||
if method == "post":
|
||||
# Annotate parameters that are in the path with Path(...) and others with Body(...),
|
||||
# but preserve existing File() and Form() annotations for multipart form data
|
||||
new_params = (
|
||||
[new_params[0]]
|
||||
+ [
|
||||
(
|
||||
param.replace(annotation=Annotated[param.annotation, FastapiPath(..., title=param.name)])
|
||||
if param.name in path_params
|
||||
else (
|
||||
param # Keep original annotation if it's already an Annotated type
|
||||
if get_origin(param.annotation) is Annotated
|
||||
else param.replace(annotation=Annotated[param.annotation, Body(..., embed=True)])
|
||||
)
|
||||
)
|
||||
for param in new_params[1:]
|
||||
]
|
||||
)
|
||||
|
||||
route_handler.__signature__ = sig.replace(parameters=new_params)
|
||||
|
||||
return route_handler
|
||||
|
||||
|
||||
class TracingMiddleware:
|
||||
def __init__(self, app, impls, external_apis: dict[str, ExternalApiSpec]):
|
||||
self.app = app
|
||||
self.impls = impls
|
||||
self.external_apis = external_apis
|
||||
# FastAPI built-in paths that should bypass custom routing
|
||||
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope.get("type") == "lifespan":
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
path = scope.get("path", "")
|
||||
|
||||
# Check if the path is a FastAPI built-in path
|
||||
if path.startswith(self.fastapi_paths):
|
||||
# Pass through to FastAPI's built-in handlers
|
||||
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
if not hasattr(self, "route_impls"):
|
||||
self.route_impls = initialize_route_impls(self.impls, self.external_apis)
|
||||
|
||||
try:
|
||||
_, _, route_path, webmethod = find_matching_route(
|
||||
scope.get("method", hdrs.METH_GET), path, self.route_impls
|
||||
)
|
||||
except ValueError:
|
||||
# If no matching endpoint is found, pass through to FastAPI
|
||||
logger.debug(f"No matching route found for path: {path}, falling back to FastAPI")
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
trace_attributes = {"__location__": "server", "raw_path": path}
|
||||
|
||||
# Extract W3C trace context headers and store as trace attributes
|
||||
headers = dict(scope.get("headers", []))
|
||||
traceparent = headers.get(b"traceparent", b"").decode()
|
||||
if traceparent:
|
||||
trace_attributes["traceparent"] = traceparent
|
||||
tracestate = headers.get(b"tracestate", b"").decode()
|
||||
if tracestate:
|
||||
trace_attributes["tracestate"] = tracestate
|
||||
|
||||
trace_path = webmethod.descriptive_name or route_path
|
||||
trace_context = await start_trace(trace_path, trace_attributes)
|
||||
|
||||
async def send_with_trace_id(message):
|
||||
if message["type"] == "http.response.start":
|
||||
headers = message.get("headers", [])
|
||||
headers.append([b"x-trace-id", str(trace_context.trace_id).encode()])
|
||||
message["headers"] = headers
|
||||
await send(message)
|
||||
|
||||
try:
|
||||
return await self.app(scope, receive, send_with_trace_id)
|
||||
finally:
|
||||
await end_trace()
|
||||
|
||||
|
||||
class ClientVersionMiddleware:
|
||||
def __init__(self, app):
|
||||
self.app = app
|
||||
self.server_version = parse_version("llama-stack")
|
||||
|
||||
async def __call__(self, scope, receive, send):
|
||||
if scope["type"] == "http":
|
||||
headers = dict(scope.get("headers", []))
|
||||
client_version = headers.get(b"x-llamastack-client-version", b"").decode()
|
||||
if client_version:
|
||||
try:
|
||||
client_version_parts = tuple(map(int, client_version.split(".")[:2]))
|
||||
server_version_parts = tuple(map(int, self.server_version.split(".")[:2]))
|
||||
if client_version_parts != server_version_parts:
|
||||
|
||||
async def send_version_error(send):
|
||||
await send(
|
||||
{
|
||||
"type": "http.response.start",
|
||||
"status": 426,
|
||||
"headers": [[b"content-type", b"application/json"]],
|
||||
}
|
||||
)
|
||||
error_msg = json.dumps(
|
||||
{
|
||||
"error": {
|
||||
"message": f"Client version {client_version} is not compatible with server version {self.server_version}. Please update your client."
|
||||
}
|
||||
}
|
||||
).encode()
|
||||
await send({"type": "http.response.body", "body": error_msg})
|
||||
|
||||
return await send_version_error(send)
|
||||
except (ValueError, IndexError):
|
||||
# If version parsing fails, let the request through
|
||||
pass
|
||||
|
||||
return await self.app(scope, receive, send)
|
||||
|
||||
|
||||
def main(args: argparse.Namespace | None = None):
|
||||
"""Start the LlamaStack server."""
|
||||
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
|
||||
|
||||
add_config_template_args(parser)
|
||||
parser.add_argument(
|
||||
"--port",
|
||||
type=int,
|
||||
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
|
||||
help="Port to listen on",
|
||||
)
|
||||
parser.add_argument(
|
||||
"--env",
|
||||
action="append",
|
||||
help="Environment variables in KEY=value format. Can be specified multiple times.",
|
||||
)
|
||||
|
||||
# Determine whether the server args are being passed by the "run" command, if this is the case
|
||||
# the args will be passed as a Namespace object to the main function, otherwise they will be
|
||||
# parsed from the command line
|
||||
if args is None:
|
||||
args = parser.parse_args()
|
||||
|
||||
config_or_template = get_config_from_args(args)
|
||||
config_file = resolve_config_or_template(config_or_template, Mode.RUN)
|
||||
|
||||
logger_config = None
|
||||
with open(config_file) as fp:
|
||||
config_contents = yaml.safe_load(fp)
|
||||
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
|
||||
logger_config = LoggingConfig(**cfg)
|
||||
logger = get_logger(name=__name__, category="server", config=logger_config)
|
||||
if args.env:
|
||||
for env_pair in args.env:
|
||||
try:
|
||||
key, value = validate_env_pair(env_pair)
|
||||
logger.info(f"Setting CLI environment variable {key} => {value}")
|
||||
os.environ[key] = value
|
||||
except ValueError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
config = replace_env_vars(config_contents)
|
||||
config = StackRunConfig(**cast_image_name_to_string(config))
|
||||
|
||||
_log_run_config(run_config=config)
|
||||
|
||||
app = FastAPI(
|
||||
lifespan=lifespan,
|
||||
docs_url="/docs",
|
||||
redoc_url="/redoc",
|
||||
openapi_url="/openapi.json",
|
||||
)
|
||||
|
||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||
app.add_middleware(ClientVersionMiddleware)
|
||||
|
||||
try:
|
||||
# Create and set the event loop that will be used for both construction and server runtime
|
||||
loop = asyncio.new_event_loop()
|
||||
asyncio.set_event_loop(loop)
|
||||
|
||||
# Construct the stack in the persistent event loop
|
||||
impls = loop.run_until_complete(construct_stack(config))
|
||||
|
||||
except InvalidProviderError as e:
|
||||
logger.error(f"Error: {str(e)}")
|
||||
sys.exit(1)
|
||||
|
||||
if config.server.auth:
|
||||
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
|
||||
app.add_middleware(AuthenticationMiddleware, auth_config=config.server.auth, impls=impls)
|
||||
else:
|
||||
if config.server.quota:
|
||||
quota = config.server.quota
|
||||
logger.warning(
|
||||
"Configured authenticated_max_requests (%d) but no auth is enabled; "
|
||||
"falling back to anonymous_max_requests (%d) for all the requests",
|
||||
quota.authenticated_max_requests,
|
||||
quota.anonymous_max_requests,
|
||||
)
|
||||
|
||||
if config.server.quota:
|
||||
logger.info("Enabling quota middleware for authenticated and anonymous clients")
|
||||
|
||||
quota = config.server.quota
|
||||
anonymous_max_requests = quota.anonymous_max_requests
|
||||
# if auth is disabled, use the anonymous max requests
|
||||
authenticated_max_requests = quota.authenticated_max_requests if config.server.auth else anonymous_max_requests
|
||||
|
||||
kv_config = quota.kvstore
|
||||
window_map = {"day": 86400}
|
||||
window_seconds = window_map[quota.period.value]
|
||||
|
||||
app.add_middleware(
|
||||
QuotaMiddleware,
|
||||
kv_config=kv_config,
|
||||
anonymous_max_requests=anonymous_max_requests,
|
||||
authenticated_max_requests=authenticated_max_requests,
|
||||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
setup_logger(TelemetryAdapter(TelemetryConfig(), {}))
|
||||
|
||||
# Load external APIs if configured
|
||||
external_apis = load_external_apis(config)
|
||||
all_routes = get_all_api_routes(external_apis)
|
||||
|
||||
if config.apis:
|
||||
apis_to_serve = set(config.apis)
|
||||
else:
|
||||
apis_to_serve = set(impls.keys())
|
||||
|
||||
for inf in builtin_automatically_routed_apis():
|
||||
# if we do not serve the corresponding router API, we should not serve the routing table API
|
||||
if inf.router_api.value not in apis_to_serve:
|
||||
continue
|
||||
apis_to_serve.add(inf.routing_table_api.value)
|
||||
|
||||
apis_to_serve.add("inspect")
|
||||
apis_to_serve.add("providers")
|
||||
for api_str in apis_to_serve:
|
||||
api = Api(api_str)
|
||||
|
||||
routes = all_routes[api]
|
||||
try:
|
||||
impl = impls[api]
|
||||
except KeyError as e:
|
||||
raise ValueError(f"Could not find provider implementation for {api} API") from e
|
||||
|
||||
for route, _ in routes:
|
||||
if not hasattr(impl, route.name):
|
||||
# ideally this should be a typing violation already
|
||||
raise ValueError(f"Could not find method {route.name} on {impl}!")
|
||||
|
||||
impl_method = getattr(impl, route.name)
|
||||
# Filter out HEAD method since it's automatically handled by FastAPI for GET routes
|
||||
available_methods = [m for m in route.methods if m != "HEAD"]
|
||||
if not available_methods:
|
||||
raise ValueError(f"No methods found for {route.name} on {impl}")
|
||||
method = available_methods[0]
|
||||
logger.debug(f"{method} {route.path}")
|
||||
|
||||
with warnings.catch_warnings():
|
||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||
getattr(app, method.lower())(route.path, response_model=None)(
|
||||
create_dynamic_typed_route(
|
||||
impl_method,
|
||||
method.lower(),
|
||||
route.path,
|
||||
)
|
||||
)
|
||||
|
||||
logger.debug(f"serving APIs: {apis_to_serve}")
|
||||
|
||||
app.exception_handler(RequestValidationError)(global_exception_handler)
|
||||
app.exception_handler(Exception)(global_exception_handler)
|
||||
|
||||
app.__llama_stack_impls__ = impls
|
||||
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
|
||||
|
||||
import uvicorn
|
||||
|
||||
# Configure SSL if certificates are provided
|
||||
port = args.port or config.server.port
|
||||
|
||||
ssl_config = None
|
||||
keyfile = config.server.tls_keyfile
|
||||
certfile = config.server.tls_certfile
|
||||
|
||||
if keyfile and certfile:
|
||||
ssl_config = {
|
||||
"ssl_keyfile": keyfile,
|
||||
"ssl_certfile": certfile,
|
||||
}
|
||||
if config.server.tls_cafile:
|
||||
ssl_config["ssl_ca_certs"] = config.server.tls_cafile
|
||||
ssl_config["ssl_cert_reqs"] = ssl.CERT_REQUIRED
|
||||
logger.info(
|
||||
f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}\n CA: {config.server.tls_cafile}"
|
||||
)
|
||||
else:
|
||||
logger.info(f"HTTPS enabled with certificates:\n Key: {keyfile}\n Cert: {certfile}")
|
||||
|
||||
listen_host = config.server.host or ["::", "0.0.0.0"]
|
||||
logger.info(f"Listening on {listen_host}:{port}")
|
||||
|
||||
uvicorn_config = {
|
||||
"app": app,
|
||||
"host": listen_host,
|
||||
"port": port,
|
||||
"lifespan": "on",
|
||||
"log_level": logger.getEffectiveLevel(),
|
||||
"log_config": logger_config,
|
||||
}
|
||||
if ssl_config:
|
||||
uvicorn_config.update(ssl_config)
|
||||
|
||||
# Run uvicorn in the existing event loop to preserve background tasks
|
||||
# We need to catch KeyboardInterrupt because uvicorn's signal handling
|
||||
# re-raises SIGINT signals using signal.raise_signal(), which Python
|
||||
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
|
||||
# stack trace when using Ctrl+C or kill -2 (SIGINT).
|
||||
# SIGTERM (kill -15) works fine without this because Python doesn't
|
||||
# have a default handler for it.
|
||||
#
|
||||
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
|
||||
# signal handling but this is quite intrusive and not worth the effort.
|
||||
try:
|
||||
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
|
||||
except (KeyboardInterrupt, SystemExit):
|
||||
logger.info("Received interrupt signal, shutting down gracefully...")
|
||||
finally:
|
||||
if not loop.is_closed():
|
||||
logger.debug("Closing event loop")
|
||||
loop.close()
|
||||
|
||||
|
||||
def _log_run_config(run_config: StackRunConfig):
|
||||
"""Logs the run config with redacted fields and disabled providers removed."""
|
||||
logger.info("Run configuration:")
|
||||
safe_config = redact_sensitive_fields(run_config.model_dump(mode="json"))
|
||||
clean_config = remove_disabled_providers(safe_config)
|
||||
logger.info(yaml.dump(clean_config, indent=2))
|
||||
|
||||
|
||||
def extract_path_params(route: str) -> list[str]:
|
||||
segments = route.split("/")
|
||||
params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")]
|
||||
# to handle path params like {param:path}
|
||||
params = [param.split(":")[0] for param in params]
|
||||
return params
|
||||
|
||||
|
||||
def remove_disabled_providers(obj):
|
||||
if isinstance(obj, dict):
|
||||
keys = ["provider_id", "shield_id", "provider_model_id", "model_id"]
|
||||
if any(k in obj and obj[k] in ("__disabled__", "", None) for k in keys):
|
||||
return None
|
||||
return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None}
|
||||
elif isinstance(obj, list):
|
||||
return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None]
|
||||
else:
|
||||
return obj
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
Loading…
Add table
Add a link
Reference in a new issue