Merge branch 'main' into use-openai-for-databricks

This commit is contained in:
Matthew Farrellee 2025-09-20 06:16:54 -04:00
commit c8623607f5
31 changed files with 665 additions and 691 deletions

View file

@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export MOCK_INFERENCE_MODEL=mock-inference
export MOCK_INFERENCE_URL=openai-mock-service:8080
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
export LLAMA_STACK_WORKERS=4
set -euo pipefail set -euo pipefail
set -x set -x

View file

@ -5,6 +5,7 @@ data:
image_name: kubernetes-benchmark-demo image_name: kubernetes-benchmark-demo
apis: apis:
- agents - agents
- files
- inference - inference
- files - files
- safety - safety
@ -23,6 +24,14 @@ data:
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
vector_io: vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} - provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb

View file

@ -52,9 +52,20 @@ spec:
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
- name: VLLM_TLS_VERIFY - name: VLLM_TLS_VERIFY
value: "false" value: "false"
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] - name: LLAMA_STACK_LOGGING
value: "all=WARNING"
- name: LLAMA_STACK_CONFIG
value: "/etc/config/stack_run_config.yaml"
- name: LLAMA_STACK_WORKERS
value: "${LLAMA_STACK_WORKERS}"
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
ports: ports:
- containerPort: 8323 - containerPort: 8323
resources:
requests:
cpu: "${LLAMA_STACK_WORKERS}"
limits:
cpu: "${LLAMA_STACK_WORKERS}"
volumeMounts: volumeMounts:
- name: llama-storage - name: llama-storage
mountPath: /root/.llama mountPath: /root/.llama

View file

@ -460,10 +460,12 @@ client = LlamaStackClient(base_url="http://localhost:8321")
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding") embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
embedding_model = embed_lm.identifier embedding_model = embed_lm.identifier
vector_db_id = f"v{uuid.uuid4().hex}" vector_db_id = f"v{uuid.uuid4().hex}"
client.vector_dbs.register( # The VectorDB API is deprecated; the server now returns its own authoritative ID.
# We capture the correct ID from the response's .identifier attribute.
vector_db_id = client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model=embedding_model, embedding_model=embedding_model,
) ).identifier
# Create Documents # Create Documents
urls = [ urls = [

View file

@ -23,7 +23,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation ## Installation
You can install Milvus using pymilvus: If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash ```bash
pip install pymilvus pip install pymilvus

View file

@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
default=None, default=None,
) )
@property
def pip_packages(self) -> list[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields # Example: /models, /shields
class RoutingTableProviderSpec(ProviderSpec): class RoutingTableProviderSpec(ProviderSpec):

View file

@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -77,27 +76,12 @@ def providable_apis() -> list[Api]:
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec: def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"]) spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
return spec return spec
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec( spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
return spec return spec

View file

@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
from llama_stack.core.resolver import ProviderRegistry from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import ( from llama_stack.core.stack import (
construct_stack, Stack,
get_stack_run_config_from_distro, get_stack_run_config_from_distro,
replace_env_vars, replace_env_vars,
) )
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
try: try:
self.route_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
stack = Stack(self.config, self.custom_provider_registry)
await stack.initialize()
self.impls = stack.impls
except ModuleNotFoundError as _e: except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr) cprint(_e.msg, color="red", file=sys.stderr)
cprint( cprint(
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
raise _e raise _e
assert self.impls is not None
if Api.telemetry in self.impls: if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry]) setup_logger(self.impls[Api.telemetry])

View file

@ -6,6 +6,7 @@
import argparse import argparse
import asyncio import asyncio
import concurrent.futures
import functools import functools
import inspect import inspect
import json import json
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
request_provider_data_context, request_provider_data_context,
user_from_scope, user_from_scope,
) )
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.server.routes import ( from llama_stack.core.server.routes import (
find_matching_route, find_matching_route,
get_all_api_routes, get_all_api_routes,
initialize_route_impls, initialize_route_impls,
) )
from llama_stack.core.stack import ( from llama_stack.core.stack import (
Stack,
cast_image_name_to_string, cast_image_name_to_string,
construct_stack,
replace_env_vars, replace_env_vars,
shutdown_stack,
validate_env_pair, validate_env_pair,
) )
from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config import redact_sensitive_fields
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
) )
async def shutdown(app): class StackApp(FastAPI):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
""" """
await shutdown_stack(app.__llama_stack_impls__) A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
"""
def __init__(self, config: StackRunConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stack: Stack = Stack(config)
# This code is called from a running event loop managed by uvicorn so we cannot simply call
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
# function.
# As a workaround, we use a thread pool executor to run the initialize() method
# in a separate thread.
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.stack.initialize())
future.result()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: StackApp):
logger.info("Starting up") logger.info("Starting up")
assert app.stack is not None
app.stack.create_registry_refresh_task()
yield yield
logger.info("Shutting down") logger.info("Shutting down")
await shutdown(app) await app.stack.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
def main(args: argparse.Namespace | None = None): def create_app(
"""Start the LlamaStack server.""" config_file: str | None = None,
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") env_vars: list[str] | None = None,
) -> StackApp:
"""Create and configure the FastAPI application.
add_config_distro_args(parser) Args:
parser.add_argument( config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
"--port", env_vars: List of environment variables in KEY=value format.
type=int, disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case Returns:
# the args will be passed as a Namespace object to the main function, otherwise they will be Configured StackApp instance.
# parsed from the command line """
if args is None: config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
args = parser.parse_args() if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
config_or_distro = get_config_from_args(args) config_file = resolve_config_or_distro(config_file, Mode.RUN)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
# Load and process configuration
logger_config = None logger_config = None
with open(config_file) as fp: with open(config_file) as fp:
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config) logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env: if env_vars:
for env_pair in env_vars:
try: try:
key, value = validate_env_pair(env_pair) key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}") logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value os.environ[key] = value
except ValueError as e: except ValueError as e:
logger.error(f"Error: {str(e)}") logger.error(f"Error: {str(e)}")
sys.exit(1) raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents) config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config)) config = StackRunConfig(**cast_image_name_to_string(config))
_log_run_config(run_config=config) _log_run_config(run_config=config)
app = FastAPI( app = StackApp(
lifespan=lifespan, lifespan=lifespan,
docs_url="/docs", docs_url="/docs",
redoc_url="/redoc", redoc_url="/redoc",
openapi_url="/openapi.json", openapi_url="/openapi.json",
config=config,
) )
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
try: impls = app.stack.impls
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth: if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn import uvicorn
# Configure SSL if certificates are provided # Configure SSL if certificates are provided
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
if ssl_config: if ssl_config:
uvicorn_config.update(ssl_config) uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
# We need to catch KeyboardInterrupt because uvicorn's signal handling # We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python # re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing # converts to KeyboardInterrupt. Without this catch, we'd get a confusing
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort. # signal handling but this is quite intrusive and not worth the effort.
try: try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...") logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig): def _log_run_config(run_config: StackRunConfig):

View file

@ -315,78 +315,84 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.prompts] = prompts_impl impls[Api.prompts] = prompts_impl
# Produces a stack of providers for the given run config. Not all APIs may be class Stack:
# asked for in the run config. def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
async def construct_stack( self.run_config = run_config
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None self.provider_registry = provider_registry
) -> dict[Api, Any]: self.impls = None
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording # Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls(
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, self.run_config)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
await register_resources(self.run_config, impls)
await refresh_registry_once(impls)
self.impls = impls
def create_registry_refresh_task(self):
assert self.impls is not None, "Must call initialize() before starting"
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
async def shutdown(self):
for impl in self.impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT global TEST_RECORDING_CONTEXT
TEST_RECORDING_CONTEXT = setup_inference_recording()
if TEST_RECORDING_CONTEXT: if TEST_RECORDING_CONTEXT:
TEST_RECORDING_CONTEXT.__enter__() try:
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) global REGISTRY_REFRESH_TASK
policy = run_config.server.auth.access_policy if run_config.server.auth else [] if REGISTRY_REFRESH_TASK:
impls = await resolve_impls( REGISTRY_REFRESH_TASK.cancel()
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
await register_resources(run_config, impls)
await refresh_registry_once(impls)
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
def cb(task):
import traceback
if task.cancelled():
logger.error("Model refresh task cancelled")
elif task.exception():
logger.error(f"Model refresh task failed: {task.exception()}")
traceback.print_exception(task.exception())
else:
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning(f"No shutdown method for {impl_name}")
except TimeoutError:
logger.exception(f"Shutdown timeout for {impl_name}")
except (Exception, asyncio.CancelledError) as e:
logger.exception(f"Failed to shutdown {impl_name}: {e}")
global TEST_RECORDING_CONTEXT
if TEST_RECORDING_CONTEXT:
try:
TEST_RECORDING_CONTEXT.__exit__(None, None, None)
except Exception as e:
logger.error(f"Error during inference recording cleanup: {e}")
global REGISTRY_REFRESH_TASK
if REGISTRY_REFRESH_TASK:
REGISTRY_REFRESH_TASK.cancel()
async def refresh_registry_once(impls: dict[Api, Any]): async def refresh_registry_once(impls: dict[Api, Any]):

View file

@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
remote_providers = [ remote_providers = [
provider provider
for provider in available_providers() for provider in available_providers()
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
] ]
inference_providers = [] inference_providers = []
for provider_spec in remote_providers: for provider_spec in remote_providers:
provider_type = provider_spec.adapter.adapter_type provider_type = provider_spec.adapter_type
if provider_type in INFERENCE_PROVIDER_IDS: if provider_type in INFERENCE_PROVIDER_IDS:
provider_id = INFERENCE_PROVIDER_IDS[provider_type] provider_id = INFERENCE_PROVIDER_IDS[provider_type]

View file

@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
""", """,
) )
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
provider_data_validator: str | None = Field(
default=None,
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.") is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now # used internally by the resolver; this is a hack for now
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
async def get_provider_impl(self, routing_key: str) -> Any: ... async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: str = Field(
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
A description of the provider. This is used to display in the documentation.
""",
)
@json_schema_type @json_schema_type
class InlineProviderSpec(ProviderSpec): class InlineProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
container_image: str | None = Field( container_image: str | None = Field(
default=None, default=None,
description=""" description="""
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image. If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""", """,
) )
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field( description: str | None = Field(
default=None, default=None,
description=""" description="""
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type @json_schema_type
class RemoteProviderSpec(ProviderSpec): class RemoteProviderSpec(ProviderSpec):
adapter: AdapterSpec = Field( adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
description: str | None = Field(
default=None,
description=""" description="""
If some code is needed to convert the remote responses into Llama Stack compatible A description of the provider. This is used to display in the documentation.
API responses, specify the adapter here.
""", """,
) )
@ -234,33 +207,6 @@ API responses, specify the adapter here.
def container_image(self) -> str | None: def container_image(self) -> str | None:
return None return None
# module field is inherited from ProviderSpec
@property
def pip_packages(self) -> list[str]:
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> str | None:
return self.adapter.provider_data_validator
def remote_provider_spec(
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)
class HealthStatus(StrEnum): class HealthStatus(StrEnum):
OK = "OK" OK = "OK"

View file

@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True) storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata # Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row: if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()") raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None, where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[], api_dependencies=[],
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.", description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.datasetio, api=Api.datasetio,
adapter=AdapterSpec( adapter_type="huggingface",
adapter_type="huggingface", provider_type="remote::huggingface",
pip_packages=[ pip_packages=[
"datasets>=4.0.0", "datasets>=4.0.0",
], ],
module="llama_stack.providers.remote.datasetio.huggingface", module="llama_stack.providers.remote.datasetio.huggingface",
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.", description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.datasetio, api=Api.datasetio,
adapter=AdapterSpec( adapter_type="nvidia",
adapter_type="nvidia", provider_type="remote::nvidia",
pip_packages=[ module="llama_stack.providers.remote.datasetio.nvidia",
"datasets>=4.0.0", config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
], pip_packages=[
module="llama_stack.providers.remote.datasetio.nvidia", "datasets>=4.0.0",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig", ],
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.", description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
),
), ),
] ]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def available_providers() -> list[ProviderSpec]: def available_providers() -> list[ProviderSpec]:
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
], ],
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.", description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.eval, api=Api.eval,
adapter=AdapterSpec( adapter_type="nvidia",
adapter_type="nvidia", pip_packages=[
pip_packages=[ "requests",
"requests", ],
], provider_type="remote::nvidia",
module="llama_stack.providers.remote.eval.nvidia", module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.", description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
),
api_dependencies=[ api_dependencies=[
Api.datasetio, Api.datasetio,
Api.datasets, Api.datasets,

View file

@ -4,13 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.", description="Local filesystem-based file storage provider for managing files and documents locally.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.files, api=Api.files,
adapter=AdapterSpec( provider_type="remote::s3",
adapter_type="s3", adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages, pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3", module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
), ),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
META_REFERENCE_DEPS = [ META_REFERENCE_DEPS = [
@ -49,177 +48,167 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
description="Sentence Transformers inference provider for text embeddings and similarity search.", description="Sentence Transformers inference provider for text embeddings and similarity search.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="cerebras",
adapter_type="cerebras", provider_type="remote::cerebras",
pip_packages=[ pip_packages=[
"cerebras_cloud_sdk", "cerebras_cloud_sdk",
], ],
module="llama_stack.providers.remote.inference.cerebras", module="llama_stack.providers.remote.inference.cerebras",
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.", description="Cerebras inference provider for running models on Cerebras Cloud platform.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="ollama",
adapter_type="ollama", provider_type="remote::ollama",
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"], pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama", module="llama_stack.providers.remote.inference.ollama",
description="Ollama inference provider for running local models through the Ollama runtime.", description="Ollama inference provider for running local models through the Ollama runtime.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="vllm",
adapter_type="vllm", provider_type="remote::vllm",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.vllm", module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
description="Remote vLLM inference provider for connecting to vLLM servers.", description="Remote vLLM inference provider for connecting to vLLM servers.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="tgi",
adapter_type="tgi", provider_type="remote::tgi",
pip_packages=["huggingface_hub", "aiohttp"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi", module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
description="Text Generation Inference (TGI) provider for HuggingFace model serving.", description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="hf::serverless",
adapter_type="hf::serverless", provider_type="remote::hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi", module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
description="HuggingFace Inference API serverless provider for on-demand model inference.", description="HuggingFace Inference API serverless provider for on-demand model inference.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( provider_type="remote::hf::endpoint",
adapter_type="hf::endpoint", adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi", module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
description="HuggingFace Inference Endpoints provider for dedicated model serving.", description="HuggingFace Inference Endpoints provider for dedicated model serving.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="fireworks",
adapter_type="fireworks", provider_type="remote::fireworks",
pip_packages=[ pip_packages=[
"fireworks-ai<=0.17.16", "fireworks-ai<=0.17.16",
], ],
module="llama_stack.providers.remote.inference.fireworks", module="llama_stack.providers.remote.inference.fireworks",
config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.", description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="together",
adapter_type="together", provider_type="remote::together",
pip_packages=[ pip_packages=[
"together", "together",
], ],
module="llama_stack.providers.remote.inference.together", module="llama_stack.providers.remote.inference.together",
config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
description="Together AI inference provider for open-source models and collaborative AI development.", description="Together AI inference provider for open-source models and collaborative AI development.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="bedrock",
adapter_type="bedrock", provider_type="remote::bedrock",
pip_packages=["boto3"], pip_packages=["boto3"],
module="llama_stack.providers.remote.inference.bedrock", module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="databricks",
adapter_type="databricks", provider_type="remote::databricks",
pip_packages=["databricks-sdk"], pip_packages=["databricks-sdk"],
module="llama_stack.providers.remote.inference.databricks", module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.", description="Databricks inference provider for running models on Databricks' unified analytics platform.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="nvidia",
adapter_type="nvidia", provider_type="remote::nvidia",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia", module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="runpod",
adapter_type="runpod", provider_type="remote::runpod",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.runpod", module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.", description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="openai",
adapter_type="openai", provider_type="remote::openai",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai", module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
description="OpenAI inference provider for accessing GPT models and other OpenAI services.", description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="anthropic",
adapter_type="anthropic", provider_type="remote::anthropic",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic", module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.", description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="gemini",
adapter_type="gemini", provider_type="remote::gemini",
pip_packages=["litellm"], pip_packages=[
module="llama_stack.providers.remote.inference.gemini", "litellm",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", ],
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", module="llama_stack.providers.remote.inference.gemini",
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
), provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="vertexai",
adapter_type="vertexai", provider_type="remote::vertexai",
pip_packages=["litellm", "google-cloud-aiplatform"], pip_packages=[
module="llama_stack.providers.remote.inference.vertexai", "litellm",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", "google-cloud-aiplatform",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", ],
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages:
Enterprise-grade security: Uses Google Cloud's security controls and IAM Enterprise-grade security: Uses Google Cloud's security controls and IAM
Better integration: Seamless integration with other Google Cloud services Better integration: Seamless integration with other Google Cloud services
@ -239,76 +228,73 @@ Available Models:
- vertex_ai/gemini-2.0-flash - vertex_ai/gemini-2.0-flash
- vertex_ai/gemini-2.5-flash - vertex_ai/gemini-2.5-flash
- vertex_ai/gemini-2.5-pro""", - vertex_ai/gemini-2.5-pro""",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="groq",
adapter_type="groq", provider_type="remote::groq",
pip_packages=["litellm"], pip_packages=[
module="llama_stack.providers.remote.inference.groq", "litellm",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig", ],
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", module="llama_stack.providers.remote.inference.groq",
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
), provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="llama-openai-compat",
adapter_type="llama-openai-compat", provider_type="remote::llama-openai-compat",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat", module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="sambanova",
adapter_type="sambanova", provider_type="remote::sambanova",
pip_packages=["litellm"], pip_packages=[
module="llama_stack.providers.remote.inference.sambanova", "litellm",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", ],
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", module="llama_stack.providers.remote.inference.sambanova",
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
), provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="passthrough",
adapter_type="passthrough", provider_type="remote::passthrough",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.passthrough", module="llama_stack.providers.remote.inference.passthrough",
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
description="Passthrough inference provider for connecting to any external inference service not directly supported.", description="Passthrough inference provider for connecting to any external inference service not directly supported.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter_type="watsonx",
adapter_type="watsonx", provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"], pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx", module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( provider_type="remote::azure",
adapter_type="azure", adapter_type="azure",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.azure", module="llama_stack.providers.remote.inference.azure",
config_class="llama_stack.providers.remote.inference.azure.AzureConfig", config_class="llama_stack.providers.remote.inference.azure.AzureConfig",
provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator",
description=""" description="""
Azure OpenAI inference provider for accessing GPT models and other Azure services. Azure OpenAI inference provider for accessing GPT models and other Azure services.
Provider documentation Provider documentation
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
""", """,
),
), ),
] ]

View file

@ -7,7 +7,7 @@
from typing import cast from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
# We provide two versions of these providers so that distributions can package the appropriate version of torch. # We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. # The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
], ],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.post_training, api=Api.post_training,
adapter=AdapterSpec( adapter_type="nvidia",
adapter_type="nvidia", provider_type="remote::nvidia",
pip_packages=["requests", "aiohttp"], pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia", module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.", description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
),
), ),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.", description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec( adapter_type="bedrock",
adapter_type="bedrock", provider_type="remote::bedrock",
pip_packages=["boto3"], pip_packages=["boto3"],
module="llama_stack.providers.remote.safety.bedrock", module="llama_stack.providers.remote.safety.bedrock",
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
description="AWS Bedrock safety provider for content moderation using AWS's safety services.", description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec( adapter_type="nvidia",
adapter_type="nvidia", provider_type="remote::nvidia",
pip_packages=["requests"], pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia", module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
description="NVIDIA's safety provider for content moderation and safety filtering.", description="NVIDIA's safety provider for content moderation and safety filtering.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec( adapter_type="sambanova",
adapter_type="sambanova", provider_type="remote::sambanova",
pip_packages=["litellm", "requests"], pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova", module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova's safety provider for content moderation and safety filtering.", description="SambaNova's safety provider for content moderation and safety filtering.",
),
), ),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[Api.vector_io, Api.inference, Api.files], api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec( adapter_type="brave-search",
adapter_type="brave-search", provider_type="remote::brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search", module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
description="Brave Search tool for web search capabilities with privacy-focused results.", description="Brave Search tool for web search capabilities with privacy-focused results.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec( adapter_type="bing-search",
adapter_type="bing-search", provider_type="remote::bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search", module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
description="Bing Search tool for web search capabilities using Microsoft's search engine.", description="Bing Search tool for web search capabilities using Microsoft's search engine.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec( adapter_type="tavily-search",
adapter_type="tavily-search", provider_type="remote::tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search", module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
description="Tavily Search tool for AI-optimized web search with structured results.", description="Tavily Search tool for AI-optimized web search with structured results.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec( adapter_type="wolfram-alpha",
adapter_type="wolfram-alpha", provider_type="remote::wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.", description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
),
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec( adapter_type="model-context-protocol",
adapter_type="model-context-protocol", provider_type="remote::model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol", module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp>=1.8.1"], pip_packages=["mcp>=1.8.1"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.", description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
),
), ),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
Please refer to the sqlite-vec provider documentation. Please refer to the sqlite-vec provider documentation.
""", """,
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec( adapter_type="chromadb",
adapter_type="chromadb", provider_type="remote::chromadb",
pip_packages=["chromadb-client"], pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma", module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
description=""" api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector [Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
That means you're not limited to storing vectors in memory or in a separate service. That means you're not limited to storing vectors in memory or in a separate service.
@ -340,9 +341,6 @@ pip install chromadb
## Documentation ## Documentation
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general. See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
""", """,
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
""", """,
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec( adapter_type="pgvector",
adapter_type="pgvector", provider_type="remote::pgvector",
pip_packages=["psycopg2-binary"], pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector", module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
description=""" api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It [PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
That means you'll get fast and efficient vector retrieval. That means you'll get fast and efficient vector retrieval.
@ -495,19 +495,18 @@ docker pull pgvector/pgvector:pg17
## Documentation ## Documentation
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general. See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
""", """,
), ),
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files], optional_api_dependencies=[Api.files],
), description="""
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="weaviate",
pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
description="""
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack. [Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database. It allows you to store and query vectors directly within a Weaviate database.
That means you're not limited to storing vectors in memory or in a separate service. That means you're not limited to storing vectors in memory or in a separate service.
@ -538,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
## Documentation ## Documentation
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
""", """,
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
@ -594,28 +590,29 @@ docker pull qdrant/qdrant
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general. See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
""", """,
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec( adapter_type="qdrant",
adapter_type="qdrant", provider_type="remote::qdrant",
pip_packages=["qdrant-client"], pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant", module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
description="""
Please refer to the inline provider documentation.
""",
),
api_dependencies=[Api.inference], api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files], optional_api_dependencies=[Api.files],
description="""
Please refer to the inline provider documentation.
""",
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec( adapter_type="milvus",
adapter_type="milvus", provider_type="remote::milvus",
pip_packages=["pymilvus>=2.4.10"], pip_packages=["pymilvus>=2.4.10"],
module="llama_stack.providers.remote.vector_io.milvus", module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
description=""" api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It [Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database. allows you to store and query vectors directly within a Milvus database.
That means you're not limited to storing vectors in memory or in a separate service. That means you're not limited to storing vectors in memory or in a separate service.
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation ## Installation
You can install Milvus using pymilvus: If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash ```bash
pip install pymilvus pip install pymilvus
@ -806,14 +809,11 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md). For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
""", """,
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
provider_type="inline::milvus", provider_type="inline::milvus",
pip_packages=["pymilvus>=2.4.10"], pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
module="llama_stack.providers.inline.vector_io.milvus", module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],

View file

@ -137,7 +137,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id} where: dict[str, str | dict] = {"id": file_id}
if not return_expired: if not return_expired:
where["expires_at"] = {">": self._now()} where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()") raise ResourceNotFoundError(file_id, "File", "files.list()")
return row return row
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config) self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config) await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table( await self._sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions, where=where_conditions,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -54,7 +54,7 @@ class InferenceStore:
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"chat_completions", "chat_completions",
{ {
@ -202,7 +202,6 @@ class InferenceStore:
order_by=[("created", order.value)], order_by=[("created", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [ data = [
@ -229,7 +228,6 @@ class InferenceStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
table="chat_completions", table="chat_completions",
where={"id": completion_id}, where={"id": completion_id},
policy=self.policy,
) )
if not row: if not row:

View file

@ -28,8 +28,7 @@ class ResponsesStore:
sql_store_config = SqliteSqlStoreConfig( sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
) )
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
self.policy = policy
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
@ -87,7 +86,6 @@ class ResponsesStore:
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
@ -105,7 +103,6 @@ class ResponsesStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
"openai_responses", "openai_responses",
where={"id": response_id}, where={"id": response_id},
policy=self.policy,
) )
if not row: if not row:
@ -116,7 +113,7 @@ class ResponsesStore:
return OpenAIResponseObjectWithInput(**row["response_object"]) return OpenAIResponseObjectWithInput(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
if not row: if not row:
raise ValueError(f"Response with id {response_id} not found") raise ValueError(f"Response with id {response_id} not found")
await self.sql_store.delete("openai_responses", where={"id": response_id}) await self.sql_store.delete("openai_responses", where={"id": response_id})

View file

@ -53,13 +53,15 @@ class AuthorizedSqlStore:
access control policies, user attribute capture, and SQL filtering optimization. access control policies, user attribute capture, and SQL filtering optimization.
""" """
def __init__(self, sql_store: SqlStore): def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
""" """
Initialize the authorization layer. Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap :param sql_store: Base SqlStore implementation to wrap
:param policy: Access control policy to use for authorization
""" """
self.sql_store = sql_store self.sql_store = sql_store
self.policy = policy
self._detect_database_type() self._detect_database_type()
self._validate_sql_optimized_policy() self._validate_sql_optimized_policy()
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
async def fetch_all( async def fetch_all(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
limit: int | None = None, limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None, cursor: tuple[str, str] | None = None,
) -> PaginatedResponse: ) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering.""" """Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(policy) access_where = self._build_access_control_where_clause(self.policy)
rows = await self.sql_store.fetch_all( rows = await self.sql_store.fetch_all(
table=table, table=table,
where=where, where=where,
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
) )
if is_action_allowed(policy, Action.READ, sql_record, current_user): if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
filtered_rows.append(row) filtered_rows.append(row)
return PaginatedResponse( return PaginatedResponse(
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
async def fetch_one( async def fetch_one(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking.""" """Fetch one row with automatic access control checking."""
results = await self.fetch_all( results = await self.fetch_all(
table=table, table=table,
policy=policy,
where=where, where=where,
limit=1, limit=1,
order_by=order_by, order_by=order_by,

View file

@ -1,6 +1,5 @@
adapter: adapter_type: kaze
adapter_type: kaze pip_packages: ["tests/external/llama-stack-provider-kaze"]
pip_packages: ["tests/external/llama-stack-provider-kaze"] config_class: llama_stack_provider_kaze.config.KazeProviderConfig
config_class: llama_stack_provider_kaze.config.KazeProviderConfig module: llama_stack_provider_kaze
module: llama_stack_provider_kaze
optional_api_dependencies: [] optional_api_dependencies: []

View file

@ -6,7 +6,7 @@
from typing import Protocol from typing import Protocol
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec
from llama_stack.schema_utils import webmethod from llama_stack.schema_utils import webmethod
@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]:
api=Api.weather, api=Api.weather,
provider_type="remote::kaze", provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig", config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec( adapter_type="kaze",
adapter_type="kaze", module="llama_stack_provider_kaze",
module="llama_stack_provider_kaze", pip_packages=["llama_stack_provider_kaze"],
pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
), ),
] ]

View file

@ -57,7 +57,7 @@ def authorized_store(backend_config):
config = config_func() config = config_func()
base_sqlstore = sqlstore_impl(config) base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
yield authorized_store yield authorized_store
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison # Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None assert result.data[0]["access_attributes"] is None
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both # Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 2 assert len(result.data) == 2
# Test with non-admin user # Test with non-admin user
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
# Should only see public record # Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
# Now test with the multi-user who has both roles=admin and teams=dev # Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
# Should see: # Should see:
# - public record (1) - no access_attributes # - public record (1) - no access_attributes
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
), ),
] ]
# Create a new authorized store with the owner-only policy
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
# Test user1 access - should only see their own record # Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1 mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Test user2 access - should only see their own record # Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2 mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Test with anonymous user - should see no records # Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally: finally:

View file

@ -66,10 +66,9 @@ def base_config(tmp_path):
def provider_spec_yaml(): def provider_spec_yaml():
"""Common provider spec YAML for testing.""" """Common provider spec YAML for testing."""
return """ return """
adapter: adapter_type: test_provider
adapter_type: test_provider config_class: test_provider.config.TestProviderConfig
config_class: test_provider.config.TestProviderConfig module: test_provider
module: test_provider
api_dependencies: api_dependencies:
- safety - safety
""" """
@ -182,9 +181,9 @@ class TestProviderRegistry:
assert Api.inference in registry assert Api.inference in registry
assert "remote::test_provider" in registry[Api.inference] assert "remote::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["remote::test_provider"] provider = registry[Api.inference]["remote::test_provider"]
assert provider.adapter.adapter_type == "test_provider" assert provider.adapter_type == "test_provider"
assert provider.adapter.module == "test_provider" assert provider.module == "test_provider"
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig" assert provider.config_class == "test_provider.config.TestProviderConfig"
assert Api.safety in provider.api_dependencies assert Api.safety in provider.api_dependencies
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml): def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
@ -246,8 +245,7 @@ class TestProviderRegistry:
"""Test handling of malformed remote provider spec (missing required fields).""" """Test handling of malformed remote provider spec (missing required fields)."""
remote_dir, _ = api_directories remote_dir, _ = api_directories
malformed_spec = """ malformed_spec = """
adapter: adapter_type: test_provider
adapter_type: test_provider
# Missing required fields # Missing required fields
api_dependencies: api_dependencies:
- safety - safety
@ -270,7 +268,7 @@ pip_packages:
with open(inline_dir / "malformed.yaml", "w") as f: with open(inline_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec) f.write(malformed_spec)
with pytest.raises(KeyError) as exc_info: with pytest.raises(ValidationError) as exc_info:
get_provider_registry(base_config) get_provider_registry(base_config)
assert "config_class" in str(exc_info.value) assert "config_class" in str(exc_info.value)

View file

@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests") client = LlamaStackAsLibraryClient("ci-tests")
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests") client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests") client = LlamaStackAsLibraryClient("ci-tests")
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests") client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests") sync_client = LlamaStackAsLibraryClient("ci-tests")

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control # Create table with access control
await sqlstore.create_table( await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents # Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document" assert result.data[0]["title"] == "Admin Document"
# User should only see their document # User should only see their document
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0 assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "User Document" assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None assert row is not None
assert row["title"] == "User Document" assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table( await sqlstore.create_table(
table="resources", table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"]) user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy) sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data} sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set() policy_ids = set()
for scenario in test_scenarios: for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table( await authorized_store.create_table(
table="user_data", table="user_data",