mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-12-17 18:12:43 +00:00
Renamed and addressed PR comments
This commit is contained in:
parent
815b5c7279
commit
17f488036f
2 changed files with 31 additions and 43 deletions
|
|
@ -318,9 +318,7 @@ class QuotaConfig(BaseModel):
|
||||||
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
|
||||||
|
|
||||||
|
|
||||||
class CORSSpec(BaseModel):
|
class CORSConfig(BaseModel):
|
||||||
"""CORS configuration with strict defaults (minimal permissions)."""
|
|
||||||
|
|
||||||
allow_origins: list[str] = Field(default_factory=list)
|
allow_origins: list[str] = Field(default_factory=list)
|
||||||
allow_origin_regex: str | None = Field(default=None)
|
allow_origin_regex: str | None = Field(default=None)
|
||||||
allow_methods: list[str] = Field(default=["OPTIONS"])
|
allow_methods: list[str] = Field(default=["OPTIONS"])
|
||||||
|
|
@ -330,39 +328,29 @@ class CORSSpec(BaseModel):
|
||||||
max_age: int = Field(default=600, ge=0)
|
max_age: int = Field(default=600, ge=0)
|
||||||
|
|
||||||
@model_validator(mode="after")
|
@model_validator(mode="after")
|
||||||
def _validate_credentials_with_wildcard(self) -> Self:
|
def validate_credentials_config(self) -> Self:
|
||||||
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
|
||||||
raise ValueError("CORS: allow_credentials=True requires explicit origins")
|
raise ValueError("Cannot use wildcard origins with credentials enabled")
|
||||||
return self
|
return self
|
||||||
|
|
||||||
|
|
||||||
# Union type for flexible CORS configuration input
|
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
|
||||||
# Accepts: bool (dev shortcuts) or CORSSpec (explicit config)
|
if cors_config is False or cors_config is None:
|
||||||
CORSConfig = bool | CORSSpec
|
|
||||||
|
|
||||||
|
|
||||||
def process_cors_config(cors_config: CORSConfig) -> CORSSpec | None:
|
|
||||||
"""Process CORS config: bool -> dev defaults, CORSSpec -> passthrough."""
|
|
||||||
if cors_config is False:
|
|
||||||
return None
|
return None
|
||||||
|
|
||||||
if cors_config is True:
|
if cors_config is True:
|
||||||
# Dev mode: localhost with any port
|
# dev mode: allow localhost on any port
|
||||||
return CORSSpec(
|
return CORSConfig(
|
||||||
allow_origins=[],
|
allow_origins=[],
|
||||||
allow_origin_regex=r"https?://localhost:\d+",
|
allow_origin_regex=r"https?://localhost:\d+",
|
||||||
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
|
||||||
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
|
||||||
allow_credentials=False,
|
|
||||||
expose_headers=[],
|
|
||||||
max_age=600,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
elif isinstance(cors_config, CORSSpec):
|
if isinstance(cors_config, CORSConfig):
|
||||||
return cors_config
|
return cors_config
|
||||||
|
|
||||||
else:
|
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
|
||||||
raise ValueError(f"Invalid CORS configuration type: {type(cors_config)}")
|
|
||||||
|
|
||||||
|
|
||||||
class ServerConfig(BaseModel):
|
class ServerConfig(BaseModel):
|
||||||
|
|
@ -396,7 +384,7 @@ class ServerConfig(BaseModel):
|
||||||
default=None,
|
default=None,
|
||||||
description="Per client quota request configuration",
|
description="Per client quota request configuration",
|
||||||
)
|
)
|
||||||
cors: CORSConfig | None = Field(
|
cors: bool | CORSConfig | None = Field(
|
||||||
default=None,
|
default=None,
|
||||||
description="CORS configuration for cross-origin requests. Can be:\n"
|
description="CORS configuration for cross-origin requests. Can be:\n"
|
||||||
"- true: Enable localhost CORS for development\n"
|
"- true: Enable localhost CORS for development\n"
|
||||||
|
|
|
||||||
|
|
@ -6,11 +6,11 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
from llama_stack.core.datatypes import CORSSpec, process_cors_config
|
from llama_stack.core.datatypes import CORSConfig, process_cors_config
|
||||||
|
|
||||||
|
|
||||||
def test_cors_spec_defaults():
|
def test_cors_config_defaults():
|
||||||
config = CORSSpec()
|
config = CORSConfig()
|
||||||
|
|
||||||
assert config.allow_origins == []
|
assert config.allow_origins == []
|
||||||
assert config.allow_origin_regex is None
|
assert config.allow_origin_regex is None
|
||||||
|
|
@ -21,8 +21,8 @@ def test_cors_spec_defaults():
|
||||||
assert config.max_age == 600
|
assert config.max_age == 600
|
||||||
|
|
||||||
|
|
||||||
def test_cors_spec_explicit_config():
|
def test_cors_config_explicit_config():
|
||||||
config = CORSSpec(
|
config = CORSConfig(
|
||||||
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
|
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -32,19 +32,19 @@ def test_cors_spec_explicit_config():
|
||||||
assert config.allow_methods == ["GET", "POST"]
|
assert config.allow_methods == ["GET", "POST"]
|
||||||
|
|
||||||
|
|
||||||
def test_cors_spec_regex():
|
def test_cors_config_regex():
|
||||||
config = CORSSpec(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
|
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
|
||||||
|
|
||||||
assert config.allow_origins == []
|
assert config.allow_origins == []
|
||||||
assert config.allow_origin_regex == r"https?://localhost:\d+"
|
assert config.allow_origin_regex == r"https?://localhost:\d+"
|
||||||
|
|
||||||
|
|
||||||
def test_cors_spec_wildcard_credentials_error():
|
def test_cors_config_wildcard_credentials_error():
|
||||||
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
|
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
|
||||||
CORSSpec(allow_origins=["*"], allow_credentials=True)
|
CORSConfig(allow_origins=["*"], allow_credentials=True)
|
||||||
|
|
||||||
with pytest.raises(ValueError, match="CORS: allow_credentials=True requires explicit origins"):
|
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
|
||||||
CORSSpec(allow_origins=["https://example.com", "*"], allow_credentials=True)
|
CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
|
||||||
|
|
||||||
|
|
||||||
def test_process_cors_config_false():
|
def test_process_cors_config_false():
|
||||||
|
|
@ -55,29 +55,29 @@ def test_process_cors_config_false():
|
||||||
def test_process_cors_config_true():
|
def test_process_cors_config_true():
|
||||||
result = process_cors_config(True)
|
result = process_cors_config(True)
|
||||||
|
|
||||||
assert isinstance(result, CORSSpec)
|
assert isinstance(result, CORSConfig)
|
||||||
assert result.allow_origins == []
|
assert result.allow_origins == []
|
||||||
assert result.allow_origin_regex == r"https?://localhost:\d+"
|
assert result.allow_origin_regex == r"https?://localhost:\d+"
|
||||||
assert result.allow_credentials is False
|
assert result.allow_credentials is False
|
||||||
assert "GET" in result.allow_methods
|
expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
|
||||||
assert "POST" in result.allow_methods
|
for method in expected_methods:
|
||||||
assert "OPTIONS" in result.allow_methods
|
assert method in result.allow_methods
|
||||||
|
|
||||||
|
|
||||||
def test_process_cors_config_passthrough():
|
def test_process_cors_config_passthrough():
|
||||||
original = CORSSpec(allow_origins=["https://example.com"], allow_methods=["GET"])
|
original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
|
||||||
result = process_cors_config(original)
|
result = process_cors_config(original)
|
||||||
|
|
||||||
assert result is original
|
assert result is original
|
||||||
|
|
||||||
|
|
||||||
def test_process_cors_config_invalid_type():
|
def test_process_cors_config_invalid_type():
|
||||||
with pytest.raises(ValueError, match="Invalid CORS configuration type"):
|
with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
|
||||||
process_cors_config("invalid")
|
process_cors_config("invalid")
|
||||||
|
|
||||||
|
|
||||||
def test_cors_spec_model_dump():
|
def test_cors_config_model_dump():
|
||||||
cors_spec = CORSSpec(
|
cors_config = CORSConfig(
|
||||||
allow_origins=["https://example.com"],
|
allow_origins=["https://example.com"],
|
||||||
allow_methods=["GET", "POST"],
|
allow_methods=["GET", "POST"],
|
||||||
allow_headers=["Content-Type"],
|
allow_headers=["Content-Type"],
|
||||||
|
|
@ -85,7 +85,7 @@ def test_cors_spec_model_dump():
|
||||||
max_age=3600,
|
max_age=3600,
|
||||||
)
|
)
|
||||||
|
|
||||||
config_dict = cors_spec.model_dump()
|
config_dict = cors_config.model_dump()
|
||||||
|
|
||||||
assert config_dict["allow_origins"] == ["https://example.com"]
|
assert config_dict["allow_origins"] == ["https://example.com"]
|
||||||
assert config_dict["allow_methods"] == ["GET", "POST"]
|
assert config_dict["allow_methods"] == ["GET", "POST"]
|
||||||
|
|
|
||||||
Loading…
Add table
Add a link
Reference in a new issue