Add CORS configuration support for FastAPI server

This commit is contained in:
skamenan7 2025-08-19 11:40:19 -04:00
parent 58e164b8bc
commit c716c8cd03
4 changed files with 272 additions and 0 deletions

View file

@ -318,6 +318,53 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default=["*"])
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["*"])
allow_headers: list[str] = Field(default=["*"])
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def _validate_credentials_with_wildcard(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("CORS: allow_credentials=True requires explicit origins")
return self
# Union type for flexible CORS configuration
CORSConfiguration = bool | CORSConfig
def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | None:
"""Process CORS config: bool -> dev defaults, object -> passthrough."""
if cors_config is None:
return None
if cors_config is False:
return None
if cors_config is True:
# Dev mode: localhost with any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
allow_credentials=False,
expose_headers=[],
max_age=600,
)
elif isinstance(cors_config, CORSConfig):
return cors_config
else:
raise ValueError(f"Invalid CORS configuration type: {type(cors_config)}")
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -349,6 +396,12 @@ class ServerConfig(BaseModel):
default=None,
description="Per client quota request configuration",
)
cors: CORSConfiguration | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel):

View file

@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
process_cors_config,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else: