diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 335fa3a68..715f73284 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -225,6 +225,7 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS + cors: true # Optional: Enable CORS (dev mode) or full config object ``` ### Authentication Configuration @@ -618,6 +619,54 @@ Content-Type: application/json } ``` +### CORS Configuration + +Configure CORS to allow web browsers to make requests from different domains. Disabled by default. + +#### Quick Setup + +For development, use the simple boolean flag: + +```yaml +server: + cors: true # Auto-enables localhost with any port +``` + +This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults. + +#### Custom Configuration + +For specific origins and full control: + +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://staging.myapp.com"] + allow_credentials: true + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern + expose_headers: ["X-Total-Count"] + max_age: 86400 +``` + +#### Configuration Options + +| Field | Description | Default | +| -------------------- | ---------------------------------------------- | ------- | +| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` | +| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` | +| `allow_methods` | Allowed HTTP methods. | `["*"]` | +| `allow_headers` | Allowed headers. | `["*"]` | +| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` | +| `expose_headers` | Headers exposed to browser. | `[]` | +| `max_age` | Preflight cache time (seconds). | `600` | + +**Security Notes**: +- `allow_credentials: true` requires explicit origins (no wildcards) +- `cors: true` enables localhost access only (secure for development) +- For public APIs, always specify exact allowed origins + ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index a1b6ad32b..5e82922c0 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -318,6 +318,53 @@ class QuotaConfig(BaseModel): period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") +class CORSConfig(BaseModel): + allow_origins: list[str] = Field(default=["*"]) + allow_origin_regex: str | None = Field(default=None) + allow_methods: list[str] = Field(default=["*"]) + allow_headers: list[str] = Field(default=["*"]) + allow_credentials: bool = Field(default=False) + expose_headers: list[str] = Field(default_factory=list) + max_age: int = Field(default=600, ge=0) + + @model_validator(mode="after") + def _validate_credentials_with_wildcard(self) -> Self: + if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins): + raise ValueError("CORS: allow_credentials=True requires explicit origins") + return self + + +# Union type for flexible CORS configuration +CORSConfiguration = bool | CORSConfig + + +def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | None: + """Process CORS config: bool -> dev defaults, object -> passthrough.""" + if cors_config is None: + return None + + if cors_config is False: + return None + + if cors_config is True: + # Dev mode: localhost with any port + return CORSConfig( + allow_origins=[], + allow_origin_regex=r"https?://localhost:\d+", + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], + allow_credentials=False, + expose_headers=[], + max_age=600, + ) + + elif isinstance(cors_config, CORSConfig): + return cors_config + + else: + raise ValueError(f"Invalid CORS configuration type: {type(cors_config)}") + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -349,6 +396,12 @@ class ServerConfig(BaseModel): default=None, description="Per client quota request configuration", ) + cors: CORSConfiguration | None = Field( + default=None, + description="CORS configuration for cross-origin requests. Can be:\n" + "- true: Enable localhost CORS for development\n" + "- {allow_origins: [...], allow_methods: [...], ...}: Full configuration", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index 3d94b6e81..350ce0052 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -28,6 +28,7 @@ from aiohttp import hdrs from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError @@ -40,6 +41,7 @@ from llama_stack.core.datatypes import ( AuthenticationRequiredError, LoggingConfig, StackRunConfig, + process_cors_config, ) from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis @@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) + if config.server.cors: + logger.info("Enabling CORS") + cors_config = process_cors_config(config.server.cors) + if cors_config: + app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: diff --git a/tests/unit/server/test_cors.py b/tests/unit/server/test_cors.py new file mode 100644 index 000000000..b44b1c10a --- /dev/null +++ b/tests/unit/server/test_cors.py @@ -0,0 +1,162 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.core.datatypes import CORSConfig, process_cors_config + + +class TestCORSConfig: + """Test basic CORS configuration.""" + + def test_defaults(self): + config = CORSConfig() + + assert config.allow_origins == ["*"] + assert config.allow_origin_regex is None + assert config.allow_methods == ["*"] + assert config.allow_headers == ["*"] + assert config.allow_credentials is False + assert config.expose_headers == [] + assert config.max_age == 600 + + def test_custom_values(self): + config = CORSConfig(allow_origins=["https://example.com"], allow_credentials=True, max_age=3600) + + assert config.allow_origins == ["https://example.com"] + assert config.allow_credentials is True + assert config.max_age == 3600 + + def test_regex_field(self): + config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+") + + assert config.allow_origins == [] + assert config.allow_origin_regex == r"https?://localhost:\d+" + + def test_credentials_with_wildcard_error(self): + """Should raise error when using credentials with wildcard origins.""" + with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"): + CORSConfig(allow_origins=["*"], allow_credentials=True) + + +class TestProcessCORSConfig: + """Test the process_cors_config function.""" + + def test_none_returns_none(self): + result = process_cors_config(None) + assert result is None + + def test_false_returns_none(self): + result = process_cors_config(False) + assert result is None + + def test_true_returns_dev_config(self): + """Test dev mode: cors: true""" + result = process_cors_config(True) + + assert isinstance(result, CORSConfig) + assert result.allow_origins == [] + assert result.allow_origin_regex == r"https?://localhost:\d+" + assert result.allow_credentials is False + assert "GET" in result.allow_methods + assert "POST" in result.allow_methods + + def test_cors_object_returned_as_is(self): + original = CORSConfig(allow_origins=["https://example.com"]) + result = process_cors_config(original) + + assert result is original + + def test_invalid_type_raises_error(self): + with pytest.raises(ValueError, match="Invalid CORS configuration type"): + process_cors_config("invalid") + + +class TestCORSIntegration: + """Test CORS with FastAPI integration.""" + + def test_dev_mode_with_fastapi(self): + """Test that dev mode config works with FastAPI middleware.""" + from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware + from fastapi.testclient import TestClient + + app = FastAPI() + + # Use our dev mode config + cors_config = process_cors_config(True) + app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + + @app.get("/test") + def test_endpoint(): + return {"message": "hello"} + + client = TestClient(app) + + # Test localhost origins work + response = client.get("/test", headers={"Origin": "http://localhost:3000"}) + assert response.status_code == 200 + assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000" + + # Test non-localhost doesn't get CORS headers + response = client.get("/test", headers={"Origin": "https://evil.com"}) + assert response.status_code == 200 + assert "Access-Control-Allow-Origin" not in response.headers + + def test_production_mode_with_fastapi(self): + """Test explicit origins configuration.""" + from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware + from fastapi.testclient import TestClient + + app = FastAPI() + + # Production config + cors_config = CORSConfig(allow_origins=["https://myapp.com"], allow_credentials=True) + app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + + @app.get("/test") + def test_endpoint(): + return {"message": "hello"} + + client = TestClient(app) + + # Test allowed origin works + response = client.get("/test", headers={"Origin": "https://myapp.com"}) + assert response.status_code == 200 + assert response.headers.get("Access-Control-Allow-Origin") == "https://myapp.com" + assert response.headers.get("Access-Control-Allow-Credentials") == "true" + + # Test disallowed origin + response = client.get("/test", headers={"Origin": "https://evil.com"}) + assert response.status_code == 200 + assert "Access-Control-Allow-Origin" not in response.headers + + def test_preflight_request(self): + """Test CORS preflight OPTIONS request.""" + from fastapi import FastAPI + from fastapi.middleware.cors import CORSMiddleware + from fastapi.testclient import TestClient + + app = FastAPI() + + cors_config = process_cors_config(True) + app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + + @app.get("/test") + def test_endpoint(): + return {"message": "hello"} + + client = TestClient(app) + + # Preflight request + response = client.options( + "/test", headers={"Origin": "http://localhost:3000", "Access-Control-Request-Method": "GET"} + ) + + assert response.status_code == 200 + assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000" + assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")