diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index bb99bc636..c3940fcbd 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -318,9 +318,7 @@ class QuotaConfig(BaseModel): period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") -class CORSSpec(BaseModel): - """CORS configuration with strict defaults (minimal permissions).""" - +class CORSConfig(BaseModel): allow_origins: list[str] = Field(default_factory=list) allow_origin_regex: str | None = Field(default=None) allow_methods: list[str] = Field(default=["OPTIONS"]) @@ -330,39 +328,29 @@ class CORSSpec(BaseModel): max_age: int = Field(default=600, ge=0) @model_validator(mode="after") - def _validate_credentials_with_wildcard(self) -> Self: + def validate_credentials_config(self) -> Self: if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins): - raise ValueError("CORS: allow_credentials=True requires explicit origins") + raise ValueError("Cannot use wildcard origins with credentials enabled") return self -# Union type for flexible CORS configuration input -# Accepts: bool (dev shortcuts) or CORSSpec (explicit config) -CORSConfig = bool | CORSSpec - - -def process_cors_config(cors_config: CORSConfig) -> CORSSpec | None: - """Process CORS config: bool -> dev defaults, CORSSpec -> passthrough.""" - if cors_config is False: +def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None: + if cors_config is False or cors_config is None: return None if cors_config is True: - # Dev mode: localhost with any port - return CORSSpec( + # dev mode: allow localhost on any port + return CORSConfig( allow_origins=[], allow_origin_regex=r"https?://localhost:\d+", allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_headers=["Content-Type", "Authorization", "X-Requested-With"], - allow_credentials=False, - expose_headers=[], - max_age=600, ) - elif isinstance(cors_config, CORSSpec): + if isinstance(cors_config, CORSConfig): return cors_config - else: - raise ValueError(f"Invalid CORS configuration type: {type(cors_config)}") + raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}") class ServerConfig(BaseModel): @@ -396,7 +384,7 @@ class ServerConfig(BaseModel): default=None, description="Per client quota request configuration", ) - cors: CORSConfig | None = Field( + cors: bool | CORSConfig | None = Field( default=None, description="CORS configuration for cross-origin requests. Can be:\n" "- true: Enable localhost CORS for development\n" diff --git a/tests/unit/server/test_cors.py b/tests/unit/server/test_cors.py index 512be1bd2..8fd2515ba 100644 --- a/tests/unit/server/test_cors.py +++ b/tests/unit/server/test_cors.py @@ -6,11 +6,11 @@ import pytest -from llama_stack.core.datatypes import CORSSpec, process_cors_config +from llama_stack.core.datatypes import CORSConfig, process_cors_config -def test_cors_spec_defaults(): - config = CORSSpec() +def test_cors_config_defaults(): + config = CORSConfig() assert config.allow_origins == [] assert config.allow_origin_regex is None @@ -21,8 +21,8 @@ def test_cors_spec_defaults(): assert config.max_age == 600 -def test_cors_spec_explicit_config(): - config = CORSSpec( +def test_cors_config_explicit_config(): + config = CORSConfig( allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"] ) @@ -32,19 +32,19 @@ def test_cors_spec_explicit_config(): assert config.allow_methods == ["GET", "POST"] -def test_cors_spec_regex(): - config = CORSSpec(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+") +def test_cors_config_regex(): + config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+") assert config.allow_origins == [] assert config.allow_origin_regex == r"https?://localhost:\d+" -def test_cors_spec_wildcard_credentials_error(): - with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"): - CORSSpec(allow_origins=["*"], allow_credentials=True) +def test_cors_config_wildcard_credentials_error(): + with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"): + CORSConfig(allow_origins=["*"], allow_credentials=True) - with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"): - CORSSpec(allow_origins=["https://example.com", "*"], allow_credentials=True) + with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"): + CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True) def test_process_cors_config_false(): @@ -55,29 +55,29 @@ def test_process_cors_config_false(): def test_process_cors_config_true(): result = process_cors_config(True) - assert isinstance(result, CORSSpec) + assert isinstance(result, CORSConfig) assert result.allow_origins == [] assert result.allow_origin_regex == r"https?://localhost:\d+" assert result.allow_credentials is False - assert "GET" in result.allow_methods - assert "POST" in result.allow_methods - assert "OPTIONS" in result.allow_methods + expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"] + for method in expected_methods: + assert method in result.allow_methods def test_process_cors_config_passthrough(): - original = CORSSpec(allow_origins=["https://example.com"], allow_methods=["GET"]) + original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"]) result = process_cors_config(original) assert result is original def test_process_cors_config_invalid_type(): - with pytest.raises(ValueError, match="Invalid CORS configuration type"): + with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"): process_cors_config("invalid") -def test_cors_spec_model_dump(): - cors_spec = CORSSpec( +def test_cors_config_model_dump(): + cors_config = CORSConfig( allow_origins=["https://example.com"], allow_methods=["GET", "POST"], allow_headers=["Content-Type"], @@ -85,7 +85,7 @@ def test_cors_spec_model_dump(): max_age=3600, ) - config_dict = cors_spec.model_dump() + config_dict = cors_config.model_dump() assert config_dict["allow_origins"] == ["https://example.com"] assert config_dict["allow_methods"] == ["GET", "POST"]