Merge branch 'main' into use-openai-for-databricks

This commit is contained in:
Matthew Farrellee 2025-09-20 06:16:54 -04:00
commit c8623607f5
31 changed files with 665 additions and 691 deletions

View file

@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export MOCK_INFERENCE_MODEL=mock-inference
export MOCK_INFERENCE_URL=openai-mock-service:8080
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
export LLAMA_STACK_WORKERS=4
set -euo pipefail set -euo pipefail
set -x set -x

View file

@ -5,6 +5,7 @@ data:
image_name: kubernetes-benchmark-demo image_name: kubernetes-benchmark-demo
apis: apis:
- agents - agents
- files
- inference - inference
- files - files
- safety - safety
@ -23,6 +24,14 @@ data:
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
vector_io: vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} - provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb

View file

@ -52,9 +52,20 @@ spec:
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
- name: VLLM_TLS_VERIFY - name: VLLM_TLS_VERIFY
value: "false" value: "false"
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] - name: LLAMA_STACK_LOGGING
value: "all=WARNING"
- name: LLAMA_STACK_CONFIG
value: "/etc/config/stack_run_config.yaml"
- name: LLAMA_STACK_WORKERS
value: "${LLAMA_STACK_WORKERS}"
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
ports: ports:
- containerPort: 8323 - containerPort: 8323
resources:
requests:
cpu: "${LLAMA_STACK_WORKERS}"
limits:
cpu: "${LLAMA_STACK_WORKERS}"
volumeMounts: volumeMounts:
- name: llama-storage - name: llama-storage
mountPath: /root/.llama mountPath: /root/.llama

View file

@ -460,10 +460,12 @@ client = LlamaStackClient(base_url="http://localhost:8321")
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding") embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
embedding_model = embed_lm.identifier embedding_model = embed_lm.identifier
vector_db_id = f"v{uuid.uuid4().hex}" vector_db_id = f"v{uuid.uuid4().hex}"
client.vector_dbs.register( # The VectorDB API is deprecated; the server now returns its own authoritative ID.
# We capture the correct ID from the response's .identifier attribute.
vector_db_id = client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model=embedding_model, embedding_model=embedding_model,
) ).identifier
# Create Documents # Create Documents
urls = [ urls = [

View file

@ -23,7 +23,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation ## Installation
You can install Milvus using pymilvus: If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash ```bash
pip install pymilvus pip install pymilvus

View file

@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
default=None, default=None,
) )
@property
def pip_packages(self) -> list[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields # Example: /models, /shields
class RoutingTableProviderSpec(ProviderSpec): class RoutingTableProviderSpec(ProviderSpec):

View file

@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -77,27 +76,12 @@ def providable_apis() -> list[Api]:
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec: def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"]) spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
return spec return spec
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec( spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
return spec return spec

View file

@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
from llama_stack.core.resolver import ProviderRegistry from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import ( from llama_stack.core.stack import (
construct_stack, Stack,
get_stack_run_config_from_distro, get_stack_run_config_from_distro,
replace_env_vars, replace_env_vars,
) )
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
try: try:
self.route_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
stack = Stack(self.config, self.custom_provider_registry)
await stack.initialize()
self.impls = stack.impls
except ModuleNotFoundError as _e: except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr) cprint(_e.msg, color="red", file=sys.stderr)
cprint( cprint(
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
raise _e raise _e
assert self.impls is not None
if Api.telemetry in self.impls: if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry]) setup_logger(self.impls[Api.telemetry])

View file

@ -6,6 +6,7 @@
import argparse import argparse
import asyncio import asyncio
import concurrent.futures
import functools import functools
import inspect import inspect
import json import json
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
request_provider_data_context, request_provider_data_context,
user_from_scope, user_from_scope,
) )
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.server.routes import ( from llama_stack.core.server.routes import (
find_matching_route, find_matching_route,
get_all_api_routes, get_all_api_routes,
initialize_route_impls, initialize_route_impls,
) )
from llama_stack.core.stack import ( from llama_stack.core.stack import (
Stack,
cast_image_name_to_string, cast_image_name_to_string,
construct_stack,
replace_env_vars, replace_env_vars,
shutdown_stack,
validate_env_pair, validate_env_pair,
) )
from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config import redact_sensitive_fields
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
) )
async def shutdown(app): class StackApp(FastAPI):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
""" """
await shutdown_stack(app.__llama_stack_impls__) A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
"""
def __init__(self, config: StackRunConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stack: Stack = Stack(config)
# This code is called from a running event loop managed by uvicorn so we cannot simply call
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
# function.
# As a workaround, we use a thread pool executor to run the initialize() method
# in a separate thread.
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.stack.initialize())
future.result()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: StackApp):
logger.info("Starting up") logger.info("Starting up")
assert app.stack is not None
app.stack.create_registry_refresh_task()
yield yield
logger.info("Shutting down") logger.info("Shutting down")
await shutdown(app) await app.stack.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
def main(args: argparse.Namespace | None = None): def create_app(
"""Start the LlamaStack server.""" config_file: str | None = None,
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") env_vars: list[str] | None = None,
) -> StackApp:
"""Create and configure the FastAPI application.
add_config_distro_args(parser) Args:
parser.add_argument( config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
"--port", env_vars: List of environment variables in KEY=value format.
type=int, disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case Returns:
# the args will be passed as a Namespace object to the main function, otherwise they will be Configured StackApp instance.
# parsed from the command line """
if args is None: config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
args = parser.parse_args() if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
config_or_distro = get_config_from_args(args) config_file = resolve_config_or_distro(config_file, Mode.RUN)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
# Load and process configuration
logger_config = None logger_config = None
with open(config_file) as fp: with open(config_file) as fp:
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config) logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env: if env_vars:
for env_pair in env_vars:
try: try:
key, value = validate_env_pair(env_pair) key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}") logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value os.environ[key] = value
except ValueError as e: except ValueError as e:
logger.error(f"Error: {str(e)}") logger.error(f"Error: {str(e)}")
sys.exit(1) raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents) config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config)) config = StackRunConfig(**cast_image_name_to_string(config))
_log_run_config(run_config=config) _log_run_config(run_config=config)
app = FastAPI( app = StackApp(
lifespan=lifespan, lifespan=lifespan,
docs_url="/docs", docs_url="/docs",
redoc_url="/redoc", redoc_url="/redoc",
openapi_url="/openapi.json", openapi_url="/openapi.json",
config=config,
) )
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
try: impls = app.stack.impls
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth: if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn import uvicorn
# Configure SSL if certificates are provided # Configure SSL if certificates are provided
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
if ssl_config: if ssl_config:
uvicorn_config.update(ssl_config) uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
# We need to catch KeyboardInterrupt because uvicorn's signal handling # We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python # re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing # converts to KeyboardInterrupt. Without this catch, we'd get a confusing
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort. # signal handling but this is quite intrusive and not worth the effort.
try: try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...") logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig): def _log_run_config(run_config: StackRunConfig):

View file

@ -315,11 +315,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.prompts] = prompts_impl impls[Api.prompts] = prompts_impl
# Produces a stack of providers for the given run config. Not all APIs may be class Stack:
# asked for in the run config. def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
async def construct_stack( self.run_config = run_config
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None self.provider_registry = provider_registry
) -> dict[Api, Any]: self.impls = None
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording from llama_stack.testing.inference_recorder import setup_inference_recording
@ -329,24 +333,28 @@ async def construct_stack(
TEST_RECORDING_CONTEXT.__enter__() TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else [] policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls( impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
) )
# Add internal implementations after all other providers are resolved # Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config) add_internal_implementations(impls, self.run_config)
if Api.prompts in impls: if Api.prompts in impls:
await impls[Api.prompts].initialize() await impls[Api.prompts].initialize()
await register_resources(run_config, impls) await register_resources(self.run_config, impls)
await refresh_registry_once(impls) await refresh_registry_once(impls)
self.impls = impls
def create_registry_refresh_task(self):
assert self.impls is not None, "Must call initialize() before starting"
global REGISTRY_REFRESH_TASK global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls)) REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
def cb(task): def cb(task):
import traceback import traceback
@ -360,11 +368,9 @@ async def construct_stack(
logger.debug("Model refresh task completed") logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb) REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown(self):
async def shutdown_stack(impls: dict[Api, Any]): for impl in self.impls.values():
for impl in impls.values():
impl_name = impl.__class__.__name__ impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}") logger.info(f"Shutting down {impl_name}")
try: try:

View file

@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
remote_providers = [ remote_providers = [
provider provider
for provider in available_providers() for provider in available_providers()
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
] ]
inference_providers = [] inference_providers = []
for provider_spec in remote_providers: for provider_spec in remote_providers:
provider_type = provider_spec.adapter.adapter_type provider_type = provider_spec.adapter_type
if provider_type in INFERENCE_PROVIDER_IDS: if provider_type in INFERENCE_PROVIDER_IDS:
provider_id = INFERENCE_PROVIDER_IDS[provider_type] provider_id = INFERENCE_PROVIDER_IDS[provider_type]

View file

@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
""", """,
) )
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
provider_data_validator: str | None = Field(
default=None,
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.") is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now # used internally by the resolver; this is a hack for now
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
async def get_provider_impl(self, routing_key: str) -> Any: ... async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: str = Field(
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
A description of the provider. This is used to display in the documentation.
""",
)
@json_schema_type @json_schema_type
class InlineProviderSpec(ProviderSpec): class InlineProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
container_image: str | None = Field( container_image: str | None = Field(
default=None, default=None,
description=""" description="""
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image. If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""", """,
) )
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field( description: str | None = Field(
default=None, default=None,
description=""" description="""
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type @json_schema_type
class RemoteProviderSpec(ProviderSpec): class RemoteProviderSpec(ProviderSpec):
adapter: AdapterSpec = Field( adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
description: str | None = Field(
default=None,
description=""" description="""
If some code is needed to convert the remote responses into Llama Stack compatible A description of the provider. This is used to display in the documentation.
API responses, specify the adapter here.
""", """,
) )
@ -234,33 +207,6 @@ API responses, specify the adapter here.
def container_image(self) -> str | None: def container_image(self) -> str | None:
return None return None
# module field is inherited from ProviderSpec
@property
def pip_packages(self) -> list[str]:
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> str | None:
return self.adapter.provider_data_validator
def remote_provider_spec(
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)
class HealthStatus(StrEnum): class HealthStatus(StrEnum):
OK = "OK" OK = "OK"

View file

@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True) storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata # Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row: if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()") raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None, where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -25,10 +24,10 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[], api_dependencies=[],
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.", description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.datasetio, api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="huggingface", adapter_type="huggingface",
provider_type="remote::huggingface",
pip_packages=[ pip_packages=[
"datasets>=4.0.0", "datasets>=4.0.0",
], ],
@ -36,17 +35,15 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.", description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.datasetio, api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
provider_type="remote::nvidia",
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
pip_packages=[ pip_packages=[
"datasets>=4.0.0", "datasets>=4.0.0",
], ],
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.", description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
), ),
),
] ]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def available_providers() -> list[ProviderSpec]: def available_providers() -> list[ProviderSpec]:
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
], ],
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.", description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.eval, api=Api.eval,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
pip_packages=[ pip_packages=[
"requests", "requests",
], ],
provider_type="remote::nvidia",
module="llama_stack.providers.remote.eval.nvidia", module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.", description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
),
api_dependencies=[ api_dependencies=[
Api.datasetio, Api.datasetio,
Api.datasets, Api.datasets,

View file

@ -4,13 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.", description="Local filesystem-based file storage provider for managing files and documents locally.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.files, api=Api.files,
adapter=AdapterSpec( provider_type="remote::s3",
adapter_type="s3", adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages, pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3", module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
), ),
),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
META_REFERENCE_DEPS = [ META_REFERENCE_DEPS = [
@ -49,10 +48,10 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
description="Sentence Transformers inference provider for text embeddings and similarity search.", description="Sentence Transformers inference provider for text embeddings and similarity search.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras", adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[ pip_packages=[
"cerebras_cloud_sdk", "cerebras_cloud_sdk",
], ],
@ -60,62 +59,56 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.", description="Cerebras inference provider for running models on Cerebras Cloud platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama", adapter_type="ollama",
provider_type="remote::ollama",
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"], pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama", module="llama_stack.providers.remote.inference.ollama",
description="Ollama inference provider for running local models through the Ollama runtime.", description="Ollama inference provider for running local models through the Ollama runtime.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="vllm", adapter_type="vllm",
provider_type="remote::vllm",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.vllm", module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
description="Remote vLLM inference provider for connecting to vLLM servers.", description="Remote vLLM inference provider for connecting to vLLM servers.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="tgi", adapter_type="tgi",
provider_type="remote::tgi",
pip_packages=["huggingface_hub", "aiohttp"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi", module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
description="Text Generation Inference (TGI) provider for HuggingFace model serving.", description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="hf::serverless", adapter_type="hf::serverless",
provider_type="remote::hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi", module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
description="HuggingFace Inference API serverless provider for on-demand model inference.", description="HuggingFace Inference API serverless provider for on-demand model inference.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( provider_type="remote::hf::endpoint",
adapter_type="hf::endpoint", adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi", module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
description="HuggingFace Inference Endpoints provider for dedicated model serving.", description="HuggingFace Inference Endpoints provider for dedicated model serving.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="fireworks", adapter_type="fireworks",
provider_type="remote::fireworks",
pip_packages=[ pip_packages=[
"fireworks-ai<=0.17.16", "fireworks-ai<=0.17.16",
], ],
@ -124,11 +117,10 @@ def available_providers() -> list[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.", description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="together", adapter_type="together",
provider_type="remote::together",
pip_packages=[ pip_packages=[
"together", "together",
], ],
@ -137,85 +129,82 @@ def available_providers() -> list[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
description="Together AI inference provider for open-source models and collaborative AI development.", description="Together AI inference provider for open-source models and collaborative AI development.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="bedrock", adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"], pip_packages=["boto3"],
module="llama_stack.providers.remote.inference.bedrock", module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="databricks", adapter_type="databricks",
provider_type="remote::databricks",
pip_packages=["databricks-sdk"], pip_packages=["databricks-sdk"],
module="llama_stack.providers.remote.inference.databricks", module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.", description="Databricks inference provider for running models on Databricks' unified analytics platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia", module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="runpod", adapter_type="runpod",
provider_type="remote::runpod",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.runpod", module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.", description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="openai", adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai", module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
description="OpenAI inference provider for accessing GPT models and other OpenAI services.", description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="anthropic", adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic", module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.", description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="gemini", adapter_type="gemini",
pip_packages=["litellm"], provider_type="remote::gemini",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.gemini", module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="vertexai", adapter_type="vertexai",
pip_packages=["litellm", "google-cloud-aiplatform"], provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai", module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
@ -240,65 +229,63 @@ Available Models:
- vertex_ai/gemini-2.5-flash - vertex_ai/gemini-2.5-flash
- vertex_ai/gemini-2.5-pro""", - vertex_ai/gemini-2.5-pro""",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq", adapter_type="groq",
pip_packages=["litellm"], provider_type="remote::groq",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.groq", module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig", config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="llama-openai-compat", adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat", module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova", adapter_type="sambanova",
pip_packages=["litellm"], provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.sambanova", module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.", description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="passthrough", adapter_type="passthrough",
provider_type="remote::passthrough",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.passthrough", module="llama_stack.providers.remote.inference.passthrough",
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
description="Passthrough inference provider for connecting to any external inference service not directly supported.", description="Passthrough inference provider for connecting to any external inference service not directly supported.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="watsonx", adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"], pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx", module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( provider_type="remote::azure",
adapter_type="azure", adapter_type="azure",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.azure", module="llama_stack.providers.remote.inference.azure",
@ -310,5 +297,4 @@ Provider documentation
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
""", """,
), ),
),
] ]

View file

@ -7,7 +7,7 @@
from typing import cast from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
# We provide two versions of these providers so that distributions can package the appropriate version of torch. # We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. # The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
], ],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.post_training, api=Api.post_training,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests", "aiohttp"], pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia", module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.", description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
), ),
),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.", description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec(
adapter_type="bedrock", adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"], pip_packages=["boto3"],
module="llama_stack.providers.remote.safety.bedrock", module="llama_stack.providers.remote.safety.bedrock",
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
description="AWS Bedrock safety provider for content moderation using AWS's safety services.", description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests"], pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia", module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
description="NVIDIA's safety provider for content moderation and safety filtering.", description="NVIDIA's safety provider for content moderation and safety filtering.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec(
adapter_type="sambanova", adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=["litellm", "requests"], pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova", module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova's safety provider for content moderation and safety filtering.", description="SambaNova's safety provider for content moderation and safety filtering.",
), ),
),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[Api.vector_io, Api.inference, Api.files], api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="brave-search", adapter_type="brave-search",
provider_type="remote::brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search", module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
description="Brave Search tool for web search capabilities with privacy-focused results.", description="Brave Search tool for web search capabilities with privacy-focused results.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="bing-search", adapter_type="bing-search",
provider_type="remote::bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search", module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
description="Bing Search tool for web search capabilities using Microsoft's search engine.", description="Bing Search tool for web search capabilities using Microsoft's search engine.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="tavily-search", adapter_type="tavily-search",
provider_type="remote::tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search", module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
description="Tavily Search tool for AI-optimized web search with structured results.", description="Tavily Search tool for AI-optimized web search with structured results.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="wolfram-alpha", adapter_type="wolfram-alpha",
provider_type="remote::wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.", description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="model-context-protocol", adapter_type="model-context-protocol",
provider_type="remote::model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol", module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp>=1.8.1"], pip_packages=["mcp>=1.8.1"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.", description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
), ),
),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -300,13 +299,15 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
Please refer to the sqlite-vec provider documentation. Please refer to the sqlite-vec provider documentation.
""", """,
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec(
adapter_type="chromadb", adapter_type="chromadb",
provider_type="remote::chromadb",
pip_packages=["chromadb-client"], pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma", module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector [Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
@ -341,9 +342,6 @@ pip install chromadb
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general. See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
""", """,
), ),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
provider_type="inline::chromadb", provider_type="inline::chromadb",
@ -387,13 +385,15 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
""", """,
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec(
adapter_type="pgvector", adapter_type="pgvector",
provider_type="remote::pgvector",
pip_packages=["psycopg2-binary"], pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector", module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It [PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
@ -496,17 +496,16 @@ docker pull pgvector/pgvector:pg17
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general. See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
""", """,
), ),
api_dependencies=[Api.inference], RemoteProviderSpec(
optional_api_dependencies=[Api.files], api=Api.vector_io,
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="weaviate", adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client"], pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate", module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack. [Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database. It allows you to store and query vectors directly within a Weaviate database.
@ -539,9 +538,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
""", """,
), ),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
provider_type="inline::qdrant", provider_type="inline::qdrant",
@ -594,27 +590,28 @@ docker pull qdrant/qdrant
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general. See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
""", """,
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec(
adapter_type="qdrant", adapter_type="qdrant",
provider_type="remote::qdrant",
pip_packages=["qdrant-client"], pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant", module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
Please refer to the inline provider documentation. Please refer to the inline provider documentation.
""", """,
), ),
api_dependencies=[Api.inference], RemoteProviderSpec(
optional_api_dependencies=[Api.files], api=Api.vector_io,
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="milvus", adapter_type="milvus",
provider_type="remote::milvus",
pip_packages=["pymilvus>=2.4.10"], pip_packages=["pymilvus>=2.4.10"],
module="llama_stack.providers.remote.vector_io.milvus", module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It [Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database. allows you to store and query vectors directly within a Milvus database.
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation ## Installation
You can install Milvus using pymilvus: If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash ```bash
pip install pymilvus pip install pymilvus
@ -807,13 +810,10 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md). For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
""", """,
), ),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
provider_type="inline::milvus", provider_type="inline::milvus",
pip_packages=["pymilvus>=2.4.10"], pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
module="llama_stack.providers.inline.vector_io.milvus", module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],

View file

@ -137,7 +137,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id} where: dict[str, str | dict] = {"id": file_id}
if not return_expired: if not return_expired:
where["expires_at"] = {">": self._now()} where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()") raise ResourceNotFoundError(file_id, "File", "files.list()")
return row return row
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config) self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config) await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table( await self._sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions, where=where_conditions,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -54,7 +54,7 @@ class InferenceStore:
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"chat_completions", "chat_completions",
{ {
@ -202,7 +202,6 @@ class InferenceStore:
order_by=[("created", order.value)], order_by=[("created", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [ data = [
@ -229,7 +228,6 @@ class InferenceStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
table="chat_completions", table="chat_completions",
where={"id": completion_id}, where={"id": completion_id},
policy=self.policy,
) )
if not row: if not row:

View file

@ -28,8 +28,7 @@ class ResponsesStore:
sql_store_config = SqliteSqlStoreConfig( sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
) )
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
self.policy = policy
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
@ -87,7 +86,6 @@ class ResponsesStore:
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
@ -105,7 +103,6 @@ class ResponsesStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
"openai_responses", "openai_responses",
where={"id": response_id}, where={"id": response_id},
policy=self.policy,
) )
if not row: if not row:
@ -116,7 +113,7 @@ class ResponsesStore:
return OpenAIResponseObjectWithInput(**row["response_object"]) return OpenAIResponseObjectWithInput(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
if not row: if not row:
raise ValueError(f"Response with id {response_id} not found") raise ValueError(f"Response with id {response_id} not found")
await self.sql_store.delete("openai_responses", where={"id": response_id}) await self.sql_store.delete("openai_responses", where={"id": response_id})

View file

@ -53,13 +53,15 @@ class AuthorizedSqlStore:
access control policies, user attribute capture, and SQL filtering optimization. access control policies, user attribute capture, and SQL filtering optimization.
""" """
def __init__(self, sql_store: SqlStore): def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
""" """
Initialize the authorization layer. Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap :param sql_store: Base SqlStore implementation to wrap
:param policy: Access control policy to use for authorization
""" """
self.sql_store = sql_store self.sql_store = sql_store
self.policy = policy
self._detect_database_type() self._detect_database_type()
self._validate_sql_optimized_policy() self._validate_sql_optimized_policy()
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
async def fetch_all( async def fetch_all(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
limit: int | None = None, limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None, cursor: tuple[str, str] | None = None,
) -> PaginatedResponse: ) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering.""" """Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(policy) access_where = self._build_access_control_where_clause(self.policy)
rows = await self.sql_store.fetch_all( rows = await self.sql_store.fetch_all(
table=table, table=table,
where=where, where=where,
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
) )
if is_action_allowed(policy, Action.READ, sql_record, current_user): if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
filtered_rows.append(row) filtered_rows.append(row)
return PaginatedResponse( return PaginatedResponse(
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
async def fetch_one( async def fetch_one(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking.""" """Fetch one row with automatic access control checking."""
results = await self.fetch_all( results = await self.fetch_all(
table=table, table=table,
policy=policy,
where=where, where=where,
limit=1, limit=1,
order_by=order_by, order_by=order_by,

View file

@ -1,6 +1,5 @@
adapter: adapter_type: kaze
adapter_type: kaze pip_packages: ["tests/external/llama-stack-provider-kaze"]
pip_packages: ["tests/external/llama-stack-provider-kaze"] config_class: llama_stack_provider_kaze.config.KazeProviderConfig
config_class: llama_stack_provider_kaze.config.KazeProviderConfig module: llama_stack_provider_kaze
module: llama_stack_provider_kaze
optional_api_dependencies: [] optional_api_dependencies: []

View file

@ -6,7 +6,7 @@
from typing import Protocol from typing import Protocol
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec
from llama_stack.schema_utils import webmethod from llama_stack.schema_utils import webmethod
@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]:
api=Api.weather, api=Api.weather,
provider_type="remote::kaze", provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig", config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec(
adapter_type="kaze", adapter_type="kaze",
module="llama_stack_provider_kaze", module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"], pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
), ),
] ]

View file

@ -57,7 +57,7 @@ def authorized_store(backend_config):
config = config_func() config = config_func()
base_sqlstore = sqlstore_impl(config) base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
yield authorized_store yield authorized_store
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison # Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None assert result.data[0]["access_attributes"] is None
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both # Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 2 assert len(result.data) == 2
# Test with non-admin user # Test with non-admin user
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
# Should only see public record # Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
# Now test with the multi-user who has both roles=admin and teams=dev # Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
# Should see: # Should see:
# - public record (1) - no access_attributes # - public record (1) - no access_attributes
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
), ),
] ]
# Create a new authorized store with the owner-only policy
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
# Test user1 access - should only see their own record # Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1 mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Test user2 access - should only see their own record # Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2 mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Test with anonymous user - should see no records # Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally: finally:

View file

@ -66,10 +66,9 @@ def base_config(tmp_path):
def provider_spec_yaml(): def provider_spec_yaml():
"""Common provider spec YAML for testing.""" """Common provider spec YAML for testing."""
return """ return """
adapter: adapter_type: test_provider
adapter_type: test_provider config_class: test_provider.config.TestProviderConfig
config_class: test_provider.config.TestProviderConfig module: test_provider
module: test_provider
api_dependencies: api_dependencies:
- safety - safety
""" """
@ -182,9 +181,9 @@ class TestProviderRegistry:
assert Api.inference in registry assert Api.inference in registry
assert "remote::test_provider" in registry[Api.inference] assert "remote::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["remote::test_provider"] provider = registry[Api.inference]["remote::test_provider"]
assert provider.adapter.adapter_type == "test_provider" assert provider.adapter_type == "test_provider"
assert provider.adapter.module == "test_provider" assert provider.module == "test_provider"
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig" assert provider.config_class == "test_provider.config.TestProviderConfig"
assert Api.safety in provider.api_dependencies assert Api.safety in provider.api_dependencies
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml): def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
@ -246,8 +245,7 @@ class TestProviderRegistry:
"""Test handling of malformed remote provider spec (missing required fields).""" """Test handling of malformed remote provider spec (missing required fields)."""
remote_dir, _ = api_directories remote_dir, _ = api_directories
malformed_spec = """ malformed_spec = """
adapter: adapter_type: test_provider
adapter_type: test_provider
# Missing required fields # Missing required fields
api_dependencies: api_dependencies:
- safety - safety
@ -270,7 +268,7 @@ pip_packages:
with open(inline_dir / "malformed.yaml", "w") as f: with open(inline_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec) f.write(malformed_spec)
with pytest.raises(KeyError) as exc_info: with pytest.raises(ValidationError) as exc_info:
get_provider_registry(base_config) get_provider_registry(base_config)
assert "config_class" in str(exc_info.value) assert "config_class" in str(exc_info.value)

View file

@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests") client = LlamaStackAsLibraryClient("ci-tests")
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests") client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests") client = LlamaStackAsLibraryClient("ci-tests")
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests") client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests") sync_client = LlamaStackAsLibraryClient("ci-tests")

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control # Create table with access control
await sqlstore.create_table( await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents # Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document" assert result.data[0]["title"] == "Admin Document"
# User should only see their document # User should only see their document
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0 assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "User Document" assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None assert row is not None
assert row["title"] == "User Document" assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table( await sqlstore.create_table(
table="resources", table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"]) user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy) sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data} sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set() policy_ids = set()
for scenario in test_scenarios: for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table( await authorized_store.create_table(
table="user_data", table="user_data",