diff --git a/benchmarking/k8s-benchmark/apply.sh b/benchmarking/k8s-benchmark/apply.sh index 4f2270da8..6e6607663 100755 --- a/benchmarking/k8s-benchmark/apply.sh +++ b/benchmarking/k8s-benchmark/apply.sh @@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B -export MOCK_INFERENCE_MODEL=mock-inference - -export MOCK_INFERENCE_URL=openai-mock-service:8080 - export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL +export LLAMA_STACK_WORKERS=4 set -euo pipefail set -x diff --git a/benchmarking/k8s-benchmark/stack-configmap.yaml b/benchmarking/k8s-benchmark/stack-configmap.yaml index bf6109b68..286ba5f77 100644 --- a/benchmarking/k8s-benchmark/stack-configmap.yaml +++ b/benchmarking/k8s-benchmark/stack-configmap.yaml @@ -5,6 +5,7 @@ data: image_name: kubernetes-benchmark-demo apis: - agents + - files - inference - files - safety @@ -23,6 +24,14 @@ data: - provider_id: sentence-transformers provider_type: inline::sentence-transformers config: {} + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db vector_io: - provider_id: ${env.ENABLE_CHROMADB:+chromadb} provider_type: remote::chromadb diff --git a/benchmarking/k8s-benchmark/stack-k8s.yaml.template b/benchmarking/k8s-benchmark/stack-k8s.yaml.template index 9cb1e5be3..8842c0bea 100644 --- a/benchmarking/k8s-benchmark/stack-k8s.yaml.template +++ b/benchmarking/k8s-benchmark/stack-k8s.yaml.template @@ -52,9 +52,20 @@ spec: value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 - name: VLLM_TLS_VERIFY value: "false" - command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] + - name: LLAMA_STACK_LOGGING + value: "all=WARNING" + - name: LLAMA_STACK_CONFIG + value: "/etc/config/stack_run_config.yaml" + - name: LLAMA_STACK_WORKERS + value: "${LLAMA_STACK_WORKERS}" + command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"] ports: - containerPort: 8323 + resources: + requests: + cpu: "${LLAMA_STACK_WORKERS}" + limits: + cpu: "${LLAMA_STACK_WORKERS}" volumeMounts: - name: llama-storage mountPath: /root/.llama diff --git a/docs/source/getting_started/detailed_tutorial.md b/docs/source/getting_started/detailed_tutorial.md index 14f888628..77a899c48 100644 --- a/docs/source/getting_started/detailed_tutorial.md +++ b/docs/source/getting_started/detailed_tutorial.md @@ -460,10 +460,12 @@ client = LlamaStackClient(base_url="http://localhost:8321") embed_lm = next(m for m in client.models.list() if m.model_type == "embedding") embedding_model = embed_lm.identifier vector_db_id = f"v{uuid.uuid4().hex}" -client.vector_dbs.register( +# The VectorDB API is deprecated; the server now returns its own authoritative ID. +# We capture the correct ID from the response's .identifier attribute. +vector_db_id = client.vector_dbs.register( vector_db_id=vector_db_id, embedding_model=embedding_model, -) +).identifier # Create Documents urls = [ diff --git a/docs/source/providers/vector_io/remote_milvus.md b/docs/source/providers/vector_io/remote_milvus.md index 075423d04..8974ada10 100644 --- a/docs/source/providers/vector_io/remote_milvus.md +++ b/docs/source/providers/vector_io/remote_milvus.md @@ -23,7 +23,13 @@ To use Milvus in your Llama Stack project, follow these steps: ## Installation -You can install Milvus using pymilvus: +If you want to use inline Milvus, you can install: + +```bash +pip install pymilvus[milvus-lite] +``` + +If you want to use remote Milvus, you can install: ```bash pip install pymilvus diff --git a/llama_stack/core/datatypes.py b/llama_stack/core/datatypes.py index faaeefd01..b5558c66f 100644 --- a/llama_stack/core/datatypes.py +++ b/llama_stack/core/datatypes.py @@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec): default=None, ) - @property - def pip_packages(self) -> list[str]: - raise AssertionError("Should not be called on AutoRoutedProviderSpec") - # Example: /models, /shields class RoutingTableProviderSpec(ProviderSpec): diff --git a/llama_stack/core/distribution.py b/llama_stack/core/distribution.py index c104b6764..302ecb960 100644 --- a/llama_stack/core/distribution.py +++ b/llama_stack/core/distribution.py @@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec from llama_stack.core.external import load_external_apis from llama_stack.log import get_logger from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) logger = get_logger(name=__name__, category="core") @@ -77,27 +76,12 @@ def providable_apis() -> list[Api]: def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec: - adapter = AdapterSpec(**spec_data["adapter"]) - spec = remote_provider_spec( - api=api, - adapter=adapter, - api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])], - ) + spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data) return spec def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: - spec = InlineProviderSpec( - api=api, - provider_type=f"inline::{provider_name}", - pip_packages=spec_data.get("pip_packages", []), - module=spec_data["module"], - config_class=spec_data["config_class"], - api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])], - optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])], - provider_data_validator=spec_data.get("provider_data_validator"), - container_image=spec_data.get("container_image"), - ) + spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data) return spec diff --git a/llama_stack/core/library_client.py b/llama_stack/core/library_client.py index ea5a2ac8e..e722e4de6 100644 --- a/llama_stack/core/library_client.py +++ b/llama_stack/core/library_client.py @@ -40,7 +40,7 @@ from llama_stack.core.request_headers import ( from llama_stack.core.resolver import ProviderRegistry from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls from llama_stack.core.stack import ( - construct_stack, + Stack, get_stack_run_config_from_distro, replace_env_vars, ) @@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): try: self.route_impls = None - self.impls = await construct_stack(self.config, self.custom_provider_registry) + + stack = Stack(self.config, self.custom_provider_registry) + await stack.initialize() + self.impls = stack.impls except ModuleNotFoundError as _e: cprint(_e.msg, color="red", file=sys.stderr) cprint( @@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) raise _e + assert self.impls is not None if Api.telemetry in self.impls: setup_logger(self.impls[Api.telemetry]) diff --git a/llama_stack/core/server/server.py b/llama_stack/core/server/server.py index d3e875fec..9cca42268 100644 --- a/llama_stack/core/server/server.py +++ b/llama_stack/core/server/server.py @@ -6,6 +6,7 @@ import argparse import asyncio +import concurrent.futures import functools import inspect import json @@ -50,17 +51,15 @@ from llama_stack.core.request_headers import ( request_provider_data_context, user_from_scope, ) -from llama_stack.core.resolver import InvalidProviderError from llama_stack.core.server.routes import ( find_matching_route, get_all_api_routes, initialize_route_impls, ) from llama_stack.core.stack import ( + Stack, cast_image_name_to_string, - construct_stack, replace_env_vars, - shutdown_stack, validate_env_pair, ) from llama_stack.core.utils.config import redact_sensitive_fields @@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro ) -async def shutdown(app): - """Initiate a graceful shutdown of the application. - - Handled by the lifespan context manager. The shutdown process involves - shutting down all implementations registered in the application. +class StackApp(FastAPI): """ - await shutdown_stack(app.__llama_stack_impls__) + A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can + start background tasks (e.g. refresh model registry periodically) from the lifespan context manager. + """ + + def __init__(self, config: StackRunConfig, *args, **kwargs): + super().__init__(*args, **kwargs) + self.stack: Stack = Stack(config) + + # This code is called from a running event loop managed by uvicorn so we cannot simply call + # asyncio.run() to initialize the stack. We cannot await either since this is not an async + # function. + # As a workaround, we use a thread pool executor to run the initialize() method + # in a separate thread. + with concurrent.futures.ThreadPoolExecutor() as executor: + future = executor.submit(asyncio.run, self.stack.initialize()) + future.result() @asynccontextmanager -async def lifespan(app: FastAPI): +async def lifespan(app: StackApp): logger.info("Starting up") + assert app.stack is not None + app.stack.create_registry_refresh_task() yield logger.info("Shutting down") - await shutdown(app) + await app.stack.shutdown() def is_streaming_request(func_name: str, request: Request, **kwargs): @@ -386,73 +398,61 @@ class ClientVersionMiddleware: return await self.app(scope, receive, send) -def main(args: argparse.Namespace | None = None): - """Start the LlamaStack server.""" - parser = argparse.ArgumentParser(description="Start the LlamaStack server.") +def create_app( + config_file: str | None = None, + env_vars: list[str] | None = None, +) -> StackApp: + """Create and configure the FastAPI application. - add_config_distro_args(parser) - parser.add_argument( - "--port", - type=int, - default=int(os.getenv("LLAMA_STACK_PORT", 8321)), - help="Port to listen on", - ) - parser.add_argument( - "--env", - action="append", - help="Environment variables in KEY=value format. Can be specified multiple times.", - ) + Args: + config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution. + env_vars: List of environment variables in KEY=value format. + disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var. - # Determine whether the server args are being passed by the "run" command, if this is the case - # the args will be passed as a Namespace object to the main function, otherwise they will be - # parsed from the command line - if args is None: - args = parser.parse_args() + Returns: + Configured StackApp instance. + """ + config_file = config_file or os.getenv("LLAMA_STACK_CONFIG") + if config_file is None: + raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set") - config_or_distro = get_config_from_args(args) - config_file = resolve_config_or_distro(config_or_distro, Mode.RUN) + config_file = resolve_config_or_distro(config_file, Mode.RUN) + # Load and process configuration logger_config = None with open(config_file) as fp: config_contents = yaml.safe_load(fp) if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): logger_config = LoggingConfig(**cfg) logger = get_logger(name=__name__, category="core::server", config=logger_config) - if args.env: - for env_pair in args.env: + + if env_vars: + for env_pair in env_vars: try: key, value = validate_env_pair(env_pair) - logger.info(f"Setting CLI environment variable {key} => {value}") + logger.info(f"Setting environment variable {key} => {value}") os.environ[key] = value except ValueError as e: logger.error(f"Error: {str(e)}") - sys.exit(1) + raise ValueError(f"Invalid environment variable format: {env_pair}") from e + config = replace_env_vars(config_contents) config = StackRunConfig(**cast_image_name_to_string(config)) _log_run_config(run_config=config) - app = FastAPI( + app = StackApp( lifespan=lifespan, docs_url="/docs", redoc_url="/redoc", openapi_url="/openapi.json", + config=config, ) if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) - try: - # Create and set the event loop that will be used for both construction and server runtime - loop = asyncio.new_event_loop() - asyncio.set_event_loop(loop) - - # Construct the stack in the persistent event loop - impls = loop.run_until_complete(construct_stack(config)) - - except InvalidProviderError as e: - logger.error(f"Error: {str(e)}") - sys.exit(1) + impls = app.stack.impls if config.server.auth: logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") @@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None): app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler) - app.__llama_stack_impls__ = impls app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) + return app + + +def main(args: argparse.Namespace | None = None): + """Start the LlamaStack server.""" + parser = argparse.ArgumentParser(description="Start the LlamaStack server.") + + add_config_distro_args(parser) + parser.add_argument( + "--port", + type=int, + default=int(os.getenv("LLAMA_STACK_PORT", 8321)), + help="Port to listen on", + ) + parser.add_argument( + "--env", + action="append", + help="Environment variables in KEY=value format. Can be specified multiple times.", + ) + + # Determine whether the server args are being passed by the "run" command, if this is the case + # the args will be passed as a Namespace object to the main function, otherwise they will be + # parsed from the command line + if args is None: + args = parser.parse_args() + + config_or_distro = get_config_from_args(args) + + try: + app = create_app( + config_file=config_or_distro, + env_vars=args.env, + ) + except Exception as e: + logger.error(f"Error creating app: {str(e)}") + sys.exit(1) + + config_file = resolve_config_or_distro(config_or_distro, Mode.RUN) + with open(config_file) as fp: + config_contents = yaml.safe_load(fp) + if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): + logger_config = LoggingConfig(**cfg) + else: + logger_config = None + config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents))) + import uvicorn # Configure SSL if certificates are provided @@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None): if ssl_config: uvicorn_config.update(ssl_config) - # Run uvicorn in the existing event loop to preserve background tasks # We need to catch KeyboardInterrupt because uvicorn's signal handling # re-raises SIGINT signals using signal.raise_signal(), which Python # converts to KeyboardInterrupt. Without this catch, we'd get a confusing @@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None): # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own # signal handling but this is quite intrusive and not worth the effort. try: - loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) + asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) except (KeyboardInterrupt, SystemExit): logger.info("Received interrupt signal, shutting down gracefully...") - finally: - if not loop.is_closed(): - logger.debug("Closing event loop") - loop.close() def _log_run_config(run_config: StackRunConfig): diff --git a/llama_stack/core/stack.py b/llama_stack/core/stack.py index 7ab8d2c64..a6c5093eb 100644 --- a/llama_stack/core/stack.py +++ b/llama_stack/core/stack.py @@ -315,78 +315,84 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf impls[Api.prompts] = prompts_impl -# Produces a stack of providers for the given run config. Not all APIs may be -# asked for in the run config. -async def construct_stack( - run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None -) -> dict[Api, Any]: - if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: - from llama_stack.testing.inference_recorder import setup_inference_recording +class Stack: + def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None): + self.run_config = run_config + self.provider_registry = provider_registry + self.impls = None + + # Produces a stack of providers for the given run config. Not all APIs may be + # asked for in the run config. + async def initialize(self): + if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: + from llama_stack.testing.inference_recorder import setup_inference_recording + + global TEST_RECORDING_CONTEXT + TEST_RECORDING_CONTEXT = setup_inference_recording() + if TEST_RECORDING_CONTEXT: + TEST_RECORDING_CONTEXT.__enter__() + logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") + + dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name) + policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else [] + impls = await resolve_impls( + self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy + ) + + # Add internal implementations after all other providers are resolved + add_internal_implementations(impls, self.run_config) + + if Api.prompts in impls: + await impls[Api.prompts].initialize() + + await register_resources(self.run_config, impls) + + await refresh_registry_once(impls) + self.impls = impls + + def create_registry_refresh_task(self): + assert self.impls is not None, "Must call initialize() before starting" + + global REGISTRY_REFRESH_TASK + REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls)) + + def cb(task): + import traceback + + if task.cancelled(): + logger.error("Model refresh task cancelled") + elif task.exception(): + logger.error(f"Model refresh task failed: {task.exception()}") + traceback.print_exception(task.exception()) + else: + logger.debug("Model refresh task completed") + + REGISTRY_REFRESH_TASK.add_done_callback(cb) + + async def shutdown(self): + for impl in self.impls.values(): + impl_name = impl.__class__.__name__ + logger.info(f"Shutting down {impl_name}") + try: + if hasattr(impl, "shutdown"): + await asyncio.wait_for(impl.shutdown(), timeout=5) + else: + logger.warning(f"No shutdown method for {impl_name}") + except TimeoutError: + logger.exception(f"Shutdown timeout for {impl_name}") + except (Exception, asyncio.CancelledError) as e: + logger.exception(f"Failed to shutdown {impl_name}: {e}") global TEST_RECORDING_CONTEXT - TEST_RECORDING_CONTEXT = setup_inference_recording() if TEST_RECORDING_CONTEXT: - TEST_RECORDING_CONTEXT.__enter__() - logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") + try: + TEST_RECORDING_CONTEXT.__exit__(None, None, None) + except Exception as e: + logger.error(f"Error during inference recording cleanup: {e}") - dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) - policy = run_config.server.auth.access_policy if run_config.server.auth else [] - impls = await resolve_impls( - run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy - ) - - # Add internal implementations after all other providers are resolved - add_internal_implementations(impls, run_config) - - if Api.prompts in impls: - await impls[Api.prompts].initialize() - - await register_resources(run_config, impls) - - await refresh_registry_once(impls) - - global REGISTRY_REFRESH_TASK - REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls)) - - def cb(task): - import traceback - - if task.cancelled(): - logger.error("Model refresh task cancelled") - elif task.exception(): - logger.error(f"Model refresh task failed: {task.exception()}") - traceback.print_exception(task.exception()) - else: - logger.debug("Model refresh task completed") - - REGISTRY_REFRESH_TASK.add_done_callback(cb) - return impls - - -async def shutdown_stack(impls: dict[Api, Any]): - for impl in impls.values(): - impl_name = impl.__class__.__name__ - logger.info(f"Shutting down {impl_name}") - try: - if hasattr(impl, "shutdown"): - await asyncio.wait_for(impl.shutdown(), timeout=5) - else: - logger.warning(f"No shutdown method for {impl_name}") - except TimeoutError: - logger.exception(f"Shutdown timeout for {impl_name}") - except (Exception, asyncio.CancelledError) as e: - logger.exception(f"Failed to shutdown {impl_name}: {e}") - - global TEST_RECORDING_CONTEXT - if TEST_RECORDING_CONTEXT: - try: - TEST_RECORDING_CONTEXT.__exit__(None, None, None) - except Exception as e: - logger.error(f"Error during inference recording cleanup: {e}") - - global REGISTRY_REFRESH_TASK - if REGISTRY_REFRESH_TASK: - REGISTRY_REFRESH_TASK.cancel() + global REGISTRY_REFRESH_TASK + if REGISTRY_REFRESH_TASK: + REGISTRY_REFRESH_TASK.cancel() async def refresh_registry_once(impls: dict[Api, Any]): diff --git a/llama_stack/distributions/starter/starter.py b/llama_stack/distributions/starter/starter.py index c2dfe95ad..6bee51ff0 100644 --- a/llama_stack/distributions/starter/starter.py +++ b/llama_stack/distributions/starter/starter.py @@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]: remote_providers = [ provider for provider in available_providers() - if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS + if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS ] inference_providers = [] for provider_spec in remote_providers: - provider_type = provider_spec.adapter.adapter_type + provider_type = provider_spec.adapter_type if provider_type in INFERENCE_PROVIDER_IDS: provider_id = INFERENCE_PROVIDER_IDS[provider_type] diff --git a/llama_stack/providers/datatypes.py b/llama_stack/providers/datatypes.py index 5e15dd8e1..c8ff9cecb 100644 --- a/llama_stack/providers/datatypes.py +++ b/llama_stack/providers/datatypes.py @@ -131,6 +131,15 @@ class ProviderSpec(BaseModel): """, ) + pip_packages: list[str] = Field( + default_factory=list, + description="The pip dependencies needed for this implementation", + ) + + provider_data_validator: str | None = Field( + default=None, + ) + is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.") # used internally by the resolver; this is a hack for now @@ -145,45 +154,8 @@ class RoutingTable(Protocol): async def get_provider_impl(self, routing_key: str) -> Any: ... -# TODO: this can now be inlined into RemoteProviderSpec -@json_schema_type -class AdapterSpec(BaseModel): - adapter_type: str = Field( - ..., - description="Unique identifier for this adapter", - ) - module: str = Field( - default_factory=str, - description=""" -Fully-qualified name of the module to import. The module is expected to have: - - - `get_adapter_impl(config, deps)`: returns the adapter implementation -""", - ) - pip_packages: list[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) - config_class: str = Field( - description="Fully-qualified classname of the config for this provider", - ) - provider_data_validator: str | None = Field( - default=None, - ) - description: str | None = Field( - default=None, - description=""" -A description of the provider. This is used to display in the documentation. -""", - ) - - @json_schema_type class InlineProviderSpec(ProviderSpec): - pip_packages: list[str] = Field( - default_factory=list, - description="The pip dependencies needed for this implementation", - ) container_image: str | None = Field( default=None, description=""" @@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack If a provider depends on other providers, the dependencies MUST NOT specify a container image. """, ) - # module field is inherited from ProviderSpec - provider_data_validator: str | None = Field( - default=None, - ) description: str | None = Field( default=None, description=""" @@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel): @json_schema_type class RemoteProviderSpec(ProviderSpec): - adapter: AdapterSpec = Field( + adapter_type: str = Field( + ..., + description="Unique identifier for this adapter", + ) + + description: str | None = Field( + default=None, description=""" -If some code is needed to convert the remote responses into Llama Stack compatible -API responses, specify the adapter here. +A description of the provider. This is used to display in the documentation. """, ) @@ -234,33 +207,6 @@ API responses, specify the adapter here. def container_image(self) -> str | None: return None - # module field is inherited from ProviderSpec - - @property - def pip_packages(self) -> list[str]: - return self.adapter.pip_packages - - @property - def provider_data_validator(self) -> str | None: - return self.adapter.provider_data_validator - - -def remote_provider_spec( - api: Api, - adapter: AdapterSpec, - api_dependencies: list[Api] | None = None, - optional_api_dependencies: list[Api] | None = None, -) -> RemoteProviderSpec: - return RemoteProviderSpec( - api=api, - provider_type=f"remote::{adapter.adapter_type}", - config_class=adapter.config_class, - module=adapter.module, - adapter=adapter, - api_dependencies=api_dependencies or [], - optional_api_dependencies=optional_api_dependencies or [], - ) - class HealthStatus(StrEnum): OK = "OK" diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 9c610c1ba..65cf8d815 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files): storage_path.mkdir(parents=True, exist_ok=True) # Initialize SQL store for metadata - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy) await self.sql_store.create_table( "openai_files", { @@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files): if not self.sql_store: raise RuntimeError("Files provider not initialized") - row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) if not row: raise ResourceNotFoundError(file_id, "File", "client.files.list()") @@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions if where_conditions else None, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/registry/datasetio.py b/llama_stack/providers/registry/datasetio.py index f641b4ce3..a9feb0bac 100644 --- a/llama_stack/providers/registry/datasetio.py +++ b/llama_stack/providers/registry/datasetio.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -25,28 +24,26 @@ def available_providers() -> list[ProviderSpec]: api_dependencies=[], description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.datasetio, - adapter=AdapterSpec( - adapter_type="huggingface", - pip_packages=[ - "datasets>=4.0.0", - ], - module="llama_stack.providers.remote.datasetio.huggingface", - config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", - description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.", - ), + adapter_type="huggingface", + provider_type="remote::huggingface", + pip_packages=[ + "datasets>=4.0.0", + ], + module="llama_stack.providers.remote.datasetio.huggingface", + config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", + description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.datasetio, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=[ - "datasets>=4.0.0", - ], - module="llama_stack.providers.remote.datasetio.nvidia", - config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig", - description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + module="llama_stack.providers.remote.datasetio.nvidia", + config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig", + pip_packages=[ + "datasets>=4.0.0", + ], + description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.", ), ] diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index 9f0d17916..4ef0bb41f 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -5,7 +5,7 @@ # the root directory of this source tree. -from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec def available_providers() -> list[ProviderSpec]: @@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]: ], description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.eval, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=[ - "requests", - ], - module="llama_stack.providers.remote.eval.nvidia", - config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", - description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.", - ), + adapter_type="nvidia", + pip_packages=[ + "requests", + ], + provider_type="remote::nvidia", + module="llama_stack.providers.remote.eval.nvidia", + config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", + description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.", api_dependencies=[ Api.datasetio, Api.datasets, diff --git a/llama_stack/providers/registry/files.py b/llama_stack/providers/registry/files.py index ebe90310c..9acabfacd 100644 --- a/llama_stack/providers/registry/files.py +++ b/llama_stack/providers/registry/files.py @@ -4,13 +4,7 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -from llama_stack.providers.datatypes import ( - AdapterSpec, - Api, - InlineProviderSpec, - ProviderSpec, - remote_provider_spec, -) +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages @@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", description="Local filesystem-based file storage provider for managing files and documents locally.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.files, - adapter=AdapterSpec( - adapter_type="s3", - pip_packages=["boto3"] + sql_store_pip_packages, - module="llama_stack.providers.remote.files.s3", - config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", - description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", - ), + provider_type="remote::s3", + adapter_type="s3", + pip_packages=["boto3"] + sql_store_pip_packages, + module="llama_stack.providers.remote.files.s3", + config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", + description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", ), ] diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 9b70f4f7b..89d7f55e8 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) META_REFERENCE_DEPS = [ @@ -49,177 +48,167 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", description="Sentence Transformers inference provider for text embeddings and similarity search.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="cerebras", - pip_packages=[ - "cerebras_cloud_sdk", - ], - module="llama_stack.providers.remote.inference.cerebras", - config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", - description="Cerebras inference provider for running models on Cerebras Cloud platform.", - ), + adapter_type="cerebras", + provider_type="remote::cerebras", + pip_packages=[ + "cerebras_cloud_sdk", + ], + module="llama_stack.providers.remote.inference.cerebras", + config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", + description="Cerebras inference provider for running models on Cerebras Cloud platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="ollama", - pip_packages=["ollama", "aiohttp", "h11>=0.16.0"], - config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", - module="llama_stack.providers.remote.inference.ollama", - description="Ollama inference provider for running local models through the Ollama runtime.", - ), + adapter_type="ollama", + provider_type="remote::ollama", + pip_packages=["ollama", "aiohttp", "h11>=0.16.0"], + config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", + module="llama_stack.providers.remote.inference.ollama", + description="Ollama inference provider for running local models through the Ollama runtime.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="vllm", - pip_packages=[], - module="llama_stack.providers.remote.inference.vllm", - config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", - provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", - description="Remote vLLM inference provider for connecting to vLLM servers.", - ), + adapter_type="vllm", + provider_type="remote::vllm", + pip_packages=[], + module="llama_stack.providers.remote.inference.vllm", + config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", + provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", + description="Remote vLLM inference provider for connecting to vLLM servers.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="tgi", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", - description="Text Generation Inference (TGI) provider for HuggingFace model serving.", - ), + adapter_type="tgi", + provider_type="remote::tgi", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", + description="Text Generation Inference (TGI) provider for HuggingFace model serving.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="hf::serverless", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", - description="HuggingFace Inference API serverless provider for on-demand model inference.", - ), + adapter_type="hf::serverless", + provider_type="remote::hf::serverless", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", + description="HuggingFace Inference API serverless provider for on-demand model inference.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="hf::endpoint", - pip_packages=["huggingface_hub", "aiohttp"], - module="llama_stack.providers.remote.inference.tgi", - config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", - description="HuggingFace Inference Endpoints provider for dedicated model serving.", - ), + provider_type="remote::hf::endpoint", + adapter_type="hf::endpoint", + pip_packages=["huggingface_hub", "aiohttp"], + module="llama_stack.providers.remote.inference.tgi", + config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", + description="HuggingFace Inference Endpoints provider for dedicated model serving.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="fireworks", - pip_packages=[ - "fireworks-ai<=0.17.16", - ], - module="llama_stack.providers.remote.inference.fireworks", - config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", - description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.", - ), + adapter_type="fireworks", + provider_type="remote::fireworks", + pip_packages=[ + "fireworks-ai<=0.17.16", + ], + module="llama_stack.providers.remote.inference.fireworks", + config_class="llama_stack.providers.remote.inference.fireworks.FireworksImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", + description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="together", - pip_packages=[ - "together", - ], - module="llama_stack.providers.remote.inference.together", - config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", - description="Together AI inference provider for open-source models and collaborative AI development.", - ), + adapter_type="together", + provider_type="remote::together", + pip_packages=[ + "together", + ], + module="llama_stack.providers.remote.inference.together", + config_class="llama_stack.providers.remote.inference.together.TogetherImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", + description="Together AI inference provider for open-source models and collaborative AI development.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="bedrock", - pip_packages=["boto3"], - module="llama_stack.providers.remote.inference.bedrock", - config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", - description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", - ), + adapter_type="bedrock", + provider_type="remote::bedrock", + pip_packages=["boto3"], + module="llama_stack.providers.remote.inference.bedrock", + config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", + description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="databricks", - pip_packages=["databricks-sdk"], - module="llama_stack.providers.remote.inference.databricks", - config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", - description="Databricks inference provider for running models on Databricks' unified analytics platform.", - ), + adapter_type="databricks", + provider_type="remote::databricks", + pip_packages=["databricks-sdk"], + module="llama_stack.providers.remote.inference.databricks", + config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", + description="Databricks inference provider for running models on Databricks' unified analytics platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=[], - module="llama_stack.providers.remote.inference.nvidia", - config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", - description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + pip_packages=[], + module="llama_stack.providers.remote.inference.nvidia", + config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", + description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="runpod", - pip_packages=[], - module="llama_stack.providers.remote.inference.runpod", - config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", - description="RunPod inference provider for running models on RunPod's cloud GPU platform.", - ), + adapter_type="runpod", + provider_type="remote::runpod", + pip_packages=[], + module="llama_stack.providers.remote.inference.runpod", + config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", + description="RunPod inference provider for running models on RunPod's cloud GPU platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="openai", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.openai", - config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", - provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", - description="OpenAI inference provider for accessing GPT models and other OpenAI services.", - ), + adapter_type="openai", + provider_type="remote::openai", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.openai", + config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", + description="OpenAI inference provider for accessing GPT models and other OpenAI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="anthropic", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.anthropic", - config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", - provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", - description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.", - ), + adapter_type="anthropic", + provider_type="remote::anthropic", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.anthropic", + config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", + provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", + description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="gemini", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.gemini", - config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", - provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", - description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", - ), + adapter_type="gemini", + provider_type="remote::gemini", + pip_packages=[ + "litellm", + ], + module="llama_stack.providers.remote.inference.gemini", + config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", + provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", + description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="vertexai", - pip_packages=["litellm", "google-cloud-aiplatform"], - module="llama_stack.providers.remote.inference.vertexai", - config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", - provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", - description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: + adapter_type="vertexai", + provider_type="remote::vertexai", + pip_packages=[ + "litellm", + "google-cloud-aiplatform", + ], + module="llama_stack.providers.remote.inference.vertexai", + config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", + provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", + description="""Google Vertex AI inference provider enables you to use Google's Gemini models through Google Cloud's Vertex AI platform, providing several advantages: • Enterprise-grade security: Uses Google Cloud's security controls and IAM • Better integration: Seamless integration with other Google Cloud services @@ -239,76 +228,73 @@ Available Models: - vertex_ai/gemini-2.0-flash - vertex_ai/gemini-2.5-flash - vertex_ai/gemini-2.5-pro""", - ), ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="groq", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.groq", - config_class="llama_stack.providers.remote.inference.groq.GroqConfig", - provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", - description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", - ), + adapter_type="groq", + provider_type="remote::groq", + pip_packages=[ + "litellm", + ], + module="llama_stack.providers.remote.inference.groq", + config_class="llama_stack.providers.remote.inference.groq.GroqConfig", + provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", + description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="llama-openai-compat", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.llama_openai_compat", - config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", - provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", - description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", - ), + adapter_type="llama-openai-compat", + provider_type="remote::llama-openai-compat", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.llama_openai_compat", + config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", + provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", + description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="sambanova", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.sambanova", - config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", - description="SambaNova inference provider for running models on SambaNova's dataflow architecture.", - ), + adapter_type="sambanova", + provider_type="remote::sambanova", + pip_packages=[ + "litellm", + ], + module="llama_stack.providers.remote.inference.sambanova", + config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", + description="SambaNova inference provider for running models on SambaNova's dataflow architecture.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="passthrough", - pip_packages=[], - module="llama_stack.providers.remote.inference.passthrough", - config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", - provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", - description="Passthrough inference provider for connecting to any external inference service not directly supported.", - ), + adapter_type="passthrough", + provider_type="remote::passthrough", + pip_packages=[], + module="llama_stack.providers.remote.inference.passthrough", + config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", + provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", + description="Passthrough inference provider for connecting to any external inference service not directly supported.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="watsonx", - pip_packages=["ibm_watsonx_ai"], - module="llama_stack.providers.remote.inference.watsonx", - config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", - provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", - description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", - ), + adapter_type="watsonx", + provider_type="remote::watsonx", + pip_packages=["ibm_watsonx_ai"], + module="llama_stack.providers.remote.inference.watsonx", + config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.inference, - adapter=AdapterSpec( - adapter_type="azure", - pip_packages=["litellm"], - module="llama_stack.providers.remote.inference.azure", - config_class="llama_stack.providers.remote.inference.azure.AzureConfig", - provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator", - description=""" + provider_type="remote::azure", + adapter_type="azure", + pip_packages=["litellm"], + module="llama_stack.providers.remote.inference.azure", + config_class="llama_stack.providers.remote.inference.azure.AzureConfig", + provider_data_validator="llama_stack.providers.remote.inference.azure.config.AzureProviderDataValidator", + description=""" Azure OpenAI inference provider for accessing GPT models and other Azure services. Provider documentation https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview """, - ), ), ] diff --git a/llama_stack/providers/registry/post_training.py b/llama_stack/providers/registry/post_training.py index 47aeb401e..2092e3b2d 100644 --- a/llama_stack/providers/registry/post_training.py +++ b/llama_stack/providers/registry/post_training.py @@ -7,7 +7,7 @@ from typing import cast -from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec +from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec # We provide two versions of these providers so that distributions can package the appropriate version of torch. # The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. @@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]: ], description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.post_training, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=["requests", "aiohttp"], - module="llama_stack.providers.remote.post_training.nvidia", - config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", - description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + pip_packages=["requests", "aiohttp"], + module="llama_stack.providers.remote.post_training.nvidia", + config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", + description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.", ), ] diff --git a/llama_stack/providers/registry/safety.py b/llama_stack/providers/registry/safety.py index 9dd791bd8..b30074398 100644 --- a/llama_stack/providers/registry/safety.py +++ b/llama_stack/providers/registry/safety.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.safety, - adapter=AdapterSpec( - adapter_type="bedrock", - pip_packages=["boto3"], - module="llama_stack.providers.remote.safety.bedrock", - config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", - description="AWS Bedrock safety provider for content moderation using AWS's safety services.", - ), + adapter_type="bedrock", + provider_type="remote::bedrock", + pip_packages=["boto3"], + module="llama_stack.providers.remote.safety.bedrock", + config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", + description="AWS Bedrock safety provider for content moderation using AWS's safety services.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.safety, - adapter=AdapterSpec( - adapter_type="nvidia", - pip_packages=["requests"], - module="llama_stack.providers.remote.safety.nvidia", - config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", - description="NVIDIA's safety provider for content moderation and safety filtering.", - ), + adapter_type="nvidia", + provider_type="remote::nvidia", + pip_packages=["requests"], + module="llama_stack.providers.remote.safety.nvidia", + config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", + description="NVIDIA's safety provider for content moderation and safety filtering.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.safety, - adapter=AdapterSpec( - adapter_type="sambanova", - pip_packages=["litellm", "requests"], - module="llama_stack.providers.remote.safety.sambanova", - config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", - provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", - description="SambaNova's safety provider for content moderation and safety filtering.", - ), + adapter_type="sambanova", + provider_type="remote::sambanova", + pip_packages=["litellm", "requests"], + module="llama_stack.providers.remote.safety.sambanova", + config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", + provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", + description="SambaNova's safety provider for content moderation and safety filtering.", ), ] diff --git a/llama_stack/providers/registry/tool_runtime.py b/llama_stack/providers/registry/tool_runtime.py index 5a58fa7af..ad8c31dfd 100644 --- a/llama_stack/providers/registry/tool_runtime.py +++ b/llama_stack/providers/registry/tool_runtime.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]: api_dependencies=[Api.vector_io, Api.inference, Api.files], description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="brave-search", - module="llama_stack.providers.remote.tool_runtime.brave_search", - config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", - description="Brave Search tool for web search capabilities with privacy-focused results.", - ), + adapter_type="brave-search", + provider_type="remote::brave-search", + module="llama_stack.providers.remote.tool_runtime.brave_search", + config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", + description="Brave Search tool for web search capabilities with privacy-focused results.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="bing-search", - module="llama_stack.providers.remote.tool_runtime.bing_search", - config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", - description="Bing Search tool for web search capabilities using Microsoft's search engine.", - ), + adapter_type="bing-search", + provider_type="remote::bing-search", + module="llama_stack.providers.remote.tool_runtime.bing_search", + config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", + description="Bing Search tool for web search capabilities using Microsoft's search engine.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="tavily-search", - module="llama_stack.providers.remote.tool_runtime.tavily_search", - config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", - description="Tavily Search tool for AI-optimized web search with structured results.", - ), + adapter_type="tavily-search", + provider_type="remote::tavily-search", + module="llama_stack.providers.remote.tool_runtime.tavily_search", + config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", + description="Tavily Search tool for AI-optimized web search with structured results.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="wolfram-alpha", - module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", - config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", - pip_packages=["requests"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", - description="Wolfram Alpha tool for computational knowledge and mathematical calculations.", - ), + adapter_type="wolfram-alpha", + provider_type="remote::wolfram-alpha", + module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", + config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", + pip_packages=["requests"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", + description="Wolfram Alpha tool for computational knowledge and mathematical calculations.", ), - remote_provider_spec( + RemoteProviderSpec( api=Api.tool_runtime, - adapter=AdapterSpec( - adapter_type="model-context-protocol", - module="llama_stack.providers.remote.tool_runtime.model_context_protocol", - config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", - pip_packages=["mcp>=1.8.1"], - provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", - description="Model Context Protocol (MCP) tool for standardized tool calling and context management.", - ), + adapter_type="model-context-protocol", + provider_type="remote::model-context-protocol", + module="llama_stack.providers.remote.tool_runtime.model_context_protocol", + config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", + pip_packages=["mcp>=1.8.1"], + provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", + description="Model Context Protocol (MCP) tool for standardized tool calling and context management.", ), ] diff --git a/llama_stack/providers/registry/vector_io.py b/llama_stack/providers/registry/vector_io.py index 511734d57..e8237bc62 100644 --- a/llama_stack/providers/registry/vector_io.py +++ b/llama_stack/providers/registry/vector_io.py @@ -6,11 +6,10 @@ from llama_stack.providers.datatypes import ( - AdapterSpec, Api, InlineProviderSpec, ProviderSpec, - remote_provider_spec, + RemoteProviderSpec, ) @@ -300,14 +299,16 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f Please refer to the sqlite-vec provider documentation. """, ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="chromadb", - pip_packages=["chromadb-client"], - module="llama_stack.providers.remote.vector_io.chroma", - config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", - description=""" + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="chromadb", + provider_type="remote::chromadb", + pip_packages=["chromadb-client"], + module="llama_stack.providers.remote.vector_io.chroma", + config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" [Chroma](https://www.trychroma.com/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. That means you're not limited to storing vectors in memory or in a separate service. @@ -340,9 +341,6 @@ pip install chromadb ## Documentation See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general. """, - ), - api_dependencies=[Api.inference], - optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -387,14 +385,16 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti """, ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="pgvector", - pip_packages=["psycopg2-binary"], - module="llama_stack.providers.remote.vector_io.pgvector", - config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", - description=""" + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="pgvector", + provider_type="remote::pgvector", + pip_packages=["psycopg2-binary"], + module="llama_stack.providers.remote.vector_io.pgvector", + config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" [PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It allows you to store and query vectors directly in memory. That means you'll get fast and efficient vector retrieval. @@ -495,19 +495,18 @@ docker pull pgvector/pgvector:pg17 ## Documentation See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general. """, - ), + ), + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="weaviate", + provider_type="remote::weaviate", + pip_packages=["weaviate-client"], + module="llama_stack.providers.remote.vector_io.weaviate", + config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", + provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", api_dependencies=[Api.inference], optional_api_dependencies=[Api.files], - ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="weaviate", - pip_packages=["weaviate-client"], - module="llama_stack.providers.remote.vector_io.weaviate", - config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", - provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", - description=""" + description=""" [Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack. It allows you to store and query vectors directly within a Weaviate database. That means you're not limited to storing vectors in memory or in a separate service. @@ -538,9 +537,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate ## Documentation See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. """, - ), - api_dependencies=[Api.inference], - optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, @@ -594,28 +590,29 @@ docker pull qdrant/qdrant See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general. """, ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="qdrant", - pip_packages=["qdrant-client"], - module="llama_stack.providers.remote.vector_io.qdrant", - config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", - description=""" -Please refer to the inline provider documentation. -""", - ), + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="qdrant", + provider_type="remote::qdrant", + pip_packages=["qdrant-client"], + module="llama_stack.providers.remote.vector_io.qdrant", + config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", api_dependencies=[Api.inference], optional_api_dependencies=[Api.files], + description=""" +Please refer to the inline provider documentation. +""", ), - remote_provider_spec( - Api.vector_io, - AdapterSpec( - adapter_type="milvus", - pip_packages=["pymilvus>=2.4.10"], - module="llama_stack.providers.remote.vector_io.milvus", - config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", - description=""" + RemoteProviderSpec( + api=Api.vector_io, + adapter_type="milvus", + provider_type="remote::milvus", + pip_packages=["pymilvus>=2.4.10"], + module="llama_stack.providers.remote.vector_io.milvus", + config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", + api_dependencies=[Api.inference], + optional_api_dependencies=[Api.files], + description=""" [Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It allows you to store and query vectors directly within a Milvus database. That means you're not limited to storing vectors in memory or in a separate service. @@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps: ## Installation -You can install Milvus using pymilvus: +If you want to use inline Milvus, you can install: + +```bash +pip install pymilvus[milvus-lite] +``` + +If you want to use remote Milvus, you can install: ```bash pip install pymilvus @@ -806,14 +809,11 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md). """, - ), - api_dependencies=[Api.inference], - optional_api_dependencies=[Api.files], ), InlineProviderSpec( api=Api.vector_io, provider_type="inline::milvus", - pip_packages=["pymilvus>=2.4.10"], + pip_packages=["pymilvus[milvus-lite]>=2.4.10"], module="llama_stack.providers.inline.vector_io.milvus", config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", api_dependencies=[Api.inference], diff --git a/llama_stack/providers/remote/files/s3/files.py b/llama_stack/providers/remote/files/s3/files.py index 54742d900..8ea96af9e 100644 --- a/llama_stack/providers/remote/files/s3/files.py +++ b/llama_stack/providers/remote/files/s3/files.py @@ -137,7 +137,7 @@ class S3FilesImpl(Files): where: dict[str, str | dict] = {"id": file_id} if not return_expired: where["expires_at"] = {">": self._now()} - if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): + if not (row := await self.sql_store.fetch_one("openai_files", where=where)): raise ResourceNotFoundError(file_id, "File", "files.list()") return row @@ -164,7 +164,7 @@ class S3FilesImpl(Files): self._client = _create_s3_client(self._config) await _create_bucket_if_not_exists(self._client, self._config) - self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) + self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy) await self._sql_store.create_table( "openai_files", { @@ -268,7 +268,6 @@ class S3FilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", - policy=self.policy, where=where_conditions, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, diff --git a/llama_stack/providers/utils/inference/inference_store.py b/llama_stack/providers/utils/inference/inference_store.py index 17f4c6268..ffc9f3e11 100644 --- a/llama_stack/providers/utils/inference/inference_store.py +++ b/llama_stack/providers/utils/inference/inference_store.py @@ -54,7 +54,7 @@ class InferenceStore: async def initialize(self): """Create the necessary tables if they don't exist.""" - self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy) await self.sql_store.create_table( "chat_completions", { @@ -202,7 +202,6 @@ class InferenceStore: order_by=[("created", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [ @@ -229,7 +228,6 @@ class InferenceStore: row = await self.sql_store.fetch_one( table="chat_completions", where={"id": completion_id}, - policy=self.policy, ) if not row: diff --git a/llama_stack/providers/utils/responses/responses_store.py b/llama_stack/providers/utils/responses/responses_store.py index 04778ed1c..829cd8a62 100644 --- a/llama_stack/providers/utils/responses/responses_store.py +++ b/llama_stack/providers/utils/responses/responses_store.py @@ -28,8 +28,7 @@ class ResponsesStore: sql_store_config = SqliteSqlStoreConfig( db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), ) - self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) - self.policy = policy + self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy) async def initialize(self): """Create the necessary tables if they don't exist.""" @@ -87,7 +86,6 @@ class ResponsesStore: order_by=[("created_at", order.value)], cursor=("id", after) if after else None, limit=limit, - policy=self.policy, ) data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] @@ -105,7 +103,6 @@ class ResponsesStore: row = await self.sql_store.fetch_one( "openai_responses", where={"id": response_id}, - policy=self.policy, ) if not row: @@ -116,7 +113,7 @@ class ResponsesStore: return OpenAIResponseObjectWithInput(**row["response_object"]) async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: - row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) + row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}) if not row: raise ValueError(f"Response with id {response_id} not found") await self.sql_store.delete("openai_responses", where={"id": response_id}) diff --git a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py index acb688f96..ab67f7052 100644 --- a/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py +++ b/llama_stack/providers/utils/sqlstore/authorized_sqlstore.py @@ -53,13 +53,15 @@ class AuthorizedSqlStore: access control policies, user attribute capture, and SQL filtering optimization. """ - def __init__(self, sql_store: SqlStore): + def __init__(self, sql_store: SqlStore, policy: list[AccessRule]): """ Initialize the authorization layer. :param sql_store: Base SqlStore implementation to wrap + :param policy: Access control policy to use for authorization """ self.sql_store = sql_store + self.policy = policy self._detect_database_type() self._validate_sql_optimized_policy() @@ -117,14 +119,13 @@ class AuthorizedSqlStore: async def fetch_all( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, limit: int | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, cursor: tuple[str, str] | None = None, ) -> PaginatedResponse: """Fetch all rows with automatic access control filtering.""" - access_where = self._build_access_control_where_clause(policy) + access_where = self._build_access_control_where_clause(self.policy) rows = await self.sql_store.fetch_all( table=table, where=where, @@ -146,7 +147,7 @@ class AuthorizedSqlStore: str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) ) - if is_action_allowed(policy, Action.READ, sql_record, current_user): + if is_action_allowed(self.policy, Action.READ, sql_record, current_user): filtered_rows.append(row) return PaginatedResponse( @@ -157,14 +158,12 @@ class AuthorizedSqlStore: async def fetch_one( self, table: str, - policy: list[AccessRule], where: Mapping[str, Any] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, ) -> dict[str, Any] | None: """Fetch one row with automatic access control checking.""" results = await self.fetch_all( table=table, - policy=policy, where=where, limit=1, order_by=order_by, diff --git a/tests/external/kaze.yaml b/tests/external/kaze.yaml index c61ac0e31..1b42f2e14 100644 --- a/tests/external/kaze.yaml +++ b/tests/external/kaze.yaml @@ -1,6 +1,5 @@ -adapter: - adapter_type: kaze - pip_packages: ["tests/external/llama-stack-provider-kaze"] - config_class: llama_stack_provider_kaze.config.KazeProviderConfig - module: llama_stack_provider_kaze +adapter_type: kaze +pip_packages: ["tests/external/llama-stack-provider-kaze"] +config_class: llama_stack_provider_kaze.config.KazeProviderConfig +module: llama_stack_provider_kaze optional_api_dependencies: [] diff --git a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py index 4b3bfb641..de1427bfd 100644 --- a/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py +++ b/tests/external/llama-stack-api-weather/src/llama_stack_api_weather/weather.py @@ -6,7 +6,7 @@ from typing import Protocol -from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec +from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec from llama_stack.schema_utils import webmethod @@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]: api=Api.weather, provider_type="remote::kaze", config_class="llama_stack_provider_kaze.KazeProviderConfig", - adapter=AdapterSpec( - adapter_type="kaze", - module="llama_stack_provider_kaze", - pip_packages=["llama_stack_provider_kaze"], - config_class="llama_stack_provider_kaze.KazeProviderConfig", - ), + adapter_type="kaze", + module="llama_stack_provider_kaze", + pip_packages=["llama_stack_provider_kaze"], ), ] diff --git a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py index 4002f2e1f..98bef0f2c 100644 --- a/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py +++ b/tests/integration/providers/utils/sqlstore/test_authorized_sqlstore.py @@ -57,7 +57,7 @@ def authorized_store(backend_config): config = config_func() base_sqlstore = sqlstore_impl(config) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) yield authorized_store @@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) # Test fetching with no user - should not error on JSON comparison - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" assert result.data[0]["access_attributes"] is None @@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) # Fetch all - admin should see both - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 2 # Test with non-admin user @@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz mock_get_authenticated_user.return_value = regular_user # Should only see public record - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) assert len(result.data) == 1 assert result.data[0]["id"] == "1" @@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz # Now test with the multi-user who has both roles=admin and teams=dev mock_get_authenticated_user.return_value = multi_user - result = await authorized_store.fetch_all(table_name, policy=default_policy()) + result = await authorized_store.fetch_all(table_name) # Should see: # - public record (1) - no access_attributes @@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto ), ] + # Create a new authorized store with the owner-only policy + owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy) + # Test user1 access - should only see their own record mock_get_authenticated_user.return_value = user1 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" # Test user2 access - should only see their own record mock_get_authenticated_user.return_value = user2 - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" # Test with anonymous user - should see no records mock_get_authenticated_user.return_value = None - result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) + result = await owner_only_store.fetch_all(table_name) assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" finally: diff --git a/tests/unit/distribution/test_distribution.py b/tests/unit/distribution/test_distribution.py index c6c2eb2c7..f24de0644 100644 --- a/tests/unit/distribution/test_distribution.py +++ b/tests/unit/distribution/test_distribution.py @@ -66,10 +66,9 @@ def base_config(tmp_path): def provider_spec_yaml(): """Common provider spec YAML for testing.""" return """ -adapter: - adapter_type: test_provider - config_class: test_provider.config.TestProviderConfig - module: test_provider +adapter_type: test_provider +config_class: test_provider.config.TestProviderConfig +module: test_provider api_dependencies: - safety """ @@ -182,9 +181,9 @@ class TestProviderRegistry: assert Api.inference in registry assert "remote::test_provider" in registry[Api.inference] provider = registry[Api.inference]["remote::test_provider"] - assert provider.adapter.adapter_type == "test_provider" - assert provider.adapter.module == "test_provider" - assert provider.adapter.config_class == "test_provider.config.TestProviderConfig" + assert provider.adapter_type == "test_provider" + assert provider.module == "test_provider" + assert provider.config_class == "test_provider.config.TestProviderConfig" assert Api.safety in provider.api_dependencies def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml): @@ -246,8 +245,7 @@ class TestProviderRegistry: """Test handling of malformed remote provider spec (missing required fields).""" remote_dir, _ = api_directories malformed_spec = """ -adapter: - adapter_type: test_provider +adapter_type: test_provider # Missing required fields api_dependencies: - safety @@ -270,7 +268,7 @@ pip_packages: with open(inline_dir / "malformed.yaml", "w") as f: f.write(malformed_spec) - with pytest.raises(KeyError) as exc_info: + with pytest.raises(ValidationError) as exc_info: get_provider_registry(base_config) assert "config_class" in str(exc_info.value) diff --git a/tests/unit/distribution/test_library_client_initialization.py b/tests/unit/distribution/test_library_client_initialization.py index b7e7a1857..b01a5c3e2 100644 --- a/tests/unit/distribution/test_library_client_initialization.py +++ b/tests/unit/distribution/test_library_client_initialization.py @@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = LlamaStackAsLibraryClient("ci-tests") @@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = AsyncLlamaStackAsLibraryClient("ci-tests") @@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = LlamaStackAsLibraryClient("ci-tests") @@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) client = AsyncLlamaStackAsLibraryClient("ci-tests") @@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization: mock_impls = {} mock_route_impls = RouteImpls({}) - async def mock_construct_stack(config, custom_provider_registry): - return mock_impls + class MockStack: + def __init__(self, config, custom_provider_registry=None): + self.impls = mock_impls + + async def initialize(self): + pass def mock_initialize_route_impls(impls): return mock_route_impls - monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) + monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) sync_client = LlamaStackAsLibraryClient("ci-tests") diff --git a/tests/unit/utils/test_authorized_sqlstore.py b/tests/unit/utils/test_authorized_sqlstore.py index 90eb706e4..d85e784a9 100644 --- a/tests/unit/utils/test_authorized_sqlstore.py +++ b/tests/unit/utils/test_authorized_sqlstore.py @@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) # Create table with access control await sqlstore.create_table( @@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic mock_get_authenticated_user.return_value = admin_user # Admin should see both documents - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 1 assert result.data[0]["title"] == "Admin Document" # User should only see their document mock_get_authenticated_user.return_value = regular_user - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) + result = await sqlstore.fetch_all("documents", where={"id": 1}) assert len(result.data) == 0 - result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) + result = await sqlstore.fetch_all("documents", where={"id": 2}) assert len(result.data) == 1 assert result.data[0]["title"] == "User Document" - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) + row = await sqlstore.fetch_one("documents", where={"id": 1}) assert row is None - row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) + row = await sqlstore.fetch_one("documents", where={"id": 2}) assert row is not None assert row["title"] == "User Document" @@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): db_path=tmp_dir + "/" + db_name, ) ) - sqlstore = AuthorizedSqlStore(base_sqlstore) + sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy()) await sqlstore.create_table( table="resources", @@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user): user = User(principal=user_data["principal"], attributes=user_data["attributes"]) mock_get_authenticated_user.return_value = user - sql_results = await sqlstore.fetch_all("resources", policy=policy) + sql_results = await sqlstore.fetch_all("resources") sql_ids = {row["id"] for row in sql_results.data} policy_ids = set() for scenario in test_scenarios: @@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us db_path=tmp_dir + "/" + db_name, ) ) - authorized_store = AuthorizedSqlStore(base_sqlstore) + authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy()) await authorized_store.create_table( table="user_data",