diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 715f73284..c9677b3b6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -228,6 +228,29 @@ server: cors: true # Optional: Enable CORS (dev mode) or full config object ``` +### CORS Configuration + +CORS (Cross-Origin Resource Sharing) can be configured in two ways: + +**Local development** (allows localhost origins only): +```yaml +server: + cors: true +``` + +**Explicit configuration** (custom origins and settings): +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://app.example.com"] + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_credentials: true + max_age: 3600 +``` + +When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security. + ### Authentication Configuration > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index 5e82922c0..bb99bc636 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -318,11 +318,13 @@ class QuotaConfig(BaseModel): period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") -class CORSConfig(BaseModel): - allow_origins: list[str] = Field(default=["*"]) +class CORSSpec(BaseModel): + """CORS configuration with strict defaults (minimal permissions).""" + + allow_origins: list[str] = Field(default_factory=list) allow_origin_regex: str | None = Field(default=None) - allow_methods: list[str] = Field(default=["*"]) - allow_headers: list[str] = Field(default=["*"]) + allow_methods: list[str] = Field(default=["OPTIONS"]) + allow_headers: list[str] = Field(default_factory=list) allow_credentials: bool = Field(default=False) expose_headers: list[str] = Field(default_factory=list) max_age: int = Field(default=600, ge=0) @@ -334,21 +336,19 @@ class CORSConfig(BaseModel): return self -# Union type for flexible CORS configuration -CORSConfiguration = bool | CORSConfig +# Union type for flexible CORS configuration input +# Accepts: bool (dev shortcuts) or CORSSpec (explicit config) +CORSConfig = bool | CORSSpec -def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | None: - """Process CORS config: bool -> dev defaults, object -> passthrough.""" - if cors_config is None: - return None - +def process_cors_config(cors_config: CORSConfig) -> CORSSpec | None: + """Process CORS config: bool -> dev defaults, CORSSpec -> passthrough.""" if cors_config is False: return None if cors_config is True: # Dev mode: localhost with any port - return CORSConfig( + return CORSSpec( allow_origins=[], allow_origin_regex=r"https?://localhost:\d+", allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], @@ -358,7 +358,7 @@ def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | N max_age=600, ) - elif isinstance(cors_config, CORSConfig): + elif isinstance(cors_config, CORSSpec): return cors_config else: @@ -396,7 +396,7 @@ class ServerConfig(BaseModel): default=None, description="Per client quota request configuration", ) - cors: CORSConfiguration | None = Field( + cors: CORSConfig | None = Field( default=None, description="CORS configuration for cross-origin requests. Can be:\n" "- true: Enable localhost CORS for development\n" diff --git a/tests/unit/server/test_cors.py b/tests/unit/server/test_cors.py index b44b1c10a..512be1bd2 100644 --- a/tests/unit/server/test_cors.py +++ b/tests/unit/server/test_cors.py @@ -6,157 +6,100 @@ import pytest -from llama_stack.core.datatypes import CORSConfig, process_cors_config +from llama_stack.core.datatypes import CORSSpec, process_cors_config -class TestCORSConfig: - """Test basic CORS configuration.""" +def test_cors_spec_defaults(): + config = CORSSpec() - def test_defaults(self): - config = CORSConfig() - - assert config.allow_origins == ["*"] - assert config.allow_origin_regex is None - assert config.allow_methods == ["*"] - assert config.allow_headers == ["*"] - assert config.allow_credentials is False - assert config.expose_headers == [] - assert config.max_age == 600 - - def test_custom_values(self): - config = CORSConfig(allow_origins=["https://example.com"], allow_credentials=True, max_age=3600) - - assert config.allow_origins == ["https://example.com"] - assert config.allow_credentials is True - assert config.max_age == 3600 - - def test_regex_field(self): - config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+") - - assert config.allow_origins == [] - assert config.allow_origin_regex == r"https?://localhost:\d+" - - def test_credentials_with_wildcard_error(self): - """Should raise error when using credentials with wildcard origins.""" - with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"): - CORSConfig(allow_origins=["*"], allow_credentials=True) + assert config.allow_origins == [] + assert config.allow_origin_regex is None + assert config.allow_methods == ["OPTIONS"] + assert config.allow_headers == [] + assert config.allow_credentials is False + assert config.expose_headers == [] + assert config.max_age == 600 -class TestProcessCORSConfig: - """Test the process_cors_config function.""" +def test_cors_spec_explicit_config(): + config = CORSSpec( + allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"] + ) - def test_none_returns_none(self): - result = process_cors_config(None) - assert result is None - - def test_false_returns_none(self): - result = process_cors_config(False) - assert result is None - - def test_true_returns_dev_config(self): - """Test dev mode: cors: true""" - result = process_cors_config(True) - - assert isinstance(result, CORSConfig) - assert result.allow_origins == [] - assert result.allow_origin_regex == r"https?://localhost:\d+" - assert result.allow_credentials is False - assert "GET" in result.allow_methods - assert "POST" in result.allow_methods - - def test_cors_object_returned_as_is(self): - original = CORSConfig(allow_origins=["https://example.com"]) - result = process_cors_config(original) - - assert result is original - - def test_invalid_type_raises_error(self): - with pytest.raises(ValueError, match="Invalid CORS configuration type"): - process_cors_config("invalid") + assert config.allow_origins == ["https://example.com"] + assert config.allow_credentials is True + assert config.max_age == 3600 + assert config.allow_methods == ["GET", "POST"] -class TestCORSIntegration: - """Test CORS with FastAPI integration.""" +def test_cors_spec_regex(): + config = CORSSpec(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+") - def test_dev_mode_with_fastapi(self): - """Test that dev mode config works with FastAPI middleware.""" - from fastapi import FastAPI - from fastapi.middleware.cors import CORSMiddleware - from fastapi.testclient import TestClient + assert config.allow_origins == [] + assert config.allow_origin_regex == r"https?://localhost:\d+" - app = FastAPI() - # Use our dev mode config - cors_config = process_cors_config(True) - app.add_middleware(CORSMiddleware, **cors_config.model_dump()) +def test_cors_spec_wildcard_credentials_error(): + with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"): + CORSSpec(allow_origins=["*"], allow_credentials=True) - @app.get("/test") - def test_endpoint(): - return {"message": "hello"} + with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"): + CORSSpec(allow_origins=["https://example.com", "*"], allow_credentials=True) - client = TestClient(app) - # Test localhost origins work - response = client.get("/test", headers={"Origin": "http://localhost:3000"}) - assert response.status_code == 200 - assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000" +def test_process_cors_config_false(): + result = process_cors_config(False) + assert result is None - # Test non-localhost doesn't get CORS headers - response = client.get("/test", headers={"Origin": "https://evil.com"}) - assert response.status_code == 200 - assert "Access-Control-Allow-Origin" not in response.headers - def test_production_mode_with_fastapi(self): - """Test explicit origins configuration.""" - from fastapi import FastAPI - from fastapi.middleware.cors import CORSMiddleware - from fastapi.testclient import TestClient +def test_process_cors_config_true(): + result = process_cors_config(True) - app = FastAPI() + assert isinstance(result, CORSSpec) + assert result.allow_origins == [] + assert result.allow_origin_regex == r"https?://localhost:\d+" + assert result.allow_credentials is False + assert "GET" in result.allow_methods + assert "POST" in result.allow_methods + assert "OPTIONS" in result.allow_methods - # Production config - cors_config = CORSConfig(allow_origins=["https://myapp.com"], allow_credentials=True) - app.add_middleware(CORSMiddleware, **cors_config.model_dump()) - @app.get("/test") - def test_endpoint(): - return {"message": "hello"} +def test_process_cors_config_passthrough(): + original = CORSSpec(allow_origins=["https://example.com"], allow_methods=["GET"]) + result = process_cors_config(original) - client = TestClient(app) + assert result is original - # Test allowed origin works - response = client.get("/test", headers={"Origin": "https://myapp.com"}) - assert response.status_code == 200 - assert response.headers.get("Access-Control-Allow-Origin") == "https://myapp.com" - assert response.headers.get("Access-Control-Allow-Credentials") == "true" - # Test disallowed origin - response = client.get("/test", headers={"Origin": "https://evil.com"}) - assert response.status_code == 200 - assert "Access-Control-Allow-Origin" not in response.headers +def test_process_cors_config_invalid_type(): + with pytest.raises(ValueError, match="Invalid CORS configuration type"): + process_cors_config("invalid") - def test_preflight_request(self): - """Test CORS preflight OPTIONS request.""" - from fastapi import FastAPI - from fastapi.middleware.cors import CORSMiddleware - from fastapi.testclient import TestClient - app = FastAPI() +def test_cors_spec_model_dump(): + cors_spec = CORSSpec( + allow_origins=["https://example.com"], + allow_methods=["GET", "POST"], + allow_headers=["Content-Type"], + allow_credentials=True, + max_age=3600, + ) - cors_config = process_cors_config(True) - app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + config_dict = cors_spec.model_dump() - @app.get("/test") - def test_endpoint(): - return {"message": "hello"} + assert config_dict["allow_origins"] == ["https://example.com"] + assert config_dict["allow_methods"] == ["GET", "POST"] + assert config_dict["allow_headers"] == ["Content-Type"] + assert config_dict["allow_credentials"] is True + assert config_dict["max_age"] == 3600 - client = TestClient(app) - - # Preflight request - response = client.options( - "/test", headers={"Origin": "http://localhost:3000", "Access-Control-Request-Method": "GET"} - ) - - assert response.status_code == 200 - assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000" - assert "GET" in response.headers.get("Access-Control-Allow-Methods", "") + expected_keys = { + "allow_origins", + "allow_origin_regex", + "allow_methods", + "allow_headers", + "allow_credentials", + "expose_headers", + "max_age", + } + assert set(config_dict.keys()) == expected_keys