Add CORS configuration support for FastAPI server

This commit is contained in:
skamenan7 2025-08-19 11:40:19 -04:00
parent c716c8cd03
commit 815b5c7279
3 changed files with 107 additions and 141 deletions

View file

@ -228,6 +228,29 @@ server:
cors: true # Optional: Enable CORS (dev mode) or full config object cors: true # Optional: Enable CORS (dev mode) or full config object
``` ```
### CORS Configuration
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
**Local development** (allows localhost origins only):
```yaml
server:
cors: true
```
**Explicit configuration** (custom origins and settings):
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://app.example.com"]
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_credentials: true
max_age: 3600
```
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
### Authentication Configuration ### Authentication Configuration
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.

View file

@ -318,11 +318,13 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel): class CORSSpec(BaseModel):
allow_origins: list[str] = Field(default=["*"]) """CORS configuration with strict defaults (minimal permissions)."""
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None) allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["*"]) allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default=["*"]) allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False) allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list) expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0) max_age: int = Field(default=600, ge=0)
@ -334,21 +336,19 @@ class CORSConfig(BaseModel):
return self return self
# Union type for flexible CORS configuration # Union type for flexible CORS configuration input
CORSConfiguration = bool | CORSConfig # Accepts: bool (dev shortcuts) or CORSSpec (explicit config)
CORSConfig = bool | CORSSpec
def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | None: def process_cors_config(cors_config: CORSConfig) -> CORSSpec | None:
"""Process CORS config: bool -> dev defaults, object -> passthrough.""" """Process CORS config: bool -> dev defaults, CORSSpec -> passthrough."""
if cors_config is None:
return None
if cors_config is False: if cors_config is False:
return None return None
if cors_config is True: if cors_config is True:
# Dev mode: localhost with any port # Dev mode: localhost with any port
return CORSConfig( return CORSSpec(
allow_origins=[], allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+", allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
@ -358,7 +358,7 @@ def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | N
max_age=600, max_age=600,
) )
elif isinstance(cors_config, CORSConfig): elif isinstance(cors_config, CORSSpec):
return cors_config return cors_config
else: else:
@ -396,7 +396,7 @@ class ServerConfig(BaseModel):
default=None, default=None,
description="Per client quota request configuration", description="Per client quota request configuration",
) )
cors: CORSConfiguration | None = Field( cors: CORSConfig | None = Field(
default=None, default=None,
description="CORS configuration for cross-origin requests. Can be:\n" description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n" "- true: Enable localhost CORS for development\n"

View file

@ -6,157 +6,100 @@
import pytest import pytest
from llama_stack.core.datatypes import CORSConfig, process_cors_config from llama_stack.core.datatypes import CORSSpec, process_cors_config
class TestCORSConfig: def test_cors_spec_defaults():
"""Test basic CORS configuration.""" config = CORSSpec()
def test_defaults(self): assert config.allow_origins == []
config = CORSConfig() assert config.allow_origin_regex is None
assert config.allow_methods == ["OPTIONS"]
assert config.allow_origins == ["*"] assert config.allow_headers == []
assert config.allow_origin_regex is None assert config.allow_credentials is False
assert config.allow_methods == ["*"] assert config.expose_headers == []
assert config.allow_headers == ["*"] assert config.max_age == 600
assert config.allow_credentials is False
assert config.expose_headers == []
assert config.max_age == 600
def test_custom_values(self):
config = CORSConfig(allow_origins=["https://example.com"], allow_credentials=True, max_age=3600)
assert config.allow_origins == ["https://example.com"]
assert config.allow_credentials is True
assert config.max_age == 3600
def test_regex_field(self):
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
assert config.allow_origins == []
assert config.allow_origin_regex == r"https?://localhost:\d+"
def test_credentials_with_wildcard_error(self):
"""Should raise error when using credentials with wildcard origins."""
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
CORSConfig(allow_origins=["*"], allow_credentials=True)
class TestProcessCORSConfig: def test_cors_spec_explicit_config():
"""Test the process_cors_config function.""" config = CORSSpec(
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
)
def test_none_returns_none(self): assert config.allow_origins == ["https://example.com"]
result = process_cors_config(None) assert config.allow_credentials is True
assert result is None assert config.max_age == 3600
assert config.allow_methods == ["GET", "POST"]
def test_false_returns_none(self):
result = process_cors_config(False)
assert result is None
def test_true_returns_dev_config(self):
"""Test dev mode: cors: true"""
result = process_cors_config(True)
assert isinstance(result, CORSConfig)
assert result.allow_origins == []
assert result.allow_origin_regex == r"https?://localhost:\d+"
assert result.allow_credentials is False
assert "GET" in result.allow_methods
assert "POST" in result.allow_methods
def test_cors_object_returned_as_is(self):
original = CORSConfig(allow_origins=["https://example.com"])
result = process_cors_config(original)
assert result is original
def test_invalid_type_raises_error(self):
with pytest.raises(ValueError, match="Invalid CORS configuration type"):
process_cors_config("invalid")
class TestCORSIntegration: def test_cors_spec_regex():
"""Test CORS with FastAPI integration.""" config = CORSSpec(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
def test_dev_mode_with_fastapi(self): assert config.allow_origins == []
"""Test that dev mode config works with FastAPI middleware.""" assert config.allow_origin_regex == r"https?://localhost:\d+"
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
app = FastAPI()
# Use our dev mode config def test_cors_spec_wildcard_credentials_error():
cors_config = process_cors_config(True) with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
app.add_middleware(CORSMiddleware, **cors_config.model_dump()) CORSSpec(allow_origins=["*"], allow_credentials=True)
@app.get("/test") with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
def test_endpoint(): CORSSpec(allow_origins=["https://example.com", "*"], allow_credentials=True)
return {"message": "hello"}
client = TestClient(app)
# Test localhost origins work def test_process_cors_config_false():
response = client.get("/test", headers={"Origin": "http://localhost:3000"}) result = process_cors_config(False)
assert response.status_code == 200 assert result is None
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
# Test non-localhost doesn't get CORS headers
response = client.get("/test", headers={"Origin": "https://evil.com"})
assert response.status_code == 200
assert "Access-Control-Allow-Origin" not in response.headers
def test_production_mode_with_fastapi(self): def test_process_cors_config_true():
"""Test explicit origins configuration.""" result = process_cors_config(True)
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
app = FastAPI() assert isinstance(result, CORSSpec)
assert result.allow_origins == []
assert result.allow_origin_regex == r"https?://localhost:\d+"
assert result.allow_credentials is False
assert "GET" in result.allow_methods
assert "POST" in result.allow_methods
assert "OPTIONS" in result.allow_methods
# Production config
cors_config = CORSConfig(allow_origins=["https://myapp.com"], allow_credentials=True)
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
@app.get("/test") def test_process_cors_config_passthrough():
def test_endpoint(): original = CORSSpec(allow_origins=["https://example.com"], allow_methods=["GET"])
return {"message": "hello"} result = process_cors_config(original)
client = TestClient(app) assert result is original
# Test allowed origin works
response = client.get("/test", headers={"Origin": "https://myapp.com"})
assert response.status_code == 200
assert response.headers.get("Access-Control-Allow-Origin") == "https://myapp.com"
assert response.headers.get("Access-Control-Allow-Credentials") == "true"
# Test disallowed origin def test_process_cors_config_invalid_type():
response = client.get("/test", headers={"Origin": "https://evil.com"}) with pytest.raises(ValueError, match="Invalid CORS configuration type"):
assert response.status_code == 200 process_cors_config("invalid")
assert "Access-Control-Allow-Origin" not in response.headers
def test_preflight_request(self):
"""Test CORS preflight OPTIONS request."""
from fastapi import FastAPI
from fastapi.middleware.cors import CORSMiddleware
from fastapi.testclient import TestClient
app = FastAPI() def test_cors_spec_model_dump():
cors_spec = CORSSpec(
allow_origins=["https://example.com"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
allow_credentials=True,
max_age=3600,
)
cors_config = process_cors_config(True) config_dict = cors_spec.model_dump()
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
@app.get("/test") assert config_dict["allow_origins"] == ["https://example.com"]
def test_endpoint(): assert config_dict["allow_methods"] == ["GET", "POST"]
return {"message": "hello"} assert config_dict["allow_headers"] == ["Content-Type"]
assert config_dict["allow_credentials"] is True
assert config_dict["max_age"] == 3600
client = TestClient(app) expected_keys = {
"allow_origins",
# Preflight request "allow_origin_regex",
response = client.options( "allow_methods",
"/test", headers={"Origin": "http://localhost:3000", "Access-Control-Request-Method": "GET"} "allow_headers",
) "allow_credentials",
"expose_headers",
assert response.status_code == 200 "max_age",
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000" }
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "") assert set(config_dict.keys()) == expected_keys