Renamed and addressed PR comments

This commit is contained in:
skamenan7 2025-08-21 09:49:15 -04:00
parent 815b5c7279
commit 17f488036f
2 changed files with 31 additions and 43 deletions

View file

@ -318,9 +318,7 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSSpec(BaseModel): class CORSConfig(BaseModel):
"""CORS configuration with strict defaults (minimal permissions)."""
allow_origins: list[str] = Field(default_factory=list) allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None) allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"]) allow_methods: list[str] = Field(default=["OPTIONS"])
@ -330,39 +328,29 @@ class CORSSpec(BaseModel):
max_age: int = Field(default=600, ge=0) max_age: int = Field(default=600, ge=0)
@model_validator(mode="after") @model_validator(mode="after")
def _validate_credentials_with_wildcard(self) -> Self: def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins): if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("CORS: allow_credentials=True requires explicit origins") raise ValueError("Cannot use wildcard origins with credentials enabled")
return self return self
# Union type for flexible CORS configuration input def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
# Accepts: bool (dev shortcuts) or CORSSpec (explicit config) if cors_config is False or cors_config is None:
CORSConfig = bool | CORSSpec
def process_cors_config(cors_config: CORSConfig) -> CORSSpec | None:
"""Process CORS config: bool -> dev defaults, CORSSpec -> passthrough."""
if cors_config is False:
return None return None
if cors_config is True: if cors_config is True:
# Dev mode: localhost with any port # dev mode: allow localhost on any port
return CORSSpec( return CORSConfig(
allow_origins=[], allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+", allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"], allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"], allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
allow_credentials=False,
expose_headers=[],
max_age=600,
) )
elif isinstance(cors_config, CORSSpec): if isinstance(cors_config, CORSConfig):
return cors_config return cors_config
else: raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
raise ValueError(f"Invalid CORS configuration type: {type(cors_config)}")
class ServerConfig(BaseModel): class ServerConfig(BaseModel):
@ -396,7 +384,7 @@ class ServerConfig(BaseModel):
default=None, default=None,
description="Per client quota request configuration", description="Per client quota request configuration",
) )
cors: CORSConfig | None = Field( cors: bool | CORSConfig | None = Field(
default=None, default=None,
description="CORS configuration for cross-origin requests. Can be:\n" description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n" "- true: Enable localhost CORS for development\n"

View file

@ -6,11 +6,11 @@
import pytest import pytest
from llama_stack.core.datatypes import CORSSpec, process_cors_config from llama_stack.core.datatypes import CORSConfig, process_cors_config
def test_cors_spec_defaults(): def test_cors_config_defaults():
config = CORSSpec() config = CORSConfig()
assert config.allow_origins == [] assert config.allow_origins == []
assert config.allow_origin_regex is None assert config.allow_origin_regex is None
@ -21,8 +21,8 @@ def test_cors_spec_defaults():
assert config.max_age == 600 assert config.max_age == 600
def test_cors_spec_explicit_config(): def test_cors_config_explicit_config():
config = CORSSpec( config = CORSConfig(
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"] allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
) )
@ -32,19 +32,19 @@ def test_cors_spec_explicit_config():
assert config.allow_methods == ["GET", "POST"] assert config.allow_methods == ["GET", "POST"]
def test_cors_spec_regex(): def test_cors_config_regex():
config = CORSSpec(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+") config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
assert config.allow_origins == [] assert config.allow_origins == []
assert config.allow_origin_regex == r"https?://localhost:\d+" assert config.allow_origin_regex == r"https?://localhost:\d+"
def test_cors_spec_wildcard_credentials_error(): def test_cors_config_wildcard_credentials_error():
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"): with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSSpec(allow_origins=["*"], allow_credentials=True) CORSConfig(allow_origins=["*"], allow_credentials=True)
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"): with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSSpec(allow_origins=["https://example.com", "*"], allow_credentials=True) CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
def test_process_cors_config_false(): def test_process_cors_config_false():
@ -55,29 +55,29 @@ def test_process_cors_config_false():
def test_process_cors_config_true(): def test_process_cors_config_true():
result = process_cors_config(True) result = process_cors_config(True)
assert isinstance(result, CORSSpec) assert isinstance(result, CORSConfig)
assert result.allow_origins == [] assert result.allow_origins == []
assert result.allow_origin_regex == r"https?://localhost:\d+" assert result.allow_origin_regex == r"https?://localhost:\d+"
assert result.allow_credentials is False assert result.allow_credentials is False
assert "GET" in result.allow_methods expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
assert "POST" in result.allow_methods for method in expected_methods:
assert "OPTIONS" in result.allow_methods assert method in result.allow_methods
def test_process_cors_config_passthrough(): def test_process_cors_config_passthrough():
original = CORSSpec(allow_origins=["https://example.com"], allow_methods=["GET"]) original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
result = process_cors_config(original) result = process_cors_config(original)
assert result is original assert result is original
def test_process_cors_config_invalid_type(): def test_process_cors_config_invalid_type():
with pytest.raises(ValueError, match="Invalid CORS configuration type"): with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
process_cors_config("invalid") process_cors_config("invalid")
def test_cors_spec_model_dump(): def test_cors_config_model_dump():
cors_spec = CORSSpec( cors_config = CORSConfig(
allow_origins=["https://example.com"], allow_origins=["https://example.com"],
allow_methods=["GET", "POST"], allow_methods=["GET", "POST"],
allow_headers=["Content-Type"], allow_headers=["Content-Type"],
@ -85,7 +85,7 @@ def test_cors_spec_model_dump():
max_age=3600, max_age=3600,
) )
config_dict = cors_spec.model_dump() config_dict = cors_config.model_dump()
assert config_dict["allow_origins"] == ["https://example.com"] assert config_dict["allow_origins"] == ["https://example.com"]
assert config_dict["allow_methods"] == ["GET", "POST"] assert config_dict["allow_methods"] == ["GET", "POST"]