Merge branch 'main' into session-manager

This commit is contained in:
Francisco Arceo 2025-08-21 18:44:02 -06:00 committed by GitHub
commit d0c5e07f8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 529 additions and 173 deletions

View file

@ -18,7 +18,7 @@ on:
- '.github/workflows/integration-auth-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:

View file

@ -16,7 +16,7 @@ on:
- '.github/workflows/integration-sql-store-tests.yml' # This workflow
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:

View file

@ -8,7 +8,7 @@ on:
branches: [main]
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:

View file

@ -26,7 +26,7 @@ on:
- 'pyproject.toml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:
@ -106,6 +106,10 @@ jobs:
- name: Inspect the container image entrypoint
run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
@ -140,6 +144,10 @@ jobs:
- name: Inspect UBI9 image
run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then

View file

@ -13,7 +13,7 @@ on:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:

View file

@ -18,7 +18,7 @@ on:
workflow_dispatch:
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:

View file

@ -27,7 +27,7 @@ on:
- '.github/workflows/update-readthedocs.yml'
concurrency:
group: ${{ github.workflow }}-${{ github.ref }}
group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true
jobs:

View file

@ -225,8 +225,32 @@ server:
port: 8321 # Port to listen on (default: 8321)
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
cors: true # Optional: Enable CORS (dev mode) or full config object
```
### CORS Configuration
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
**Local development** (allows localhost origins only):
```yaml
server:
cors: true
```
**Explicit configuration** (custom origins and settings):
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://app.example.com"]
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_credentials: true
max_age: 3600
```
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
### Authentication Configuration
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
@ -618,6 +642,54 @@ Content-Type: application/json
}
```
### CORS Configuration
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
#### Quick Setup
For development, use the simple boolean flag:
```yaml
server:
cors: true # Auto-enables localhost with any port
```
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
#### Custom Configuration
For specific origins and full control:
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
allow_credentials: true
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
expose_headers: ["X-Total-Count"]
max_age: 86400
```
#### Configuration Options
| Field | Description | Default |
| -------------------- | ---------------------------------------------- | ------- |
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
| `allow_headers` | Allowed headers. | `["*"]` |
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
| `expose_headers` | Headers exposed to browser. | `[]` |
| `max_age` | Preflight cache time (seconds). | `600` |
**Security Notes**:
- `allow_credentials: true` requires explicit origins (no wildcards)
- `cors: true` enables localhost access only (secure for development)
- For public APIs, always specify exact allowed origins
## Extending to handle Safety
Configuring Safety can be a little involved so it is instructive to go through an example.

View file

@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
)
client.initialize()
```
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
```python
client = LlamaStackAsLibraryClient(config_path)
client.initialize()
```

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="cli")
class StackRun(Subcommand):

View file

@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class ServerConfig(BaseModel):
port: int = Field(
default=8321,
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
default=None,
description="Per client quota request configuration",
)
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel):

View file

@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
):
super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data
config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
)
self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data
self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
# use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
try:
return loop.run_until_complete(self.async_client.initialize())
loop.run_until_complete(self.async_client.initialize())
finally:
asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self):
def initialize(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
Deprecated method for backward compatibility.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
pass
def request(self, *args, **kwargs):
loop = self.loop
@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
):
super().__init__()
# when using the library client, we should not log to console since many
@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name)
if not config_path.exists():
@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool:
"""
Initialize the async client.
Returns:
bool: True if initialization was successful
"""
try:
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class DatasetIORouter(DatasetIO):

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ScoringRouter(Scoring):

View file

@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="core::routers")
class InferenceRouter(Inference):

View file

@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class SafetyRouter(Safety):

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime):

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routers")
class VectorIORouter(VectorIO):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api:

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthenticationMiddleware:

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
)
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth")
logger = get_logger(name=__name__, category="core::auth")
class AuthResponse(BaseModel):

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota")
logger = get_logger(name=__name__, category="core::server")
class QuotaMiddleware:

View file

@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError,
LoggingConfig,
StackRunConfig,
process_cors_config,
)
from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server")
logger = get_logger(name=__name__, category="core::server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env:
try:
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds,
)
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls:
setup_logger(impls[Api.telemetry])
else:

View file

@ -16,7 +16,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core")
logger = get_logger(__name__, category="core::registry")
class DistributionRegistry(Protocol):

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution")
logger = get_logger(name=__name__, category="core")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"

View file

@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
MP_SCALE = 8
logger = get_logger(name=__name__, category="models")
logger = get_logger(name=__name__, category="models::llama")
def reduce_from_tensor_model_parallel_region(input_):

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="models::llama")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")

View file

@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock
from ..moe import MoE
log = get_logger(name=__name__, category="models")
log = get_logger(name=__name__, category="models::llama")
def swiglu_wrapper_no_reduce(

View file

@ -9,7 +9,7 @@ import collections
from llama_stack.log import get_logger
log = get_logger(name=__name__, category="llama")
log = get_logger(name=__name__, category="models::llama")
try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ChatAgent(ShieldRunnerMixin):

View file

@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl
logger = get_logger(name=__name__, category="agents")
logger = get_logger(name=__name__, category="agents::meta_reference")
class MetaReferenceAgentsImpl(Agents):

View file

@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore
log = get_logger(name=__name__, category="agents")
log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session):

View file

@ -41,7 +41,7 @@ from .utils import (
convert_response_text_to_chat_response_format,
)
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="openai::responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel):

View file

@ -47,7 +47,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class StreamingResponseOrchestrator:

View file

@ -38,7 +38,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses")
logger = get_logger(name=__name__, category="agents::meta_reference")
class ToolExecutor:

View file

@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing
log = get_logger(name=__name__, category="agents")
log = get_logger(name=__name__, category="agents::meta_reference")
class SafetyException(Exception): # noqa: N818

View file

@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -10,7 +10,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):

View file

@ -41,6 +41,11 @@ client.initialize()
### Create Completion
> Note on Completion API
>
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
```python
response = client.inference.completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
@ -76,6 +81,73 @@ response = client.inference.chat_completion(
print(f"Response: {response.completion_message.content}")
```
### Tool Calling Example ###
```python
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
tool_definition = ToolDefinition(
tool_name="get_weather",
description="Get current weather information for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True,
),
"unit": ToolParamDefinition(
param_type="string",
description="Temperature unit (celsius or fahrenheit)",
required=False,
default="celsius",
),
},
)
tool_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.completion_message.content}")
if tool_response.completion_message.tool_calls:
for tool_call in tool_response.completion_message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
```
### Structured Output Example
```python
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
response_format=response_format,
)
print(f"Structured Response: {structured_response.completion_message.content}")
```
### Create Embeddings
> Note on OpenAI embeddings compatibility
>

View file

@ -7,7 +7,7 @@
import warnings
from collections.abc import AsyncIterator
from openai import NOT_GIVEN, APIConnectionError, BadRequestError
from openai import NOT_GIVEN, APIConnectionError
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -57,7 +57,7 @@ from .openai_utils import (
)
from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
}
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
response = await self.client.embeddings.create(
model=provider_model_id,
input=input,
extra_body=extra_body,
)
#
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# ->

View file

@ -10,7 +10,7 @@ from llama_stack.log import get_logger
from . import NVIDIAConfig
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::nvidia")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(

View file

@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::openai")
#

View file

@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries():

View file

@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig
from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries():

View file

@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
from .config import NvidiaPostTrainingConfig
logger = get_logger(name=__name__, category="integration")
logger = get_logger(name=__name__, category="post_training::nvidia")
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::bedrock")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -9,7 +9,7 @@ from typing import Any
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import NVIDIASafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::nvidia")
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
class NeMoGuardrails:
"""

View file

@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import SambaNovaSafetyConfig
logger = get_logger(name=__name__, category="safety")
logger = get_logger(name=__name__, category="safety::sambanova")
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = get_logger(name=__name__, category="vector_io")
logger = get_logger(name=__name__, category="vector_io::milvus")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::pgvector")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::qdrant")
CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig
log = get_logger(name=__name__, category="vector_io")
log = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"

View file

@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
EMBEDDING_MODELS = {}
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="providers::utils")
class SentenceTransformerEmbeddingMixin:

View file

@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="providers::utils")
class LiteLLMOpenAIMixin(

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel):

View file

@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message,
)
logger = get_logger(name=__name__, category="inference")
logger = get_logger(name=__name__, category="providers::utils")
class OpenAICompatCompletionChoiceDelta(BaseModel):

View file

@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core")
logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ABC):

View file

@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="inference")
log = get_logger(name=__name__, category="providers::utils")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest):

View file

@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig
log = get_logger(name=__name__, category="kvstore")
log = get_logger(name=__name__, category="providers::utils")
class MongoDBKVStoreImpl(KVStore):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from ..api import KVStore
from ..config import PostgresKVStoreConfig
log = get_logger(name=__name__, category="kvstore")
log = get_logger(name=__name__, category="providers::utils")
class PostgresKVStoreImpl(KVStore):

View file

@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
make_overlapped_chunks,
)
logger = get_logger(name=__name__, category="memory")
logger = get_logger(name=__name__, category="providers::utils")
# Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
)
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = get_logger(name=__name__, category="memory")
log = get_logger(name=__name__, category="providers::utils")
class ChunkForDeletion(BaseModel):

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel
from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler")
logger = get_logger(name=__name__, category="providers::utils")
# TODO: revisit the list of possible statuses when defining a more coherent

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore")
logger = get_logger(name=__name__, category="providers::utils")
# Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig
logger = get_logger(name=__name__, category="sqlstore")
logger = get_logger(name=__name__, category="providers::utils")
TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer,

View file

@ -256,9 +256,6 @@ def instantiate_llama_stack_client(session):
provider_data=get_provider_data(),
skip_logger_removal=True,
)
if not client.initialize():
raise RuntimeError("Initialization failed")
return client

View file

@ -55,7 +55,7 @@
#
import pytest
from llama_stack_client import BadRequestError
from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import EmbeddingsResponse
from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem,
@ -63,6 +63,9 @@ from llama_stack_client.types.shared.interleaved_content import (
ImageContentItemImageURL,
TextContentItem,
)
from openai import BadRequestError as OpenAIBadRequestError
from llama_stack.core.library_client import LlamaStackAsLibraryClient
DUMMY_STRING = "hello"
DUMMY_STRING2 = "world"
@ -203,7 +206,14 @@ def test_embedding_truncation_error(
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError):
# Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
# While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
error_type = (
OpenAIBadRequestError
if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
else LlamaStackBadRequestError
)
with pytest.raises(error_type):
llama_stack_client.inference.embeddings(
model_id=embedding_model_id,
contents=[DUMMY_LONG_TEXT],
@ -283,7 +293,8 @@ def test_embedding_text_truncation_error(
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError):
error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
with pytest.raises(error_type):
llama_stack_client.inference.embeddings(
model_id=embedding_model_id,
contents=[DUMMY_STRING],

View file

@ -113,8 +113,6 @@ def openai_client(base_url, api_key, provider):
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
config = parts[1]
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
if not client.initialize():
raise RuntimeError("Initialization failed")
return client
return OpenAI(

View file

@ -5,86 +5,121 @@
# the root directory of this source tree.
"""
Unit tests for LlamaStackAsLibraryClient initialization error handling.
Unit tests for LlamaStackAsLibraryClient automatic initialization.
These tests ensure that users get proper error messages when they forget to call
initialize() on the library client, preventing AttributeError regressions.
These tests ensure that the library client is automatically initialized
and ready to use immediately after construction.
"""
import pytest
from llama_stack.core.library_client import (
AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient,
)
from llama_stack.core.server.routes import RouteImpls
class TestLlamaStackAsLibraryClientInitialization:
"""Test proper error handling for uninitialized library clients."""
class TestLlamaStackAsLibraryClientAutoInitialization:
"""Test automatic initialization of library clients."""
@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
lambda client: next(
client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
),
],
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
)
def test_sync_client_proper_error_without_initialization(self, api_call):
"""Test that sync client raises ValueError with helpful message when not initialized."""
client = LlamaStackAsLibraryClient("nvidia")
def test_sync_client_auto_initialization(self, monkeypatch):
"""Test that sync client is automatically initialized after construction."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
with pytest.raises(ValueError) as exc_info:
api_call(client)
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
def mock_initialize_route_impls(impls):
return mock_route_impls
@pytest.mark.parametrize(
"api_call",
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
],
ids=["models.list", "chat.completions.create"],
)
async def test_async_client_proper_error_without_initialization(self, api_call):
"""Test that async client raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
with pytest.raises(ValueError) as exc_info:
await api_call(client)
client = LlamaStackAsLibraryClient("ci-tests")
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
assert client.async_client.route_impls is not None
async def test_async_client_streaming_error_without_initialization(self):
"""Test that async client streaming raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")
async def test_async_client_auto_initialization(self, monkeypatch):
"""Test that async client can be initialized and works properly."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
with pytest.raises(ValueError) as exc_info:
stream = await client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
await anext(stream)
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
error_msg = str(exc_info.value)
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
def mock_initialize_route_impls(impls):
return mock_route_impls
def test_route_impls_initialized_to_none(self):
"""Test that route_impls is initialized to None to prevent AttributeError."""
# Test sync client
sync_client = LlamaStackAsLibraryClient("nvidia")
assert sync_client.async_client.route_impls is None
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
# Test async client directly
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
assert async_client.route_impls is None
client = AsyncLlamaStackAsLibraryClient("ci-tests")
# Initialize the client
result = await client.initialize()
assert result is True
assert client.route_impls is not None
def test_initialize_method_backward_compatibility(self, monkeypatch):
"""Test that initialize() method still works for backward compatibility."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
result = client.initialize()
assert result is None
result2 = client.initialize()
assert result2 is None
async def test_async_initialize_method_idempotent(self, monkeypatch):
"""Test that async initialize() method can be called multiple times safely."""
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
result1 = await client.initialize()
assert result1 is True
result2 = await client.initialize()
assert result2 is True
def test_route_impls_automatically_set(self, monkeypatch):
"""Test that route_impls is automatically set during construction."""
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests")
assert sync_client.async_client.route_impls is not None

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.core.datatypes import CORSConfig, process_cors_config
def test_cors_config_defaults():
config = CORSConfig()
assert config.allow_origins == []
assert config.allow_origin_regex is None
assert config.allow_methods == ["OPTIONS"]
assert config.allow_headers == []
assert config.allow_credentials is False
assert config.expose_headers == []
assert config.max_age == 600
def test_cors_config_explicit_config():
config = CORSConfig(
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
)
assert config.allow_origins == ["https://example.com"]
assert config.allow_credentials is True
assert config.max_age == 3600
assert config.allow_methods == ["GET", "POST"]
def test_cors_config_regex():
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
assert config.allow_origins == []
assert config.allow_origin_regex == r"https?://localhost:\d+"
def test_cors_config_wildcard_credentials_error():
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSConfig(allow_origins=["*"], allow_credentials=True)
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
def test_process_cors_config_false():
result = process_cors_config(False)
assert result is None
def test_process_cors_config_true():
result = process_cors_config(True)
assert isinstance(result, CORSConfig)
assert result.allow_origins == []
assert result.allow_origin_regex == r"https?://localhost:\d+"
assert result.allow_credentials is False
expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
for method in expected_methods:
assert method in result.allow_methods
def test_process_cors_config_passthrough():
original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
result = process_cors_config(original)
assert result is original
def test_process_cors_config_invalid_type():
with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
process_cors_config("invalid")
def test_cors_config_model_dump():
cors_config = CORSConfig(
allow_origins=["https://example.com"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
allow_credentials=True,
max_age=3600,
)
config_dict = cors_config.model_dump()
assert config_dict["allow_origins"] == ["https://example.com"]
assert config_dict["allow_methods"] == ["GET", "POST"]
assert config_dict["allow_headers"] == ["Content-Type"]
assert config_dict["allow_credentials"] is True
assert config_dict["max_age"] == 3600
expected_keys = {
"allow_origins",
"allow_origin_regex",
"allow_methods",
"allow_headers",
"allow_credentials",
"expose_headers",
"max_age",
}
assert set(config_dict.keys()) == expected_keys