mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-10-04 04:04:14 +00:00
feat: Add CORS configuration support for server (#3201)
Adds flexible CORS (Cross-Origin Resource Sharing) configuration support to the FastAPI server with both local development and explicit configuration modes: - **Local development mode**: `cors: true` enables localhost-only access with regex pattern `https?://localhost:\d+` - **Explicit configuration mode**: Specific origins configuration with credential support and validation - Prevents insecure combinations (wildcards with credentials) - FastAPI CORSMiddleware integration via `model_dump()` Addresses the need for configurable CORS policies to support web frontends and cross-origin API access while maintaining security. Closes #2119 ## Test Plan 1. Ran Unit Tests. 2. Manual tests: FastAPI middleware integration with actual HTTP requests - Local development mode localhost access validation - Explicit configuration mode origins validation - Preflight OPTIONS request handling Some screenshots of manual tests. <img width="1920" height="927" alt="image" src="https://github.com/user-attachments/assets/79322338-40c7-45c9-a9ea-e3e8d8e2f849" /> <img width="1911" height="1037" alt="image" src="https://github.com/user-attachments/assets/1683524e-b0c9-48c9-a0a5-782e949cde01" /> cc: @leseb @rhuss @franciscojavierarceo
This commit is contained in:
parent
58e164b8bc
commit
ac25e35124
4 changed files with 226 additions and 0 deletions
|
@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
|
|||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||
|
||||
|
||||
class CORSConfig(BaseModel):
|
||||
allow_origins: list[str] = Field(default_factory=list)
|
||||
allow_origin_regex: str | None = Field(default=None)
|
||||
allow_methods: list[str] = Field(default=["OPTIONS"])
|
||||
allow_headers: list[str] = Field(default_factory=list)
|
||||
allow_credentials: bool = Field(default=False)
|
||||
expose_headers: list[str] = Field(default_factory=list)
|
||||
max_age: int = Field(default=600, ge=0)
|
||||
|
||||
@model_validator(mode="after")
|
||||
def validate_credentials_config(self) -> Self:
|
||||
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
||||
raise ValueError("Cannot use wildcard origins with credentials enabled")
|
||||
return self
|
||||
|
||||
|
||||
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
|
||||
if cors_config is False or cors_config is None:
|
||||
return None
|
||||
|
||||
if cors_config is True:
|
||||
# dev mode: allow localhost on any port
|
||||
return CORSConfig(
|
||||
allow_origins=[],
|
||||
allow_origin_regex=r"https?://localhost:\d+",
|
||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
||||
)
|
||||
|
||||
if isinstance(cors_config, CORSConfig):
|
||||
return cors_config
|
||||
|
||||
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
|
||||
|
||||
|
||||
class ServerConfig(BaseModel):
|
||||
port: int = Field(
|
||||
default=8321,
|
||||
|
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
|
|||
default=None,
|
||||
description="Per client quota request configuration",
|
||||
)
|
||||
cors: bool | CORSConfig | None = Field(
|
||||
default=None,
|
||||
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||
"- true: Enable localhost CORS for development\n"
|
||||
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||
)
|
||||
|
||||
|
||||
class StackRunConfig(BaseModel):
|
||||
|
|
|
@ -28,6 +28,7 @@ from aiohttp import hdrs
|
|||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||
from fastapi import Path as FastapiPath
|
||||
from fastapi.exceptions import RequestValidationError
|
||||
from fastapi.middleware.cors import CORSMiddleware
|
||||
from fastapi.responses import JSONResponse, StreamingResponse
|
||||
from openai import BadRequestError
|
||||
from pydantic import BaseModel, ValidationError
|
||||
|
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
|
|||
AuthenticationRequiredError,
|
||||
LoggingConfig,
|
||||
StackRunConfig,
|
||||
process_cors_config,
|
||||
)
|
||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||
|
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
|
|||
window_seconds=window_seconds,
|
||||
)
|
||||
|
||||
if config.server.cors:
|
||||
logger.info("Enabling CORS")
|
||||
cors_config = process_cors_config(config.server.cors)
|
||||
if cors_config:
|
||||
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||
|
||||
if Api.telemetry in impls:
|
||||
setup_logger(impls[Api.telemetry])
|
||||
else:
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue