Add CORS configuration support for FastAPI server

This commit is contained in:
skamenan7 2025-08-19 11:40:19 -04:00
parent 58e164b8bc
commit c716c8cd03
4 changed files with 272 additions and 0 deletions

View file

@ -225,6 +225,7 @@ server:
port: 8321 # Port to listen on (default: 8321) port: 8321 # Port to listen on (default: 8321)
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
cors: true # Optional: Enable CORS (dev mode) or full config object
``` ```
### Authentication Configuration ### Authentication Configuration
@ -618,6 +619,54 @@ Content-Type: application/json
} }
``` ```
### CORS Configuration
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
#### Quick Setup
For development, use the simple boolean flag:
```yaml
server:
cors: true # Auto-enables localhost with any port
```
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
#### Custom Configuration
For specific origins and full control:
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
allow_credentials: true
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
expose_headers: ["X-Total-Count"]
max_age: 86400
```
#### Configuration Options
| Field | Description | Default |
| -------------------- | ---------------------------------------------- | ------- |
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
| `allow_headers` | Allowed headers. | `["*"]` |
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
| `expose_headers` | Headers exposed to browser. | `[]` |
| `max_age` | Preflight cache time (seconds). | `600` |
**Security Notes**:
- `allow_credentials: true` requires explicit origins (no wildcards)
- `cors: true` enables localhost access only (secure for development)
- For public APIs, always specify exact allowed origins
## Extending to handle Safety ## Extending to handle Safety
Configuring Safety can be a little involved so it is instructive to go through an example. Configuring Safety can be a little involved so it is instructive to go through an example.

View file

@ -318,6 +318,53 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default=["*"])
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["*"])
allow_headers: list[str] = Field(default=["*"])
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def _validate_credentials_with_wildcard(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("CORS: allow_credentials=True requires explicit origins")
return self
# Union type for flexible CORS configuration
CORSConfiguration = bool | CORSConfig
def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | None:
"""Process CORS config: bool -> dev defaults, object -> passthrough."""
if cors_config is None:
return None
if cors_config is False:
return None
if cors_config is True:
# Dev mode: localhost with any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
allow_credentials=False,
expose_headers=[],
max_age=600,
)
elif isinstance(cors_config, CORSConfig):
return cors_config
else:
raise ValueError(f"Invalid CORS configuration type: {type(cors_config)}")
class ServerConfig(BaseModel): class ServerConfig(BaseModel):
port: int = Field( port: int = Field(
default=8321, default=8321,
@ -349,6 +396,12 @@ class ServerConfig(BaseModel):
default=None, default=None,
description="Per client quota request configuration", description="Per client quota request configuration",
) )
cors: CORSConfiguration | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):

View file

@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError, AuthenticationRequiredError,
LoggingConfig, LoggingConfig,
StackRunConfig, StackRunConfig,
process_cors_config,
) )
from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds, window_seconds=window_seconds,
) )
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
else: else:

View file

@ -0,0 +1,162 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.core.datatypes import CORSConfig, process_cors_config
class TestCORSConfig:
"""Test basic CORS configuration."""
def test_defaults(self):
config = CORSConfig()
assert config.allow_origins == ["*"]
assert config.allow_origin_regex is None
assert config.allow_methods == ["*"]
assert config.allow_headers == ["*"]
assert config.allow_credentials is False
assert config.expose_headers == []
assert config.max_age == 600
def test_custom_values(self):
config = CORSConfig(allow_origins=["https://example.com"], allow_credentials=True, max_age=3600)
assert config.allow_origins == ["https://example.com"]
assert config.allow_credentials is True
assert config.max_age == 3600
def test_regex_field(self):
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
assert config.allow_origins == []
assert config.allow_origin_regex == r"https?://localhost:\d+"
def test_credentials_with_wildcard_error(self):
"""Should raise error when using credentials with wildcard origins."""
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
CORSConfig(allow_origins=["*"], allow_credentials=True)
class TestProcessCORSConfig:
"""Test the process_cors_config function."""
def test_none_returns_none(self):
result = process_cors_config(None)
assert result is None
def test_false_returns_none(self):
result = process_cors_config(False)
assert result is None
def test_true_returns_dev_config(self):
"""Test dev mode: cors: true"""
result = process_cors_config(True)
assert isinstance(result, CORSConfig)
assert result.allow_origins == []
assert result.allow_origin_regex == r"https?://localhost:\d+"
assert result.allow_credentials is False
assert "GET" in result.allow_methods
assert "POST" in result.allow_methods
def test_cors_object_returned_as_is(self):
original = CORSConfig(allow_origins=["https://example.com"])
result = process_cors_config(original)
assert result is original
def test_invalid_type_raises_error(self):
with pytest.raises(ValueError, match="Invalid CORS configuration type"):
process_cors_config("invalid")
class TestCORSIntegration:
"""Test CORS with FastAPI integration."""
def test_dev_mode_with_fastapi(self):
"""Test that dev mode config works with FastAPI middleware."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
app = FastAPI()
# Use our dev mode config
cors_config = process_cors_config(True)
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
@app.get("/test")
def test_endpoint():
return {"message": "hello"}
client = TestClient(app)
# Test localhost origins work
response = client.get("/test", headers={"Origin": "http://localhost:3000"})
assert response.status_code == 200
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
# Test non-localhost doesn't get CORS headers
response = client.get("/test", headers={"Origin": "https://evil.com"})
assert response.status_code == 200
assert "Access-Control-Allow-Origin" not in response.headers
def test_production_mode_with_fastapi(self):
"""Test explicit origins configuration."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
app = FastAPI()
# Production config
cors_config = CORSConfig(allow_origins=["https://myapp.com"], allow_credentials=True)
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
@app.get("/test")
def test_endpoint():
return {"message": "hello"}
client = TestClient(app)
# Test allowed origin works
response = client.get("/test", headers={"Origin": "https://myapp.com"})
assert response.status_code == 200
assert response.headers.get("Access-Control-Allow-Origin") == "https://myapp.com"
assert response.headers.get("Access-Control-Allow-Credentials") == "true"
# Test disallowed origin
response = client.get("/test", headers={"Origin": "https://evil.com"})
assert response.status_code == 200
assert "Access-Control-Allow-Origin" not in response.headers
def test_preflight_request(self):
"""Test CORS preflight OPTIONS request."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
app = FastAPI()
cors_config = process_cors_config(True)
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
@app.get("/test")
def test_endpoint():
return {"message": "hello"}
client = TestClient(app)
# Preflight request
response = client.options(
"/test", headers={"Origin": "http://localhost:3000", "Access-Control-Request-Method": "GET"}
)
assert response.status_code == 200
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")