Merge branch 'main' into use-openai-for-cerebras

This commit is contained in:
Matthew Farrellee 2025-09-23 07:31:11 -04:00
commit 9ceb45f611
53 changed files with 2612 additions and 1966 deletions

View file

@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0 uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv - name: Install uv
uses: astral-sh/setup-uv@557e51de59eb14aaaba2ed9621916900a91d50c6 # v6.6.1 uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
activate-environment: true activate-environment: true

View file

@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export MOCK_INFERENCE_MODEL=mock-inference
export MOCK_INFERENCE_URL=openai-mock-service:8080
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
export LLAMA_STACK_WORKERS=4
set -euo pipefail set -euo pipefail
set -x set -x

View file

@ -5,6 +5,7 @@ data:
image_name: kubernetes-benchmark-demo image_name: kubernetes-benchmark-demo
apis: apis:
- agents - agents
- files
- inference - inference
- files - files
- safety - safety
@ -23,6 +24,14 @@ data:
- provider_id: sentence-transformers - provider_id: sentence-transformers
provider_type: inline::sentence-transformers provider_type: inline::sentence-transformers
config: {} config: {}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
vector_io: vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb} - provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb provider_type: remote::chromadb

View file

@ -52,9 +52,20 @@ spec:
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1 value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
- name: VLLM_TLS_VERIFY - name: VLLM_TLS_VERIFY
value: "false" value: "false"
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"] - name: LLAMA_STACK_LOGGING
value: "all=WARNING"
- name: LLAMA_STACK_CONFIG
value: "/etc/config/stack_run_config.yaml"
- name: LLAMA_STACK_WORKERS
value: "${LLAMA_STACK_WORKERS}"
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
ports: ports:
- containerPort: 8323 - containerPort: 8323
resources:
requests:
cpu: "${LLAMA_STACK_WORKERS}"
limits:
cpu: "${LLAMA_STACK_WORKERS}"
volumeMounts: volumeMounts:
- name: llama-storage - name: llama-storage
mountPath: /root/.llama mountPath: /root/.llama

View file

@ -11,6 +11,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| agents | `inline::meta-reference` | | agents | `inline::meta-reference` |
| datasetio | `inline::localfs`, `remote::nvidia` | | datasetio | `inline::localfs`, `remote::nvidia` |
| eval | `remote::nvidia` | | eval | `remote::nvidia` |
| files | `inline::localfs` |
| inference | `remote::nvidia` | | inference | `remote::nvidia` |
| post_training | `remote::nvidia` | | post_training | `remote::nvidia` |
| safety | `remote::nvidia` | | safety | `remote::nvidia` |

View file

@ -460,10 +460,12 @@ client = LlamaStackClient(base_url="http://localhost:8321")
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding") embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
embedding_model = embed_lm.identifier embedding_model = embed_lm.identifier
vector_db_id = f"v{uuid.uuid4().hex}" vector_db_id = f"v{uuid.uuid4().hex}"
client.vector_dbs.register( # The VectorDB API is deprecated; the server now returns its own authoritative ID.
# We capture the correct ID from the response's .identifier attribute.
vector_db_id = client.vector_dbs.register(
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model=embedding_model, embedding_model=embedding_model,
) ).identifier
# Create Documents # Create Documents
urls = [ urls = [

View file

@ -23,7 +23,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation ## Installation
You can install Milvus using pymilvus: If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash ```bash
pip install pymilvus pip install pymilvus

View file

@ -147,7 +147,7 @@ WORKDIR /app
RUN dnf -y update && dnf install -y iputils git net-tools wget \ RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.12 python3.12-pip python3.12-wheel \ vim-minimal python3.12 python3.12-pip python3.12-wheel \
python3.12-setuptools python3.12-devel gcc make && \ python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1
@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \
procps psmisc lsof \ procps psmisc lsof \
traceroute \ traceroute \
bubblewrap \ bubblewrap \
gcc \ gcc g++ \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1

View file

@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
default=None, default=None,
) )
@property
def pip_packages(self) -> list[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields # Example: /models, /shields
class RoutingTableProviderSpec(ProviderSpec): class RoutingTableProviderSpec(ProviderSpec):

View file

@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger from llama_stack.log import get_logger
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
logger = get_logger(name=__name__, category="core") logger = get_logger(name=__name__, category="core")
@ -77,27 +76,12 @@ def providable_apis() -> list[Api]:
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec: def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"]) spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
return spec return spec
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec: def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec( spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
return spec return spec

View file

@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
from llama_stack.core.resolver import ProviderRegistry from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import ( from llama_stack.core.stack import (
construct_stack, Stack,
get_stack_run_config_from_distro, get_stack_run_config_from_distro,
replace_env_vars, replace_env_vars,
) )
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
try: try:
self.route_impls = None self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
stack = Stack(self.config, self.custom_provider_registry)
await stack.initialize()
self.impls = stack.impls
except ModuleNotFoundError as _e: except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr) cprint(_e.msg, color="red", file=sys.stderr)
cprint( cprint(
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
) )
raise _e raise _e
assert self.impls is not None
if Api.telemetry in self.impls: if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry]) setup_logger(self.impls[Api.telemetry])

View file

@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try: try:
models = await provider.list_models() models = await provider.list_models()
except Exception as e: except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}") logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
continue continue
self.listed_providers.add(provider_id) self.listed_providers.add(provider_id)

View file

@ -6,6 +6,7 @@
import argparse import argparse
import asyncio import asyncio
import concurrent.futures
import functools import functools
import inspect import inspect
import json import json
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
request_provider_data_context, request_provider_data_context,
user_from_scope, user_from_scope,
) )
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.server.routes import ( from llama_stack.core.server.routes import (
find_matching_route, find_matching_route,
get_all_api_routes, get_all_api_routes,
initialize_route_impls, initialize_route_impls,
) )
from llama_stack.core.stack import ( from llama_stack.core.stack import (
Stack,
cast_image_name_to_string, cast_image_name_to_string,
construct_stack,
replace_env_vars, replace_env_vars,
shutdown_stack,
validate_env_pair, validate_env_pair,
) )
from llama_stack.core.utils.config import redact_sensitive_fields from llama_stack.core.utils.config import redact_sensitive_fields
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
) )
async def shutdown(app): class StackApp(FastAPI):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
""" """
await shutdown_stack(app.__llama_stack_impls__) A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
"""
def __init__(self, config: StackRunConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stack: Stack = Stack(config)
# This code is called from a running event loop managed by uvicorn so we cannot simply call
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
# function.
# As a workaround, we use a thread pool executor to run the initialize() method
# in a separate thread.
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.stack.initialize())
future.result()
@asynccontextmanager @asynccontextmanager
async def lifespan(app: FastAPI): async def lifespan(app: StackApp):
logger.info("Starting up") logger.info("Starting up")
assert app.stack is not None
app.stack.create_registry_refresh_task()
yield yield
logger.info("Shutting down") logger.info("Shutting down")
await shutdown(app) await app.stack.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs): def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send) return await self.app(scope, receive, send)
def main(args: argparse.Namespace | None = None): def create_app(
"""Start the LlamaStack server.""" config_file: str | None = None,
parser = argparse.ArgumentParser(description="Start the LlamaStack server.") env_vars: list[str] | None = None,
) -> StackApp:
"""Create and configure the FastAPI application.
add_config_distro_args(parser) Args:
parser.add_argument( config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
"--port", env_vars: List of environment variables in KEY=value format.
type=int, disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case Returns:
# the args will be passed as a Namespace object to the main function, otherwise they will be Configured StackApp instance.
# parsed from the command line """
if args is None: config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
args = parser.parse_args() if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
config_or_distro = get_config_from_args(args) config_file = resolve_config_or_distro(config_file, Mode.RUN)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
# Load and process configuration
logger_config = None logger_config = None
with open(config_file) as fp: with open(config_file) as fp:
config_contents = yaml.safe_load(fp) config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")): if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg) logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config) logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env: if env_vars:
for env_pair in env_vars:
try: try:
key, value = validate_env_pair(env_pair) key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}") logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value os.environ[key] = value
except ValueError as e: except ValueError as e:
logger.error(f"Error: {str(e)}") logger.error(f"Error: {str(e)}")
sys.exit(1) raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents) config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config)) config = StackRunConfig(**cast_image_name_to_string(config))
_log_run_config(run_config=config) _log_run_config(run_config=config)
app = FastAPI( app = StackApp(
lifespan=lifespan, lifespan=lifespan,
docs_url="/docs", docs_url="/docs",
redoc_url="/redoc", redoc_url="/redoc",
openapi_url="/openapi.json", openapi_url="/openapi.json",
config=config,
) )
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware) app.add_middleware(ClientVersionMiddleware)
try: impls = app.stack.impls
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
if config.server.auth: if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}") logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(RequestValidationError)(global_exception_handler) app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler) app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis) app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn import uvicorn
# Configure SSL if certificates are provided # Configure SSL if certificates are provided
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
if ssl_config: if ssl_config:
uvicorn_config.update(ssl_config) uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
# We need to catch KeyboardInterrupt because uvicorn's signal handling # We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python # re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing # converts to KeyboardInterrupt. Without this catch, we'd get a confusing
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own # Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort. # signal handling but this is quite intrusive and not worth the effort.
try: try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit): except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...") logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig): def _log_run_config(run_config: StackRunConfig):

View file

@ -315,11 +315,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.prompts] = prompts_impl impls[Api.prompts] = prompts_impl
# Produces a stack of providers for the given run config. Not all APIs may be class Stack:
# asked for in the run config. def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
async def construct_stack( self.run_config = run_config
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None self.provider_registry = provider_registry
) -> dict[Api, Any]: self.impls = None
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ: if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording from llama_stack.testing.inference_recorder import setup_inference_recording
@ -329,24 +333,28 @@ async def construct_stack(
TEST_RECORDING_CONTEXT.__enter__() TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}") logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name) dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else [] policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls( impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
) )
# Add internal implementations after all other providers are resolved # Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config) add_internal_implementations(impls, self.run_config)
if Api.prompts in impls: if Api.prompts in impls:
await impls[Api.prompts].initialize() await impls[Api.prompts].initialize()
await register_resources(run_config, impls) await register_resources(self.run_config, impls)
await refresh_registry_once(impls) await refresh_registry_once(impls)
self.impls = impls
def create_registry_refresh_task(self):
assert self.impls is not None, "Must call initialize() before starting"
global REGISTRY_REFRESH_TASK global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls)) REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
def cb(task): def cb(task):
import traceback import traceback
@ -360,11 +368,9 @@ async def construct_stack(
logger.debug("Model refresh task completed") logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb) REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown(self):
async def shutdown_stack(impls: dict[Api, Any]): for impl in self.impls.values():
for impl in impls.values():
impl_name = impl.__class__.__name__ impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}") logger.info(f"Shutting down {impl_name}")
try: try:

View file

@ -23,6 +23,8 @@ distribution_spec:
- provider_type: inline::basic - provider_type: inline::basic
tool_runtime: tool_runtime:
- provider_type: inline::rag-runtime - provider_type: inline::rag-runtime
files:
- provider_type: inline::localfs
image_type: venv image_type: venv
additional_pip_packages: additional_pip_packages:
- aiosqlite - aiosqlite

View file

@ -8,6 +8,7 @@ from pathlib import Path
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
@ -15,7 +16,7 @@ from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
def get_distribution_template() -> DistributionTemplate: def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
providers = { providers = {
"inference": [BuildProvider(provider_type="remote::nvidia")], "inference": [BuildProvider(provider_type="remote::nvidia")],
"vector_io": [BuildProvider(provider_type="inline::faiss")], "vector_io": [BuildProvider(provider_type="inline::faiss")],
@ -30,6 +31,7 @@ def get_distribution_template() -> DistributionTemplate:
], ],
"scoring": [BuildProvider(provider_type="inline::basic")], "scoring": [BuildProvider(provider_type="inline::basic")],
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")], "tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
"files": [BuildProvider(provider_type="inline::localfs")],
} }
inference_provider = Provider( inference_provider = Provider(
@ -52,6 +54,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia", provider_type="remote::nvidia",
config=NVIDIAEvalConfig.sample_run_config(), config=NVIDIAEvalConfig.sample_run_config(),
) )
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput( inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}", model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia", provider_id="nvidia",
@ -73,7 +80,7 @@ def get_distribution_template() -> DistributionTemplate:
default_models, _ = get_model_registry(available_models) default_models, _ = get_model_registry(available_models)
return DistributionTemplate( return DistributionTemplate(
name="nvidia", name=name,
distro_type="self_hosted", distro_type="self_hosted",
description="Use NVIDIA NIM for running LLM inference, evaluation and safety", description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
container_image=None, container_image=None,
@ -86,6 +93,7 @@ def get_distribution_template() -> DistributionTemplate:
"inference": [inference_provider], "inference": [inference_provider],
"datasetio": [datasetio_provider], "datasetio": [datasetio_provider],
"eval": [eval_provider], "eval": [eval_provider],
"files": [files_provider],
}, },
default_models=default_models, default_models=default_models,
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
@ -97,6 +105,7 @@ def get_distribution_template() -> DistributionTemplate:
safety_provider, safety_provider,
], ],
"eval": [eval_provider], "eval": [eval_provider],
"files": [files_provider],
}, },
default_models=[inference_model, safety_model], default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],

View file

@ -4,6 +4,7 @@ apis:
- agents - agents
- datasetio - datasetio
- eval - eval
- files
- inference - inference
- post_training - post_training
- safety - safety
@ -88,6 +89,14 @@ providers:
tool_runtime: tool_runtime:
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db

View file

@ -4,6 +4,7 @@ apis:
- agents - agents
- datasetio - datasetio
- eval - eval
- files
- inference - inference
- post_training - post_training
- safety - safety
@ -77,6 +78,14 @@ providers:
tool_runtime: tool_runtime:
- provider_id: rag-runtime - provider_id: rag-runtime
provider_type: inline::rag-runtime provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
metadata_store: metadata_store:
type: sqlite type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db

View file

@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
remote_providers = [ remote_providers = [
provider provider
for provider in available_providers() for provider in available_providers()
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
] ]
inference_providers = [] inference_providers = []
for provider_spec in remote_providers: for provider_spec in remote_providers:
provider_type = provider_spec.adapter.adapter_type provider_type = provider_spec.adapter_type
if provider_type in INFERENCE_PROVIDER_IDS: if provider_type in INFERENCE_PROVIDER_IDS:
provider_id = INFERENCE_PROVIDER_IDS[provider_type] provider_id = INFERENCE_PROVIDER_IDS[provider_type]

View file

@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
""", """,
) )
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
provider_data_validator: str | None = Field(
default=None,
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.") is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now # used internally by the resolver; this is a hack for now
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
async def get_provider_impl(self, routing_key: str) -> Any: ... async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: str = Field(
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
A description of the provider. This is used to display in the documentation.
""",
)
@json_schema_type @json_schema_type
class InlineProviderSpec(ProviderSpec): class InlineProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
container_image: str | None = Field( container_image: str | None = Field(
default=None, default=None,
description=""" description="""
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image. If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""", """,
) )
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field( description: str | None = Field(
default=None, default=None,
description=""" description="""
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type @json_schema_type
class RemoteProviderSpec(ProviderSpec): class RemoteProviderSpec(ProviderSpec):
adapter: AdapterSpec = Field( adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
description: str | None = Field(
default=None,
description=""" description="""
If some code is needed to convert the remote responses into Llama Stack compatible A description of the provider. This is used to display in the documentation.
API responses, specify the adapter here.
""", """,
) )
@ -234,33 +207,6 @@ API responses, specify the adapter here.
def container_image(self) -> str | None: def container_image(self) -> str | None:
return None return None
# module field is inherited from ProviderSpec
@property
def pip_packages(self) -> list[str]:
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> str | None:
return self.adapter.provider_data_validator
def remote_provider_spec(
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)
class HealthStatus(StrEnum): class HealthStatus(StrEnum):
OK = "OK" OK = "OK"

View file

@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True) storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata # Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store: if not self.sql_store:
raise RuntimeError("Files provider not initialized") raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row: if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()") raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None, where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -25,10 +24,10 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[], api_dependencies=[],
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.", description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.datasetio, api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="huggingface", adapter_type="huggingface",
provider_type="remote::huggingface",
pip_packages=[ pip_packages=[
"datasets>=4.0.0", "datasets>=4.0.0",
], ],
@ -36,17 +35,15 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig", config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.", description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.datasetio, api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
provider_type="remote::nvidia",
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
pip_packages=[ pip_packages=[
"datasets>=4.0.0", "datasets>=4.0.0",
], ],
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.", description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
), ),
),
] ]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def available_providers() -> list[ProviderSpec]: def available_providers() -> list[ProviderSpec]:
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
], ],
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.", description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.eval, api=Api.eval,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
pip_packages=[ pip_packages=[
"requests", "requests",
], ],
provider_type="remote::nvidia",
module="llama_stack.providers.remote.eval.nvidia", module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.", description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
),
api_dependencies=[ api_dependencies=[
Api.datasetio, Api.datasetio,
Api.datasets, Api.datasets,

View file

@ -4,13 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig", config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.", description="Local filesystem-based file storage provider for managing files and documents locally.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.files, api=Api.files,
adapter=AdapterSpec( provider_type="remote::s3",
adapter_type="s3", adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages, pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3", module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig", config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.", description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
), ),
),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
META_REFERENCE_DEPS = [ META_REFERENCE_DEPS = [
@ -49,10 +48,10 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig", config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
description="Sentence Transformers inference provider for text embeddings and similarity search.", description="Sentence Transformers inference provider for text embeddings and similarity search.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras", adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[ pip_packages=[
"cerebras_cloud_sdk", "cerebras_cloud_sdk",
], ],
@ -60,62 +59,56 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig", config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.", description="Cerebras inference provider for running models on Cerebras Cloud platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama", adapter_type="ollama",
provider_type="remote::ollama",
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"], pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig", config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama", module="llama_stack.providers.remote.inference.ollama",
description="Ollama inference provider for running local models through the Ollama runtime.", description="Ollama inference provider for running local models through the Ollama runtime.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="vllm", adapter_type="vllm",
provider_type="remote::vllm",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.vllm", module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig", config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
description="Remote vLLM inference provider for connecting to vLLM servers.", description="Remote vLLM inference provider for connecting to vLLM servers.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="tgi", adapter_type="tgi",
provider_type="remote::tgi",
pip_packages=["huggingface_hub", "aiohttp"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi", module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig", config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
description="Text Generation Inference (TGI) provider for HuggingFace model serving.", description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="hf::serverless", adapter_type="hf::serverless",
provider_type="remote::hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi", module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig", config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
description="HuggingFace Inference API serverless provider for on-demand model inference.", description="HuggingFace Inference API serverless provider for on-demand model inference.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( provider_type="remote::hf::endpoint",
adapter_type="hf::endpoint", adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"], pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi", module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig", config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
description="HuggingFace Inference Endpoints provider for dedicated model serving.", description="HuggingFace Inference Endpoints provider for dedicated model serving.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="fireworks", adapter_type="fireworks",
provider_type="remote::fireworks",
pip_packages=[ pip_packages=[
"fireworks-ai<=0.17.16", "fireworks-ai<=0.17.16",
], ],
@ -124,11 +117,10 @@ def available_providers() -> list[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.", description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="together", adapter_type="together",
provider_type="remote::together",
pip_packages=[ pip_packages=[
"together", "together",
], ],
@ -137,85 +129,82 @@ def available_providers() -> list[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
description="Together AI inference provider for open-source models and collaborative AI development.", description="Together AI inference provider for open-source models and collaborative AI development.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="bedrock", adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"], pip_packages=["boto3"],
module="llama_stack.providers.remote.inference.bedrock", module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig", config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.", description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="databricks", adapter_type="databricks",
provider_type="remote::databricks",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.databricks", module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig", config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.", description="Databricks inference provider for running models on Databricks' unified analytics platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia", module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig", config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.", description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="runpod", adapter_type="runpod",
provider_type="remote::runpod",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.runpod", module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig", config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.", description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="openai", adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai", module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig", config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
description="OpenAI inference provider for accessing GPT models and other OpenAI services.", description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="anthropic", adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic", module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig", config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.", description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="gemini", adapter_type="gemini",
pip_packages=["litellm"], provider_type="remote::gemini",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.gemini", module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig", config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.", description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="vertexai", adapter_type="vertexai",
pip_packages=["litellm", "google-cloud-aiplatform"], provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai", module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig", config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
@ -240,65 +229,63 @@ Available Models:
- vertex_ai/gemini-2.5-flash - vertex_ai/gemini-2.5-flash
- vertex_ai/gemini-2.5-pro""", - vertex_ai/gemini-2.5-pro""",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq", adapter_type="groq",
pip_packages=["litellm"], provider_type="remote::groq",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.groq", module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig", config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.", description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="llama-openai-compat", adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat", module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig", config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.", description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova", adapter_type="sambanova",
pip_packages=["litellm"], provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.sambanova", module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig", config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.", description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="passthrough", adapter_type="passthrough",
provider_type="remote::passthrough",
pip_packages=[], pip_packages=[],
module="llama_stack.providers.remote.inference.passthrough", module="llama_stack.providers.remote.inference.passthrough",
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig", config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
description="Passthrough inference provider for connecting to any external inference service not directly supported.", description="Passthrough inference provider for connecting to any external inference service not directly supported.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec(
adapter_type="watsonx", adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"], pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx", module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.", description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( provider_type="remote::azure",
adapter_type="azure", adapter_type="azure",
pip_packages=["litellm"], pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.azure", module="llama_stack.providers.remote.inference.azure",
@ -310,5 +297,4 @@ Provider documentation
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
""", """,
), ),
),
] ]

View file

@ -7,7 +7,7 @@
from typing import cast from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
# We provide two versions of these providers so that distributions can package the appropriate version of torch. # We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images. # The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
], ],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.", description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.post_training, api=Api.post_training,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests", "aiohttp"], pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia", module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig", config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.", description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
), ),
),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig", config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.", description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec(
adapter_type="bedrock", adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"], pip_packages=["boto3"],
module="llama_stack.providers.remote.safety.bedrock", module="llama_stack.providers.remote.safety.bedrock",
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig", config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
description="AWS Bedrock safety provider for content moderation using AWS's safety services.", description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia", adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests"], pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia", module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig", config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
description="NVIDIA's safety provider for content moderation and safety filtering.", description="NVIDIA's safety provider for content moderation and safety filtering.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.safety, api=Api.safety,
adapter=AdapterSpec(
adapter_type="sambanova", adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=["litellm", "requests"], pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova", module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig", config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator", provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova's safety provider for content moderation and safety filtering.", description="SambaNova's safety provider for content moderation and safety filtering.",
), ),
),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[Api.vector_io, Api.inference, Api.files], api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.", description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
), ),
remote_provider_spec( RemoteProviderSpec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="brave-search", adapter_type="brave-search",
provider_type="remote::brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search", module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig", config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
description="Brave Search tool for web search capabilities with privacy-focused results.", description="Brave Search tool for web search capabilities with privacy-focused results.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="bing-search", adapter_type="bing-search",
provider_type="remote::bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search", module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig", config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
description="Bing Search tool for web search capabilities using Microsoft's search engine.", description="Bing Search tool for web search capabilities using Microsoft's search engine.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="tavily-search", adapter_type="tavily-search",
provider_type="remote::tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search", module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig", config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
description="Tavily Search tool for AI-optimized web search with structured results.", description="Tavily Search tool for AI-optimized web search with structured results.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="wolfram-alpha", adapter_type="wolfram-alpha",
provider_type="remote::wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha", module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig", config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"], pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.", description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
), ),
), RemoteProviderSpec(
remote_provider_spec(
api=Api.tool_runtime, api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="model-context-protocol", adapter_type="model-context-protocol",
provider_type="remote::model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol", module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig", config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp>=1.8.1"], pip_packages=["mcp>=1.8.1"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator", provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.", description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
), ),
),
] ]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import ( from llama_stack.providers.datatypes import (
AdapterSpec,
Api, Api,
InlineProviderSpec, InlineProviderSpec,
ProviderSpec, ProviderSpec,
remote_provider_spec, RemoteProviderSpec,
) )
@ -300,13 +299,15 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
Please refer to the sqlite-vec provider documentation. Please refer to the sqlite-vec provider documentation.
""", """,
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec(
adapter_type="chromadb", adapter_type="chromadb",
provider_type="remote::chromadb",
pip_packages=["chromadb-client"], pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma", module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector [Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database. database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
@ -341,9 +342,6 @@ pip install chromadb
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general. See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
""", """,
), ),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
provider_type="inline::chromadb", provider_type="inline::chromadb",
@ -387,13 +385,15 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
""", """,
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec(
adapter_type="pgvector", adapter_type="pgvector",
provider_type="remote::pgvector",
pip_packages=["psycopg2-binary"], pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector", module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It [PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory. allows you to store and query vectors directly in memory.
@ -496,17 +496,16 @@ docker pull pgvector/pgvector:pg17
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general. See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
""", """,
), ),
api_dependencies=[Api.inference], RemoteProviderSpec(
optional_api_dependencies=[Api.files], api=Api.vector_io,
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="weaviate", adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client"], pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate", module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData", provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack. [Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database. It allows you to store and query vectors directly within a Weaviate database.
@ -539,9 +538,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general. See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
""", """,
), ),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
provider_type="inline::qdrant", provider_type="inline::qdrant",
@ -594,27 +590,28 @@ docker pull qdrant/qdrant
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general. See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
""", """,
), ),
remote_provider_spec( RemoteProviderSpec(
Api.vector_io, api=Api.vector_io,
AdapterSpec(
adapter_type="qdrant", adapter_type="qdrant",
provider_type="remote::qdrant",
pip_packages=["qdrant-client"], pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant", module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
Please refer to the inline provider documentation. Please refer to the inline provider documentation.
""", """,
), ),
api_dependencies=[Api.inference], RemoteProviderSpec(
optional_api_dependencies=[Api.files], api=Api.vector_io,
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
adapter_type="milvus", adapter_type="milvus",
provider_type="remote::milvus",
pip_packages=["pymilvus>=2.4.10"], pip_packages=["pymilvus>=2.4.10"],
module="llama_stack.providers.remote.vector_io.milvus", module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig", config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description=""" description="""
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It [Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database. allows you to store and query vectors directly within a Milvus database.
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation ## Installation
You can install Milvus using pymilvus: If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash ```bash
pip install pymilvus pip install pymilvus
@ -807,13 +810,10 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md). For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
""", """,
), ),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec( InlineProviderSpec(
api=Api.vector_io, api=Api.vector_io,
provider_type="inline::milvus", provider_type="inline::milvus",
pip_packages=["pymilvus>=2.4.10"], pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
module="llama_stack.providers.inline.vector_io.milvus", module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig", config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference], api_dependencies=[Api.inference],

View file

@ -137,7 +137,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id} where: dict[str, str | dict] = {"id": file_id}
if not return_expired: if not return_expired:
where["expires_at"] = {">": self._now()} where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)): if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()") raise ResourceNotFoundError(file_id, "File", "files.list()")
return row return row
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config) self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config) await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store)) self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table( await self._sql_store.create_table(
"openai_files", "openai_files",
{ {
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all( paginated_result = await self.sql_store.fetch_all(
table="openai_files", table="openai_files",
policy=self.policy,
where=where_conditions, where=where_conditions,
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel
from .config import AnthropicConfig from .config import AnthropicConfig
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: str | None = None
async def get_adapter_impl(config: AnthropicConfig, _deps): async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter from .anthropic import AnthropicInferenceAdapter

View file

@ -4,11 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator
from typing import Any
from fireworks.client import Fireworks from fireworks.client import Fireworks
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
InterleavedContent, InterleavedContent,
@ -24,12 +22,6 @@ from llama_stack.apis.inference import (
Inference, Inference,
LogProbConfig, LogProbConfig,
Message, Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
ResponseFormatType, ResponseFormatType,
SamplingParams, SamplingParams,
@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper, ModelRegistryHelper,
) )
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict, convert_message_to_openai_dict,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
process_completion_stream_response, process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
@ -68,7 +59,7 @@ from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::fireworks") logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData): class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None: def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models) ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config self.config = config
@ -79,7 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def _get_api_key(self) -> str: def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key: if config_api_key:
return config_api_key return config_api_key
@ -91,15 +82,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
) )
return provider_data.fireworks_api_key return provider_data.fireworks_api_key
def _get_base_url(self) -> str: def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1" return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks: def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key() fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key) return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> AsyncOpenAI: def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key()) """Remove BOS token as Fireworks automatically prepends it"""
if prompt.startswith("<|begin_of_text|>"):
return prompt[len("<|begin_of_text|>") :]
return prompt
async def completion( async def completion(
self, self,
@ -285,153 +279,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
embeddings = [data.embedding for data in response.data] embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings) return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel
from .config import GeminiConfig from .config import GeminiConfig
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: str | None = None
async def get_adapter_impl(config: GeminiConfig, _deps): async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter from .gemini import GeminiInferenceAdapter

View file

@ -7,12 +7,10 @@
import asyncio import asyncio
import base64 import base64
import uuid
from collections.abc import AsyncGenerator, AsyncIterator from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any from typing import Any
from ollama import AsyncClient # type: ignore[attr-defined] from ollama import AsyncClient as AsyncOllamaClient
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import ( from llama_stack.apis.common.content_types import (
ImageContentItem, ImageContentItem,
@ -37,9 +35,6 @@ from llama_stack.apis.inference import (
Message, Message,
OpenAIChatCompletion, OpenAIChatCompletion,
OpenAIChatCompletionChunk, OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam, OpenAIMessageParam,
OpenAIResponseFormatParam, OpenAIResponseFormatParam,
ResponseFormat, ResponseFormat,
@ -64,15 +59,14 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import ( from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice, OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse, OpenAICompatCompletionResponse,
b64_encode_openai_embeddings_response,
get_sampling_options, get_sampling_options,
prepare_openai_completion_params, prepare_openai_completion_params,
prepare_openai_embeddings_params,
process_chat_completion_response, process_chat_completion_response,
process_chat_completion_stream_response, process_chat_completion_stream_response,
process_completion_response, process_completion_response,
process_completion_stream_response, process_completion_stream_response,
) )
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
@ -89,6 +83,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter( class OllamaInferenceAdapter(
OpenAIMixin,
InferenceProvider, InferenceProvider,
ModelsProtocolPrivate, ModelsProtocolPrivate,
): ):
@ -98,23 +93,21 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None: def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config self.config = config
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {} self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
self._openai_client = None
@property @property
def client(self) -> AsyncClient: def ollama_client(self) -> AsyncOllamaClient:
# ollama client attaches itself to the current event loop (sadly?) # ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop() loop = asyncio.get_running_loop()
if loop not in self._clients: if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url) self._clients[loop] = AsyncOllamaClient(host=self.config.url)
return self._clients[loop] return self._clients[loop]
@property def get_api_key(self):
def openai_client(self) -> AsyncOpenAI: return "NO_KEY"
if self._openai_client is None:
url = self.config.url.rstrip("/") def get_base_url(self):
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama") return self.config.url.rstrip("/") + "/v1"
return self._openai_client
async def initialize(self) -> None: async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...") logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
@ -129,7 +122,7 @@ class OllamaInferenceAdapter(
async def list_models(self) -> list[Model] | None: async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__ provider_id = self.__provider_id__
response = await self.client.list() response = await self.ollama_client.list()
# always add the two embedding models which can be pulled on demand # always add the two embedding models which can be pulled on demand
models = [ models = [
@ -189,7 +182,7 @@ class OllamaInferenceAdapter(
HealthResponse: A dictionary containing the health status. HealthResponse: A dictionary containing the health status.
""" """
try: try:
await self.client.ps() await self.ollama_client.ps()
return HealthResponse(status=HealthStatus.OK) return HealthResponse(status=HealthStatus.OK)
except Exception as e: except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -238,7 +231,7 @@ class OllamaInferenceAdapter(
params = await self._get_params(request) params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params) s = await self.ollama_client.generate(**params)
async for chunk in s: async for chunk in s:
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None, finish_reason=chunk["done_reason"] if chunk["done"] else None,
@ -254,7 +247,7 @@ class OllamaInferenceAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse: async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
r = await self.client.generate(**params) r = await self.ollama_client.generate(**params)
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None, finish_reason=r["done_reason"] if r["done"] else None,
@ -346,9 +339,9 @@ class OllamaInferenceAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request) params = await self._get_params(request)
if "messages" in params: if "messages" in params:
r = await self.client.chat(**params) r = await self.ollama_client.chat(**params)
else: else:
r = await self.client.generate(**params) r = await self.ollama_client.generate(**params)
if "message" in r: if "message" in r:
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
@ -372,9 +365,9 @@ class OllamaInferenceAdapter(
async def _generate_and_convert_to_openai_compat(): async def _generate_and_convert_to_openai_compat():
if "messages" in params: if "messages" in params:
s = await self.client.chat(**params) s = await self.ollama_client.chat(**params)
else: else:
s = await self.client.generate(**params) s = await self.ollama_client.generate(**params)
async for chunk in s: async for chunk in s:
if "message" in chunk: if "message" in chunk:
choice = OpenAICompatCompletionChoice( choice = OpenAICompatCompletionChoice(
@ -407,7 +400,7 @@ class OllamaInferenceAdapter(
assert all(not content_has_media(content) for content in contents), ( assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings" "Ollama does not support media for embeddings"
) )
response = await self.client.embed( response = await self.ollama_client.embed(
model=model.provider_resource_id, model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents], input=[interleaved_content_as_str(content) for content in contents],
) )
@ -422,14 +415,14 @@ class OllamaInferenceAdapter(
pass # Ignore statically unknown model, will check live listing pass # Ignore statically unknown model, will check live listing
if model.model_type == ModelType.embedding: if model.model_type == ModelType.embedding:
response = await self.client.list() response = await self.ollama_client.list()
if model.provider_resource_id not in [m.model for m in response.models]: if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id) await self.ollama_client.pull(model.provider_resource_id)
# we use list() here instead of ps() - # we use list() here instead of ps() -
# - ps() only lists running models, not available models # - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed # - models not currently running are run by the ollama server as needed
response = await self.client.list() response = await self.ollama_client.list()
available_models = [m.model for m in response.models] available_models = [m.model for m in response.models]
provider_resource_id = model.provider_resource_id provider_resource_id = model.provider_resource_id
@ -448,90 +441,6 @@ class OllamaInferenceAdapter(
return model return model
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
params = prepare_openai_embeddings_params(
model=model_obj.provider_resource_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
response = await self.openai_client.embeddings.create(**params)
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
# TODO: Investigate why model_obj.identifier is used instead of response.model
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.identifier,
usage=usage,
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion( async def openai_chat_completion(
self, self,
model: str, model: str,
@ -599,25 +508,7 @@ class OllamaInferenceAdapter(
top_p=top_p, top_p=top_p,
user=user, user=user,
) )
response = await self.openai_client.chat.completions.create(**params) return await OpenAIMixin.openai_chat_completion(self, **params)
return await self._adjust_ollama_chat_completion_response_ids(response)
async def _adjust_ollama_chat_completion_response_ids(
self,
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
id = f"chatcmpl-{uuid.uuid4()}"
if isinstance(response, AsyncIterator):
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
async for chunk in response:
chunk.id = id
yield chunk
return stream_with_chunk_ids()
else:
response.id = id
return response
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]: async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from pydantic import BaseModel
from .config import OpenAIConfig from .config import OpenAIConfig
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: str | None = None
async def get_adapter_impl(config: OpenAIConfig, _deps): async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter from .openai import OpenAIInferenceAdapter

View file

@ -54,7 +54,7 @@ class InferenceStore:
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
await self.sql_store.create_table( await self.sql_store.create_table(
"chat_completions", "chat_completions",
{ {
@ -202,7 +202,6 @@ class InferenceStore:
order_by=[("created", order.value)], order_by=[("created", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [ data = [
@ -229,7 +228,6 @@ class InferenceStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
table="chat_completions", table="chat_completions",
where={"id": completion_id}, where={"id": completion_id},
policy=self.policy,
) )
if not row: if not row:

View file

@ -103,7 +103,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
Model( Model(
identifier=id, identifier=id,
provider_resource_id=entry.provider_model_id, provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm, model_type=entry.model_type,
metadata=entry.metadata, metadata=entry.metadata,
provider_id=self.__provider_id__, provider_id=self.__provider_id__,
) )

View file

@ -28,8 +28,7 @@ class ResponsesStore:
sql_store_config = SqliteSqlStoreConfig( sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(), db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
) )
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config)) self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
self.policy = policy
async def initialize(self): async def initialize(self):
"""Create the necessary tables if they don't exist.""" """Create the necessary tables if they don't exist."""
@ -87,7 +86,6 @@ class ResponsesStore:
order_by=[("created_at", order.value)], order_by=[("created_at", order.value)],
cursor=("id", after) if after else None, cursor=("id", after) if after else None,
limit=limit, limit=limit,
policy=self.policy,
) )
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data] data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
@ -105,7 +103,6 @@ class ResponsesStore:
row = await self.sql_store.fetch_one( row = await self.sql_store.fetch_one(
"openai_responses", "openai_responses",
where={"id": response_id}, where={"id": response_id},
policy=self.policy,
) )
if not row: if not row:
@ -116,7 +113,7 @@ class ResponsesStore:
return OpenAIResponseObjectWithInput(**row["response_object"]) return OpenAIResponseObjectWithInput(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject: async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy) row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
if not row: if not row:
raise ValueError(f"Response with id {response_id} not found") raise ValueError(f"Response with id {response_id} not found")
await self.sql_store.delete("openai_responses", where={"id": response_id}) await self.sql_store.delete("openai_responses", where={"id": response_id})

View file

@ -53,13 +53,15 @@ class AuthorizedSqlStore:
access control policies, user attribute capture, and SQL filtering optimization. access control policies, user attribute capture, and SQL filtering optimization.
""" """
def __init__(self, sql_store: SqlStore): def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
""" """
Initialize the authorization layer. Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap :param sql_store: Base SqlStore implementation to wrap
:param policy: Access control policy to use for authorization
""" """
self.sql_store = sql_store self.sql_store = sql_store
self.policy = policy
self._detect_database_type() self._detect_database_type()
self._validate_sql_optimized_policy() self._validate_sql_optimized_policy()
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
async def fetch_all( async def fetch_all(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
limit: int | None = None, limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None, cursor: tuple[str, str] | None = None,
) -> PaginatedResponse: ) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering.""" """Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(policy) access_where = self._build_access_control_where_clause(self.policy)
rows = await self.sql_store.fetch_all( rows = await self.sql_store.fetch_all(
table=table, table=table,
where=where, where=where,
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs) str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
) )
if is_action_allowed(policy, Action.READ, sql_record, current_user): if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
filtered_rows.append(row) filtered_rows.append(row)
return PaginatedResponse( return PaginatedResponse(
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
async def fetch_one( async def fetch_one(
self, self,
table: str, table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None, where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None, order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None: ) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking.""" """Fetch one row with automatic access control checking."""
results = await self.fetch_all( results = await self.fetch_all(
table=table, table=table,
policy=policy,
where=where, where=where,
limit=1, limit=1,
order_by=order_by, order_by=order_by,

View file

@ -203,6 +203,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ] - '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
Returns a list of unique identifiers or None if structure doesn't match. Returns a list of unique identifiers or None if structure doesn't match.
""" """
if "models" in response["body"]:
# ollama
items = response["body"]["models"]
else:
# openai
items = response["body"] items = response["body"]
idents = [m.model if endpoint == "/api/tags" else m.id for m in items] idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
return sorted(set(idents)) return sorted(set(idents))

File diff suppressed because it is too large Load diff

View file

@ -14,7 +14,7 @@
}, },
"dependencies": { "dependencies": {
"@radix-ui/react-collapsible": "^1.1.12", "@radix-ui/react-collapsible": "^1.1.12",
"@radix-ui/react-dialog": "^1.1.13", "@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-dropdown-menu": "^2.1.16", "@radix-ui/react-dropdown-menu": "^2.1.16",
"@radix-ui/react-select": "^2.2.6", "@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7", "@radix-ui/react-separator": "^1.1.7",
@ -32,7 +32,7 @@
"react-dom": "^19.1.1", "react-dom": "^19.1.1",
"react-markdown": "^10.1.0", "react-markdown": "^10.1.0",
"remark-gfm": "^4.0.1", "remark-gfm": "^4.0.1",
"remeda": "^2.30.0", "remeda": "^2.32.0",
"shiki": "^1.29.2", "shiki": "^1.29.2",
"sonner": "^2.0.7", "sonner": "^2.0.7",
"tailwind-merge": "^3.3.1" "tailwind-merge": "^3.3.1"
@ -52,7 +52,7 @@
"eslint-config-prettier": "^10.1.8", "eslint-config-prettier": "^10.1.8",
"eslint-plugin-prettier": "^5.5.4", "eslint-plugin-prettier": "^5.5.4",
"jest": "^29.7.0", "jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0", "jest-environment-jsdom": "^30.1.2",
"prettier": "3.6.2", "prettier": "3.6.2",
"tailwindcss": "^4", "tailwindcss": "^4",
"ts-node": "^10.9.2", "ts-node": "^10.9.2",

View file

@ -1,6 +1,5 @@
adapter: adapter_type: kaze
adapter_type: kaze pip_packages: ["tests/external/llama-stack-provider-kaze"]
pip_packages: ["tests/external/llama-stack-provider-kaze"] config_class: llama_stack_provider_kaze.config.KazeProviderConfig
config_class: llama_stack_provider_kaze.config.KazeProviderConfig module: llama_stack_provider_kaze
module: llama_stack_provider_kaze
optional_api_dependencies: [] optional_api_dependencies: []

View file

@ -6,7 +6,7 @@
from typing import Protocol from typing import Protocol
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec
from llama_stack.schema_utils import webmethod from llama_stack.schema_utils import webmethod
@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]:
api=Api.weather, api=Api.weather,
provider_type="remote::kaze", provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig", config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec(
adapter_type="kaze", adapter_type="kaze",
module="llama_stack_provider_kaze", module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"], pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
), ),
] ]

View file

@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
provider = provider_from_model(client, model_id) provider = provider_from_model(client, model_id)
if provider.provider_type in ( if provider.provider_type in (
"remote::together", # service returns 400 "remote::together", # service returns 400
"remote::fireworks", # service returns 400 malformed input
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
provider = provider_from_model(client, model_id) provider = provider_from_model(client, model_id)
if provider.provider_type in ( if provider.provider_type in (
"remote::together", # param silently ignored, always returns floats "remote::together", # param silently ignored, always returns floats
"remote::fireworks", # param silently ignored, always returns list of floats
): ):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.") pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
input=input_texts, input=input_texts,
encoding_format="base64", encoding_format="base64",
) )
# Validate response structure # Validate response structure
assert response.object == "list" assert response.object == "list"
assert response.model == embedding_model_id assert response.model == embedding_model_id

View file

@ -57,7 +57,7 @@ def authorized_store(backend_config):
config = config_func() config = config_func()
base_sqlstore = sqlstore_impl(config) base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
yield authorized_store yield authorized_store
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"}) await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison # Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None assert result.data[0]["access_attributes"] is None
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"}) await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both # Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 2 assert len(result.data) == 2
# Test with non-admin user # Test with non-admin user
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
# Should only see public record # Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["id"] == "1" assert result.data[0]["id"] == "1"
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
# Now test with the multi-user who has both roles=admin and teams=dev # Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy()) result = await authorized_store.fetch_all(table_name)
# Should see: # Should see:
# - public record (1) - no access_attributes # - public record (1) - no access_attributes
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
), ),
] ]
# Create a new authorized store with the owner-only policy
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
# Test user1 access - should only see their own record # Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1 mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Test user2 access - should only see their own record # Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2 mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}" assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}" assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Test with anonymous user - should see no records # Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy) result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}" assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally: finally:

View file

@ -0,0 +1,990 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"prompt": "Respond to this question and explain your answer. Complete the sentence using one word: Roses are red, violets are ",
"max_tokens": 50,
"stream": true,
"extra_body": {}
},
"endpoint": "/v1/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "Blue"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ".\n\n"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "The"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " completed"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " sentence"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " is"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " a"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " well"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "-known"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " phrase"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " from"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " a"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " traditional"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " English"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " poem"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ":\n\n"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "\""
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "R"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "oses"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " are"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " red"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ","
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " v"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "io"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "lets"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " are"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " blue"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ",\n"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "Sugar"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " is"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " sweet"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ","
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " and"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " so"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " are"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " you"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ".\""
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " However"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ","
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " in"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " many"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " variations"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " of"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " this"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " poem"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ","
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " the"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " line"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " \""
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "vio"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": ""
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,43 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"prompt": "Respond to this question and explain your answer. Complete the sentence using one word: Roses are red, violets are ",
"stream": false,
"extra_body": {}
},
"endpoint": "/v1/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": {
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-104",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"text": "blue.\n\nI completed the sentence with \"blue\" because it is a common completion used to complete the traditional nursery rhyme, which ends with:\n\nRoses are red,\nViolets are blue.\n\nThe complete rhyme is often remembered and recited as follows:\n\nRoses are red,\nViolets are blue,\nSugar is sweet,\nAnd so are you!"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 72,
"prompt_tokens": 50,
"total_tokens": 122,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,43 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"prompt": "Say completions",
"max_tokens": 20,
"extra_body": {}
},
"endpoint": "/v1/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": {
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-406",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "Sure, I'd be happy to provide some definitions and examples of related words or phrases.\n\nTo better"
}
],
"created": 1757857133,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 28,
"total_tokens": 48,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -115,6 +115,15 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
"text_model": "cerebras/llama-3.3-70b", "text_model": "cerebras/llama-3.3-70b",
}, },
), ),
"fireworks": Setup(
name="fireworks",
description="Fireworks provider with a text model",
defaults={
"text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
},
),
} }

View file

@ -66,10 +66,9 @@ def base_config(tmp_path):
def provider_spec_yaml(): def provider_spec_yaml():
"""Common provider spec YAML for testing.""" """Common provider spec YAML for testing."""
return """ return """
adapter: adapter_type: test_provider
adapter_type: test_provider config_class: test_provider.config.TestProviderConfig
config_class: test_provider.config.TestProviderConfig module: test_provider
module: test_provider
api_dependencies: api_dependencies:
- safety - safety
""" """
@ -182,9 +181,9 @@ class TestProviderRegistry:
assert Api.inference in registry assert Api.inference in registry
assert "remote::test_provider" in registry[Api.inference] assert "remote::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["remote::test_provider"] provider = registry[Api.inference]["remote::test_provider"]
assert provider.adapter.adapter_type == "test_provider" assert provider.adapter_type == "test_provider"
assert provider.adapter.module == "test_provider" assert provider.module == "test_provider"
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig" assert provider.config_class == "test_provider.config.TestProviderConfig"
assert Api.safety in provider.api_dependencies assert Api.safety in provider.api_dependencies
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml): def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
@ -246,8 +245,7 @@ class TestProviderRegistry:
"""Test handling of malformed remote provider spec (missing required fields).""" """Test handling of malformed remote provider spec (missing required fields)."""
remote_dir, _ = api_directories remote_dir, _ = api_directories
malformed_spec = """ malformed_spec = """
adapter: adapter_type: test_provider
adapter_type: test_provider
# Missing required fields # Missing required fields
api_dependencies: api_dependencies:
- safety - safety
@ -270,7 +268,7 @@ pip_packages:
with open(inline_dir / "malformed.yaml", "w") as f: with open(inline_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec) f.write(malformed_spec)
with pytest.raises(KeyError) as exc_info: with pytest.raises(ValidationError) as exc_info:
get_provider_registry(base_config) get_provider_registry(base_config)
assert "config_class" in str(exc_info.value) assert "config_class" in str(exc_info.value)

View file

@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests") client = LlamaStackAsLibraryClient("ci-tests")
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests") client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests") client = LlamaStackAsLibraryClient("ci-tests")
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests") client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {} mock_impls = {}
mock_route_impls = RouteImpls({}) mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry): class MockStack:
return mock_impls def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls): def mock_initialize_route_impls(impls):
return mock_route_impls return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack) monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls) monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests") sync_client = LlamaStackAsLibraryClient("ci-tests")

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control # Create table with access control
await sqlstore.create_table( await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents # Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document" assert result.data[0]["title"] == "Admin Document"
# User should only see their document # User should only see their document
mock_get_authenticated_user.return_value = regular_user mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1}) result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0 assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2}) result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1 assert len(result.data) == 1
assert result.data[0]["title"] == "User Document" assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1}) row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2}) row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None assert row is not None
assert row["title"] == "User Document" assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
sqlstore = AuthorizedSqlStore(base_sqlstore) sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table( await sqlstore.create_table(
table="resources", table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"]) user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy) sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data} sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set() policy_ids = set()
for scenario in test_scenarios: for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name, db_path=tmp_dir + "/" + db_name,
) )
) )
authorized_store = AuthorizedSqlStore(base_sqlstore) authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table( await authorized_store.create_table(
table="user_data", table="user_data",