mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-07 02:58:21 +00:00
add authentication middleware
This commit is contained in:
parent
9c8e88ea9c
commit
eff54c1640
3 changed files with 86 additions and 0 deletions
|
@ -125,6 +125,13 @@ class LoggingConfig(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationConfig(BaseModel):
|
||||||
|
endpoint: str = Field(
|
||||||
|
...,
|
||||||
|
description="Endpoint URL to validate authentication tokens",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
@ -140,6 +147,10 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Path to TLS key file for HTTPS",
|
description="Path to TLS key file for HTTPS",
|
||||||
)
|
)
|
||||||
|
auth: Optional[AuthenticationConfig] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Authentication configuration for the server",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
|
69
llama_stack/distribution/server/auth.py
Normal file
69
llama_stack/distribution/server/auth.py
Normal file
|
@ -0,0 +1,69 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
from urllib.parse import parse_qs
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="auth")
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationMiddleware:
|
||||||
|
def __init__(self, app, auth_endpoint):
|
||||||
|
self.app = app
|
||||||
|
self.auth_endpoint = auth_endpoint
|
||||||
|
|
||||||
|
async def __call__(self, scope, receive, send):
|
||||||
|
if scope["type"] == "http":
|
||||||
|
headers = dict(scope.get("headers", []))
|
||||||
|
auth_header = headers.get(b"authorization", b"").decode()
|
||||||
|
|
||||||
|
if not auth_header or not auth_header.startswith("Bearer "):
|
||||||
|
return await self._send_auth_error(send, "Missing or invalid Authorization header")
|
||||||
|
|
||||||
|
api_key = auth_header.split("Bearer ", 1)[1]
|
||||||
|
|
||||||
|
path = scope.get("path", "")
|
||||||
|
request_headers = {k.decode(): v.decode() for k, v in headers.items()}
|
||||||
|
|
||||||
|
query_string = scope.get("query_string", b"").decode()
|
||||||
|
params = parse_qs(query_string)
|
||||||
|
|
||||||
|
auth_data = {
|
||||||
|
"api_key": api_key,
|
||||||
|
"request": {
|
||||||
|
"path": path,
|
||||||
|
"headers": request_headers,
|
||||||
|
"params": params,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
# Validate with authentication endpoint
|
||||||
|
try:
|
||||||
|
async with httpx.AsyncClient() as client:
|
||||||
|
response = await client.post(self.auth_endpoint, json=auth_data)
|
||||||
|
if response.status_code != 200:
|
||||||
|
logger.warning(f"Authentication failed: {response.status_code}")
|
||||||
|
return await self._send_auth_error(send, "Authentication failed")
|
||||||
|
except Exception:
|
||||||
|
logger.exception("Error during authentication")
|
||||||
|
return await self._send_auth_error(send, "Authentication service error")
|
||||||
|
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
|
async def _send_auth_error(self, send, message):
|
||||||
|
await send(
|
||||||
|
{
|
||||||
|
"type": "http.response.start",
|
||||||
|
"status": 401,
|
||||||
|
"headers": [[b"content-type", b"application/json"]],
|
||||||
|
}
|
||||||
|
)
|
||||||
|
error_msg = json.dumps({"error": {"message": message}}).encode()
|
||||||
|
await send({"type": "http.response.body", "body": error_msg})
|
|
@ -52,6 +52,7 @@ from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
start_trace,
|
start_trace,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from .auth import AuthenticationMiddleware
|
||||||
from .endpoints import get_all_api_endpoints
|
from .endpoints import get_all_api_endpoints
|
||||||
|
|
||||||
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
REPO_ROOT = Path(__file__).parent.parent.parent.parent
|
||||||
|
@ -351,6 +352,11 @@ def main():
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
|
# Add authentication middleware if configured
|
||||||
|
if config.server.auth and config.server.auth.endpoint:
|
||||||
|
logger.info(f"Enabling authentication with endpoint: {config.server.auth.endpoint}")
|
||||||
|
app.add_middleware(AuthenticationMiddleware, auth_endpoint=config.server.auth.endpoint)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
impls = asyncio.run(construct_stack(config))
|
impls = asyncio.run(construct_stack(config))
|
||||||
except InvalidProviderError as e:
|
except InvalidProviderError as e:
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue