mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 16:32:38 +00:00
Add CORS configuration support for FastAPI server
This commit is contained in:
parent
c716c8cd03
commit
815b5c7279
3 changed files with 107 additions and 141 deletions
|
|
@ -228,6 +228,29 @@ server:
|
||||||
cors: true # Optional: Enable CORS (dev mode) or full config object
|
cors: true # Optional: Enable CORS (dev mode) or full config object
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### CORS Configuration
|
||||||
|
|
||||||
|
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
|
||||||
|
|
||||||
|
**Local development** (allows localhost origins only):
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors: true
|
||||||
|
```
|
||||||
|
|
||||||
|
**Explicit configuration** (custom origins and settings):
|
||||||
|
```yaml
|
||||||
|
server:
|
||||||
|
cors:
|
||||||
|
allow_origins: ["https://myapp.com", "https://app.example.com"]
|
||||||
|
allow_methods: ["GET", "POST", "PUT", "DELETE"]
|
||||||
|
allow_headers: ["Content-Type", "Authorization"]
|
||||||
|
allow_credentials: true
|
||||||
|
max_age: 3600
|
||||||
|
```
|
||||||
|
|
||||||
|
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
|
||||||
|
|
||||||
### Authentication Configuration
|
### Authentication Configuration
|
||||||
|
|
||||||
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
|
||||||
|
|
|
||||||
|
|
@ -318,11 +318,13 @@ class QuotaConfig(BaseModel):
|
||||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||||
|
|
||||||
|
|
||||||
class CORSConfig(BaseModel):
|
class CORSSpec(BaseModel):
|
||||||
allow_origins: list[str] = Field(default=["*"])
|
"""CORS configuration with strict defaults (minimal permissions)."""
|
||||||
|
|
||||||
|
allow_origins: list[str] = Field(default_factory=list)
|
||||||
allow_origin_regex: str | None = Field(default=None)
|
allow_origin_regex: str | None = Field(default=None)
|
||||||
allow_methods: list[str] = Field(default=["*"])
|
allow_methods: list[str] = Field(default=["OPTIONS"])
|
||||||
allow_headers: list[str] = Field(default=["*"])
|
allow_headers: list[str] = Field(default_factory=list)
|
||||||
allow_credentials: bool = Field(default=False)
|
allow_credentials: bool = Field(default=False)
|
||||||
expose_headers: list[str] = Field(default_factory=list)
|
expose_headers: list[str] = Field(default_factory=list)
|
||||||
max_age: int = Field(default=600, ge=0)
|
max_age: int = Field(default=600, ge=0)
|
||||||
|
|
@ -334,21 +336,19 @@ class CORSConfig(BaseModel):
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
# Union type for flexible CORS configuration
|
# Union type for flexible CORS configuration input
|
||||||
CORSConfiguration = bool | CORSConfig
|
# Accepts: bool (dev shortcuts) or CORSSpec (explicit config)
|
||||||
|
CORSConfig = bool | CORSSpec
|
||||||
|
|
||||||
|
|
||||||
def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | None:
|
def process_cors_config(cors_config: CORSConfig) -> CORSSpec | None:
|
||||||
"""Process CORS config: bool -> dev defaults, object -> passthrough."""
|
"""Process CORS config: bool -> dev defaults, CORSSpec -> passthrough."""
|
||||||
if cors_config is None:
|
|
||||||
return None
|
|
||||||
|
|
||||||
if cors_config is False:
|
if cors_config is False:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if cors_config is True:
|
if cors_config is True:
|
||||||
# Dev mode: localhost with any port
|
# Dev mode: localhost with any port
|
||||||
return CORSConfig(
|
return CORSSpec(
|
||||||
allow_origins=[],
|
allow_origins=[],
|
||||||
allow_origin_regex=r"https?://localhost:\d+",
|
allow_origin_regex=r"https?://localhost:\d+",
|
||||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
|
|
@ -358,7 +358,7 @@ def process_cors_config(cors_config: CORSConfiguration | None) -> CORSConfig | N
|
||||||
max_age=600,
|
max_age=600,
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(cors_config, CORSConfig):
|
elif isinstance(cors_config, CORSSpec):
|
||||||
return cors_config
|
return cors_config
|
||||||
|
|
||||||
else:
|
else:
|
||||||
|
|
@ -396,7 +396,7 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Per client quota request configuration",
|
description="Per client quota request configuration",
|
||||||
)
|
)
|
||||||
cors: CORSConfiguration | None = Field(
|
cors: CORSConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="CORS configuration for cross-origin requests. Can be:\n"
|
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||||
"- true: Enable localhost CORS for development\n"
|
"- true: Enable localhost CORS for development\n"
|
||||||
|
|
|
||||||
|
|
@ -6,157 +6,100 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.core.datatypes import CORSConfig, process_cors_config
|
from llama_stack.core.datatypes import CORSSpec, process_cors_config
|
||||||
|
|
||||||
|
|
||||||
class TestCORSConfig:
|
def test_cors_spec_defaults():
|
||||||
"""Test basic CORS configuration."""
|
config = CORSSpec()
|
||||||
|
|
||||||
def test_defaults(self):
|
assert config.allow_origins == []
|
||||||
config = CORSConfig()
|
assert config.allow_origin_regex is None
|
||||||
|
assert config.allow_methods == ["OPTIONS"]
|
||||||
assert config.allow_origins == ["*"]
|
assert config.allow_headers == []
|
||||||
assert config.allow_origin_regex is None
|
assert config.allow_credentials is False
|
||||||
assert config.allow_methods == ["*"]
|
assert config.expose_headers == []
|
||||||
assert config.allow_headers == ["*"]
|
assert config.max_age == 600
|
||||||
assert config.allow_credentials is False
|
|
||||||
assert config.expose_headers == []
|
|
||||||
assert config.max_age == 600
|
|
||||||
|
|
||||||
def test_custom_values(self):
|
|
||||||
config = CORSConfig(allow_origins=["https://example.com"], allow_credentials=True, max_age=3600)
|
|
||||||
|
|
||||||
assert config.allow_origins == ["https://example.com"]
|
|
||||||
assert config.allow_credentials is True
|
|
||||||
assert config.max_age == 3600
|
|
||||||
|
|
||||||
def test_regex_field(self):
|
|
||||||
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
|
|
||||||
|
|
||||||
assert config.allow_origins == []
|
|
||||||
assert config.allow_origin_regex == r"https?://localhost:\d+"
|
|
||||||
|
|
||||||
def test_credentials_with_wildcard_error(self):
|
|
||||||
"""Should raise error when using credentials with wildcard origins."""
|
|
||||||
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
|
|
||||||
CORSConfig(allow_origins=["*"], allow_credentials=True)
|
|
||||||
|
|
||||||
|
|
||||||
class TestProcessCORSConfig:
|
def test_cors_spec_explicit_config():
|
||||||
"""Test the process_cors_config function."""
|
config = CORSSpec(
|
||||||
|
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
|
||||||
|
)
|
||||||
|
|
||||||
def test_none_returns_none(self):
|
assert config.allow_origins == ["https://example.com"]
|
||||||
result = process_cors_config(None)
|
assert config.allow_credentials is True
|
||||||
assert result is None
|
assert config.max_age == 3600
|
||||||
|
assert config.allow_methods == ["GET", "POST"]
|
||||||
def test_false_returns_none(self):
|
|
||||||
result = process_cors_config(False)
|
|
||||||
assert result is None
|
|
||||||
|
|
||||||
def test_true_returns_dev_config(self):
|
|
||||||
"""Test dev mode: cors: true"""
|
|
||||||
result = process_cors_config(True)
|
|
||||||
|
|
||||||
assert isinstance(result, CORSConfig)
|
|
||||||
assert result.allow_origins == []
|
|
||||||
assert result.allow_origin_regex == r"https?://localhost:\d+"
|
|
||||||
assert result.allow_credentials is False
|
|
||||||
assert "GET" in result.allow_methods
|
|
||||||
assert "POST" in result.allow_methods
|
|
||||||
|
|
||||||
def test_cors_object_returned_as_is(self):
|
|
||||||
original = CORSConfig(allow_origins=["https://example.com"])
|
|
||||||
result = process_cors_config(original)
|
|
||||||
|
|
||||||
assert result is original
|
|
||||||
|
|
||||||
def test_invalid_type_raises_error(self):
|
|
||||||
with pytest.raises(ValueError, match="Invalid CORS configuration type"):
|
|
||||||
process_cors_config("invalid")
|
|
||||||
|
|
||||||
|
|
||||||
class TestCORSIntegration:
|
def test_cors_spec_regex():
|
||||||
"""Test CORS with FastAPI integration."""
|
config = CORSSpec(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
|
||||||
|
|
||||||
def test_dev_mode_with_fastapi(self):
|
assert config.allow_origins == []
|
||||||
"""Test that dev mode config works with FastAPI middleware."""
|
assert config.allow_origin_regex == r"https?://localhost:\d+"
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
app = FastAPI()
|
|
||||||
|
|
||||||
# Use our dev mode config
|
def test_cors_spec_wildcard_credentials_error():
|
||||||
cors_config = process_cors_config(True)
|
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
|
||||||
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
CORSSpec(allow_origins=["*"], allow_credentials=True)
|
||||||
|
|
||||||
@app.get("/test")
|
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
|
||||||
def test_endpoint():
|
CORSSpec(allow_origins=["https://example.com", "*"], allow_credentials=True)
|
||||||
return {"message": "hello"}
|
|
||||||
|
|
||||||
client = TestClient(app)
|
|
||||||
|
|
||||||
# Test localhost origins work
|
def test_process_cors_config_false():
|
||||||
response = client.get("/test", headers={"Origin": "http://localhost:3000"})
|
result = process_cors_config(False)
|
||||||
assert response.status_code == 200
|
assert result is None
|
||||||
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
|
||||||
|
|
||||||
# Test non-localhost doesn't get CORS headers
|
|
||||||
response = client.get("/test", headers={"Origin": "https://evil.com"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert "Access-Control-Allow-Origin" not in response.headers
|
|
||||||
|
|
||||||
def test_production_mode_with_fastapi(self):
|
def test_process_cors_config_true():
|
||||||
"""Test explicit origins configuration."""
|
result = process_cors_config(True)
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
app = FastAPI()
|
assert isinstance(result, CORSSpec)
|
||||||
|
assert result.allow_origins == []
|
||||||
|
assert result.allow_origin_regex == r"https?://localhost:\d+"
|
||||||
|
assert result.allow_credentials is False
|
||||||
|
assert "GET" in result.allow_methods
|
||||||
|
assert "POST" in result.allow_methods
|
||||||
|
assert "OPTIONS" in result.allow_methods
|
||||||
|
|
||||||
# Production config
|
|
||||||
cors_config = CORSConfig(allow_origins=["https://myapp.com"], allow_credentials=True)
|
|
||||||
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
|
||||||
|
|
||||||
@app.get("/test")
|
def test_process_cors_config_passthrough():
|
||||||
def test_endpoint():
|
original = CORSSpec(allow_origins=["https://example.com"], allow_methods=["GET"])
|
||||||
return {"message": "hello"}
|
result = process_cors_config(original)
|
||||||
|
|
||||||
client = TestClient(app)
|
assert result is original
|
||||||
|
|
||||||
# Test allowed origin works
|
|
||||||
response = client.get("/test", headers={"Origin": "https://myapp.com"})
|
|
||||||
assert response.status_code == 200
|
|
||||||
assert response.headers.get("Access-Control-Allow-Origin") == "https://myapp.com"
|
|
||||||
assert response.headers.get("Access-Control-Allow-Credentials") == "true"
|
|
||||||
|
|
||||||
# Test disallowed origin
|
def test_process_cors_config_invalid_type():
|
||||||
response = client.get("/test", headers={"Origin": "https://evil.com"})
|
with pytest.raises(ValueError, match="Invalid CORS configuration type"):
|
||||||
assert response.status_code == 200
|
process_cors_config("invalid")
|
||||||
assert "Access-Control-Allow-Origin" not in response.headers
|
|
||||||
|
|
||||||
def test_preflight_request(self):
|
|
||||||
"""Test CORS preflight OPTIONS request."""
|
|
||||||
from fastapi import FastAPI
|
|
||||||
from fastapi.middleware.cors import CORSMiddleware
|
|
||||||
from fastapi.testclient import TestClient
|
|
||||||
|
|
||||||
app = FastAPI()
|
def test_cors_spec_model_dump():
|
||||||
|
cors_spec = CORSSpec(
|
||||||
|
allow_origins=["https://example.com"],
|
||||||
|
allow_methods=["GET", "POST"],
|
||||||
|
allow_headers=["Content-Type"],
|
||||||
|
allow_credentials=True,
|
||||||
|
max_age=3600,
|
||||||
|
)
|
||||||
|
|
||||||
cors_config = process_cors_config(True)
|
config_dict = cors_spec.model_dump()
|
||||||
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
|
|
||||||
|
|
||||||
@app.get("/test")
|
assert config_dict["allow_origins"] == ["https://example.com"]
|
||||||
def test_endpoint():
|
assert config_dict["allow_methods"] == ["GET", "POST"]
|
||||||
return {"message": "hello"}
|
assert config_dict["allow_headers"] == ["Content-Type"]
|
||||||
|
assert config_dict["allow_credentials"] is True
|
||||||
|
assert config_dict["max_age"] == 3600
|
||||||
|
|
||||||
client = TestClient(app)
|
expected_keys = {
|
||||||
|
"allow_origins",
|
||||||
# Preflight request
|
"allow_origin_regex",
|
||||||
response = client.options(
|
"allow_methods",
|
||||||
"/test", headers={"Origin": "http://localhost:3000", "Access-Control-Request-Method": "GET"}
|
"allow_headers",
|
||||||
)
|
"allow_credentials",
|
||||||
|
"expose_headers",
|
||||||
assert response.status_code == 200
|
"max_age",
|
||||||
assert response.headers.get("Access-Control-Allow-Origin") == "http://localhost:3000"
|
}
|
||||||
assert "GET" in response.headers.get("Access-Control-Allow-Methods", "")
|
assert set(config_dict.keys()) == expected_keys
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue