mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 20:19:52 +00:00
Add CORS configuration support for FastAPI server
This commit is contained in:
parent
58e164b8bc
commit
c716c8cd03
4 changed files with 272 additions and 0 deletions
|
|
@ -225,6 +225,7 @@ server:
|
||||||
port: 8321 # Port to listen on (default: 8321)
|
port: 8321 # Port to listen on (default: 8321)
|
||||||
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
|
||||||
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
|
||||||
|
cors: true # Optional: Enable CORS (dev mode) or full config object
|
||||||
```
|
```
|
||||||
|
|
||||||
### Authentication Configuration
|
### Authentication Configuration
|
||||||
|
|
@ -618,6 +619,54 @@ Content-Type: application/json
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### CORS Configuration
|
||||||
|
|
||||||
|
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
|
||||||
|
|
||||||
|
#### Quick Setup
|
||||||
|
|
||||||
|
For development, use the simple boolean flag:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors: true # Auto-enables localhost with any port
|
||||||
|
```
|
||||||
|
|
||||||
|
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
|
||||||
|
|
||||||
|
#### Custom Configuration
|
||||||
|
|
||||||
|
For specific origins and full control:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors:
|
||||||
|
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
|
||||||
|
allow_credentials: true
|
||||||
|
allow_methods: ["GET", "POST", "PUT", "DELETE"]
|
||||||
|
allow_headers: ["Content-Type", "Authorization"]
|
||||||
|
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
|
||||||
|
expose_headers: ["X-Total-Count"]
|
||||||
|
max_age: 86400
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Configuration Options
|
||||||
|
|
||||||
|
| Field | Description | Default |
|
||||||
|
| -------------------- | ---------------------------------------------- | ------- |
|
||||||
|
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
|
||||||
|
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
|
||||||
|
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
|
||||||
|
| `allow_headers` | Allowed headers. | `["*"]` |
|
||||||
|
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
|
||||||
|
| `expose_headers` | Headers exposed to browser. | `[]` |
|
||||||
|
| `max_age` | Preflight cache time (seconds). | `600` |
|
||||||
|
|
||||||
|
**Security Notes**:
|
||||||
|
- `allow_credentials: true` requires explicit origins (no wildcards)
|
||||||
|
- `cors: true` enables localhost access only (secure for development)
|
||||||
|
- For public APIs, always specify exact allowed origins
|
||||||
|
|
||||||
## Extending to handle Safety
|
## Extending to handle Safety
|
||||||
|
|
||||||
Configuring Safety can be a little involved so it is instructive to go through an example.
|
Configuring Safety can be a little involved so it is instructive to go through an example.
|
||||||
|
|
|
||||||
|
|
@ -318,6 +318,53 @@ class QuotaConfig(BaseModel):
|
||||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||||
|
|
||||||
|
|
||||||
|
class CORSConfig(BaseModel):
|
||||||
|
allow_origins: list[str] = Field(default=["*"])
|
||||||
|
allow_origin_regex: str | None = Field(default=None)
|
||||||
|
allow_methods: list[str] = Field(default=["*"])
|
||||||
|
allow_headers: list[str] = Field(default=["*"])
|
||||||
|
allow_credentials: bool = Field(default=False)
|
||||||
|
expose_headers: list[str] = Field(default_factory=list)
|
||||||
|
max_age: int = Field(default=600, ge=0)
|
||||||
|
|
||||||
|
@model_validator(mode="after")
|
||||||
|
def _validate_credentials_with_wildcard(self) -> Self:
|
||||||
|
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
||||||
|
raise ValueError("CORS: allow_credentials=True requires explicit origins")
|
||||||
|
return self
|
||||||
|
|
||||||
|
|
||||||
|
# Union type for flexible CORS configuration
|
||||||
|
CORSConfiguration = bool | CORSConfig
|
||||||
|
|
||||||
|
|
||||||
|
def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | None:
|
||||||
|
"""Process CORS config: bool -> dev defaults, object -> passthrough."""
|
||||||
|
if cors_config is None:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cors_config is False:
|
||||||
|
return None
|
||||||
|
|
||||||
|
if cors_config is True:
|
||||||
|
# Dev mode: localhost with any port
|
||||||
|
return CORSConfig(
|
||||||
|
allow_origins=[],
|
||||||
|
allow_origin_regex=r"https?://localhost:\d+",
|
||||||
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
||||||
|
allow_credentials=False,
|
||||||
|
expose_headers=[],
|
||||||
|
max_age=600,
|
||||||
|
)
|
||||||
|
|
||||||
|
elif isinstance(cors_config, CORSConfig):
|
||||||
|
return cors_config
|
||||||
|
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Invalid CORS configuration type: {type(cors_config)}")
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
port: int = Field(
|
port: int = Field(
|
||||||
default=8321,
|
default=8321,
|
||||||
|
|
@ -349,6 +396,12 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Per client quota request configuration",
|
description="Per client quota request configuration",
|
||||||
)
|
)
|
||||||
|
cors: CORSConfiguration | None = Field(
|
||||||
|
default=None,
|
||||||
|
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||||
|
"- true: Enable localhost CORS for development\n"
|
||||||
|
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class StackRunConfig(BaseModel):
|
class StackRunConfig(BaseModel):
|
||||||
|
|
|
||||||
|
|
@ -28,6 +28,7 @@ from aiohttp import hdrs
|
||||||
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
from fastapi import Body, FastAPI, HTTPException, Request, Response
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
from openai import BadRequestError
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
|
|
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
|
||||||
AuthenticationRequiredError,
|
AuthenticationRequiredError,
|
||||||
LoggingConfig,
|
LoggingConfig,
|
||||||
StackRunConfig,
|
StackRunConfig,
|
||||||
|
process_cors_config,
|
||||||
)
|
)
|
||||||
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
from llama_stack.core.distribution import builtin_automatically_routed_apis
|
||||||
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
from llama_stack.core.external import ExternalApiSpec, load_external_apis
|
||||||
|
|
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
|
||||||
window_seconds=window_seconds,
|
window_seconds=window_seconds,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
if config.server.cors:
|
||||||
|
logger.info("Enabling CORS")
|
||||||
|
cors_config = process_cors_config(config.server.cors)
|
||||||
|
if cors_config:
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||||
|
|
||||||
if Api.telemetry in impls:
|
if Api.telemetry in impls:
|
||||||
setup_logger(impls[Api.telemetry])
|
setup_logger(impls[Api.telemetry])
|
||||||
else:
|
else:
|
||||||
|
|
|
||||||
162
tests/unit/server/test_cors.py
Normal file
162
tests/unit/server/test_cors.py
Normal file
|
|
@ -0,0 +1,162 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.core.datatypes import CORSConfig, process_cors_config
|
||||||
|
|
||||||
|
|
||||||
|
class TestCORSConfig:
|
||||||
|
"""Test basic CORS configuration."""
|
||||||
|
|
||||||
|
def test_defaults(self):
|
||||||
|
config = CORSConfig()
|
||||||
|
|
||||||
|
assert config.allow_origins == ["*"]
|
||||||
|
assert config.allow_origin_regex is None
|
||||||
|
assert config.allow_methods == ["*"]
|
||||||
|
assert config.allow_headers == ["*"]
|
||||||
|
assert config.allow_credentials is False
|
||||||
|
assert config.expose_headers == []
|
||||||
|
assert config.max_age == 600
|
||||||
|
|
||||||
|
def test_custom_values(self):
|
||||||
|
config = CORSConfig(allow_origins=["https://example.com"], allow_credentials=True, max_age=3600)
|
||||||
|
|
||||||
|
assert config.allow_origins == ["https://example.com"]
|
||||||
|
assert config.allow_credentials is True
|
||||||
|
assert config.max_age == 3600
|
||||||
|
|
||||||
|
def test_regex_field(self):
|
||||||
|
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
|
||||||
|
|
||||||
|
assert config.allow_origins == []
|
||||||
|
assert config.allow_origin_regex == r"https?://localhost:\d+"
|
||||||
|
|
||||||
|
def test_credentials_with_wildcard_error(self):
|
||||||
|
"""Should raise error when using credentials with wildcard origins."""
|
||||||
|
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
|
||||||
|
CORSConfig(allow_origins=["*"], allow_credentials=True)
|
||||||
|
|
||||||
|
|
||||||
|
class TestProcessCORSConfig:
|
||||||
|
"""Test the process_cors_config function."""
|
||||||
|
|
||||||
|
def test_none_returns_none(self):
|
||||||
|
result = process_cors_config(None)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_false_returns_none(self):
|
||||||
|
result = process_cors_config(False)
|
||||||
|
assert result is None
|
||||||
|
|
||||||
|
def test_true_returns_dev_config(self):
|
||||||
|
"""Test dev mode: cors: true"""
|
||||||
|
result = process_cors_config(True)
|
||||||
|
|
||||||
|
assert isinstance(result, CORSConfig)
|
||||||
|
assert result.allow_origins == []
|
||||||
|
assert result.allow_origin_regex == r"https?://localhost:\d+"
|
||||||
|
assert result.allow_credentials is False
|
||||||
|
assert "GET" in result.allow_methods
|
||||||
|
assert "POST" in result.allow_methods
|
||||||
|
|
||||||
|
def test_cors_object_returned_as_is(self):
|
||||||
|
original = CORSConfig(allow_origins=["https://example.com"])
|
||||||
|
result = process_cors_config(original)
|
||||||
|
|
||||||
|
assert result is original
|
||||||
|
|
||||||
|
def test_invalid_type_raises_error(self):
|
||||||
|
with pytest.raises(ValueError, match="Invalid CORS configuration type"):
|
||||||
|
process_cors_config("invalid")
|
||||||
|
|
||||||
|
|
||||||
|
class TestCORSIntegration:
|
||||||
|
"""Test CORS with FastAPI integration."""
|
||||||
|
|
||||||
|
def test_dev_mode_with_fastapi(self):
|
||||||
|
"""Test that dev mode config works with FastAPI middleware."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Use our dev mode config
|
||||||
|
cors_config = process_cors_config(True)
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "hello"}
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Test localhost origins work
|
||||||
|
response = client.get("/test", headers={"Origin": "http://localhost:3000"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
||||||
|
|
||||||
|
# Test non-localhost doesn't get CORS headers
|
||||||
|
response = client.get("/test", headers={"Origin": "https://evil.com"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "Access-Control-Allow-Origin" not in response.headers
|
||||||
|
|
||||||
|
def test_production_mode_with_fastapi(self):
|
||||||
|
"""Test explicit origins configuration."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
# Production config
|
||||||
|
cors_config = CORSConfig(allow_origins=["https://myapp.com"], allow_credentials=True)
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "hello"}
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Test allowed origin works
|
||||||
|
response = client.get("/test", headers={"Origin": "https://myapp.com"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers.get("Access-Control-Allow-Origin") == "https://myapp.com"
|
||||||
|
assert response.headers.get("Access-Control-Allow-Credentials") == "true"
|
||||||
|
|
||||||
|
# Test disallowed origin
|
||||||
|
response = client.get("/test", headers={"Origin": "https://evil.com"})
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert "Access-Control-Allow-Origin" not in response.headers
|
||||||
|
|
||||||
|
def test_preflight_request(self):
|
||||||
|
"""Test CORS preflight OPTIONS request."""
|
||||||
|
from fastapi import FastAPI
|
||||||
|
from fastapi.middleware.cors import CORSMiddleware
|
||||||
|
from fastapi.testclient import TestClient
|
||||||
|
|
||||||
|
app = FastAPI()
|
||||||
|
|
||||||
|
cors_config = process_cors_config(True)
|
||||||
|
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
||||||
|
|
||||||
|
@app.get("/test")
|
||||||
|
def test_endpoint():
|
||||||
|
return {"message": "hello"}
|
||||||
|
|
||||||
|
client = TestClient(app)
|
||||||
|
|
||||||
|
# Preflight request
|
||||||
|
response = client.options(
|
||||||
|
"/test", headers={"Origin": "http://localhost:3000", "Access-Control-Request-Method": "GET"}
|
||||||
|
)
|
||||||
|
|
||||||
|
assert response.status_code == 200
|
||||||
|
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
||||||
|
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")
|
||||||
Loading…
Add table
Add a link
Reference in a new issue