Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
jhpiedrahitao 2025-04-01 07:57:21 -05:00
commit 9c9f9577e2
173 changed files with 3073 additions and 3118 deletions

View file

@ -15,6 +15,7 @@ class JobStatus(Enum):
in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"
cancelled = "cancelled"
@json_schema_type

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PaginatedResponse(BaseModel):
"""A generic paginated response that follows a simple format.
:param data: The list of items for the current page
:param has_more: Whether there are more items available after this set
"""
data: List[Dict[str, Any]]
has_more: bool

View file

@ -6,23 +6,9 @@
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from pydantic import BaseModel
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class IterrowsResponse(BaseModel):
"""
A paginated list of rows from a dataset.
:param data: The rows in the current page.
:param next_start_index: Index into dataset for the first row in the next page. None if there are no more rows.
"""
data: List[Dict[str, Any]]
next_start_index: Optional[int] = None
from llama_stack.schema_utils import webmethod
class DatasetStore(Protocol):
@ -34,15 +20,22 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows(
self,
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
"""Get a paginated list of rows from a dataset. Uses cursor-based pagination.
) -> PaginatedResponse:
"""Get a paginated list of rows from a dataset.
Uses offset-based pagination where:
- start_index: The starting index (0-based). If None, starts from beginning.
- limit: Number of items to return. If None or -1, returns all items.
The response includes:
- data: List of items for the current page
- has_more: Whether there are more items available after this set
:param dataset_id: The ID of the dataset to get the rows from.
:param start_index: Index into dataset for the first row to get. Get all rows if None.

View file

@ -34,6 +34,7 @@ class Api(Enum):
scoring_functions = "scoring_functions"
benchmarks = "benchmarks"
tool_groups = "tool_groups"
files = "files"
# built-in API
inspect = "inspect"

View file

@ -164,7 +164,7 @@ class Files(Protocol):
self,
bucket: str,
key: str,
) -> FileResponse:
) -> None:
"""
Delete a file identified by a bucket and key.

View file

@ -88,6 +88,10 @@ class ListToolsResponse(BaseModel):
data: List[Tool]
class ListToolDefsResponse(BaseModel):
data: list[ToolDef]
@runtime_checkable
@trace_protocol
class ToolGroups(Protocol):
@ -148,7 +152,7 @@ class ToolRuntime(Protocol):
@webmethod(route="/tool-runtime/list-tools", method="GET")
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]: ...
) -> ListToolDefsResponse: ...
@webmethod(route="/tool-runtime/invoke", method="POST")
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:

View file

@ -21,6 +21,7 @@ from prompt_toolkit.completion import WordCompleter
from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.table import print_table
from llama_stack.distribution.build import (
SERVER_DEPENDENCIES,
@ -62,10 +63,10 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
if args.list_templates:
return _run_template_list_cmd()
if args.image_type == "venv":
if args.image_type == ImageType.VENV.value:
current_venv = os.environ.get("VIRTUAL_ENV")
image_name = args.image_name or current_venv
elif args.image_type == "conda":
elif args.image_type == ImageType.CONDA.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
else:
@ -84,7 +85,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
build_config.image_type = args.image_type
else:
cprint(
f"Please specify a image-type (container | conda | venv) for {args.template}",
f"Please specify a image-type ({' | '.join(e.value for e in ImageType)}) for {args.template}",
color="red",
)
sys.exit(1)
@ -98,15 +99,15 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
image_type = prompt(
"> Enter the image type you want your Llama Stack to be built as (container or conda or venv): ",
f"> Enter the image type you want your Llama Stack to be built as ({' or '.join(e.value for e in ImageType)}): ",
validator=Validator.from_callable(
lambda x: x in ["container", "conda", "venv"],
error_message="Invalid image type, please enter conda or container or venv",
lambda x: x in [e.value for e in ImageType],
error_message=f"Invalid image type, please enter {' or '.join(e.value for e in ImageType)}",
),
default="conda",
default=ImageType.CONDA.value,
)
if image_type == "conda":
if image_type == ImageType.CONDA.value:
if not image_name:
cprint(
f"No current conda environment detected or specified, will create a new conda environment with the name `llamastack-{name}`",
@ -136,6 +137,8 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
providers = dict()
for api, providers_for_api in get_provider_registry().items():
available_providers = [x for x in providers_for_api.keys() if x not in ("remote", "remote::sample")]
if not available_providers:
continue
api_provider = prompt(
"> Enter provider for API {}: ".format(api.value),
completer=WordCompleter(available_providers),

View file

@ -6,6 +6,7 @@
import argparse
import textwrap
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
@ -46,16 +47,16 @@ class StackBuild(Subcommand):
self.parser.add_argument(
"--image-type",
type=str,
help="Image Type to use for the build. This can be either conda or container or venv. If not specified, will use the image type from the template config.",
choices=["conda", "container", "venv"],
default="conda",
help="Image Type to use for the build. If not specified, will use the image type from the template config.",
choices=[e.value for e in ImageType],
default=ImageType.CONDA.value,
)
self.parser.add_argument(
"--image-name",
type=str,
help=textwrap.dedent(
"""[for image-type=conda|venv] Name of the conda or virtual environment to use for
f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the conda or virtual environment to use for
the build. If not specified, currently active Conda environment will be used if found.
"""
),

View file

@ -8,6 +8,7 @@ import argparse
import os
from pathlib import Path
from llama_stack.cli.stack.utils import ImageType
from llama_stack.cli.subcommand import Subcommand
from llama_stack.log import get_logger
@ -56,7 +57,6 @@ class StackRun(Subcommand):
"--env",
action="append",
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
default=[],
metavar="KEY=VALUE",
)
self.parser.add_argument(
@ -73,10 +73,24 @@ class StackRun(Subcommand):
"--image-type",
type=str,
help="Image Type used during the build. This can be either conda or container or venv.",
choices=["conda", "container", "venv"],
default="conda",
choices=[e.value for e in ImageType],
)
# If neither image type nor image name is provided, but at the same time
# the current environment has conda breadcrumbs, then assume what the user
# wants to use conda mode and not the usual default mode (using
# pre-installed system packages).
#
# Note: yes, this is hacky. It's implemented this way to keep the existing
# conda users unaffected by the switch of the default behavior to using
# system packages.
def _get_image_type_and_name(self, args: argparse.Namespace) -> tuple[str, str]:
conda_env = os.environ.get("CONDA_DEFAULT_ENV")
if conda_env and args.image_name == conda_env:
logger.warning(f"Conda detected. Using conda environment {conda_env} for the run.")
return ImageType.CONDA.value, args.image_name
return args.image_type, args.image_name
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import yaml
@ -120,20 +134,44 @@ class StackRun(Subcommand):
except AttributeError as e:
self.parser.error(f"failed to parse config file '{config_file}':\n {e}")
run_args = formulate_run_args(args.image_type, args.image_name, config, template_name)
image_type, image_name = self._get_image_type_and_name(args)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
# If neither image type nor image name is provided, assume the server should be run directly
# using the current environment packages.
if not image_type and not image_name:
logger.info("No image type or image name provided. Assuming environment packages.")
from llama_stack.distribution.server.server import main as server_main
for env_var in args.env:
if "=" not in env_var:
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
key, value = env_var.split("=", 1) # split on first = only
if not key:
self.parser.error(f"Environment variable '{env_var}' has empty key")
run_args.extend(["--env", f"{key}={value}"])
# Build the server args from the current args passed to the CLI
server_args = argparse.Namespace()
for arg in vars(args):
# If this is a function, avoid passing it
# "args" contains:
# func=<bound method StackRun._run_stack_run_cmd of <llama_stack.cli.stack.run.StackRun object at 0x10484b010>>
if callable(getattr(args, arg)):
continue
setattr(server_args, arg, getattr(args, arg))
if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_command(run_args)
# Run the server
server_main(server_args)
else:
run_args = formulate_run_args(image_type, image_name, config, template_name)
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:
run_args.append("--disable-ipv6")
if args.env:
for env_var in args.env:
if "=" not in env_var:
self.parser.error(f"Environment variable '{env_var}' must be in KEY=VALUE format")
return
key, value = env_var.split("=", 1) # split on first = only
if not key:
self.parser.error(f"Environment variable '{env_var}' has empty key")
return
run_args.extend(["--env", f"{key}={value}"])
if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_command(run_args)

View file

@ -4,6 +4,14 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
class ImageType(Enum):
CONDA = "conda"
CONTAINER = "container"
VENV = "venv"
def print_subcommand_description(parser, subparsers):
"""Print descriptions of subcommands."""

View file

@ -328,8 +328,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
body = self._convert_body(path, options.method, body)
await start_trace(route, {"__location__": "library_client"})
async def gen():
await start_trace(route, {"__location__": "library_client"})
try:
async for chunk in await func(**body):
data = json.dumps(convert_pydantic_to_json_value(chunk))

View file

@ -12,6 +12,7 @@ from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.files import Files
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
@ -79,6 +80,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
Api.files: Files,
}

View file

@ -12,7 +12,8 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
from llama_stack.apis.inference import (
@ -45,11 +46,11 @@ from llama_stack.apis.scoring import (
from llama_stack.apis.shields import Shield
from llama_stack.apis.telemetry import MetricEvent, MetricInResponse, Telemetry
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
RAGToolRuntime,
ToolDef,
ToolRuntime,
)
from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
@ -497,7 +498,7 @@ class DatasetIORouter(DatasetIO):
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
) -> PaginatedResponse:
logger.debug(
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
@ -706,6 +707,6 @@ class ToolRuntimeRouter(ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
) -> ListToolDefsResponse:
logger.debug(f"ToolRuntimeRouter.list_runtime_tools: {tool_group_id}")
return await self.routing_table.get_provider_impl(tool_group_id).list_tools(tool_group_id, mcp_endpoint)

View file

@ -568,7 +568,7 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
tool_defs = await self.impls_by_provider_id[provider_id].list_runtime_tools(toolgroup_id, mcp_endpoint)
tool_host = ToolHost.model_context_protocol if mcp_endpoint else ToolHost.distribution
for tool_def in tool_defs:
for tool_def in tool_defs.data:
tools.append(
ToolWithACL(
identifier=tool_def.name,

View file

@ -15,7 +15,7 @@ import warnings
from contextlib import asynccontextmanager
from importlib.metadata import version as parse_version
from pathlib import Path
from typing import Any, List, Union
from typing import Any, List, Optional, Union
import yaml
from fastapi import Body, FastAPI, HTTPException, Request
@ -294,11 +294,17 @@ class ClientVersionMiddleware:
return await self.app(scope, receive, send)
def main():
def main(args: Optional[argparse.Namespace] = None):
"""Start the LlamaStack server."""
parser = argparse.ArgumentParser(description="Start the LlamaStack server.")
parser.add_argument(
"--yaml-config",
dest="config",
help="(Deprecated) Path to YAML configuration file - use --config instead",
)
parser.add_argument(
"--config",
dest="config",
help="Path to YAML configuration file",
)
parser.add_argument(
@ -328,12 +334,24 @@ def main():
required="--tls-keyfile" in sys.argv,
)
args = parser.parse_args()
# Determine whether the server args are being passed by the "run" command, if this is the case
# the args will be passed as a Namespace object to the main function, otherwise they will be
# parsed from the command line
if args is None:
args = parser.parse_args()
# Check for deprecated argument usage
if "--yaml-config" in sys.argv:
warnings.warn(
"The '--yaml-config' argument is deprecated and will be removed in a future version. Use '--config' instead.",
DeprecationWarning,
stacklevel=2,
)
log_line = ""
if args.yaml_config:
if args.config:
# if the user provided a config file, use it, even if template was specified
config_file = Path(args.yaml_config)
config_file = Path(args.config)
if not config_file.exists():
raise ValueError(f"Config file {config_file} does not exist")
log_line = f"Using config file: {config_file}"

View file

@ -13,6 +13,7 @@ LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
PYPI_VERSION=${PYPI_VERSION:-}
VIRTUAL_ENV=${VIRTUAL_ENV:-}
set -euo pipefail
@ -69,22 +70,25 @@ while [[ $# -gt 0 ]]; do
;;
esac
done
PYTHON_BINARY="python"
case "$env_type" in
"venv")
# Activate virtual environment
if [ ! -d "$env_path_or_name" ]; then
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
exit 1
fi
if [ -n "$VIRTUAL_ENV" && "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
else
# Activate virtual environment
if [ ! -d "$env_path_or_name" ]; then
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
exit 1
fi
if [ ! -f "$env_path_or_name/bin/activate" ]; then
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
exit 1
fi
if [ ! -f "$env_path_or_name/bin/activate" ]; then
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
exit 1
fi
source "$env_path_or_name/bin/activate"
source "$env_path_or_name/bin/activate"
fi
;;
"conda")
if ! is_command_available conda; then

View file

@ -58,6 +58,7 @@ def rag_chat_page():
llama_stack_api.client.tool_runtime.rag_tool.insert(
vector_db_id=vector_db_name, # Use the user-provided name
documents=documents,
chunk_size_in_tokens=512,
)
st.success("Vector database created successfully!")

View file

@ -18,15 +18,19 @@ def preserve_contexts_async_generator(
This is needed because we start a new asyncio event loop for each streaming request,
and we need to preserve the context across the event loop boundary.
"""
# Capture initial context values
initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
async def wrapper() -> AsyncGenerator[T, None]:
while True:
try:
item = await gen.__anext__()
context_values = {context_var.name: context_var.get() for context_var in context_vars}
yield item
# Restore context values before any await
for context_var in context_vars:
_ = context_var.set(context_values[context_var.name])
context_var.set(initial_context_values[context_var.name])
item = await gen.__anext__()
yield item
except StopAsyncIteration:
break

View file

@ -139,7 +139,7 @@ def setup_logging(category_levels: Dict[str, int], log_file: str | None) -> None
category_levels (Dict[str, int]): A dictionary mapping categories to their log levels.
log_file (str): Path to a log file to additionally pipe the logs into
"""
log_format = "[dim]%(asctime)s %(name)s:%(lineno)d[/] [yellow dim]%(category)s[/]: %(message)s"
log_format = "%(asctime)s %(name)s:%(lineno)d %(category)s: %(message)s"
class CategoryFilter(logging.Filter):
"""Ensure category is always present in log records."""

View file

@ -195,10 +195,22 @@ register_schema(SamplingStrategy, name="SamplingStrategy")
@json_schema_type
class SamplingParams(BaseModel):
"""Sampling parameters.
:param strategy: The sampling strategy.
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
your prompt plus max_tokens cannot exceed the model's context length.
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
:param stop: Up to 4 sequences where the API will stop generating further tokens.
The returned text will not contain the stop sequence.
"""
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
stop: Optional[List[str]] = None
class CheckpointQuantizationFormat(Enum):

View file

@ -57,11 +57,7 @@ from llama_stack.apis.inference import (
UserMessage,
)
from llama_stack.apis.safety import Safety
from llama_stack.apis.tools import (
ToolGroups,
ToolInvocationResult,
ToolRuntime,
)
from llama_stack.apis.tools import ToolGroups, ToolInvocationResult, ToolRuntime
from llama_stack.apis.vector_io import VectorIO
from llama_stack.log import get_logger
from llama_stack.models.llama.datatypes import (
@ -459,7 +455,15 @@ class ChatAgent(ShieldRunnerMixin):
contexts.append(raw_document_text)
attached_context = "\n".join(contexts)
input_messages[-1].context = attached_context
if isinstance(input_messages[-1].content, str):
input_messages[-1].content += attached_context
elif isinstance(input_messages[-1].content, list):
input_messages[-1].content.append(TextContentItem(text=attached_context))
else:
input_messages[-1].content = [
input_messages[-1].content,
TextContentItem(text=attached_context),
]
session_info = await self.storage.get_session_info(session_id)
# if the session has a memory bank id, let the memory tool use it

View file

@ -7,9 +7,11 @@ from typing import Any, Dict, List, Optional
import pandas
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
@ -92,24 +94,13 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
) -> PaginatedResponse:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
await dataset_impl.load()
start_index = start_index or 0
if limit is None or limit == -1:
end = len(dataset_impl)
else:
end = min(start_index + limit, len(dataset_impl))
rows = dataset_impl[start_index:end]
return IterrowsResponse(
data=rows,
next_start_index=end if end < len(dataset_impl) else None,
)
records = dataset_impl.df.to_dict("records")
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]

View file

@ -28,6 +28,11 @@ class TelemetryConfig(BaseModel):
default="http://localhost:4318/v1/metrics",
description="The OpenTelemetry collector endpoint URL for metrics",
)
service_name: str = Field(
# service name is always the same, use zero-width space to avoid clutter
default="",
description="The service name to use for telemetry",
)
sinks: List[TelemetrySink] = Field(
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
@ -47,6 +52,7 @@ class TelemetryConfig(BaseModel):
@classmethod
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
return {
"service_name": "${env.OTEL_SERVICE_NAME:}",
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
}

View file

@ -67,8 +67,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
resource = Resource.create(
{
# service name is always the same, use zero-width space to avoid clutter
ResourceAttributes.SERVICE_NAME: "",
ResourceAttributes.SERVICE_NAME: self.config.service_name,
}
)
@ -204,16 +203,6 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
parent_span = _GLOBAL_STORAGE["active_spans"].get(parent_span_id)
context = trace.set_span_in_context(parent_span)
else:
context = trace.set_span_in_context(
trace.NonRecordingSpan(
trace.SpanContext(
trace_id=int(event.trace_id, 16),
span_id=span_id,
is_remote=False,
trace_flags=trace.TraceFlags(trace.TraceFlags.SAMPLED),
)
)
)
event.attributes["__root_span__"] = "true"
span = tracer.start_span(

View file

@ -9,10 +9,11 @@ import asyncio
import logging
import os
import tempfile
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -46,20 +47,22 @@ class CodeInterpreterToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="code_interpreter",
description="Execute code",
parameters=[
ToolParameter(
name="code",
description="The code to execute",
parameter_type="string",
),
],
)
]
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="code_interpreter",
description="Execute code",
parameters=[
ToolParameter(
name="code",
description="The code to execute",
parameter_type="string",
),
],
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
script = kwargs["code"]

View file

@ -20,6 +20,7 @@ from llama_stack.apis.common.content_types import (
)
from llama_stack.apis.inference import Inference
from llama_stack.apis.tools import (
ListToolDefsResponse,
RAGDocument,
RAGQueryConfig,
RAGQueryResult,
@ -162,27 +163,29 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
) -> ListToolDefsResponse:
# Parameters are not listed since these methods are not yet invoked automatically
# by the LLM. The method is only implemented so things like /tools can list without
# encountering fatals.
return [
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
),
]
return ListToolDefsResponse(
data=[
ToolDef(
name="insert_into_memory",
description="Insert documents into memory",
),
ToolDef(
name="knowledge_search",
description="Search for information in a database.",
parameters=[
ToolParameter(
name="query",
description="The query to search for. Can be a natural language sentence or keywords.",
parameter_type="string",
),
],
),
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
vector_db_ids = kwargs.get("vector_db_ids", [])

View file

@ -0,0 +1,11 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.datatypes import ProviderSpec
def available_providers() -> list[ProviderSpec]:
return []

View file

@ -6,7 +6,7 @@
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
def available_providers() -> List[ProviderSpec]:
@ -22,4 +22,13 @@ def available_providers() -> List[ProviderSpec]:
Api.datasets,
],
),
remote_provider_spec(
api=Api.post_training,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests", "aiohttp"],
module="llama_stack.providers.remote.post_training.nvidia",
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
),
),
]

View file

@ -8,9 +8,11 @@ from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.common.responses import PaginatedResponse
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.pagination import paginate_records
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig
@ -70,24 +72,13 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_id: str,
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
) -> PaginatedResponse:
dataset_def = self.dataset_infos[dataset_id]
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
start_index = start_index or 0
if limit is None or limit == -1:
end = len(loaded_dataset)
else:
end = min(start_index + limit, len(loaded_dataset))
rows = [loaded_dataset[i] for i in range(start_index, end)]
return IterrowsResponse(
data=rows,
next_start_index=end if end < len(loaded_dataset) else None,
)
records = [loaded_dataset[i] for i in range(len(loaded_dataset))]
return paginate_records(records, start_index, limit)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]

View file

@ -55,7 +55,7 @@ from .openai_utils import (
convert_openai_completion_choice,
convert_openai_completion_stream,
)
from .utils import _is_nvidia_hosted, check_health
from .utils import _is_nvidia_hosted
logger = logging.getLogger(__name__)
@ -134,7 +134,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
if content_has_media(content):
raise NotImplementedError("Media is not supported")
await check_health(self._config) # this raises errors
# ToDo: check health of NeMo endpoints and enable this
# removing this health check as NeMo customizer endpoint health check is returning 404
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
request = convert_completion_request(
@ -203,7 +205,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
extra_body["input_type"] = task_type_options[task_type]
try:
response = await self._client.embeddings.create(
response = await self._get_client(model).embeddings.create(
model=model,
input=input,
extra_body=extra_body,
@ -236,7 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
await check_health(self._config) # this raises errors
# await check_health(self._config) # this raises errors
provider_model_id = self.get_provider_model_id(model_id)
request = await convert_chat_completion_request(

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,138 @@
# NVIDIA Post-Training Provider for LlamaStack
This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service.
## Features
- Supervised fine-tuning of Llama models
- LoRA fine-tuning support
- Job management and status tracking
## Getting Started
### Prerequisites
- LlamaStack with NVIDIA configuration
- Access to Hosted NVIDIA NeMo Customizer service
- Dataset registered in the Hosted NVIDIA NeMo Customizer service
- Base model downloaded and available in the Hosted NVIDIA NeMo Customizer service
### Setup
Build the NVIDIA environment:
```bash
llama stack build --template nvidia --image-type conda
```
### Basic Usage using the LlamaStack Python Client
### Create Customization Job
#### Initialize the client
```python
import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
client = LlamaStackAsLibraryClient("nvidia")
client.initialize()
```
#### Configure fine-tuning parameters
```python
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
```
#### Set up LoRA configuration
```python
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
```
#### Configure training data
```python
data_config = TrainingConfigDataConfig(
dataset_id="your-dataset-id", # Use client.datasets.list() to see available datasets
batch_size=16,
)
```
#### Configure optimizer
```python
optimizer_config = TrainingConfigOptimizerConfig(
lr=0.0001,
)
```
#### Set up training configuration
```python
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
```
#### Start fine-tuning job
```python
training_job = client.post_training.supervised_fine_tune(
job_uuid="unique-job-id",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
logger_config={},
hyperparam_search_config={},
)
```
### List all jobs
```python
jobs = client.post_training.job.list()
```
### Check job status
```python
job_status = client.post_training.job.status(job_uuid="your-job-id")
```
### Cancel a job
```python
client.post_training.job.cancel(job_uuid="your-job-id")
```
### Inference with the fine-tuned model
```python
response = client.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",
stream=False,
model_id="test-example-model@v1",
sampling_params={
"max_tokens": 50,
},
)
print(response.content)
```

View file

@ -0,0 +1,23 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .config import NvidiaPostTrainingConfig
async def get_adapter_impl(
config: NvidiaPostTrainingConfig,
_deps,
):
from .post_training import NvidiaPostTrainingAdapter
if not isinstance(config, NvidiaPostTrainingConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
impl = NvidiaPostTrainingAdapter(config)
return impl
__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"]

View file

@ -0,0 +1,113 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
# TODO: add default values for all fields
class NvidiaPostTrainingConfig(BaseModel):
"""Configuration for NVIDIA Post Training implementation."""
api_key: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
description="The NVIDIA API key.",
)
dataset_namespace: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
description="The NVIDIA dataset namespace.",
)
project_id: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-example-model@v1"),
description="The NVIDIA project ID.",
)
# ToDO: validate this, add default value
customizer_url: Optional[str] = Field(
default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"),
description="Base URL for the NeMo Customizer API",
)
timeout: int = Field(
default=300,
description="Timeout for the NVIDIA Post Training API",
)
max_retries: int = Field(
default=3,
description="Maximum number of retries for the NVIDIA Post Training API",
)
# ToDo: validate this
output_model_dir: str = Field(
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
description="Directory to save the output model",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"api_key": "${env.NVIDIA_API_KEY:}",
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
}
class SFTLoRADefaultConfig(BaseModel):
"""NVIDIA-specific training configuration with default values."""
# ToDo: split into SFT and LoRA configs??
# General training parameters
n_epochs: int = 50
# NeMo customizer specific parameters
log_every_n_steps: Optional[int] = None
val_check_interval: float = 0.25
sequence_packing_enabled: bool = False
weight_decay: float = 0.01
lr: float = 0.0001
# SFT specific parameters
hidden_dropout: Optional[float] = None
attention_dropout: Optional[float] = None
ffn_dropout: Optional[float] = None
# LoRA default parameters
lora_adapter_dim: int = 8
lora_adapter_dropout: Optional[float] = None
lora_alpha: int = 16
# Data config
batch_size: int = 8
@classmethod
def sample_config(cls) -> Dict[str, Any]:
"""Return a sample configuration for NVIDIA training."""
return {
"n_epochs": 50,
"log_every_n_steps": 10,
"val_check_interval": 0.25,
"sequence_packing_enabled": False,
"weight_decay": 0.01,
"hidden_dropout": 0.1,
"attention_dropout": 0.1,
"lora_adapter_dim": 8,
"lora_alpha": 16,
"data_config": {
"dataset_id": "default",
"batch_size": 8,
},
"optimizer_config": {
"lr": 0.0001,
},
}

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.models.llama.datatypes import CoreModelId
from llama_stack.providers.utils.inference.model_registry import (
ProviderModelEntry,
build_hf_repo_model_entry,
)
_MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta/llama-3.1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
)
]
def get_model_entries() -> List[ProviderModelEntry]:
return _MODEL_ENTRIES

View file

@ -0,0 +1,439 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import warnings
from datetime import datetime
from typing import Any, Dict, List, Literal, Optional
import aiohttp
from pydantic import BaseModel, ConfigDict
from llama_stack.apis.post_training import (
AlgorithmConfig,
DPOAlignmentConfig,
JobStatus,
PostTrainingJob,
PostTrainingJobArtifactsResponse,
PostTrainingJobStatusResponse,
TrainingConfig,
)
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .models import _MODEL_ENTRIES
# Map API status to JobStatus enum
STATUS_MAPPING = {
"running": "in_progress",
"completed": "completed",
"failed": "failed",
"cancelled": "cancelled",
"pending": "scheduled",
}
class NvidiaPostTrainingJob(PostTrainingJob):
"""Parse the response from the Customizer API.
Inherits job_uuid from PostTrainingJob.
Adds status, created_at, updated_at parameters.
Passes through all other parameters from data field in the response.
"""
model_config = ConfigDict(extra="allow")
status: JobStatus
created_at: datetime
updated_at: datetime
class ListNvidiaPostTrainingJobs(BaseModel):
data: List[NvidiaPostTrainingJob]
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
model_config = ConfigDict(extra="allow")
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
def __init__(self, config: NvidiaPostTrainingConfig):
self.config = config
self.headers = {}
if config.api_key:
self.headers["Authorization"] = f"Bearer {config.api_key}"
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
# TODO: filter by available models based on /config endpoint
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
self.customizer_url = config.customizer_url
if not self.customizer_url:
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
self.customizer_url = "http://nemo.test"
async def _make_request(
self,
method: str,
path: str,
headers: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
json: Optional[Dict[str, Any]] = None,
**kwargs,
) -> Dict[str, Any]:
"""Helper method to make HTTP requests to the Customizer API."""
url = f"{self.customizer_url}{path}"
request_headers = self.headers.copy()
if headers:
request_headers.update(headers)
# Add content-type header for JSON requests
if json and "Content-Type" not in request_headers:
request_headers["Content-Type"] = "application/json"
for _ in range(self.config.max_retries):
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
if response.status >= 400:
error_data = await response.json()
raise Exception(f"API request failed: {error_data}")
return await response.json()
async def get_training_jobs(
self,
page: Optional[int] = 1,
page_size: Optional[int] = 10,
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
) -> ListNvidiaPostTrainingJobs:
"""Get all customization jobs.
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
Returns a ListNvidiaPostTrainingJobs object with the following fields:
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
ToDo: Support for schema input for filtering.
"""
params = {"page": page, "page_size": page_size, "sort": sort}
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
jobs = []
for job in response.get("data", []):
job_id = job.pop("id")
job_status = job.pop("status", "unknown").lower()
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
# Convert string timestamps to datetime objects
created_at = (
datetime.fromisoformat(job.pop("created_at"))
if "created_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
updated_at = (
datetime.fromisoformat(job.pop("updated_at"))
if "updated_at" in job
else datetime.now(tz=datetime.timezone.utc)
)
# Create NvidiaPostTrainingJob instance
jobs.append(
NvidiaPostTrainingJob(
job_uuid=job_id,
status=JobStatus(mapped_status),
created_at=created_at,
updated_at=updated_at,
**job,
)
)
return ListNvidiaPostTrainingJobs(data=jobs)
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
"""Get the status of a customization job.
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
Returns a NvidiaPostTrainingJob object with the following fields:
- job_uuid: str - Unique identifier for the job
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
- created_at: datetime - The time when the job was created
- updated_at: datetime - The last time the job status was updated
Additional fields that may be included:
- steps_completed: Optional[int] - Number of training steps completed
- epochs_completed: Optional[int] - Number of epochs completed
- percentage_done: Optional[float] - Percentage of training completed (0-100)
- best_epoch: Optional[int] - The epoch with the best performance
- train_loss: Optional[float] - Training loss of the best checkpoint
- val_loss: Optional[float] - Validation loss of the best checkpoint
- metrics: Optional[Dict] - Additional training metrics
- status_logs: Optional[List] - Detailed logs of status changes
"""
response = await self._make_request(
"GET",
f"/v1/customization/jobs/{job_uuid}/status",
params={"job_id": job_uuid},
)
api_status = response.pop("status").lower()
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
return NvidiaPostTrainingJobStatusResponse(
status=JobStatus(mapped_status),
job_uuid=job_uuid,
started_at=datetime.fromisoformat(response.pop("created_at")),
updated_at=datetime.fromisoformat(response.pop("updated_at")),
**response,
)
async def cancel_training_job(self, job_uuid: str) -> None:
await self._make_request(
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
)
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
raise NotImplementedError("Job artifacts are not implemented yet")
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: Dict[str, Any],
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig] = None,
extra_json: Optional[Dict[str, Any]] = None,
params: Optional[Dict[str, Any]] = None,
headers: Optional[Dict[str, Any]] = None,
**kwargs,
) -> NvidiaPostTrainingJob:
"""
Fine-tunes a model on a dataset.
Currently only supports Lora finetuning for standlone docker container.
Assumptions:
- nemo microservice is running and endpoint is set in config.customizer_url
- dataset is registered separately in nemo datastore
- model checkpoint is downloaded as per nemo customizer requirements
Parameters:
training_config: TrainingConfig - Configuration for training
model: str - Model identifier
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
job_uuid: str - Unique identifier for the job, ignored atm
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search, ignored atm
logger_config: Dict[str, Any] - Configuration for logging, ignored atm
Environment Variables:
- NVIDIA_API_KEY: str - API key for the NVIDIA API
Default: None
- NVIDIA_DATASET_NAMESPACE: str - Namespace of the dataset
Default: "default"
- NVIDIA_CUSTOMIZER_URL: str - URL of the NeMo Customizer API
Default: "http://nemo.test"
- NVIDIA_PROJECT_ID: str - ID of the project
Default: "test-project"
- NVIDIA_OUTPUT_MODEL_DIR: str - Directory to save the output model
Default: "test-example-model@v1"
Supported models:
- meta/llama-3.1-8b-instruct
Supported algorithm configs:
- LoRA, SFT
Supported Parameters:
- TrainingConfig:
- n_epochs: int - Number of epochs to train
Default: 50
- data_config: DataConfig - Configuration for the dataset
- optimizer_config: OptimizerConfig - Configuration for the optimizer
- dtype: str - Data type for training
not supported (users are informed via warnings)
- efficiency_config: EfficiencyConfig - Configuration for efficiency
not supported
- max_steps_per_epoch: int - Maximum number of steps per epoch
Default: 1000
## NeMo customizer specific parameters
- log_every_n_steps: int - Log every n steps
Default: None
- val_check_interval: float - Validation check interval
Default: 0.25
- sequence_packing_enabled: bool - Sequence packing enabled
Default: False
## NeMo customizer specific SFT parameters
- hidden_dropout: float - Hidden dropout
Default: None (0.0-1.0)
- attention_dropout: float - Attention dropout
Default: None (0.0-1.0)
- ffn_dropout: float - FFN dropout
Default: None (0.0-1.0)
- DataConfig:
- dataset_id: str - Dataset ID
- batch_size: int - Batch size
Default: 8
- OptimizerConfig:
- lr: float - Learning rate
Default: 0.0001
## NeMo customizer specific parameter
- weight_decay: float - Weight decay
Default: 0.01
- LoRA config:
## NeMo customizer specific LoRA parameters
- adapter_dim: int - Adapter dimension
Default: 8 (supports powers of 2)
- adapter_dropout: float - Adapter dropout
Default: None (0.0-1.0)
- alpha: int - Scaling factor for the LoRA update
Default: 16
Note:
- checkpoint_dir, hyperparam_search_config, logger_config are not supported (users are informed via warnings)
- Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported (users are informed via warnings)
User is informed about unsupported parameters via warnings.
"""
# Map model to nvidia model name
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
nvidia_model = self.get_provider_model_id(model)
# Check for unsupported method parameters
unsupported_method_params = []
if checkpoint_dir:
unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
if hyperparam_search_config:
unsupported_method_params.append("hyperparam_search_config")
if logger_config:
unsupported_method_params.append("logger_config")
if unsupported_method_params:
warnings.warn(
f"Parameters: {', '.join(unsupported_method_params)} are not supported and will be ignored",
stacklevel=2,
)
# Define all supported parameters
supported_params = {
"training_config": {
"n_epochs",
"data_config",
"optimizer_config",
"log_every_n_steps",
"val_check_interval",
"sequence_packing_enabled",
"hidden_dropout",
"attention_dropout",
"ffn_dropout",
},
"data_config": {"dataset_id", "batch_size"},
"optimizer_config": {"lr", "weight_decay"},
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
}
# Validate all parameters at once
warn_unsupported_params(training_config, supported_params["training_config"], "TrainingConfig")
warn_unsupported_params(training_config["data_config"], supported_params["data_config"], "DataConfig")
warn_unsupported_params(
training_config["optimizer_config"], supported_params["optimizer_config"], "OptimizerConfig"
)
output_model = self.config.output_model_dir
# Prepare base job configuration
job_config = {
"config": nvidia_model,
"dataset": {
"name": training_config["data_config"]["dataset_id"],
"namespace": self.config.dataset_namespace,
},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
**{
k: v
for k, v in {
"epochs": training_config.get("n_epochs"),
"batch_size": training_config["data_config"].get("batch_size"),
"learning_rate": training_config["optimizer_config"].get("lr"),
"weight_decay": training_config["optimizer_config"].get("weight_decay"),
"log_every_n_steps": training_config.get("log_every_n_steps"),
"val_check_interval": training_config.get("val_check_interval"),
"sequence_packing_enabled": training_config.get("sequence_packing_enabled"),
}.items()
if v is not None
},
},
"project": self.config.project_id,
# TODO: ignored ownership, add it later
# "ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
"output_model": output_model,
}
# Handle SFT-specific optional parameters
job_config["hyperparameters"]["sft"] = {
k: v
for k, v in {
"ffn_dropout": training_config.get("ffn_dropout"),
"hidden_dropout": training_config.get("hidden_dropout"),
"attention_dropout": training_config.get("attention_dropout"),
}.items()
if v is not None
}
# Remove the sft dictionary if it's empty
if not job_config["hyperparameters"]["sft"]:
job_config["hyperparameters"].pop("sft")
# Handle LoRA-specific configuration
if algorithm_config:
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
job_config["hyperparameters"]["lora"] = {
k: v
for k, v in {
"adapter_dim": algorithm_config.get("adapter_dim"),
"alpha": algorithm_config.get("alpha"),
"adapter_dropout": algorithm_config.get("adapter_dropout"),
}.items()
if v is not None
}
else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
# Create the customization job
response = await self._make_request(
method="POST",
path="/v1/customization/jobs",
headers={"Accept": "application/json"},
json=job_config,
)
job_uuid = response["id"]
response.pop("status")
created_at = datetime.fromisoformat(response.pop("created_at"))
updated_at = datetime.fromisoformat(response.pop("updated_at"))
return NvidiaPostTrainingJob(
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
)
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob:
"""Optimize a model based on preference data."""
raise NotImplementedError("Preference optimization is not implemented yet")
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
raise NotImplementedError("Job logs are not implemented yet")

View file

@ -0,0 +1,63 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import warnings
from typing import Any, Dict, Set, Tuple
from pydantic import BaseModel
from llama_stack.apis.post_training import TrainingConfig
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
from .config import NvidiaPostTrainingConfig
logger = logging.getLogger(__name__)
def warn_unsupported_params(config_dict: Any, supported_keys: Set[str], config_name: str) -> None:
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
unsupported_params = [k for k in keys if k not in supported_keys]
if unsupported_params:
warnings.warn(
f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.", stacklevel=2
)
def validate_training_params(
training_config: Dict[str, Any], supported_keys: Set[str], config_name: str = "TrainingConfig"
) -> None:
"""
Validates training parameters against supported keys.
Args:
training_config: Dictionary containing training configuration parameters
supported_keys: Set of supported parameter keys
config_name: Name of the configuration for warning messages
"""
sft_lora_fields = set(SFTLoRADefaultConfig.__annotations__.keys())
training_config_fields = set(TrainingConfig.__annotations__.keys())
# Check for not supported parameters:
# - not in either of configs
# - in TrainingConfig but not in SFTLoRADefaultConfig
unsupported_params = []
for key in training_config:
if isinstance(key, str) and key not in (supported_keys.union(sft_lora_fields)):
if key in (not sft_lora_fields or training_config_fields):
unsupported_params.append(key)
if unsupported_params:
warnings.warn(
f"Parameters: {unsupported_params} in `{config_name}` are not supported and will be ignored.", stacklevel=2
)
# ToDo: implement post health checks for customizer are enabled
async def _get_health(url: str) -> Tuple[bool, bool]: ...
async def check_health(config: NvidiaPostTrainingConfig) -> None: ...

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -50,20 +51,22 @@ class BingSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequestP
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web using Bing Search API",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="web_search",
description="Search the web using Bing Search API",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()

View file

@ -4,12 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -49,21 +50,23 @@ class BraveSearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsRequest
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
built_in_type=BuiltinTool.brave_search,
)
]
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
built_in_type=BuiltinTool.brave_search,
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
from urllib.parse import urlparse
from mcp import ClientSession
@ -12,6 +12,7 @@ from mcp.client.sse import sse_client
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
ToolDef,
ToolInvocationResult,
ToolParameter,
@ -31,7 +32,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
) -> ListToolDefsResponse:
if mcp_endpoint is None:
raise ValueError("mcp_endpoint is required")
@ -60,7 +61,7 @@ class ModelContextProtocolToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime):
},
)
)
return tools
return ListToolDefsResponse(data=tools)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
tool = await self.tool_store.get_tool(tool_name)

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -49,20 +50,22 @@ class TavilySearchToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="web_search",
description="Search the web for information",
parameters=[
ToolParameter(
name="query",
description="The query to search for",
parameter_type="string",
)
],
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()

View file

@ -5,12 +5,13 @@
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from typing import Any, Dict, Optional
import httpx
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.tools import (
ListToolDefsResponse,
Tool,
ToolDef,
ToolInvocationResult,
@ -50,20 +51,22 @@ class WolframAlphaToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, NeedsReques
async def list_runtime_tools(
self, tool_group_id: Optional[str] = None, mcp_endpoint: Optional[URL] = None
) -> List[ToolDef]:
return [
ToolDef(
name="wolfram_alpha",
description="Query WolframAlpha for computational knowledge",
parameters=[
ToolParameter(
name="query",
description="The query to compute",
parameter_type="string",
)
],
)
]
) -> ListToolDefsResponse:
return ListToolDefsResponse(
data=[
ToolDef(
name="wolfram_alpha",
description="Query WolframAlpha for computational knowledge",
parameters=[
ToolParameter(
name="query",
description="The query to compute",
parameter_type="string",
)
],
)
]
)
async def invoke_tool(self, tool_name: str, kwargs: Dict[str, Any]) -> ToolInvocationResult:
api_key = self._get_api_key()

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import hashlib
import logging
import os
@ -35,15 +36,16 @@ class MilvusIndex(EmbeddingIndex):
self.consistency_level = consistency_level
async def delete(self):
if self.client.has_collection(self.collection_name):
self.client.drop_collection(collection_name=self.collection_name)
if await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(self.client.drop_collection, collection_name=self.collection_name)
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
assert len(chunks) == len(embeddings), (
f"Chunk length {len(chunks)} does not match embedding length {len(embeddings)}"
)
if not self.client.has_collection(self.collection_name):
self.client.create_collection(
if not await asyncio.to_thread(self.client.has_collection, self.collection_name):
await asyncio.to_thread(
self.client.create_collection,
self.collection_name,
dimension=len(embeddings[0]),
auto_id=True,
@ -62,7 +64,8 @@ class MilvusIndex(EmbeddingIndex):
}
)
try:
self.client.insert(
await asyncio.to_thread(
self.client.insert,
self.collection_name,
data=data,
)
@ -71,7 +74,8 @@ class MilvusIndex(EmbeddingIndex):
raise e
async def query(self, embedding: NDArray, k: int, score_threshold: float) -> QueryChunksResponse:
search_res = self.client.search(
search_res = await asyncio.to_thread(
self.client.search,
collection_name=self.collection_name,
data=[embedding],
limit=k,

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
from llama_stack.apis.common.responses import PaginatedResponse
def paginate_records(
records: List[Dict[str, Any]],
start_index: int | None = None,
limit: int | None = None,
) -> PaginatedResponse:
"""Helper function to handle pagination of records consistently across implementations.
Inspired by stripe's pagination: https://docs.stripe.com/api/pagination
:param records: List of records to paginate
:param start_index: The starting index (0-based). If None, starts from beginning.
:param limit: Number of items to return. If None or -1, returns all items.
:return: PaginatedResponse with the paginated data
"""
# Handle special case for fetching all rows
if limit is None or limit == -1:
return PaginatedResponse(
data=records,
has_more=False,
)
# Use offset-based pagination
start_index = start_index or 0
end_index = min(start_index + limit, len(records))
page_data = records[start_index:end_index]
# Calculate if there are more records
has_more = end_index < len(records)
return PaginatedResponse(
data=page_data,
has_more=has_more,
)

View file

@ -147,6 +147,9 @@ def get_sampling_options(params: SamplingParams) -> dict:
if params.repetition_penalty is not None and params.repetition_penalty != 1.0:
options["repeat_penalty"] = params.repetition_penalty
if params.stop is not None:
options["stop"] = params.stop
return options

View file

@ -39,6 +39,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/bedrock/trace_store.db}
eval:

View file

@ -79,6 +79,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db}
tool_runtime:

View file

@ -42,6 +42,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db}
eval:

View file

@ -45,6 +45,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
eval:

View file

@ -41,6 +41,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
eval:

View file

@ -0,0 +1,770 @@
{
"bedrock": [
"aiosqlite",
"autoevals",
"blobfile",
"boto3",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"cerebras": [
"aiosqlite",
"autoevals",
"blobfile",
"cerebras_cloud_sdk",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"ci-tests": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"dell": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"dev": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"groq": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"hf-endpoint": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"hf-serverless": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"meta-reference-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
"sentencepiece",
"torch",
"torchvision",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"zmq"
],
"meta-reference-quantized-gpu": [
"accelerate",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fairscale",
"faiss-cpu",
"fastapi",
"fbgemm-gpu",
"fire",
"httpx",
"langdetect",
"lm-format-enforcer",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentence-transformers",
"sentencepiece",
"torch",
"torchao==0.5.0",
"torchvision",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"zmq"
],
"nvidia": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"ollama": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"ollama",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"open-benchmark": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"fastapi",
"fire",
"httpx",
"langdetect",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"sqlite-vec",
"together",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"passthrough": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"sambanova": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"litellm",
"matplotlib",
"mcp",
"nltk",
"numpy",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"tgi": [
"aiohttp",
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"together": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"together",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"vllm-gpu": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"chromadb-client",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn",
"vllm",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
]
}

View file

@ -76,6 +76,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
eval:

View file

@ -50,6 +50,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
eval:

View file

@ -45,6 +45,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
eval:

View file

@ -45,6 +45,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/groq/trace_store.db}
eval:

View file

@ -50,6 +50,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
eval:

View file

@ -45,6 +45,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
eval:

View file

@ -50,6 +50,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
eval:

View file

@ -45,6 +45,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
eval:

View file

@ -52,6 +52,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
eval:

View file

@ -46,6 +46,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
eval:

View file

@ -48,6 +48,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-quantized-gpu/trace_store.db}
eval:

View file

@ -14,6 +14,8 @@ distribution_spec:
- inline::meta-reference
eval:
- inline::meta-reference
post_training:
- remote::nvidia
datasetio:
- inline::localfs
scoring:

View file

@ -21,6 +21,7 @@ def get_distribution_template() -> DistributionTemplate:
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"post_training": ["remote::nvidia"],
"datasetio": ["inline::localfs"],
"scoring": ["inline::basic"],
"tool_runtime": ["inline::rag-runtime"],
@ -89,6 +90,31 @@ def get_distribution_template() -> DistributionTemplate:
"",
"NVIDIA API Key",
),
## Nemo Customizer related variables
"NVIDIA_USER_ID": (
"llama-stack-user",
"NVIDIA User ID",
),
"NVIDIA_DATASET_NAMESPACE": (
"default",
"NVIDIA Dataset Namespace",
),
"NVIDIA_ACCESS_POLICIES": (
"{}",
"NVIDIA Access Policies",
),
"NVIDIA_PROJECT_ID": (
"test-project",
"NVIDIA Project ID",
),
"NVIDIA_CUSTOMIZER_URL": (
"https://customizer.api.nvidia.com",
"NVIDIA Customizer URL",
),
"NVIDIA_OUTPUT_MODEL_DIR": (
"test-example-model@v1",
"NVIDIA Output Model Directory",
),
"GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",

View file

@ -5,6 +5,7 @@ apis:
- datasetio
- eval
- inference
- post_training
- safety
- scoring
- telemetry
@ -48,6 +49,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
@ -58,6 +60,14 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
post_training:
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:default}
project_id: ${env.NVIDIA_PROJECT_ID:test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}
datasetio:
- provider_id: localfs
provider_type: inline::localfs

View file

@ -5,6 +5,7 @@ apis:
- datasetio
- eval
- inference
- post_training
- safety
- scoring
- telemetry
@ -43,6 +44,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
@ -53,6 +55,14 @@ providers:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
post_training:
- provider_id: nvidia
provider_type: remote::nvidia
config:
api_key: ${env.NVIDIA_API_KEY:}
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:default}
project_id: ${env.NVIDIA_PROJECT_ID:test-project}
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}
datasetio:
- provider_id: localfs
provider_type: inline::localfs

View file

@ -43,6 +43,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
eval:

View file

@ -41,6 +41,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
eval:

View file

@ -68,6 +68,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
eval:

View file

@ -50,6 +50,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
eval:

View file

@ -45,6 +45,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
eval:

View file

@ -85,11 +85,14 @@ export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export LLAMA_STACK_PORT=8321
# You need a local checkout of llama-stack to run this, get it using
# git clone https://github.com/meta-llama/llama-stack.git
cd /path/to/llama-stack
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
-v ./llama_stack/templates/remote-vllm/run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
@ -108,7 +111,6 @@ export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
cd /path/to/llama-stack
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \

View file

@ -88,6 +88,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
tool_runtime:

View file

@ -81,6 +81,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
tool_runtime:

View file

@ -54,6 +54,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/sambanova/trace_store.db}
tool_runtime:

View file

@ -45,6 +45,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/tgi/trace_store.db}
eval:

View file

@ -44,6 +44,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/tgi/trace_store.db}
eval:

View file

@ -50,6 +50,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/together/trace_store.db}
eval:

View file

@ -45,6 +45,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/together/trace_store.db}
eval:

View file

@ -49,6 +49,7 @@ providers:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/vllm-gpu/trace_store.db}
eval: