Merge branch 'main' into use-openai-for-cerebras

This commit is contained in:
Matthew Farrellee 2025-09-23 07:31:11 -04:00
commit 9ceb45f611
53 changed files with 2612 additions and 1966 deletions

View file

@ -24,7 +24,7 @@ jobs:
uses: actions/checkout@08c6903cd8c0fde910a37f88322edcfb5dd907a8 # v5.0.0
- name: Install uv
uses: astral-sh/setup-uv@557e51de59eb14aaaba2ed9621916900a91d50c6 # v6.6.1
uses: astral-sh/setup-uv@b75a909f75acd358c2196fb9a5f1299a9a8868a4 # v6.7.0
with:
python-version: ${{ matrix.python-version }}
activate-environment: true

View file

@ -17,11 +17,8 @@ export POSTGRES_PASSWORD=llamastack
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export MOCK_INFERENCE_MODEL=mock-inference
export MOCK_INFERENCE_URL=openai-mock-service:8080
export BENCHMARK_INFERENCE_MODEL=$INFERENCE_MODEL
export LLAMA_STACK_WORKERS=4
set -euo pipefail
set -x

View file

@ -5,6 +5,7 @@ data:
image_name: kubernetes-benchmark-demo
apis:
- agents
- files
- inference
- files
- safety
@ -23,6 +24,14 @@ data:
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db
vector_io:
- provider_id: ${env.ENABLE_CHROMADB:+chromadb}
provider_type: remote::chromadb

View file

@ -52,9 +52,20 @@ spec:
value: http://vllm-server-safety.default.svc.cluster.local:8001/v1
- name: VLLM_TLS_VERIFY
value: "false"
command: ["python", "-m", "llama_stack.core.server.server", "/etc/config/stack_run_config.yaml", "--port", "8323"]
- name: LLAMA_STACK_LOGGING
value: "all=WARNING"
- name: LLAMA_STACK_CONFIG
value: "/etc/config/stack_run_config.yaml"
- name: LLAMA_STACK_WORKERS
value: "${LLAMA_STACK_WORKERS}"
command: ["uvicorn", "llama_stack.core.server.server:create_app", "--host", "0.0.0.0", "--port", "8323", "--workers", "$LLAMA_STACK_WORKERS", "--factory"]
ports:
- containerPort: 8323
resources:
requests:
cpu: "${LLAMA_STACK_WORKERS}"
limits:
cpu: "${LLAMA_STACK_WORKERS}"
volumeMounts:
- name: llama-storage
mountPath: /root/.llama

View file

@ -11,6 +11,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs`, `remote::nvidia` |
| eval | `remote::nvidia` |
| files | `inline::localfs` |
| inference | `remote::nvidia` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |

View file

@ -460,10 +460,12 @@ client = LlamaStackClient(base_url="http://localhost:8321")
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
embedding_model = embed_lm.identifier
vector_db_id = f"v{uuid.uuid4().hex}"
client.vector_dbs.register(
# The VectorDB API is deprecated; the server now returns its own authoritative ID.
# We capture the correct ID from the response's .identifier attribute.
vector_db_id = client.vector_dbs.register(
vector_db_id=vector_db_id,
embedding_model=embedding_model,
)
).identifier
# Create Documents
urls = [

View file

@ -23,7 +23,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation
You can install Milvus using pymilvus:
If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash
pip install pymilvus

View file

@ -147,7 +147,7 @@ WORKDIR /app
RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.12 python3.12-pip python3.12-wheel \
python3.12-setuptools python3.12-devel gcc make && \
python3.12-setuptools python3.12-devel gcc gcc-c++ make && \
ln -s /bin/pip3.12 /bin/pip && ln -s /bin/python3.12 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1
@ -164,7 +164,7 @@ RUN apt-get update && apt-get install -y \
procps psmisc lsof \
traceroute \
bubblewrap \
gcc \
gcc g++ \
&& rm -rf /var/lib/apt/lists/*
ENV UV_SYSTEM_PYTHON=1

View file

@ -121,10 +121,6 @@ class AutoRoutedProviderSpec(ProviderSpec):
default=None,
)
@property
def pip_packages(self) -> list[str]:
raise AssertionError("Should not be called on AutoRoutedProviderSpec")
# Example: /models, /shields
class RoutingTableProviderSpec(ProviderSpec):

View file

@ -16,11 +16,10 @@ from llama_stack.core.datatypes import BuildConfig, DistributionSpec
from llama_stack.core.external import load_external_apis
from llama_stack.log import get_logger
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
logger = get_logger(name=__name__, category="core")
@ -77,27 +76,12 @@ def providable_apis() -> list[Api]:
def _load_remote_provider_spec(spec_data: dict[str, Any], api: Api) -> ProviderSpec:
adapter = AdapterSpec(**spec_data["adapter"])
spec = remote_provider_spec(
api=api,
adapter=adapter,
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
)
spec = RemoteProviderSpec(api=api, provider_type=f"remote::{spec_data['adapter_type']}", **spec_data)
return spec
def _load_inline_provider_spec(spec_data: dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
spec = InlineProviderSpec(
api=api,
provider_type=f"inline::{provider_name}",
pip_packages=spec_data.get("pip_packages", []),
module=spec_data["module"],
config_class=spec_data["config_class"],
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
provider_data_validator=spec_data.get("provider_data_validator"),
container_image=spec_data.get("container_image"),
)
spec = InlineProviderSpec(api=api, provider_type=f"inline::{provider_name}", **spec_data)
return spec

View file

@ -40,7 +40,7 @@ from llama_stack.core.request_headers import (
from llama_stack.core.resolver import ProviderRegistry
from llama_stack.core.server.routes import RouteImpls, find_matching_route, initialize_route_impls
from llama_stack.core.stack import (
construct_stack,
Stack,
get_stack_run_config_from_distro,
replace_env_vars,
)
@ -252,7 +252,10 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
try:
self.route_impls = None
self.impls = await construct_stack(self.config, self.custom_provider_registry)
stack = Stack(self.config, self.custom_provider_registry)
await stack.initialize()
self.impls = stack.impls
except ModuleNotFoundError as _e:
cprint(_e.msg, color="red", file=sys.stderr)
cprint(
@ -289,6 +292,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
)
raise _e
assert self.impls is not None
if Api.telemetry in self.impls:
setup_logger(self.impls[Api.telemetry])

View file

@ -33,7 +33,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
try:
models = await provider.list_models()
except Exception as e:
logger.exception(f"Model refresh failed for provider {provider_id}: {e}")
logger.warning(f"Model refresh failed for provider {provider_id}: {e}")
continue
self.listed_providers.add(provider_id)

View file

@ -6,6 +6,7 @@
import argparse
import asyncio
import concurrent.futures
import functools
import inspect
import json
@ -50,17 +51,15 @@ from llama_stack.core.request_headers import (
request_provider_data_context,
user_from_scope,
)
from llama_stack.core.resolver import InvalidProviderError
from llama_stack.core.server.routes import (
find_matching_route,
get_all_api_routes,
initialize_route_impls,
)
from llama_stack.core.stack import (
Stack,
cast_image_name_to_string,
construct_stack,
replace_env_vars,
shutdown_stack,
validate_env_pair,
)
from llama_stack.core.utils.config import redact_sensitive_fields
@ -156,21 +155,34 @@ def translate_exception(exc: Exception) -> HTTPException | RequestValidationErro
)
async def shutdown(app):
"""Initiate a graceful shutdown of the application.
Handled by the lifespan context manager. The shutdown process involves
shutting down all implementations registered in the application.
class StackApp(FastAPI):
"""
await shutdown_stack(app.__llama_stack_impls__)
A wrapper around the FastAPI application to hold a reference to the Stack instance so that we can
start background tasks (e.g. refresh model registry periodically) from the lifespan context manager.
"""
def __init__(self, config: StackRunConfig, *args, **kwargs):
super().__init__(*args, **kwargs)
self.stack: Stack = Stack(config)
# This code is called from a running event loop managed by uvicorn so we cannot simply call
# asyncio.run() to initialize the stack. We cannot await either since this is not an async
# function.
# As a workaround, we use a thread pool executor to run the initialize() method
# in a separate thread.
with concurrent.futures.ThreadPoolExecutor() as executor:
future = executor.submit(asyncio.run, self.stack.initialize())
future.result()
@asynccontextmanager
async def lifespan(app: FastAPI):
async def lifespan(app: StackApp):
logger.info("Starting up")
assert app.stack is not None
app.stack.create_registry_refresh_task()
yield
logger.info("Shutting down")
await shutdown(app)
await app.stack.shutdown()
def is_streaming_request(func_name: str, request: Request, **kwargs):
@ -386,73 +398,61 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send)
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
def create_app(
config_file: str | None = None,
env_vars: list[str] | None = None,
) -> StackApp:
"""Create and configure the FastAPI application.
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
Args:
config_file: Path to config file. If None, uses LLAMA_STACK_CONFIG env var or default resolution.
env_vars: List of environment variables in KEY=value format.
disable_version_check: Whether to disable version checking. If None, uses LLAMA_STACK_DISABLE_VERSION_CHECK env var.
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
Returns:
Configured StackApp instance.
"""
config_file = config_file or os.getenv("LLAMA_STACK_CONFIG")
if config_file is None:
raise ValueError("No config file provided and LLAMA_STACK_CONFIG env var is not set")
config_or_distro = get_config_from_args(args)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
config_file = resolve_config_or_distro(config_file, Mode.RUN)
# Load and process configuration
logger_config = None
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
logger = get_logger(name=__name__, category="core::server", config=logger_config)
if args.env:
for env_pair in args.env:
if env_vars:
for env_pair in env_vars:
try:
key, value = validate_env_pair(env_pair)
logger.info(f"Setting CLI environment variable {key} => {value}")
logger.info(f"Setting environment variable {key} => {value}")
os.environ[key] = value
except ValueError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
raise ValueError(f"Invalid environment variable format: {env_pair}") from e
config = replace_env_vars(config_contents)
config = StackRunConfig(**cast_image_name_to_string(config))
_log_run_config(run_config=config)
app = FastAPI(
app = StackApp(
lifespan=lifespan,
docs_url="/docs",
redoc_url="/redoc",
openapi_url="/openapi.json",
config=config,
)
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
app.add_middleware(ClientVersionMiddleware)
try:
# Create and set the event loop that will be used for both construction and server runtime
loop = asyncio.new_event_loop()
asyncio.set_event_loop(loop)
# Construct the stack in the persistent event loop
impls = loop.run_until_complete(construct_stack(config))
except InvalidProviderError as e:
logger.error(f"Error: {str(e)}")
sys.exit(1)
impls = app.stack.impls
if config.server.auth:
logger.info(f"Enabling authentication with provider: {config.server.auth.provider_config.type.value}")
@ -553,9 +553,54 @@ def main(args: argparse.Namespace | None = None):
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
app.__llama_stack_impls__ = impls
app.add_middleware(TracingMiddleware, impls=impls, external_apis=external_apis)
return app
def main(args: argparse.Namespace | None = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
add_config_distro_args(parser)
parser.add_argument(
"--port",
type=int,
default=int(os.getenv("LLAMA_STACK_PORT", 8321)),
help="Port to listen on",
)
parser.add_argument(
"--env",
action="append",
help="Environment variables in KEY=value format. Can be specified multiple times.",
)
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
config_or_distro = get_config_from_args(args)
try:
app = create_app(
config_file=config_or_distro,
env_vars=args.env,
)
except Exception as e:
logger.error(f"Error creating app: {str(e)}")
sys.exit(1)
config_file = resolve_config_or_distro(config_or_distro, Mode.RUN)
with open(config_file) as fp:
config_contents = yaml.safe_load(fp)
if isinstance(config_contents, dict) and (cfg := config_contents.get("logging_config")):
logger_config = LoggingConfig(**cfg)
else:
logger_config = None
config = StackRunConfig(**cast_image_name_to_string(replace_env_vars(config_contents)))
import uvicorn
# Configure SSL if certificates are provided
@ -593,7 +638,6 @@ def main(args: argparse.Namespace | None = None):
if ssl_config:
uvicorn_config.update(ssl_config)
# Run uvicorn in the existing event loop to preserve background tasks
# We need to catch KeyboardInterrupt because uvicorn's signal handling
# re-raises SIGINT signals using signal.raise_signal(), which Python
# converts to KeyboardInterrupt. Without this catch, we'd get a confusing
@ -604,13 +648,9 @@ def main(args: argparse.Namespace | None = None):
# Another approach would be to ignore SIGINT entirely - let uvicorn handle it through its own
# signal handling but this is quite intrusive and not worth the effort.
try:
loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
asyncio.run(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve())
except (KeyboardInterrupt, SystemExit):
logger.info("Received interrupt signal, shutting down gracefully...")
finally:
if not loop.is_closed():
logger.debug("Closing event loop")
loop.close()
def _log_run_config(run_config: StackRunConfig):

View file

@ -315,11 +315,15 @@ def add_internal_implementations(impls: dict[Api, Any], run_config: StackRunConf
impls[Api.prompts] = prompts_impl
class Stack:
def __init__(self, run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None):
self.run_config = run_config
self.provider_registry = provider_registry
self.impls = None
# Produces a stack of providers for the given run config. Not all APIs may be
# asked for in the run config.
async def construct_stack(
run_config: StackRunConfig, provider_registry: ProviderRegistry | None = None
) -> dict[Api, Any]:
async def initialize(self):
if "LLAMA_STACK_TEST_INFERENCE_MODE" in os.environ:
from llama_stack.testing.inference_recorder import setup_inference_recording
@ -329,24 +333,28 @@ async def construct_stack(
TEST_RECORDING_CONTEXT.__enter__()
logger.info(f"Inference recording enabled: mode={os.environ.get('LLAMA_STACK_TEST_INFERENCE_MODE')}")
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
policy = run_config.server.auth.access_policy if run_config.server.auth else []
dist_registry, _ = await create_dist_registry(self.run_config.metadata_store, self.run_config.image_name)
policy = self.run_config.server.auth.access_policy if self.run_config.server.auth else []
impls = await resolve_impls(
run_config, provider_registry or get_provider_registry(run_config), dist_registry, policy
self.run_config, self.provider_registry or get_provider_registry(self.run_config), dist_registry, policy
)
# Add internal implementations after all other providers are resolved
add_internal_implementations(impls, run_config)
add_internal_implementations(impls, self.run_config)
if Api.prompts in impls:
await impls[Api.prompts].initialize()
await register_resources(run_config, impls)
await register_resources(self.run_config, impls)
await refresh_registry_once(impls)
self.impls = impls
def create_registry_refresh_task(self):
assert self.impls is not None, "Must call initialize() before starting"
global REGISTRY_REFRESH_TASK
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(impls))
REGISTRY_REFRESH_TASK = asyncio.create_task(refresh_registry_task(self.impls))
def cb(task):
import traceback
@ -360,11 +368,9 @@ async def construct_stack(
logger.debug("Model refresh task completed")
REGISTRY_REFRESH_TASK.add_done_callback(cb)
return impls
async def shutdown_stack(impls: dict[Api, Any]):
for impl in impls.values():
async def shutdown(self):
for impl in self.impls.values():
impl_name = impl.__class__.__name__
logger.info(f"Shutting down {impl_name}")
try:

View file

@ -23,6 +23,8 @@ distribution_spec:
- provider_type: inline::basic
tool_runtime:
- provider_type: inline::rag-runtime
files:
- provider_type: inline::localfs
image_type: venv
additional_pip_packages:
- aiosqlite

View file

@ -8,6 +8,7 @@ from pathlib import Path
from llama_stack.core.datatypes import BuildProvider, ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.distributions.template import DistributionTemplate, RunConfigSettings, get_model_registry
from llama_stack.providers.inline.files.localfs.config import LocalfsFilesImplConfig
from llama_stack.providers.remote.datasetio.nvidia import NvidiaDatasetIOConfig
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
@ -15,7 +16,7 @@ from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
def get_distribution_template() -> DistributionTemplate:
def get_distribution_template(name: str = "nvidia") -> DistributionTemplate:
providers = {
"inference": [BuildProvider(provider_type="remote::nvidia")],
"vector_io": [BuildProvider(provider_type="inline::faiss")],
@ -30,6 +31,7 @@ def get_distribution_template() -> DistributionTemplate:
],
"scoring": [BuildProvider(provider_type="inline::basic")],
"tool_runtime": [BuildProvider(provider_type="inline::rag-runtime")],
"files": [BuildProvider(provider_type="inline::localfs")],
}
inference_provider = Provider(
@ -52,6 +54,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia",
config=NVIDIAEvalConfig.sample_run_config(),
)
files_provider = Provider(
provider_id="meta-reference-files",
provider_type="inline::localfs",
config=LocalfsFilesImplConfig.sample_run_config(f"~/.llama/distributions/{name}"),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
@ -73,7 +80,7 @@ def get_distribution_template() -> DistributionTemplate:
default_models, _ = get_model_registry(available_models)
return DistributionTemplate(
name="nvidia",
name=name,
distro_type="self_hosted",
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
container_image=None,
@ -86,6 +93,7 @@ def get_distribution_template() -> DistributionTemplate:
"inference": [inference_provider],
"datasetio": [datasetio_provider],
"eval": [eval_provider],
"files": [files_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
@ -97,6 +105,7 @@ def get_distribution_template() -> DistributionTemplate:
safety_provider,
],
"eval": [eval_provider],
"files": [files_provider],
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
@ -88,6 +89,14 @@ providers:
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db

View file

@ -4,6 +4,7 @@ apis:
- agents
- datasetio
- eval
- files
- inference
- post_training
- safety
@ -77,6 +78,14 @@ providers:
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
files:
- provider_id: meta-reference-files
provider_type: inline::localfs
config:
storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/nvidia/files}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/files_metadata.db
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/nvidia}/registry.db

View file

@ -78,12 +78,12 @@ def get_remote_inference_providers() -> list[Provider]:
remote_providers = [
provider
for provider in available_providers()
if isinstance(provider, RemoteProviderSpec) and provider.adapter.adapter_type in ENABLED_INFERENCE_PROVIDERS
if isinstance(provider, RemoteProviderSpec) and provider.adapter_type in ENABLED_INFERENCE_PROVIDERS
]
inference_providers = []
for provider_spec in remote_providers:
provider_type = provider_spec.adapter.adapter_type
provider_type = provider_spec.adapter_type
if provider_type in INFERENCE_PROVIDER_IDS:
provider_id = INFERENCE_PROVIDER_IDS[provider_type]

View file

@ -131,6 +131,15 @@ class ProviderSpec(BaseModel):
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
provider_data_validator: str | None = Field(
default=None,
)
is_external: bool = Field(default=False, description="Notes whether this provider is an external provider.")
# used internally by the resolver; this is a hack for now
@ -145,45 +154,8 @@ class RoutingTable(Protocol):
async def get_provider_impl(self, routing_key: str) -> Any: ...
# TODO: this can now be inlined into RemoteProviderSpec
@json_schema_type
class AdapterSpec(BaseModel):
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
module: str = Field(
default_factory=str,
description="""
Fully-qualified name of the module to import. The module is expected to have:
- `get_adapter_impl(config, deps)`: returns the adapter implementation
""",
)
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
config_class: str = Field(
description="Fully-qualified classname of the config for this provider",
)
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
A description of the provider. This is used to display in the documentation.
""",
)
@json_schema_type
class InlineProviderSpec(ProviderSpec):
pip_packages: list[str] = Field(
default_factory=list,
description="The pip dependencies needed for this implementation",
)
container_image: str | None = Field(
default=None,
description="""
@ -191,10 +163,6 @@ The container image to use for this implementation. If one is provided, pip_pack
If a provider depends on other providers, the dependencies MUST NOT specify a container image.
""",
)
# module field is inherited from ProviderSpec
provider_data_validator: str | None = Field(
default=None,
)
description: str | None = Field(
default=None,
description="""
@ -223,10 +191,15 @@ class RemoteProviderConfig(BaseModel):
@json_schema_type
class RemoteProviderSpec(ProviderSpec):
adapter: AdapterSpec = Field(
adapter_type: str = Field(
...,
description="Unique identifier for this adapter",
)
description: str | None = Field(
default=None,
description="""
If some code is needed to convert the remote responses into Llama Stack compatible
API responses, specify the adapter here.
A description of the provider. This is used to display in the documentation.
""",
)
@ -234,33 +207,6 @@ API responses, specify the adapter here.
def container_image(self) -> str | None:
return None
# module field is inherited from ProviderSpec
@property
def pip_packages(self) -> list[str]:
return self.adapter.pip_packages
@property
def provider_data_validator(self) -> str | None:
return self.adapter.provider_data_validator
def remote_provider_spec(
api: Api,
adapter: AdapterSpec,
api_dependencies: list[Api] | None = None,
optional_api_dependencies: list[Api] | None = None,
) -> RemoteProviderSpec:
return RemoteProviderSpec(
api=api,
provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class,
module=adapter.module,
adapter=adapter,
api_dependencies=api_dependencies or [],
optional_api_dependencies=optional_api_dependencies or [],
)
class HealthStatus(StrEnum):
OK = "OK"

View file

@ -44,7 +44,7 @@ class LocalfsFilesImpl(Files):
storage_path.mkdir(parents=True, exist_ok=True)
# Initialize SQL store for metadata
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store))
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store), self.policy)
await self.sql_store.create_table(
"openai_files",
{
@ -74,7 +74,7 @@ class LocalfsFilesImpl(Files):
if not self.sql_store:
raise RuntimeError("Files provider not initialized")
row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id})
row = await self.sql_store.fetch_one("openai_files", where={"id": file_id})
if not row:
raise ResourceNotFoundError(file_id, "File", "client.files.list()")
@ -150,7 +150,6 @@ class LocalfsFilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions if where_conditions else None,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -25,10 +24,10 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[],
description="Local filesystem-based dataset I/O provider for reading and writing datasets to local storage.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="huggingface",
provider_type="remote::huggingface",
pip_packages=[
"datasets>=4.0.0",
],
@ -36,17 +35,15 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.datasetio.huggingface.HuggingfaceDatasetIOConfig",
description="HuggingFace datasets provider for accessing and managing datasets from the HuggingFace Hub.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.datasetio,
adapter=AdapterSpec(
adapter_type="nvidia",
provider_type="remote::nvidia",
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
pip_packages=[
"datasets>=4.0.0",
],
module="llama_stack.providers.remote.datasetio.nvidia",
config_class="llama_stack.providers.remote.datasetio.nvidia.NvidiaDatasetIOConfig",
description="NVIDIA's dataset I/O provider for accessing datasets from NVIDIA's data platform.",
),
),
]

View file

@ -5,7 +5,7 @@
# the root directory of this source tree.
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
def available_providers() -> list[ProviderSpec]:
@ -25,17 +25,16 @@ def available_providers() -> list[ProviderSpec]:
],
description="Meta's reference implementation of evaluation tasks with support for multiple languages and evaluation metrics.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.eval,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"requests",
],
provider_type="remote::nvidia",
module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
description="NVIDIA's evaluation provider for running evaluation tasks on NVIDIA's platform.",
),
api_dependencies=[
Api.datasetio,
Api.datasets,

View file

@ -4,13 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
)
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
from llama_stack.providers.utils.sqlstore.sqlstore import sql_store_pip_packages
@ -25,14 +19,13 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.files.localfs.config.LocalfsFilesImplConfig",
description="Local filesystem-based file storage provider for managing files and documents locally.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.files,
adapter=AdapterSpec(
provider_type="remote::s3",
adapter_type="s3",
pip_packages=["boto3"] + sql_store_pip_packages,
module="llama_stack.providers.remote.files.s3",
config_class="llama_stack.providers.remote.files.s3.config.S3FilesImplConfig",
description="AWS S3-based file storage provider for scalable cloud file management with metadata persistence.",
),
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
META_REFERENCE_DEPS = [
@ -49,10 +48,10 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
description="Sentence Transformers inference provider for text embeddings and similarity search.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="cerebras",
provider_type="remote::cerebras",
pip_packages=[
"cerebras_cloud_sdk",
],
@ -60,62 +59,56 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.remote.inference.cerebras.CerebrasImplConfig",
description="Cerebras inference provider for running models on Cerebras Cloud platform.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="ollama",
provider_type="remote::ollama",
pip_packages=["ollama", "aiohttp", "h11>=0.16.0"],
config_class="llama_stack.providers.remote.inference.ollama.OllamaImplConfig",
module="llama_stack.providers.remote.inference.ollama",
description="Ollama inference provider for running local models through the Ollama runtime.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="vllm",
provider_type="remote::vllm",
pip_packages=[],
module="llama_stack.providers.remote.inference.vllm",
config_class="llama_stack.providers.remote.inference.vllm.VLLMInferenceAdapterConfig",
provider_data_validator="llama_stack.providers.remote.inference.vllm.VLLMProviderDataValidator",
description="Remote vLLM inference provider for connecting to vLLM servers.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="tgi",
provider_type="remote::tgi",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.TGIImplConfig",
description="Text Generation Inference (TGI) provider for HuggingFace model serving.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="hf::serverless",
provider_type="remote::hf::serverless",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceAPIImplConfig",
description="HuggingFace Inference API serverless provider for on-demand model inference.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
provider_type="remote::hf::endpoint",
adapter_type="hf::endpoint",
pip_packages=["huggingface_hub", "aiohttp"],
module="llama_stack.providers.remote.inference.tgi",
config_class="llama_stack.providers.remote.inference.tgi.InferenceEndpointImplConfig",
description="HuggingFace Inference Endpoints provider for dedicated model serving.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="fireworks",
provider_type="remote::fireworks",
pip_packages=[
"fireworks-ai<=0.17.16",
],
@ -124,11 +117,10 @@ def available_providers() -> list[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.fireworks.FireworksProviderDataValidator",
description="Fireworks AI inference provider for Llama models and other AI models on the Fireworks platform.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="together",
provider_type="remote::together",
pip_packages=[
"together",
],
@ -137,85 +129,82 @@ def available_providers() -> list[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.together.TogetherProviderDataValidator",
description="Together AI inference provider for open-source models and collaborative AI development.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.inference.bedrock",
config_class="llama_stack.providers.remote.inference.bedrock.BedrockConfig",
description="AWS Bedrock inference provider for accessing various AI models through AWS's managed service.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="databricks",
provider_type="remote::databricks",
pip_packages=[],
module="llama_stack.providers.remote.inference.databricks",
config_class="llama_stack.providers.remote.inference.databricks.DatabricksImplConfig",
description="Databricks inference provider for running models on Databricks' unified analytics platform.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=[],
module="llama_stack.providers.remote.inference.nvidia",
config_class="llama_stack.providers.remote.inference.nvidia.NVIDIAConfig",
description="NVIDIA inference provider for accessing NVIDIA NIM models and AI services.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="runpod",
provider_type="remote::runpod",
pip_packages=[],
module="llama_stack.providers.remote.inference.runpod",
config_class="llama_stack.providers.remote.inference.runpod.RunpodImplConfig",
description="RunPod inference provider for running models on RunPod's cloud GPU platform.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="openai",
provider_type="remote::openai",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.openai",
config_class="llama_stack.providers.remote.inference.openai.OpenAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.openai.config.OpenAIProviderDataValidator",
description="OpenAI inference provider for accessing GPT models and other OpenAI services.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="anthropic",
provider_type="remote::anthropic",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.anthropic",
config_class="llama_stack.providers.remote.inference.anthropic.AnthropicConfig",
provider_data_validator="llama_stack.providers.remote.inference.anthropic.config.AnthropicProviderDataValidator",
description="Anthropic inference provider for accessing Claude models and Anthropic's AI services.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="gemini",
pip_packages=["litellm"],
provider_type="remote::gemini",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.gemini",
config_class="llama_stack.providers.remote.inference.gemini.GeminiConfig",
provider_data_validator="llama_stack.providers.remote.inference.gemini.config.GeminiProviderDataValidator",
description="Google Gemini inference provider for accessing Gemini models and Google's AI services.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="vertexai",
pip_packages=["litellm", "google-cloud-aiplatform"],
provider_type="remote::vertexai",
pip_packages=[
"litellm",
"google-cloud-aiplatform",
],
module="llama_stack.providers.remote.inference.vertexai",
config_class="llama_stack.providers.remote.inference.vertexai.VertexAIConfig",
provider_data_validator="llama_stack.providers.remote.inference.vertexai.config.VertexAIProviderDataValidator",
@ -240,65 +229,63 @@ Available Models:
- vertex_ai/gemini-2.5-flash
- vertex_ai/gemini-2.5-pro""",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="groq",
pip_packages=["litellm"],
provider_type="remote::groq",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.groq",
config_class="llama_stack.providers.remote.inference.groq.GroqConfig",
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
description="Groq inference provider for ultra-fast inference using Groq's LPU technology.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="llama-openai-compat",
provider_type="remote::llama-openai-compat",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.llama_openai_compat",
config_class="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaCompatConfig",
provider_data_validator="llama_stack.providers.remote.inference.llama_openai_compat.config.LlamaProviderDataValidator",
description="Llama OpenAI-compatible provider for using Llama models with OpenAI API format.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="sambanova",
pip_packages=["litellm"],
provider_type="remote::sambanova",
pip_packages=[
"litellm",
],
module="llama_stack.providers.remote.inference.sambanova",
config_class="llama_stack.providers.remote.inference.sambanova.SambaNovaImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova inference provider for running models on SambaNova's dataflow architecture.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="passthrough",
provider_type="remote::passthrough",
pip_packages=[],
module="llama_stack.providers.remote.inference.passthrough",
config_class="llama_stack.providers.remote.inference.passthrough.PassthroughImplConfig",
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
description="Passthrough inference provider for connecting to any external inference service not directly supported.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="watsonx",
provider_type="remote::watsonx",
pip_packages=["ibm_watsonx_ai"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
description="IBM WatsonX inference provider for accessing AI models on IBM's WatsonX platform.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.inference,
adapter=AdapterSpec(
provider_type="remote::azure",
adapter_type="azure",
pip_packages=["litellm"],
module="llama_stack.providers.remote.inference.azure",
@ -310,5 +297,4 @@ Provider documentation
https://learn.microsoft.com/en-us/azure/ai-foundry/openai/overview
""",
),
),
]

View file

@ -7,7 +7,7 @@
from typing import cast
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec, RemoteProviderSpec
# We provide two versions of these providers so that distributions can package the appropriate version of torch.
# The CPU version is used for distributions that don't have GPU support -- they result in smaller container images.
@ -57,14 +57,13 @@ def available_providers() -> list[ProviderSpec]:
],
description="HuggingFace-based post-training provider for fine-tuning models using the HuggingFace ecosystem.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.post_training,
adapter=AdapterSpec(
adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
description="NVIDIA's post-training provider for fine-tuning models on NVIDIA's platform.",
),
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -48,35 +47,32 @@ def available_providers() -> list[ProviderSpec]:
config_class="llama_stack.providers.inline.safety.code_scanner.CodeScannerConfig",
description="Code Scanner safety provider for detecting security vulnerabilities and unsafe code patterns.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="bedrock",
provider_type="remote::bedrock",
pip_packages=["boto3"],
module="llama_stack.providers.remote.safety.bedrock",
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
description="AWS Bedrock safety provider for content moderation using AWS's safety services.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia",
provider_type="remote::nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
description="NVIDIA's safety provider for content moderation and safety filtering.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="sambanova",
provider_type="remote::sambanova",
pip_packages=["litellm", "requests"],
module="llama_stack.providers.remote.safety.sambanova",
config_class="llama_stack.providers.remote.safety.sambanova.SambaNovaSafetyConfig",
provider_data_validator="llama_stack.providers.remote.safety.sambanova.config.SambaNovaProviderDataValidator",
description="SambaNova's safety provider for content moderation and safety filtering.",
),
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -35,59 +34,54 @@ def available_providers() -> list[ProviderSpec]:
api_dependencies=[Api.vector_io, Api.inference, Api.files],
description="RAG (Retrieval-Augmented Generation) tool runtime for document ingestion, chunking, and semantic search.",
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="brave-search",
provider_type="remote::brave-search",
module="llama_stack.providers.remote.tool_runtime.brave_search",
config_class="llama_stack.providers.remote.tool_runtime.brave_search.config.BraveSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.brave_search.BraveSearchToolProviderDataValidator",
description="Brave Search tool for web search capabilities with privacy-focused results.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="bing-search",
provider_type="remote::bing-search",
module="llama_stack.providers.remote.tool_runtime.bing_search",
config_class="llama_stack.providers.remote.tool_runtime.bing_search.config.BingSearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.bing_search.BingSearchToolProviderDataValidator",
description="Bing Search tool for web search capabilities using Microsoft's search engine.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="tavily-search",
provider_type="remote::tavily-search",
module="llama_stack.providers.remote.tool_runtime.tavily_search",
config_class="llama_stack.providers.remote.tool_runtime.tavily_search.config.TavilySearchToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.tavily_search.TavilySearchToolProviderDataValidator",
description="Tavily Search tool for AI-optimized web search with structured results.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="wolfram-alpha",
provider_type="remote::wolfram-alpha",
module="llama_stack.providers.remote.tool_runtime.wolfram_alpha",
config_class="llama_stack.providers.remote.tool_runtime.wolfram_alpha.config.WolframAlphaToolConfig",
pip_packages=["requests"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.wolfram_alpha.WolframAlphaToolProviderDataValidator",
description="Wolfram Alpha tool for computational knowledge and mathematical calculations.",
),
),
remote_provider_spec(
RemoteProviderSpec(
api=Api.tool_runtime,
adapter=AdapterSpec(
adapter_type="model-context-protocol",
provider_type="remote::model-context-protocol",
module="llama_stack.providers.remote.tool_runtime.model_context_protocol",
config_class="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderConfig",
pip_packages=["mcp>=1.8.1"],
provider_data_validator="llama_stack.providers.remote.tool_runtime.model_context_protocol.config.MCPProviderDataValidator",
description="Model Context Protocol (MCP) tool for standardized tool calling and context management.",
),
),
]

View file

@ -6,11 +6,10 @@
from llama_stack.providers.datatypes import (
AdapterSpec,
Api,
InlineProviderSpec,
ProviderSpec,
remote_provider_spec,
RemoteProviderSpec,
)
@ -300,13 +299,15 @@ See [sqlite-vec's GitHub repo](https://github.com/asg017/sqlite-vec/tree/main) f
Please refer to the sqlite-vec provider documentation.
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="chromadb",
provider_type="remote::chromadb",
pip_packages=["chromadb-client"],
module="llama_stack.providers.remote.vector_io.chroma",
config_class="llama_stack.providers.remote.vector_io.chroma.ChromaVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Chroma](https://www.trychroma.com/) is an inline and remote vector
database provider for Llama Stack. It allows you to store and query vectors directly within a Chroma database.
@ -341,9 +342,6 @@ pip install chromadb
See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introduction) for more details about Chroma in general.
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::chromadb",
@ -387,13 +385,15 @@ See [Chroma's documentation](https://docs.trychroma.com/docs/overview/introducti
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="pgvector",
provider_type="remote::pgvector",
pip_packages=["psycopg2-binary"],
module="llama_stack.providers.remote.vector_io.pgvector",
config_class="llama_stack.providers.remote.vector_io.pgvector.PGVectorVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[PGVector](https://github.com/pgvector/pgvector) is a remote vector database provider for Llama Stack. It
allows you to store and query vectors directly in memory.
@ -496,17 +496,16 @@ docker pull pgvector/pgvector:pg17
See [PGVector's documentation](https://github.com/pgvector/pgvector) for more details about PGVector in general.
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="weaviate",
provider_type="remote::weaviate",
pip_packages=["weaviate-client"],
module="llama_stack.providers.remote.vector_io.weaviate",
config_class="llama_stack.providers.remote.vector_io.weaviate.WeaviateVectorIOConfig",
provider_data_validator="llama_stack.providers.remote.vector_io.weaviate.WeaviateRequestProviderData",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Weaviate](https://weaviate.io/) is a vector database provider for Llama Stack.
It allows you to store and query vectors directly within a Weaviate database.
@ -539,9 +538,6 @@ To install Weaviate see the [Weaviate quickstart documentation](https://weaviate
See [Weaviate's documentation](https://weaviate.io/developers/weaviate) for more details about Weaviate in general.
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::qdrant",
@ -594,27 +590,28 @@ docker pull qdrant/qdrant
See the [Qdrant documentation](https://qdrant.tech/documentation/) for more details about Qdrant in general.
""",
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="qdrant",
provider_type="remote::qdrant",
pip_packages=["qdrant-client"],
module="llama_stack.providers.remote.vector_io.qdrant",
config_class="llama_stack.providers.remote.vector_io.qdrant.QdrantVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
Please refer to the inline provider documentation.
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
remote_provider_spec(
Api.vector_io,
AdapterSpec(
RemoteProviderSpec(
api=Api.vector_io,
adapter_type="milvus",
provider_type="remote::milvus",
pip_packages=["pymilvus>=2.4.10"],
module="llama_stack.providers.remote.vector_io.milvus",
config_class="llama_stack.providers.remote.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
description="""
[Milvus](https://milvus.io/) is an inline and remote vector database provider for Llama Stack. It
allows you to store and query vectors directly within a Milvus database.
@ -636,7 +633,13 @@ To use Milvus in your Llama Stack project, follow these steps:
## Installation
You can install Milvus using pymilvus:
If you want to use inline Milvus, you can install:
```bash
pip install pymilvus[milvus-lite]
```
If you want to use remote Milvus, you can install:
```bash
pip install pymilvus
@ -807,13 +810,10 @@ See the [Milvus documentation](https://milvus.io/docs/install-overview.md) for m
For more details on TLS configuration, refer to the [TLS setup guide](https://milvus.io/docs/tls.md).
""",
),
api_dependencies=[Api.inference],
optional_api_dependencies=[Api.files],
),
InlineProviderSpec(
api=Api.vector_io,
provider_type="inline::milvus",
pip_packages=["pymilvus>=2.4.10"],
pip_packages=["pymilvus[milvus-lite]>=2.4.10"],
module="llama_stack.providers.inline.vector_io.milvus",
config_class="llama_stack.providers.inline.vector_io.milvus.MilvusVectorIOConfig",
api_dependencies=[Api.inference],

View file

@ -137,7 +137,7 @@ class S3FilesImpl(Files):
where: dict[str, str | dict] = {"id": file_id}
if not return_expired:
where["expires_at"] = {">": self._now()}
if not (row := await self.sql_store.fetch_one("openai_files", policy=self.policy, where=where)):
if not (row := await self.sql_store.fetch_one("openai_files", where=where)):
raise ResourceNotFoundError(file_id, "File", "files.list()")
return row
@ -164,7 +164,7 @@ class S3FilesImpl(Files):
self._client = _create_s3_client(self._config)
await _create_bucket_if_not_exists(self._client, self._config)
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store))
self._sql_store = AuthorizedSqlStore(sqlstore_impl(self._config.metadata_store), self.policy)
await self._sql_store.create_table(
"openai_files",
{
@ -268,7 +268,6 @@ class S3FilesImpl(Files):
paginated_result = await self.sql_store.fetch_all(
table="openai_files",
policy=self.policy,
where=where_conditions,
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import AnthropicConfig
class AnthropicProviderDataValidator(BaseModel):
anthropic_api_key: str | None = None
async def get_adapter_impl(config: AnthropicConfig, _deps):
from .anthropic import AnthropicInferenceAdapter

View file

@ -4,11 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from collections.abc import AsyncGenerator
from fireworks.client import Fireworks
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
InterleavedContent,
@ -24,12 +22,6 @@ from llama_stack.apis.inference import (
Inference,
LogProbConfig,
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
ResponseFormatType,
SamplingParams,
@ -45,15 +37,14 @@ from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
from llama_stack.providers.utils.inference.openai_compat import (
OpenAIChatCompletionToLlamaStackMixin,
convert_message_to_openai_dict,
get_sampling_options,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
@ -68,7 +59,7 @@ from .models import MODEL_ENTRIES
logger = get_logger(name=__name__, category="inference::fireworks")
class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProviderData):
class FireworksInferenceAdapter(OpenAIMixin, ModelRegistryHelper, Inference, NeedsRequestProviderData):
def __init__(self, config: FireworksImplConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES, config.allowed_models)
self.config = config
@ -79,7 +70,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
async def shutdown(self) -> None:
pass
def _get_api_key(self) -> str:
def get_api_key(self) -> str:
config_api_key = self.config.api_key.get_secret_value() if self.config.api_key else None
if config_api_key:
return config_api_key
@ -91,15 +82,18 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
)
return provider_data.fireworks_api_key
def _get_base_url(self) -> str:
def get_base_url(self) -> str:
return "https://api.fireworks.ai/inference/v1"
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
fireworks_api_key = self.get_api_key()
return Fireworks(api_key=fireworks_api_key)
def _get_openai_client(self) -> AsyncOpenAI:
return AsyncOpenAI(base_url=self._get_base_url(), api_key=self._get_api_key())
def _preprocess_prompt_for_fireworks(self, prompt: str) -> str:
"""Remove BOS token as Fireworks automatically prepends it"""
if prompt.startswith("<|begin_of_text|>"):
return prompt[len("<|begin_of_text|>") :]
return prompt
async def completion(
self,
@ -285,153 +279,3 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
raise NotImplementedError()
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
# Fireworks always prepends with BOS
if isinstance(prompt, str) and prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params)
async def openai_chat_completion(
self,
model: str,
messages: list[OpenAIMessageParam],
frequency_penalty: float | None = None,
function_call: str | dict[str, Any] | None = None,
functions: list[dict[str, Any]] | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_completion_tokens: int | None = None,
max_tokens: int | None = None,
n: int | None = None,
parallel_tool_calls: bool | None = None,
presence_penalty: float | None = None,
response_format: OpenAIResponseFormatParam | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
tool_choice: str | dict[str, Any] | None = None,
tools: list[dict[str, Any]] | None = None,
top_logprobs: int | None = None,
top_p: float | None = None,
user: str | None = None,
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
model_obj = await self.model_store.get_model(model)
# Divert Llama Models through Llama Stack inference APIs because
# Fireworks chat completions OpenAI-compatible API does not support
# tool calls properly.
llama_model = self.get_llama_model(model_obj.provider_resource_id)
if llama_model:
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
self,
model=model,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
params = await prepare_openai_completion_params(
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
logger.debug(f"fireworks params: {params}")
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import GeminiConfig
class GeminiProviderDataValidator(BaseModel):
gemini_api_key: str | None = None
async def get_adapter_impl(config: GeminiConfig, _deps):
from .gemini import GeminiInferenceAdapter

View file

@ -7,12 +7,10 @@
import asyncio
import base64
import uuid
from collections.abc import AsyncGenerator, AsyncIterator
from typing import Any
from ollama import AsyncClient # type: ignore[attr-defined]
from openai import AsyncOpenAI
from ollama import AsyncClient as AsyncOllamaClient
from llama_stack.apis.common.content_types import (
ImageContentItem,
@ -37,9 +35,6 @@ from llama_stack.apis.inference import (
Message,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIEmbeddingsResponse,
OpenAIEmbeddingUsage,
OpenAIMessageParam,
OpenAIResponseFormatParam,
ResponseFormat,
@ -64,15 +59,14 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
b64_encode_openai_embeddings_response,
get_sampling_options,
prepare_openai_completion_params,
prepare_openai_embeddings_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.openai_mixin import OpenAIMixin
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
@ -89,6 +83,7 @@ logger = get_logger(name=__name__, category="inference::ollama")
class OllamaInferenceAdapter(
OpenAIMixin,
InferenceProvider,
ModelsProtocolPrivate,
):
@ -98,23 +93,21 @@ class OllamaInferenceAdapter(
def __init__(self, config: OllamaImplConfig) -> None:
self.register_helper = ModelRegistryHelper(MODEL_ENTRIES)
self.config = config
self._clients: dict[asyncio.AbstractEventLoop, AsyncClient] = {}
self._openai_client = None
self._clients: dict[asyncio.AbstractEventLoop, AsyncOllamaClient] = {}
@property
def client(self) -> AsyncClient:
def ollama_client(self) -> AsyncOllamaClient:
# ollama client attaches itself to the current event loop (sadly?)
loop = asyncio.get_running_loop()
if loop not in self._clients:
self._clients[loop] = AsyncClient(host=self.config.url)
self._clients[loop] = AsyncOllamaClient(host=self.config.url)
return self._clients[loop]
@property
def openai_client(self) -> AsyncOpenAI:
if self._openai_client is None:
url = self.config.url.rstrip("/")
self._openai_client = AsyncOpenAI(base_url=f"{url}/v1", api_key="ollama")
return self._openai_client
def get_api_key(self):
return "NO_KEY"
def get_base_url(self):
return self.config.url.rstrip("/") + "/v1"
async def initialize(self) -> None:
logger.info(f"checking connectivity to Ollama at `{self.config.url}`...")
@ -129,7 +122,7 @@ class OllamaInferenceAdapter(
async def list_models(self) -> list[Model] | None:
provider_id = self.__provider_id__
response = await self.client.list()
response = await self.ollama_client.list()
# always add the two embedding models which can be pulled on demand
models = [
@ -189,7 +182,7 @@ class OllamaInferenceAdapter(
HealthResponse: A dictionary containing the health status.
"""
try:
await self.client.ps()
await self.ollama_client.ps()
return HealthResponse(status=HealthStatus.OK)
except Exception as e:
return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}")
@ -238,7 +231,7 @@ class OllamaInferenceAdapter(
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.generate(**params)
s = await self.ollama_client.generate(**params)
async for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=chunk["done_reason"] if chunk["done"] else None,
@ -254,7 +247,7 @@ class OllamaInferenceAdapter(
async def _nonstream_completion(self, request: CompletionRequest) -> CompletionResponse:
params = await self._get_params(request)
r = await self.client.generate(**params)
r = await self.ollama_client.generate(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r["done_reason"] if r["done"] else None,
@ -346,9 +339,9 @@ class OllamaInferenceAdapter(
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
if "messages" in params:
r = await self.client.chat(**params)
r = await self.ollama_client.chat(**params)
else:
r = await self.client.generate(**params)
r = await self.ollama_client.generate(**params)
if "message" in r:
choice = OpenAICompatCompletionChoice(
@ -372,9 +365,9 @@ class OllamaInferenceAdapter(
async def _generate_and_convert_to_openai_compat():
if "messages" in params:
s = await self.client.chat(**params)
s = await self.ollama_client.chat(**params)
else:
s = await self.client.generate(**params)
s = await self.ollama_client.generate(**params)
async for chunk in s:
if "message" in chunk:
choice = OpenAICompatCompletionChoice(
@ -407,7 +400,7 @@ class OllamaInferenceAdapter(
assert all(not content_has_media(content) for content in contents), (
"Ollama does not support media for embeddings"
)
response = await self.client.embed(
response = await self.ollama_client.embed(
model=model.provider_resource_id,
input=[interleaved_content_as_str(content) for content in contents],
)
@ -422,14 +415,14 @@ class OllamaInferenceAdapter(
pass # Ignore statically unknown model, will check live listing
if model.model_type == ModelType.embedding:
response = await self.client.list()
response = await self.ollama_client.list()
if model.provider_resource_id not in [m.model for m in response.models]:
await self.client.pull(model.provider_resource_id)
await self.ollama_client.pull(model.provider_resource_id)
# we use list() here instead of ps() -
# - ps() only lists running models, not available models
# - models not currently running are run by the ollama server as needed
response = await self.client.list()
response = await self.ollama_client.list()
available_models = [m.model for m in response.models]
provider_resource_id = model.provider_resource_id
@ -448,90 +441,6 @@ class OllamaInferenceAdapter(
return model
async def openai_embeddings(
self,
model: str,
input: str | list[str],
encoding_format: str | None = "float",
dimensions: int | None = None,
user: str | None = None,
) -> OpenAIEmbeddingsResponse:
model_obj = await self._get_model(model)
if model_obj.provider_resource_id is None:
raise ValueError(f"Model {model} has no provider_resource_id set")
# Note, at the moment Ollama does not support encoding_format, dimensions, and user parameters
params = prepare_openai_embeddings_params(
model=model_obj.provider_resource_id,
input=input,
encoding_format=encoding_format,
dimensions=dimensions,
user=user,
)
response = await self.openai_client.embeddings.create(**params)
data = b64_encode_openai_embeddings_response(response.data, encoding_format)
usage = OpenAIEmbeddingUsage(
prompt_tokens=response.usage.prompt_tokens,
total_tokens=response.usage.total_tokens,
)
# TODO: Investigate why model_obj.identifier is used instead of response.model
return OpenAIEmbeddingsResponse(
data=data,
model=model_obj.identifier,
usage=usage,
)
async def openai_completion(
self,
model: str,
prompt: str | list[str] | list[int] | list[list[int]],
best_of: int | None = None,
echo: bool | None = None,
frequency_penalty: float | None = None,
logit_bias: dict[str, float] | None = None,
logprobs: bool | None = None,
max_tokens: int | None = None,
n: int | None = None,
presence_penalty: float | None = None,
seed: int | None = None,
stop: str | list[str] | None = None,
stream: bool | None = None,
stream_options: dict[str, Any] | None = None,
temperature: float | None = None,
top_p: float | None = None,
user: str | None = None,
guided_choice: list[str] | None = None,
prompt_logprobs: int | None = None,
suffix: str | None = None,
) -> OpenAICompletion:
if not isinstance(prompt, str):
raise ValueError("Ollama does not support non-string prompts for completion")
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
suffix=suffix,
)
return await self.openai_client.completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
@ -599,25 +508,7 @@ class OllamaInferenceAdapter(
top_p=top_p,
user=user,
)
response = await self.openai_client.chat.completions.create(**params)
return await self._adjust_ollama_chat_completion_response_ids(response)
async def _adjust_ollama_chat_completion_response_ids(
self,
response: OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk],
) -> OpenAIChatCompletion | AsyncIterator[OpenAIChatCompletionChunk]:
id = f"chatcmpl-{uuid.uuid4()}"
if isinstance(response, AsyncIterator):
async def stream_with_chunk_ids() -> AsyncIterator[OpenAIChatCompletionChunk]:
async for chunk in response:
chunk.id = id
yield chunk
return stream_with_chunk_ids()
else:
response.id = id
return response
return await OpenAIMixin.openai_chat_completion(self, **params)
async def convert_message_to_openai_dict_for_ollama(message: Message) -> list[dict]:

View file

@ -4,15 +4,9 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pydantic import BaseModel
from .config import OpenAIConfig
class OpenAIProviderDataValidator(BaseModel):
openai_api_key: str | None = None
async def get_adapter_impl(config: OpenAIConfig, _deps):
from .openai import OpenAIInferenceAdapter

View file

@ -54,7 +54,7 @@ class InferenceStore:
async def initialize(self):
"""Create the necessary tables if they don't exist."""
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config))
self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.sql_store_config), self.policy)
await self.sql_store.create_table(
"chat_completions",
{
@ -202,7 +202,6 @@ class InferenceStore:
order_by=[("created", order.value)],
cursor=("id", after) if after else None,
limit=limit,
policy=self.policy,
)
data = [
@ -229,7 +228,6 @@ class InferenceStore:
row = await self.sql_store.fetch_one(
table="chat_completions",
where={"id": completion_id},
policy=self.policy,
)
if not row:

View file

@ -103,7 +103,7 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
Model(
identifier=id,
provider_resource_id=entry.provider_model_id,
model_type=ModelType.llm,
model_type=entry.model_type,
metadata=entry.metadata,
provider_id=self.__provider_id__,
)

View file

@ -28,8 +28,7 @@ class ResponsesStore:
sql_store_config = SqliteSqlStoreConfig(
db_path=(RUNTIME_BASE_DIR / "sqlstore.db").as_posix(),
)
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config))
self.policy = policy
self.sql_store = AuthorizedSqlStore(sqlstore_impl(sql_store_config), policy)
async def initialize(self):
"""Create the necessary tables if they don't exist."""
@ -87,7 +86,6 @@ class ResponsesStore:
order_by=[("created_at", order.value)],
cursor=("id", after) if after else None,
limit=limit,
policy=self.policy,
)
data = [OpenAIResponseObjectWithInput(**row["response_object"]) for row in paginated_result.data]
@ -105,7 +103,6 @@ class ResponsesStore:
row = await self.sql_store.fetch_one(
"openai_responses",
where={"id": response_id},
policy=self.policy,
)
if not row:
@ -116,7 +113,7 @@ class ResponsesStore:
return OpenAIResponseObjectWithInput(**row["response_object"])
async def delete_response_object(self, response_id: str) -> OpenAIDeleteResponseObject:
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id}, policy=self.policy)
row = await self.sql_store.fetch_one("openai_responses", where={"id": response_id})
if not row:
raise ValueError(f"Response with id {response_id} not found")
await self.sql_store.delete("openai_responses", where={"id": response_id})

View file

@ -53,13 +53,15 @@ class AuthorizedSqlStore:
access control policies, user attribute capture, and SQL filtering optimization.
"""
def __init__(self, sql_store: SqlStore):
def __init__(self, sql_store: SqlStore, policy: list[AccessRule]):
"""
Initialize the authorization layer.
:param sql_store: Base SqlStore implementation to wrap
:param policy: Access control policy to use for authorization
"""
self.sql_store = sql_store
self.policy = policy
self._detect_database_type()
self._validate_sql_optimized_policy()
@ -117,14 +119,13 @@ class AuthorizedSqlStore:
async def fetch_all(
self,
table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None,
limit: int | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
cursor: tuple[str, str] | None = None,
) -> PaginatedResponse:
"""Fetch all rows with automatic access control filtering."""
access_where = self._build_access_control_where_clause(policy)
access_where = self._build_access_control_where_clause(self.policy)
rows = await self.sql_store.fetch_all(
table=table,
where=where,
@ -146,7 +147,7 @@ class AuthorizedSqlStore:
str(record_id), table, User(principal=stored_owner_principal, attributes=stored_access_attrs)
)
if is_action_allowed(policy, Action.READ, sql_record, current_user):
if is_action_allowed(self.policy, Action.READ, sql_record, current_user):
filtered_rows.append(row)
return PaginatedResponse(
@ -157,14 +158,12 @@ class AuthorizedSqlStore:
async def fetch_one(
self,
table: str,
policy: list[AccessRule],
where: Mapping[str, Any] | None = None,
order_by: list[tuple[str, Literal["asc", "desc"]]] | None = None,
) -> dict[str, Any] | None:
"""Fetch one row with automatic access control checking."""
results = await self.fetch_all(
table=table,
policy=policy,
where=where,
limit=1,
order_by=order_by,

View file

@ -203,6 +203,11 @@ def _model_identifiers_digest(endpoint: str, response: dict[str, Any]) -> str:
- '/v1/models' (OpenAI): response body is: [ { id: ... }, ... ]
Returns a list of unique identifiers or None if structure doesn't match.
"""
if "models" in response["body"]:
# ollama
items = response["body"]["models"]
else:
# openai
items = response["body"]
idents = [m.model if endpoint == "/api/tags" else m.id for m in items]
return sorted(set(idents))

File diff suppressed because it is too large Load diff

View file

@ -14,7 +14,7 @@
},
"dependencies": {
"@radix-ui/react-collapsible": "^1.1.12",
"@radix-ui/react-dialog": "^1.1.13",
"@radix-ui/react-dialog": "^1.1.15",
"@radix-ui/react-dropdown-menu": "^2.1.16",
"@radix-ui/react-select": "^2.2.6",
"@radix-ui/react-separator": "^1.1.7",
@ -32,7 +32,7 @@
"react-dom": "^19.1.1",
"react-markdown": "^10.1.0",
"remark-gfm": "^4.0.1",
"remeda": "^2.30.0",
"remeda": "^2.32.0",
"shiki": "^1.29.2",
"sonner": "^2.0.7",
"tailwind-merge": "^3.3.1"
@ -52,7 +52,7 @@
"eslint-config-prettier": "^10.1.8",
"eslint-plugin-prettier": "^5.5.4",
"jest": "^29.7.0",
"jest-environment-jsdom": "^29.7.0",
"jest-environment-jsdom": "^30.1.2",
"prettier": "3.6.2",
"tailwindcss": "^4",
"ts-node": "^10.9.2",

View file

@ -1,4 +1,3 @@
adapter:
adapter_type: kaze
pip_packages: ["tests/external/llama-stack-provider-kaze"]
config_class: llama_stack_provider_kaze.config.KazeProviderConfig

View file

@ -6,7 +6,7 @@
from typing import Protocol
from llama_stack.providers.datatypes import AdapterSpec, Api, ProviderSpec, RemoteProviderSpec
from llama_stack.providers.datatypes import Api, ProviderSpec, RemoteProviderSpec
from llama_stack.schema_utils import webmethod
@ -16,12 +16,9 @@ def available_providers() -> list[ProviderSpec]:
api=Api.weather,
provider_type="remote::kaze",
config_class="llama_stack_provider_kaze.KazeProviderConfig",
adapter=AdapterSpec(
adapter_type="kaze",
module="llama_stack_provider_kaze",
pip_packages=["llama_stack_provider_kaze"],
config_class="llama_stack_provider_kaze.KazeProviderConfig",
),
),
]

View file

@ -33,6 +33,7 @@ def skip_if_model_doesnt_support_user_param(client, model_id):
provider = provider_from_model(client, model_id)
if provider.provider_type in (
"remote::together", # service returns 400
"remote::fireworks", # service returns 400 malformed input
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support user param.")
@ -41,6 +42,7 @@ def skip_if_model_doesnt_support_encoding_format_base64(client, model_id):
provider = provider_from_model(client, model_id)
if provider.provider_type in (
"remote::together", # param silently ignored, always returns floats
"remote::fireworks", # param silently ignored, always returns list of floats
):
pytest.skip(f"Model {model_id} hosted by {provider.provider_type} does not support encoding_format='base64'.")
@ -287,7 +289,6 @@ def test_openai_embeddings_base64_batch_processing(compat_client, client_with_mo
input=input_texts,
encoding_format="base64",
)
# Validate response structure
assert response.object == "list"
assert response.model == embedding_model_id

View file

@ -57,7 +57,7 @@ def authorized_store(backend_config):
config = config_func()
base_sqlstore = sqlstore_impl(config)
authorized_store = AuthorizedSqlStore(base_sqlstore)
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
yield authorized_store
@ -106,7 +106,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "1", "data": "public_data"})
# Test fetching with no user - should not error on JSON comparison
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
assert result.data[0]["access_attributes"] is None
@ -119,7 +119,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
await authorized_store.insert(table_name, {"id": "2", "data": "admin_data"})
# Fetch all - admin should see both
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 2
# Test with non-admin user
@ -127,7 +127,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
mock_get_authenticated_user.return_value = regular_user
# Should only see public record
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
assert len(result.data) == 1
assert result.data[0]["id"] == "1"
@ -156,7 +156,7 @@ async def test_authorized_store_attributes(mock_get_authenticated_user, authoriz
# Now test with the multi-user who has both roles=admin and teams=dev
mock_get_authenticated_user.return_value = multi_user
result = await authorized_store.fetch_all(table_name, policy=default_policy())
result = await authorized_store.fetch_all(table_name)
# Should see:
# - public record (1) - no access_attributes
@ -217,21 +217,24 @@ async def test_user_ownership_policy(mock_get_authenticated_user, authorized_sto
),
]
# Create a new authorized store with the owner-only policy
owner_only_store = AuthorizedSqlStore(authorized_store.sql_store, owner_only_policy)
# Test user1 access - should only see their own record
mock_get_authenticated_user.return_value = user1
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user1 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "1", f"Expected user1's record, got {result.data[0]['id']}"
# Test user2 access - should only see their own record
mock_get_authenticated_user.return_value = user2
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 1, f"Expected user2 to see 1 record, got {len(result.data)}"
assert result.data[0]["id"] == "2", f"Expected user2's record, got {result.data[0]['id']}"
# Test with anonymous user - should see no records
mock_get_authenticated_user.return_value = None
result = await authorized_store.fetch_all(table_name, policy=owner_only_policy)
result = await owner_only_store.fetch_all(table_name)
assert len(result.data) == 0, f"Expected anonymous user to see 0 records, got {len(result.data)}"
finally:

View file

@ -0,0 +1,990 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"prompt": "Respond to this question and explain your answer. Complete the sentence using one word: Roses are red, violets are ",
"max_tokens": 50,
"stream": true,
"extra_body": {}
},
"endpoint": "/v1/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": [
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "Blue"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ".\n\n"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "The"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " completed"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " sentence"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " is"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " a"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " well"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "-known"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " phrase"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " from"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " a"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " traditional"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " English"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " poem"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ":\n\n"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "\""
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "R"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "oses"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " are"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " red"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ","
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " v"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "io"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "lets"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " are"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " blue"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ",\n"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "Sugar"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " is"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " sweet"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ","
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " and"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " so"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " are"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " you"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ".\""
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " However"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ","
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " in"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " many"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " variations"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " of"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " this"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " poem"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": ","
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " the"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " line"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": " \""
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": null,
"index": 0,
"logprobs": null,
"text": "vio"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
},
{
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-439",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": ""
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": null
}
}
],
"is_streaming": true
}
}

View file

@ -0,0 +1,43 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"prompt": "Respond to this question and explain your answer. Complete the sentence using one word: Roses are red, violets are ",
"stream": false,
"extra_body": {}
},
"endpoint": "/v1/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": {
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-104",
"choices": [
{
"finish_reason": "stop",
"index": 0,
"logprobs": null,
"text": "blue.\n\nI completed the sentence with \"blue\" because it is a common completion used to complete the traditional nursery rhyme, which ends with:\n\nRoses are red,\nViolets are blue.\n\nThe complete rhyme is often remembered and recited as follows:\n\nRoses are red,\nViolets are blue,\nSugar is sweet,\nAnd so are you!"
}
],
"created": 1757857132,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 72,
"prompt_tokens": 50,
"total_tokens": 122,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -0,0 +1,43 @@
{
"request": {
"method": "POST",
"url": "http://0.0.0.0:11434/v1/v1/completions",
"headers": {},
"body": {
"model": "llama3.2:3b-instruct-fp16",
"prompt": "Say completions",
"max_tokens": 20,
"extra_body": {}
},
"endpoint": "/v1/completions",
"model": "llama3.2:3b-instruct-fp16"
},
"response": {
"body": {
"__type__": "openai.types.completion.Completion",
"__data__": {
"id": "cmpl-406",
"choices": [
{
"finish_reason": "length",
"index": 0,
"logprobs": null,
"text": "Sure, I'd be happy to provide some definitions and examples of related words or phrases.\n\nTo better"
}
],
"created": 1757857133,
"model": "llama3.2:3b-instruct-fp16",
"object": "text_completion",
"system_fingerprint": "fp_ollama",
"usage": {
"completion_tokens": 20,
"prompt_tokens": 28,
"total_tokens": 48,
"completion_tokens_details": null,
"prompt_tokens_details": null
}
}
},
"is_streaming": false
}
}

View file

@ -115,6 +115,15 @@ SETUP_DEFINITIONS: dict[str, Setup] = {
"text_model": "cerebras/llama-3.3-70b",
},
),
"fireworks": Setup(
name="fireworks",
description="Fireworks provider with a text model",
defaults={
"text_model": "accounts/fireworks/models/llama-v3p1-8b-instruct",
"vision_model": "accounts/fireworks/models/llama-v3p2-90b-vision-instruct",
"embedding_model": "nomic-ai/nomic-embed-text-v1.5",
},
),
}

View file

@ -66,7 +66,6 @@ def base_config(tmp_path):
def provider_spec_yaml():
"""Common provider spec YAML for testing."""
return """
adapter:
adapter_type: test_provider
config_class: test_provider.config.TestProviderConfig
module: test_provider
@ -182,9 +181,9 @@ class TestProviderRegistry:
assert Api.inference in registry
assert "remote::test_provider" in registry[Api.inference]
provider = registry[Api.inference]["remote::test_provider"]
assert provider.adapter.adapter_type == "test_provider"
assert provider.adapter.module == "test_provider"
assert provider.adapter.config_class == "test_provider.config.TestProviderConfig"
assert provider.adapter_type == "test_provider"
assert provider.module == "test_provider"
assert provider.config_class == "test_provider.config.TestProviderConfig"
assert Api.safety in provider.api_dependencies
def test_external_inline_providers(self, api_directories, mock_providers, base_config, inline_provider_spec_yaml):
@ -246,7 +245,6 @@ class TestProviderRegistry:
"""Test handling of malformed remote provider spec (missing required fields)."""
remote_dir, _ = api_directories
malformed_spec = """
adapter:
adapter_type: test_provider
# Missing required fields
api_dependencies:
@ -270,7 +268,7 @@ pip_packages:
with open(inline_dir / "malformed.yaml", "w") as f:
f.write(malformed_spec)
with pytest.raises(KeyError) as exc_info:
with pytest.raises(ValidationError) as exc_info:
get_provider_registry(base_config)
assert "config_class" in str(exc_info.value)

View file

@ -27,13 +27,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
@ -46,13 +50,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -68,13 +76,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = LlamaStackAsLibraryClient("ci-tests")
@ -90,13 +102,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
client = AsyncLlamaStackAsLibraryClient("ci-tests")
@ -112,13 +128,17 @@ class TestLlamaStackAsLibraryClientAutoInitialization:
mock_impls = {}
mock_route_impls = RouteImpls({})
async def mock_construct_stack(config, custom_provider_registry):
return mock_impls
class MockStack:
def __init__(self, config, custom_provider_registry=None):
self.impls = mock_impls
async def initialize(self):
pass
def mock_initialize_route_impls(impls):
return mock_route_impls
monkeypatch.setattr("llama_stack.core.library_client.construct_stack", mock_construct_stack)
monkeypatch.setattr("llama_stack.core.library_client.Stack", MockStack)
monkeypatch.setattr("llama_stack.core.library_client.initialize_route_impls", mock_initialize_route_impls)
sync_client = LlamaStackAsLibraryClient("ci-tests")

View file

@ -26,7 +26,7 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
# Create table with access control
await sqlstore.create_table(
@ -56,24 +56,24 @@ async def test_authorized_fetch_with_where_sql_access_control(mock_get_authentic
mock_get_authenticated_user.return_value = admin_user
# Admin should see both documents
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 1
assert result.data[0]["title"] == "Admin Document"
# User should only see their document
mock_get_authenticated_user.return_value = regular_user
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 1})
result = await sqlstore.fetch_all("documents", where={"id": 1})
assert len(result.data) == 0
result = await sqlstore.fetch_all("documents", policy=default_policy(), where={"id": 2})
result = await sqlstore.fetch_all("documents", where={"id": 2})
assert len(result.data) == 1
assert result.data[0]["title"] == "User Document"
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 1})
row = await sqlstore.fetch_one("documents", where={"id": 1})
assert row is None
row = await sqlstore.fetch_one("documents", policy=default_policy(), where={"id": 2})
row = await sqlstore.fetch_one("documents", where={"id": 2})
assert row is not None
assert row["title"] == "User Document"
@ -88,7 +88,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
db_path=tmp_dir + "/" + db_name,
)
)
sqlstore = AuthorizedSqlStore(base_sqlstore)
sqlstore = AuthorizedSqlStore(base_sqlstore, default_policy())
await sqlstore.create_table(
table="resources",
@ -144,7 +144,7 @@ async def test_sql_policy_consistency(mock_get_authenticated_user):
user = User(principal=user_data["principal"], attributes=user_data["attributes"])
mock_get_authenticated_user.return_value = user
sql_results = await sqlstore.fetch_all("resources", policy=policy)
sql_results = await sqlstore.fetch_all("resources")
sql_ids = {row["id"] for row in sql_results.data}
policy_ids = set()
for scenario in test_scenarios:
@ -174,7 +174,7 @@ async def test_authorized_store_user_attribute_capture(mock_get_authenticated_us
db_path=tmp_dir + "/" + db_name,
)
)
authorized_store = AuthorizedSqlStore(base_sqlstore)
authorized_store = AuthorizedSqlStore(base_sqlstore, default_policy())
await authorized_store.create_table(
table="user_data",