diff --git a/docs/source/distributions/configuration.md b/docs/source/distributions/configuration.md index 335fa3a68..c9677b3b6 100644 --- a/docs/source/distributions/configuration.md +++ b/docs/source/distributions/configuration.md @@ -225,8 +225,32 @@ server: port: 8321 # Port to listen on (default: 8321) tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS + cors: true # Optional: Enable CORS (dev mode) or full config object ``` +### CORS Configuration + +CORS (Cross-Origin Resource Sharing) can be configured in two ways: + +**Local development** (allows localhost origins only): +```yaml +server: + cors: true +``` + +**Explicit configuration** (custom origins and settings): +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://app.example.com"] + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_credentials: true + max_age: 3600 +``` + +When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security. + ### Authentication Configuration > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. @@ -618,6 +642,54 @@ Content-Type: application/json } ``` +### CORS Configuration + +Configure CORS to allow web browsers to make requests from different domains. Disabled by default. + +#### Quick Setup + +For development, use the simple boolean flag: + +```yaml +server: + cors: true # Auto-enables localhost with any port +``` + +This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults. + +#### Custom Configuration + +For specific origins and full control: + +```yaml +server: + cors: + allow_origins: ["https://myapp.com", "https://staging.myapp.com"] + allow_credentials: true + allow_methods: ["GET", "POST", "PUT", "DELETE"] + allow_headers: ["Content-Type", "Authorization"] + allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern + expose_headers: ["X-Total-Count"] + max_age: 86400 +``` + +#### Configuration Options + +| Field | Description | Default | +| -------------------- | ---------------------------------------------- | ------- | +| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` | +| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` | +| `allow_methods` | Allowed HTTP methods. | `["*"]` | +| `allow_headers` | Allowed headers. | `["*"]` | +| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` | +| `expose_headers` | Headers exposed to browser. | `[]` | +| `max_age` | Preflight cache time (seconds). | `600` | + +**Security Notes**: +- `allow_credentials: true` requires explicit origins (no wildcards) +- `cors: true` enables localhost access only (secure for development) +- For public APIs, always specify exact allowed origins + ## Extending to handle Safety Configuring Safety can be a little involved so it is instructive to go through an example. diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index a1b6ad32b..c3940fcbd 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -318,6 +318,41 @@ class QuotaConfig(BaseModel): period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") +class CORSConfig(BaseModel): + allow_origins: list[str] = Field(default_factory=list) + allow_origin_regex: str | None = Field(default=None) + allow_methods: list[str] = Field(default=["OPTIONS"]) + allow_headers: list[str] = Field(default_factory=list) + allow_credentials: bool = Field(default=False) + expose_headers: list[str] = Field(default_factory=list) + max_age: int = Field(default=600, ge=0) + + @model_validator(mode="after") + def validate_credentials_config(self) -> Self: + if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins): + raise ValueError("Cannot use wildcard origins with credentials enabled") + return self + + +def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None: + if cors_config is False or cors_config is None: + return None + + if cors_config is True: + # dev mode: allow localhost on any port + return CORSConfig( + allow_origins=[], + allow_origin_regex=r"https?://localhost:\d+", + allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], + allow_headers=["Content-Type", "Authorization", "X-Requested-With"], + ) + + if isinstance(cors_config, CORSConfig): + return cors_config + + raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}") + + class ServerConfig(BaseModel): port: int = Field( default=8321, @@ -349,6 +384,12 @@ class ServerConfig(BaseModel): default=None, description="Per client quota request configuration", ) + cors: bool | CORSConfig | None = Field( + default=None, + description="CORS configuration for cross-origin requests. Can be:\n" + "- true: Enable localhost CORS for development\n" + "- {allow_origins: [...], allow_methods: [...], ...}: Full configuration", + ) class StackRunConfig(BaseModel): diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index 3d94b6e81..350ce0052 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -28,6 +28,7 @@ from aiohttp import hdrs from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError +from fastapi.middleware.cors import CORSMiddleware from fastapi.responses import JSONResponse, StreamingResponse from openai import BadRequestError from pydantic import BaseModel, ValidationError @@ -40,6 +41,7 @@ from llama_stack.core.datatypes import ( AuthenticationRequiredError, LoggingConfig, StackRunConfig, + process_cors_config, ) from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis @@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None): window_seconds=window_seconds, ) + if config.server.cors: + logger.info("Enabling CORS") + cors_config = process_cors_config(config.server.cors) + if cors_config: + app.add_middleware(CORSMiddleware, **cors_config.model_dump()) + if Api.telemetry in impls: setup_logger(impls[Api.telemetry]) else: diff --git a/tests/unit/server/test_cors.py b/tests/unit/server/test_cors.py new file mode 100644 index 000000000..8fd2515ba --- /dev/null +++ b/tests/unit/server/test_cors.py @@ -0,0 +1,105 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.core.datatypes import CORSConfig, process_cors_config + + +def test_cors_config_defaults(): + config = CORSConfig() + + assert config.allow_origins == [] + assert config.allow_origin_regex is None + assert config.allow_methods == ["OPTIONS"] + assert config.allow_headers == [] + assert config.allow_credentials is False + assert config.expose_headers == [] + assert config.max_age == 600 + + +def test_cors_config_explicit_config(): + config = CORSConfig( + allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"] + ) + + assert config.allow_origins == ["https://example.com"] + assert config.allow_credentials is True + assert config.max_age == 3600 + assert config.allow_methods == ["GET", "POST"] + + +def test_cors_config_regex(): + config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+") + + assert config.allow_origins == [] + assert config.allow_origin_regex == r"https?://localhost:\d+" + + +def test_cors_config_wildcard_credentials_error(): + with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"): + CORSConfig(allow_origins=["*"], allow_credentials=True) + + with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"): + CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True) + + +def test_process_cors_config_false(): + result = process_cors_config(False) + assert result is None + + +def test_process_cors_config_true(): + result = process_cors_config(True) + + assert isinstance(result, CORSConfig) + assert result.allow_origins == [] + assert result.allow_origin_regex == r"https?://localhost:\d+" + assert result.allow_credentials is False + expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"] + for method in expected_methods: + assert method in result.allow_methods + + +def test_process_cors_config_passthrough(): + original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"]) + result = process_cors_config(original) + + assert result is original + + +def test_process_cors_config_invalid_type(): + with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"): + process_cors_config("invalid") + + +def test_cors_config_model_dump(): + cors_config = CORSConfig( + allow_origins=["https://example.com"], + allow_methods=["GET", "POST"], + allow_headers=["Content-Type"], + allow_credentials=True, + max_age=3600, + ) + + config_dict = cors_config.model_dump() + + assert config_dict["allow_origins"] == ["https://example.com"] + assert config_dict["allow_methods"] == ["GET", "POST"] + assert config_dict["allow_headers"] == ["Content-Type"] + assert config_dict["allow_credentials"] is True + assert config_dict["max_age"] == 3600 + + expected_keys = { + "allow_origins", + "allow_origin_regex", + "allow_methods", + "allow_headers", + "allow_credentials", + "expose_headers", + "max_age", + } + assert set(config_dict.keys()) == expected_keys