feat: Add CORS configuration support for server (#3201)

Adds flexible CORS (Cross-Origin Resource Sharing) configuration support
to the FastAPI
  server with both local development and explicit configuration modes:

- **Local development mode**: `cors: true` enables localhost-only access
with regex
  pattern `https?://localhost:\d+`
- **Explicit configuration mode**: Specific origins configuration with
credential support
   and validation
   
- Prevents insecure combinations (wildcards with credentials)
  
- FastAPI CORSMiddleware integration via `model_dump()`

Addresses the need for configurable CORS policies to support web
frontends and
  cross-origin API access while maintaining security.

  Closes #2119

  ## Test Plan

  1.  Ran Unit Tests.

2. Manual tests: FastAPI middleware integration with actual HTTP
requests
    - Local development mode localhost access validation
    - Explicit configuration mode origins validation
    - Preflight OPTIONS request handling

Some screenshots of manual tests.
<img width="1920" height="927" alt="image"
src="https://github.com/user-attachments/assets/79322338-40c7-45c9-a9ea-e3e8d8e2f849"
/>

<img width="1911" height="1037" alt="image"
src="https://github.com/user-attachments/assets/1683524e-b0c9-48c9-a0a5-782e949cde01"
/>

cc: @leseb @rhuss @franciscojavierarceo
This commit is contained in:
Sumanth Kamenani 2025-08-21 17:23:27 -04:00 committed by GitHub
parent 58e164b8bc
commit ac25e35124
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
4 changed files with 226 additions and 0 deletions

View file

@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
default=None,
description="Per client quota request configuration",
)
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel):

View file

@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
process_cors_config,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else: