Merge branch 'main' into session-manager

This commit is contained in:
Francisco Arceo 2025-08-21 18:44:02 -06:00 committed by GitHub
commit d0c5e07f8a
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
80 changed files with 529 additions and 173 deletions

View file

@ -18,7 +18,7 @@ on:
- '.github/workflows/integration-auth-tests.yml' # This workflow - '.github/workflows/integration-auth-tests.yml' # This workflow
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:

View file

@ -16,7 +16,7 @@ on:
- '.github/workflows/integration-sql-store-tests.yml' # This workflow - '.github/workflows/integration-sql-store-tests.yml' # This workflow
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:

View file

@ -8,7 +8,7 @@ on:
branches: [main] branches: [main]
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:

View file

@ -26,7 +26,7 @@ on:
- 'pyproject.toml' - 'pyproject.toml'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:
@ -106,6 +106,10 @@ jobs:
- name: Inspect the container image entrypoint - name: Inspect the container image entrypoint
run: | run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint" echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then
@ -140,6 +144,10 @@ jobs:
- name: Inspect UBI9 image - name: Inspect UBI9 image
run: | run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
if [ -z "$IMAGE_ID" ]; then
echo "No image found"
exit 1
fi
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint" echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then if [ "$entrypoint" != "[python -m llama_stack.core.server.server /app/run.yaml]" ]; then

View file

@ -13,7 +13,7 @@ on:
workflow_dispatch: workflow_dispatch:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:

View file

@ -18,7 +18,7 @@ on:
workflow_dispatch: workflow_dispatch:
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:

View file

@ -27,7 +27,7 @@ on:
- '.github/workflows/update-readthedocs.yml' - '.github/workflows/update-readthedocs.yml'
concurrency: concurrency:
group: ${{ github.workflow }}-${{ github.ref }} group: ${{ github.workflow }}-${{ github.ref == 'refs/heads/main' && github.run_id || github.ref }}
cancel-in-progress: true cancel-in-progress: true
jobs: jobs:

View file

@ -225,8 +225,32 @@ server:
port: 8321 # Port to listen on (default: 8321) port: 8321 # Port to listen on (default: 8321)
tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS tls_certfile: "/path/to/cert.pem" # Optional: Path to TLS certificate for HTTPS
tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS tls_keyfile: "/path/to/key.pem" # Optional: Path to TLS key for HTTPS
cors: true # Optional: Enable CORS (dev mode) or full config object
``` ```
### CORS Configuration
CORS (Cross-Origin Resource Sharing) can be configured in two ways:
**Local development** (allows localhost origins only):
```yaml
server:
cors: true
```
**Explicit configuration** (custom origins and settings):
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://app.example.com"]
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_credentials: true
max_age: 3600
```
When `cors: true`, the server enables secure localhost-only access for local development. For production, specify exact origins to maintain security.
### Authentication Configuration ### Authentication Configuration
> **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly. > **Breaking Change (v0.2.14)**: The authentication configuration structure has changed. The previous format with `provider_type` and `config` fields has been replaced with a unified `provider_config` field that includes the `type` field. Update your configuration files accordingly.
@ -618,6 +642,54 @@ Content-Type: application/json
} }
``` ```
### CORS Configuration
Configure CORS to allow web browsers to make requests from different domains. Disabled by default.
#### Quick Setup
For development, use the simple boolean flag:
```yaml
server:
cors: true # Auto-enables localhost with any port
```
This automatically allows `http://localhost:*` and `https://localhost:*` with secure defaults.
#### Custom Configuration
For specific origins and full control:
```yaml
server:
cors:
allow_origins: ["https://myapp.com", "https://staging.myapp.com"]
allow_credentials: true
allow_methods: ["GET", "POST", "PUT", "DELETE"]
allow_headers: ["Content-Type", "Authorization"]
allow_origin_regex: "https://.*\\.example\\.com" # Optional regex pattern
expose_headers: ["X-Total-Count"]
max_age: 86400
```
#### Configuration Options
| Field | Description | Default |
| -------------------- | ---------------------------------------------- | ------- |
| `allow_origins` | List of allowed origins. Use `["*"]` for any. | `["*"]` |
| `allow_origin_regex` | Regex pattern for allowed origins (optional). | `None` |
| `allow_methods` | Allowed HTTP methods. | `["*"]` |
| `allow_headers` | Allowed headers. | `["*"]` |
| `allow_credentials` | Allow credentials (cookies, auth headers). | `false` |
| `expose_headers` | Headers exposed to browser. | `[]` |
| `max_age` | Preflight cache time (seconds). | `600` |
**Security Notes**:
- `allow_credentials: true` requires explicit origins (no wildcards)
- `cors: true` enables localhost access only (secure for development)
- For public APIs, always specify exact allowed origins
## Extending to handle Safety ## Extending to handle Safety
Configuring Safety can be a little involved so it is instructive to go through an example. Configuring Safety can be a little involved so it is instructive to go through an example.

View file

@ -17,7 +17,6 @@ client = LlamaStackAsLibraryClient(
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here. # provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]}, provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
) )
client.initialize()
``` ```
This will parse your config and set up any inline implementations and remote clients needed for your implementation. This will parse your config and set up any inline implementations and remote clients needed for your implementation.
@ -32,5 +31,4 @@ If you've created a [custom distribution](https://llama-stack.readthedocs.io/en/
```python ```python
client = LlamaStackAsLibraryClient(config_path) client = LlamaStackAsLibraryClient(config_path)
client.initialize()
``` ```

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server") logger = get_logger(name=__name__, category="cli")
class StackRun(Subcommand): class StackRun(Subcommand):

View file

@ -318,6 +318,41 @@ class QuotaConfig(BaseModel):
period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set") period: QuotaPeriod = Field(default=QuotaPeriod.DAY, description="Quota period to set")
class CORSConfig(BaseModel):
allow_origins: list[str] = Field(default_factory=list)
allow_origin_regex: str | None = Field(default=None)
allow_methods: list[str] = Field(default=["OPTIONS"])
allow_headers: list[str] = Field(default_factory=list)
allow_credentials: bool = Field(default=False)
expose_headers: list[str] = Field(default_factory=list)
max_age: int = Field(default=600, ge=0)
@model_validator(mode="after")
def validate_credentials_config(self) -> Self:
if self.allow_credentials and (self.allow_origins == ["*"] or "*" in self.allow_origins):
raise ValueError("Cannot use wildcard origins with credentials enabled")
return self
def process_cors_config(cors_config: bool | CORSConfig | None) -> CORSConfig | None:
if cors_config is False or cors_config is None:
return None
if cors_config is True:
# dev mode: allow localhost on any port
return CORSConfig(
allow_origins=[],
allow_origin_regex=r"https?://localhost:\d+",
allow_methods=["GET", "POST", "PUT", "DELETE", "OPTIONS"],
allow_headers=["Content-Type", "Authorization", "X-Requested-With"],
)
if isinstance(cors_config, CORSConfig):
return cors_config
raise ValueError(f"Expected bool or CORSConfig, got {type(cors_config).__name__}")
class ServerConfig(BaseModel): class ServerConfig(BaseModel):
port: int = Field( port: int = Field(
default=8321, default=8321,
@ -349,6 +384,12 @@ class ServerConfig(BaseModel):
default=None, default=None,
description="Per client quota request configuration", description="Per client quota request configuration",
) )
cors: bool | CORSConfig | None = Field(
default=None,
description="CORS configuration for cross-origin requests. Can be:\n"
"- true: Enable localhost CORS for development\n"
"- {allow_origins: [...], allow_methods: [...], ...}: Full configuration",
)
class StackRunConfig(BaseModel): class StackRunConfig(BaseModel):

View file

@ -146,39 +146,26 @@ class LlamaStackAsLibraryClient(LlamaStackClient):
): ):
super().__init__() super().__init__()
self.async_client = AsyncLlamaStackAsLibraryClient( self.async_client = AsyncLlamaStackAsLibraryClient(
config_path_or_distro_name, custom_provider_registry, provider_data config_path_or_distro_name, custom_provider_registry, provider_data, skip_logger_removal
) )
self.pool_executor = ThreadPoolExecutor(max_workers=4) self.pool_executor = ThreadPoolExecutor(max_workers=4)
self.skip_logger_removal = skip_logger_removal
self.provider_data = provider_data self.provider_data = provider_data
self.loop = asyncio.new_event_loop() self.loop = asyncio.new_event_loop()
def initialize(self):
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not self.skip_logger_removal:
self._remove_root_logger_handlers()
# use a new event loop to avoid interfering with the main event loop # use a new event loop to avoid interfering with the main event loop
loop = asyncio.new_event_loop() loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop) asyncio.set_event_loop(loop)
try: try:
return loop.run_until_complete(self.async_client.initialize()) loop.run_until_complete(self.async_client.initialize())
finally: finally:
asyncio.set_event_loop(None) asyncio.set_event_loop(None)
def _remove_root_logger_handlers(self): def initialize(self):
""" """
Remove all handlers from the root logger. Needed to avoid polluting the console with logs. Deprecated method for backward compatibility.
""" """
root_logger = logging.getLogger() pass
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
def request(self, *args, **kwargs): def request(self, *args, **kwargs):
loop = self.loop loop = self.loop
@ -216,6 +203,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
config_path_or_distro_name: str, config_path_or_distro_name: str,
custom_provider_registry: ProviderRegistry | None = None, custom_provider_registry: ProviderRegistry | None = None,
provider_data: dict[str, Any] | None = None, provider_data: dict[str, Any] | None = None,
skip_logger_removal: bool = False,
): ):
super().__init__() super().__init__()
# when using the library client, we should not log to console since many # when using the library client, we should not log to console since many
@ -223,6 +211,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",") current_sinks = os.environ.get("TELEMETRY_SINKS", "sqlite").split(",")
os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console") os.environ["TELEMETRY_SINKS"] = ",".join(sink for sink in current_sinks if sink != "console")
if in_notebook():
import nest_asyncio
nest_asyncio.apply()
if not skip_logger_removal:
self._remove_root_logger_handlers()
if config_path_or_distro_name.endswith(".yaml"): if config_path_or_distro_name.endswith(".yaml"):
config_path = Path(config_path_or_distro_name) config_path = Path(config_path_or_distro_name)
if not config_path.exists(): if not config_path.exists():
@ -239,7 +234,24 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
self.provider_data = provider_data self.provider_data = provider_data
self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError self.route_impls: RouteImpls | None = None # Initialize to None to prevent AttributeError
def _remove_root_logger_handlers(self):
"""
Remove all handlers from the root logger. Needed to avoid polluting the console with logs.
"""
root_logger = logging.getLogger()
for handler in root_logger.handlers[:]:
root_logger.removeHandler(handler)
logger.info(f"Removed handler {handler.__class__.__name__} from root logger")
async def initialize(self) -> bool: async def initialize(self) -> bool:
"""
Initialize the async client.
Returns:
bool: True if initialization was successful
"""
try: try:
self.route_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry) self.impls = await construct_stack(self.config, self.custom_provider_registry)

View file

@ -12,7 +12,7 @@ from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class DatasetIORouter(DatasetIO): class DatasetIORouter(DatasetIO):

View file

@ -16,7 +16,7 @@ from llama_stack.apis.scoring import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class ScoringRouter(Scoring): class ScoringRouter(Scoring):

View file

@ -65,7 +65,7 @@ from llama_stack.providers.datatypes import HealthResponse, HealthStatus, Routin
from llama_stack.providers.utils.inference.inference_store import InferenceStore from llama_stack.providers.utils.inference.inference_store import InferenceStore
from llama_stack.providers.utils.telemetry.tracing import get_current_span from llama_stack.providers.utils.telemetry.tracing import get_current_span
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="core::routers")
class InferenceRouter(Inference): class InferenceRouter(Inference):

View file

@ -13,7 +13,7 @@ from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import RoutingTable from llama_stack.providers.datatypes import RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class SafetyRouter(Safety): class SafetyRouter(Safety):

View file

@ -22,7 +22,7 @@ from llama_stack.log import get_logger
from ..routing_tables.toolgroups import ToolGroupsRoutingTable from ..routing_tables.toolgroups import ToolGroupsRoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class ToolRuntimeRouter(ToolRuntime): class ToolRuntimeRouter(ToolRuntime):

View file

@ -30,7 +30,7 @@ from llama_stack.apis.vector_io import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routers")
class VectorIORouter(VectorIO): class VectorIORouter(VectorIO):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks): class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):

View file

@ -23,7 +23,7 @@ from llama_stack.core.store import DistributionRegistry
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import Api, RoutingTable from llama_stack.providers.datatypes import Api, RoutingTable
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
def get_impl_api(p: Any) -> Api: def get_impl_api(p: Any) -> Api:

View file

@ -26,7 +26,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets): class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class ModelsRoutingTable(CommonRoutingTableImpl, Models): class ModelsRoutingTable(CommonRoutingTableImpl, Models):

View file

@ -19,7 +19,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions): class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):

View file

@ -15,7 +15,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class ShieldsRoutingTable(CommonRoutingTableImpl, Shields): class ShieldsRoutingTable(CommonRoutingTableImpl, Shields):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl from .common import CommonRoutingTableImpl
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None: def parse_toolgroup_from_toolgroup_name_pair(toolgroup_name_with_maybe_tool_name: str) -> str | None:

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .common import CommonRoutingTableImpl, lookup_model from .common import CommonRoutingTableImpl, lookup_model
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core::routing_tables")
class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs): class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):

View file

@ -15,7 +15,7 @@ from llama_stack.core.server.auth_providers import create_auth_provider
from llama_stack.core.server.routes import find_matching_route, initialize_route_impls from llama_stack.core.server.routes import find_matching_route, initialize_route_impls
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="core::auth")
class AuthenticationMiddleware: class AuthenticationMiddleware:

View file

@ -23,7 +23,7 @@ from llama_stack.core.datatypes import (
) )
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="auth") logger = get_logger(name=__name__, category="core::auth")
class AuthResponse(BaseModel): class AuthResponse(BaseModel):

View file

@ -15,7 +15,7 @@ from llama_stack.providers.utils.kvstore.api import KVStore
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore.kvstore import kvstore_impl
logger = get_logger(name=__name__, category="quota") logger = get_logger(name=__name__, category="core::server")
class QuotaMiddleware: class QuotaMiddleware:

View file

@ -28,6 +28,7 @@ from aiohttp import hdrs
from fastapi import Body, FastAPI, HTTPException, Request, Response from fastapi import Body, FastAPI, HTTPException, Request, Response
from fastapi import Path as FastapiPath from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError from fastapi.exceptions import RequestValidationError
from fastapi.middleware.cors import CORSMiddleware
from fastapi.responses import JSONResponse, StreamingResponse from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError from openai import BadRequestError
from pydantic import BaseModel, ValidationError from pydantic import BaseModel, ValidationError
@ -40,6 +41,7 @@ from llama_stack.core.datatypes import (
AuthenticationRequiredError, AuthenticationRequiredError,
LoggingConfig, LoggingConfig,
StackRunConfig, StackRunConfig,
process_cors_config,
) )
from llama_stack.core.distribution import builtin_automatically_routed_apis from llama_stack.core.distribution import builtin_automatically_routed_apis
from llama_stack.core.external import ExternalApiSpec, load_external_apis from llama_stack.core.external import ExternalApiSpec, load_external_apis
@ -82,7 +84,7 @@ from .quota import QuotaMiddleware
REPO_ROOT = Path(__file__).parent.parent.parent.parent REPO_ROOT = Path(__file__).parent.parent.parent.parent
logger = get_logger(name=__name__, category="server") logger = get_logger(name=__name__, category="core::server")
def warn_with_traceback(message, category, filename, lineno, file=None, line=None): def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
@ -413,7 +415,7 @@ def main(args: argparse.Namespace | None = None):
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="server", config=logger_config) logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env: if args.env:
for env_pair in args.env: for env_pair in args.env:
try: try:
@ -483,6 +485,12 @@ def main(args: argparse.Namespace | None = None):
window_seconds=window_seconds, window_seconds=window_seconds,
) )
if config.server.cors:
logger.info("Enabling CORS")
cors_config = process_cors_config(config.server.cors)
if cors_config:
app.add_middleware(CORSMiddleware, **cors_config.model_dump())
if Api.telemetry in impls: if Api.telemetry in impls:
setup_logger(impls[Api.telemetry]) setup_logger(impls[Api.telemetry])
else: else:

View file

@ -16,7 +16,7 @@ from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl from llama_stack.providers.utils.kvstore import KVStore, kvstore_impl
from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig from llama_stack.providers.utils.kvstore.config import KVStoreConfig, SqliteKVStoreConfig
logger = get_logger(__name__, category="core") logger = get_logger(__name__, category="core::registry")
class DistributionRegistry(Protocol): class DistributionRegistry(Protocol):

View file

@ -10,7 +10,7 @@ from pathlib import Path
from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.core.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="config_resolution") logger = get_logger(name=__name__, category="core")
DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions" DISTRO_DIR = Path(__file__).parent.parent.parent.parent / "llama_stack" / "distributions"

View file

@ -36,7 +36,7 @@ from .utils import get_negative_inf_value, to_2tuple
MP_SCALE = 8 MP_SCALE = 8
logger = get_logger(name=__name__, category="models") logger = get_logger(name=__name__, category="models::llama")
def reduce_from_tensor_model_parallel_region(input_): def reduce_from_tensor_model_parallel_region(input_):

View file

@ -11,7 +11,7 @@ from llama_stack.log import get_logger
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="models::llama")
BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)' BUILTIN_TOOL_PATTERN = r'\b(?P<tool_name>\w+)\.call\(query="(?P<query>[^"]*)"\)'
CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})") CUSTOM_TOOL_CALL_PATTERN = re.compile(r"<function=(?P<function_name>[^}]+)>(?P<args>{.*?})")

View file

@ -18,7 +18,7 @@ from ...datatypes import QuantizationMode
from ..model import Transformer, TransformerBlock from ..model import Transformer, TransformerBlock
from ..moe import MoE from ..moe import MoE
log = get_logger(name=__name__, category="models") log = get_logger(name=__name__, category="models::llama")
def swiglu_wrapper_no_reduce( def swiglu_wrapper_no_reduce(

View file

@ -9,7 +9,7 @@ import collections
from llama_stack.log import get_logger from llama_stack.log import get_logger
log = get_logger(name=__name__, category="llama") log = get_logger(name=__name__, category="models::llama")
try: try:
import fbgemm_gpu.experimental.gen_ai # noqa: F401 import fbgemm_gpu.experimental.gen_ai # noqa: F401

View file

@ -84,7 +84,7 @@ MEMORY_QUERY_TOOL = "knowledge_search"
WEB_SEARCH_TOOL = "web_search" WEB_SEARCH_TOOL = "web_search"
RAG_TOOL_GROUP = "builtin::rag" RAG_TOOL_GROUP = "builtin::rag"
logger = get_logger(name=__name__, category="agents") logger = get_logger(name=__name__, category="agents::meta_reference")
class ChatAgent(ShieldRunnerMixin): class ChatAgent(ShieldRunnerMixin):

View file

@ -51,7 +51,7 @@ from .config import MetaReferenceAgentsImplConfig
from .persistence import AgentInfo from .persistence import AgentInfo
from .responses.openai_responses import OpenAIResponsesImpl from .responses.openai_responses import OpenAIResponsesImpl
logger = get_logger(name=__name__, category="agents") logger = get_logger(name=__name__, category="agents::meta_reference")
class MetaReferenceAgentsImpl(Agents): class MetaReferenceAgentsImpl(Agents):

View file

@ -17,7 +17,7 @@ from llama_stack.core.request_headers import get_authenticated_user
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.kvstore import KVStore from llama_stack.providers.utils.kvstore import KVStore
log = get_logger(name=__name__, category="agents") log = get_logger(name=__name__, category="agents::meta_reference")
class AgentSessionInfo(Session): class AgentSessionInfo(Session):

View file

@ -41,7 +41,7 @@ from .utils import (
convert_response_text_to_chat_response_format, convert_response_text_to_chat_response_format,
) )
logger = get_logger(name=__name__, category="responses") logger = get_logger(name=__name__, category="openai::responses")
class OpenAIResponsePreviousResponseWithInputItems(BaseModel): class OpenAIResponsePreviousResponseWithInputItems(BaseModel):

View file

@ -47,7 +47,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ChatCompletionResult from .types import ChatCompletionContext, ChatCompletionResult
from .utils import convert_chat_choice_to_response_message, is_function_tool_call from .utils import convert_chat_choice_to_response_message, is_function_tool_call
logger = get_logger(name=__name__, category="responses") logger = get_logger(name=__name__, category="agents::meta_reference")
class StreamingResponseOrchestrator: class StreamingResponseOrchestrator:

View file

@ -38,7 +38,7 @@ from llama_stack.log import get_logger
from .types import ChatCompletionContext, ToolExecutionResult from .types import ChatCompletionContext, ToolExecutionResult
logger = get_logger(name=__name__, category="responses") logger = get_logger(name=__name__, category="agents::meta_reference")
class ToolExecutor: class ToolExecutor:

View file

@ -11,7 +11,7 @@ from llama_stack.apis.safety import Safety, SafetyViolation, ViolationLevel
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.telemetry import tracing from llama_stack.providers.utils.telemetry import tracing
log = get_logger(name=__name__, category="agents") log = get_logger(name=__name__, category="agents::meta_reference")
class SafetyException(Exception): # noqa: N818 class SafetyException(Exception): # noqa: N818

View file

@ -65,7 +65,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import FireworksImplConfig from .config import FireworksImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -10,7 +10,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::llama_openai_compat")
class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin): class LlamaCompatInferenceAdapter(OpenAIMixin, LiteLLMOpenAIMixin):

View file

@ -41,6 +41,11 @@ client.initialize()
### Create Completion ### Create Completion
> Note on Completion API
>
> The hosted NVIDIA Llama NIMs (e.g., `meta-llama/Llama-3.1-8B-Instruct`) with ```NVIDIA_BASE_URL="https://integrate.api.nvidia.com"``` does not support the ```completion``` method, while the locally deployed NIM does.
```python ```python
response = client.inference.completion( response = client.inference.completion(
model_id="meta-llama/Llama-3.1-8B-Instruct", model_id="meta-llama/Llama-3.1-8B-Instruct",
@ -76,6 +81,73 @@ response = client.inference.chat_completion(
print(f"Response: {response.completion_message.content}") print(f"Response: {response.completion_message.content}")
``` ```
### Tool Calling Example ###
```python
from llama_stack.models.llama.datatypes import ToolDefinition, ToolParamDefinition
tool_definition = ToolDefinition(
tool_name="get_weather",
description="Get current weather information for a location",
parameters={
"location": ToolParamDefinition(
param_type="string",
description="The city and state, e.g. San Francisco, CA",
required=True,
),
"unit": ToolParamDefinition(
param_type="string",
description="Temperature unit (celsius or fahrenheit)",
required=False,
default="celsius",
),
},
)
tool_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[{"role": "user", "content": "What's the weather like in San Francisco?"}],
tools=[tool_definition],
)
print(f"Tool Response: {tool_response.completion_message.content}")
if tool_response.completion_message.tool_calls:
for tool_call in tool_response.completion_message.tool_calls:
print(f"Tool Called: {tool_call.tool_name}")
print(f"Arguments: {tool_call.arguments}")
```
### Structured Output Example
```python
from llama_stack.apis.inference import JsonSchemaResponseFormat, ResponseFormatType
person_schema = {
"type": "object",
"properties": {
"name": {"type": "string"},
"age": {"type": "integer"},
"occupation": {"type": "string"},
},
"required": ["name", "age", "occupation"],
}
response_format = JsonSchemaResponseFormat(
type=ResponseFormatType.json_schema, json_schema=person_schema
)
structured_response = client.inference.chat_completion(
model_id="meta-llama/Llama-3.1-8B-Instruct",
messages=[
{
"role": "user",
"content": "Create a profile for a fictional person named Alice who is 30 years old and is a software engineer. ",
}
],
response_format=response_format,
)
print(f"Structured Response: {structured_response.completion_message.content}")
```
### Create Embeddings ### Create Embeddings
> Note on OpenAI embeddings compatibility > Note on OpenAI embeddings compatibility
> >

View file

@ -7,7 +7,7 @@
import warnings import warnings
from collections.abc import AsyncIterator from collections.abc import AsyncIterator
from openai import NOT_GIVEN, APIConnectionError, BadRequestError from openai import NOT_GIVEN, APIConnectionError
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -57,7 +57,7 @@ from .openai_utils import (
) )
from .utils import _is_nvidia_hosted from .utils import _is_nvidia_hosted
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::nvidia")
class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper): class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
@ -197,15 +197,11 @@ class NVIDIAInferenceAdapter(OpenAIMixin, Inference, ModelRegistryHelper):
} }
extra_body["input_type"] = task_type_options[task_type] extra_body["input_type"] = task_type_options[task_type]
try:
response = await self.client.embeddings.create( response = await self.client.embeddings.create(
model=provider_model_id, model=provider_model_id,
input=input, input=input,
extra_body=extra_body, extra_body=extra_body,
) )
except BadRequestError as e:
raise ValueError(f"Failed to get embeddings: {e}") from e
# #
# OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...) # OpenAI: CreateEmbeddingResponse(data=[Embedding(embedding=list[float], ...)], ...)
# -> # ->

View file

@ -10,7 +10,7 @@ from llama_stack.log import get_logger
from . import NVIDIAConfig from . import NVIDIAConfig
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::nvidia")
def _is_nvidia_hosted(config: NVIDIAConfig) -> bool: def _is_nvidia_hosted(config: NVIDIAConfig) -> bool:

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter( class OllamaInferenceAdapter(

View file

@ -11,7 +11,7 @@ from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from .config import OpenAIConfig from .config import OpenAIConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::openai")
# #

View file

@ -58,7 +58,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference::tgi")
def build_hf_repo_model_entries(): def build_hf_repo_model_entries():

View file

@ -61,7 +61,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import TogetherImplConfig from .config import TogetherImplConfig
from .models import MODEL_ENTRIES from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="inference::together")
class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):

View file

@ -85,7 +85,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
from .config import VLLMInferenceAdapterConfig from .config import VLLMInferenceAdapterConfig
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="inference::vllm")
def build_hf_repo_model_entries(): def build_hf_repo_model_entries():

View file

@ -15,7 +15,7 @@ from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefa
from .config import NvidiaPostTrainingConfig from .config import NvidiaPostTrainingConfig
logger = get_logger(name=__name__, category="integration") logger = get_logger(name=__name__, category="post_training::nvidia")
def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None: def warn_unsupported_params(config_dict: Any, supported_keys: set[str], config_name: str) -> None:

View file

@ -21,7 +21,7 @@ from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from .config import BedrockSafetyConfig from .config import BedrockSafetyConfig
logger = get_logger(name=__name__, category="safety") logger = get_logger(name=__name__, category="safety::bedrock")
class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate): class BedrockSafetyAdapter(Safety, ShieldsProtocolPrivate):

View file

@ -9,7 +9,7 @@ from typing import Any
import requests import requests
from llama_stack.apis.inference import Message from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel from llama_stack.apis.safety import ModerationObject, RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield from llama_stack.apis.shields import Shield
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ShieldsProtocolPrivate from llama_stack.providers.datatypes import ShieldsProtocolPrivate
@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import NVIDIASafetyConfig from .config import NVIDIASafetyConfig
logger = get_logger(name=__name__, category="safety") logger = get_logger(name=__name__, category="safety::nvidia")
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate): class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
@ -67,6 +67,9 @@ class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
self.shield = NeMoGuardrails(self.config, shield.shield_id) self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages) return await self.shield.run(messages)
async def run_moderation(self, input: str | list[str], model: str) -> ModerationObject:
raise NotImplementedError("NVIDIA safety provider currently does not implement run_moderation")
class NeMoGuardrails: class NeMoGuardrails:
""" """

View file

@ -25,7 +25,7 @@ from llama_stack.providers.utils.inference.openai_compat import convert_message_
from .config import SambaNovaSafetyConfig from .config import SambaNovaSafetyConfig
logger = get_logger(name=__name__, category="safety") logger = get_logger(name=__name__, category="safety::sambanova")
CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?" CANNED_RESPONSE_TEXT = "I can't answer that. Can I help with something else?"

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig from .config import ChromaVectorIOConfig as RemoteChromaVectorIOConfig
log = get_logger(name=__name__, category="vector_io") log = get_logger(name=__name__, category="vector_io::chroma")
ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI ChromaClientType = chromadb.api.AsyncClientAPI | chromadb.api.ClientAPI

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig from .config import MilvusVectorIOConfig as RemoteMilvusVectorIOConfig
logger = get_logger(name=__name__, category="vector_io") logger = get_logger(name=__name__, category="vector_io::milvus")
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:milvus:{VERSION}::"

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import PGVectorVectorIOConfig from .config import PGVectorVectorIOConfig
log = get_logger(name=__name__, category="vector_io") log = get_logger(name=__name__, category="vector_io::pgvector")
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:pgvector:{VERSION}::"

View file

@ -36,7 +36,7 @@ from llama_stack.providers.utils.memory.vector_store import (
from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig from .config import QdrantVectorIOConfig as RemoteQdrantVectorIOConfig
log = get_logger(name=__name__, category="vector_io") log = get_logger(name=__name__, category="vector_io::qdrant")
CHUNK_ID_KEY = "_chunk_id" CHUNK_ID_KEY = "_chunk_id"
# KV store prefixes for vector databases # KV store prefixes for vector databases

View file

@ -34,7 +34,7 @@ from llama_stack.providers.utils.vector_io.vector_utils import sanitize_collecti
from .config import WeaviateVectorIOConfig from .config import WeaviateVectorIOConfig
log = get_logger(name=__name__, category="vector_io") log = get_logger(name=__name__, category="vector_io::weaviate")
VERSION = "v3" VERSION = "v3"
VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::" VECTOR_DBS_PREFIX = f"vector_dbs:weaviate:{VERSION}::"

View file

@ -28,7 +28,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import interleaved_con
EMBEDDING_MODELS = {} EMBEDDING_MODELS = {}
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="providers::utils")
class SentenceTransformerEmbeddingMixin: class SentenceTransformerEmbeddingMixin:

View file

@ -54,7 +54,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str, interleaved_content_as_str,
) )
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="providers::utils")
class LiteLLMOpenAIMixin( class LiteLLMOpenAIMixin(

View file

@ -17,7 +17,7 @@ from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
) )
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="providers::utils")
class RemoteInferenceProviderConfig(BaseModel): class RemoteInferenceProviderConfig(BaseModel):

View file

@ -134,7 +134,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
decode_assistant_message, decode_assistant_message,
) )
logger = get_logger(name=__name__, category="inference") logger = get_logger(name=__name__, category="providers::utils")
class OpenAICompatCompletionChoiceDelta(BaseModel): class OpenAICompatCompletionChoiceDelta(BaseModel):

View file

@ -25,7 +25,7 @@ from llama_stack.apis.inference import (
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params from llama_stack.providers.utils.inference.openai_compat import prepare_openai_completion_params
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="providers::utils")
class OpenAIMixin(ABC): class OpenAIMixin(ABC):

View file

@ -58,7 +58,7 @@ from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models from llama_stack.providers.utils.inference import supported_inference_models
log = get_logger(name=__name__, category="inference") log = get_logger(name=__name__, category="providers::utils")
class ChatCompletionRequestWithRawContent(ChatCompletionRequest): class ChatCompletionRequestWithRawContent(ChatCompletionRequest):

View file

@ -13,7 +13,7 @@ from llama_stack.providers.utils.kvstore import KVStore
from ..config import MongoDBKVStoreConfig from ..config import MongoDBKVStoreConfig
log = get_logger(name=__name__, category="kvstore") log = get_logger(name=__name__, category="providers::utils")
class MongoDBKVStoreImpl(KVStore): class MongoDBKVStoreImpl(KVStore):

View file

@ -14,7 +14,7 @@ from llama_stack.log import get_logger
from ..api import KVStore from ..api import KVStore
from ..config import PostgresKVStoreConfig from ..config import PostgresKVStoreConfig
log = get_logger(name=__name__, category="kvstore") log = get_logger(name=__name__, category="providers::utils")
class PostgresKVStoreImpl(KVStore): class PostgresKVStoreImpl(KVStore):

View file

@ -44,7 +44,7 @@ from llama_stack.providers.utils.memory.vector_store import (
make_overlapped_chunks, make_overlapped_chunks,
) )
logger = get_logger(name=__name__, category="memory") logger = get_logger(name=__name__, category="providers::utils")
# Constants for OpenAI vector stores # Constants for OpenAI vector stores
CHUNK_MULTIPLIER = 5 CHUNK_MULTIPLIER = 5

View file

@ -33,7 +33,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
) )
from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id from llama_stack.providers.utils.vector_io.vector_utils import generate_chunk_id
log = get_logger(name=__name__, category="memory") log = get_logger(name=__name__, category="providers::utils")
class ChunkForDeletion(BaseModel): class ChunkForDeletion(BaseModel):

View file

@ -17,7 +17,7 @@ from pydantic import BaseModel
from llama_stack.log import get_logger from llama_stack.log import get_logger
logger = get_logger(name=__name__, category="scheduler") logger = get_logger(name=__name__, category="providers::utils")
# TODO: revisit the list of possible statuses when defining a more coherent # TODO: revisit the list of possible statuses when defining a more coherent

View file

@ -17,7 +17,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore from .api import ColumnDefinition, ColumnType, PaginatedResponse, SqlStore
from .sqlstore import SqlStoreType from .sqlstore import SqlStoreType
logger = get_logger(name=__name__, category="authorized_sqlstore") logger = get_logger(name=__name__, category="providers::utils")
# Hardcoded copy of the default policy that our SQL filtering implements # Hardcoded copy of the default policy that our SQL filtering implements
# WARNING: If default_policy() changes, this constant must be updated accordingly # WARNING: If default_policy() changes, this constant must be updated accordingly

View file

@ -30,7 +30,7 @@ from llama_stack.log import get_logger
from .api import ColumnDefinition, ColumnType, SqlStore from .api import ColumnDefinition, ColumnType, SqlStore
from .sqlstore import SqlAlchemySqlStoreConfig from .sqlstore import SqlAlchemySqlStoreConfig
logger = get_logger(name=__name__, category="sqlstore") logger = get_logger(name=__name__, category="providers::utils")
TYPE_MAPPING: dict[ColumnType, Any] = { TYPE_MAPPING: dict[ColumnType, Any] = {
ColumnType.INTEGER: Integer, ColumnType.INTEGER: Integer,

View file

@ -256,9 +256,6 @@ def instantiate_llama_stack_client(session):
provider_data=get_provider_data(), provider_data=get_provider_data(),
skip_logger_removal=True, skip_logger_removal=True,
) )
if not client.initialize():
raise RuntimeError("Initialization failed")
return client return client

View file

@ -55,7 +55,7 @@
# #
import pytest import pytest
from llama_stack_client import BadRequestError from llama_stack_client import BadRequestError as LlamaStackBadRequestError
from llama_stack_client.types import EmbeddingsResponse from llama_stack_client.types import EmbeddingsResponse
from llama_stack_client.types.shared.interleaved_content import ( from llama_stack_client.types.shared.interleaved_content import (
ImageContentItem, ImageContentItem,
@ -63,6 +63,9 @@ from llama_stack_client.types.shared.interleaved_content import (
ImageContentItemImageURL, ImageContentItemImageURL,
TextContentItem, TextContentItem,
) )
from openai import BadRequestError as OpenAIBadRequestError
from llama_stack.core.library_client import LlamaStackAsLibraryClient
DUMMY_STRING = "hello" DUMMY_STRING = "hello"
DUMMY_STRING2 = "world" DUMMY_STRING2 = "world"
@ -203,7 +206,14 @@ def test_embedding_truncation_error(
): ):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError): # Using LlamaStackClient from llama_stack_client will raise llama_stack_client.BadRequestError
# While using LlamaStackAsLibraryClient from llama_stack.distribution.library_client will raise the error that the backend raises
error_type = (
OpenAIBadRequestError
if isinstance(llama_stack_client, LlamaStackAsLibraryClient)
else LlamaStackBadRequestError
)
with pytest.raises(error_type):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, model_id=embedding_model_id,
contents=[DUMMY_LONG_TEXT], contents=[DUMMY_LONG_TEXT],
@ -283,7 +293,8 @@ def test_embedding_text_truncation_error(
): ):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError): error_type = ValueError if isinstance(llama_stack_client, LlamaStackAsLibraryClient) else LlamaStackBadRequestError
with pytest.raises(error_type):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, model_id=embedding_model_id,
contents=[DUMMY_STRING], contents=[DUMMY_STRING],

View file

@ -113,8 +113,6 @@ def openai_client(base_url, api_key, provider):
raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'") raise ValueError(f"Invalid config for Llama Stack: {provider}, it must be of the form 'stack:<config>'")
config = parts[1] config = parts[1]
client = LlamaStackAsLibraryClient(config, skip_logger_removal=True) client = LlamaStackAsLibraryClient(config, skip_logger_removal=True)
if not client.initialize():
raise RuntimeError("Initialization failed")
return client return client
return OpenAI( return OpenAI(

View file

@ -5,86 +5,121 @@
# the root directory of this source tree. # the root directory of this source tree.
""" """
Unit tests for LlamaStackAsLibraryClient initialization error handling. Unit tests for LlamaStackAsLibraryClient automatic initialization.
These tests ensure that users get proper error messages when they forget to call These tests ensure that the library client is automatically initialized
initialize() on the library client, preventing AttributeError regressions. and ready to use immediately after construction.
""" """
import pytest
from llama_stack.core.library_client import ( from llama_stack.core.library_client import (
AsyncLlamaStackAsLibraryClient, AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient, LlamaStackAsLibraryClient,
) )
from llama_stack.core.server.routes import RouteImpls
class TestLlamaStackAsLibraryClientInitialization: class TestLlamaStackAsLibraryClientAutoInitialization:
"""Test proper error handling for uninitialized library clients.""" """Test automatic initialization of library clients."""
@pytest.mark.parametrize( def test_sync_client_auto_initialization(self, monkeypatch):
"api_call", """Test that sync client is automatically initialized after construction."""
[ # Mock the stack construction to avoid dependency issues
lambda client: client.models.list(), mock_impls = {}
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]), mock_route_impls = RouteImpls({})
lambda client: next(
client.chat.completions.create(
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
),
],
ids=["models.list", "chat.completions.create", "chat.completions.create_stream"],
)
def test_sync_client_proper_error_without_initialization(self, api_call):
"""Test that sync client raises ValueError with helpful message when not initialized."""
client = LlamaStackAsLibraryClient("nvidia")
with pytest.raises(ValueError) as exc_info: async def mock_construct_stack(config, custom_provider_registry):
api_call(client) return mock_impls
error_msg = str(exc_info.value) def mock_initialize_route_impls(impls):
assert "Client not initialized" in error_msg return mock_route_impls
assert "Please call initialize() first" in error_msg
@pytest.mark.parametrize( monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
"api_call", monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
[
lambda client: client.models.list(),
lambda client: client.chat.completions.create(model="test", messages=[{"role": "user", "content": "test"}]),
],
ids=["models.list", "chat.completions.create"],
)
async def test_async_client_proper_error_without_initialization(self, api_call):
"""Test that async client raises ValueError with helpful message when not initialized."""
client = AsyncLlamaStackAsLibraryClient("nvidia")
with pytest.raises(ValueError) as exc_info: client = LlamaStackAsLibraryClient("ci-tests")
await api_call(client)
error_msg = str(exc_info.value) assert client.async_client.route_impls is not None
assert "Client not initialized" in error_msg
assert "Please call initialize() first" in error_msg
async def test_async_client_streaming_error_without_initialization(self): async def test_async_client_auto_initialization(self, monkeypatch):
"""Test that async client streaming raises ValueError with helpful message when not initialized.""" """Test that async client can be initialized and works properly."""
client = AsyncLlamaStackAsLibraryClient("nvidia") # Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
with pytest.raises(ValueError) as exc_info: async def mock_construct_stack(config, custom_provider_registry):
stream = await client.chat.completions.create( return mock_impls
model="test", messages=[{"role": "user", "content": "test"}], stream=True
)
await anext(stream)
error_msg = str(exc_info.value) def mock_initialize_route_impls(impls):
assert "Client not initialized" in error_msg return mock_route_impls
assert "Please call initialize() first" in error_msg
def test_route_impls_initialized_to_none(self): monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
"""Test that route_impls is initialized to None to prevent AttributeError.""" monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
# Test sync client
sync_client = LlamaStackAsLibraryClient("nvidia")
assert sync_client.async_client.route_impls is None
# Test async client directly client = AsyncLlamaStackAsLibraryClient("ci-tests")
async_client = AsyncLlamaStackAsLibraryClient("nvidia")
assert async_client.route_impls is None # Initialize the client
result = await client.initialize()
assert result is True
assert client.route_impls is not None
def test_initialize_method_backward_compatibility(self, monkeypatch):
"""Test that initialize() method still works for backward compatibility."""
# Mock the stack construction to avoid dependency issues
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
result = client.initialize()
assert result is None
result2 = client.initialize()
assert result2 is None
async def test_async_initialize_method_idempotent(self, monkeypatch):
"""Test that async initialize() method can be called multiple times safely."""
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
result1 = await client.initialize()
assert result1 is True
result2 = await client.initialize()
assert result2 is True
def test_route_impls_automatically_set(self, monkeypatch):
"""Test that route_impls is automatically set during construction."""
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests")
assert sync_client.async_client.route_impls is not None

View file

@ -0,0 +1,105 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.core.datatypes import CORSConfig, process_cors_config
def test_cors_config_defaults():
config = CORSConfig()
assert config.allow_origins == []
assert config.allow_origin_regex is None
assert config.allow_methods == ["OPTIONS"]
assert config.allow_headers == []
assert config.allow_credentials is False
assert config.expose_headers == []
assert config.max_age == 600
def test_cors_config_explicit_config():
config = CORSConfig(
allow_origins=["https://example.com"], allow_credentials=True, max_age=3600, allow_methods=["GET", "POST"]
)
assert config.allow_origins == ["https://example.com"]
assert config.allow_credentials is True
assert config.max_age == 3600
assert config.allow_methods == ["GET", "POST"]
def test_cors_config_regex():
config = CORSConfig(allow_origins=[], allow_origin_regex=r"https?://localhost:\d+")
assert config.allow_origins == []
assert config.allow_origin_regex == r"https?://localhost:\d+"
def test_cors_config_wildcard_credentials_error():
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSConfig(allow_origins=["*"], allow_credentials=True)
with pytest.raises(ValueError, match="Cannot use wildcard origins with credentials enabled"):
CORSConfig(allow_origins=["https://example.com", "*"], allow_credentials=True)
def test_process_cors_config_false():
result = process_cors_config(False)
assert result is None
def test_process_cors_config_true():
result = process_cors_config(True)
assert isinstance(result, CORSConfig)
assert result.allow_origins == []
assert result.allow_origin_regex == r"https?://localhost:\d+"
assert result.allow_credentials is False
expected_methods = ["GET", "POST", "PUT", "DELETE", "OPTIONS"]
for method in expected_methods:
assert method in result.allow_methods
def test_process_cors_config_passthrough():
original = CORSConfig(allow_origins=["https://example.com"], allow_methods=["GET"])
result = process_cors_config(original)
assert result is original
def test_process_cors_config_invalid_type():
with pytest.raises(ValueError, match="Expected bool or CORSConfig, got str"):
process_cors_config("invalid")
def test_cors_config_model_dump():
cors_config = CORSConfig(
allow_origins=["https://example.com"],
allow_methods=["GET", "POST"],
allow_headers=["Content-Type"],
allow_credentials=True,
max_age=3600,
)
config_dict = cors_config.model_dump()
assert config_dict["allow_origins"] == ["https://example.com"]
assert config_dict["allow_methods"] == ["GET", "POST"]
assert config_dict["allow_headers"] == ["Content-Type"]
assert config_dict["allow_credentials"] is True
assert config_dict["max_age"] == 3600
expected_keys = {
"allow_origins",
"allow_origin_regex",
"allow_methods",
"allow_headers",
"allow_credentials",
"expose_headers",
"max_age",
}
assert set(config_dict.keys()) == expected_keys