Merge branch 'main' into feat/litellm_sambanova_usage

This commit is contained in:
Jorge Piedrahita Ortiz 2025-03-18 12:06:58 -05:00 committed by GitHub
commit 5bd1bd30e2
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
76 changed files with 3534 additions and 2843 deletions

View file

@ -13,19 +13,16 @@ from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class PaginatedRowsResult(BaseModel):
class IterrowsResponse(BaseModel):
"""
A paginated list of rows from a dataset.
:param rows: The rows in the current page.
:param total_count: The total number of rows in the dataset.
:param next_page_token: The token to get the next page of rows.
:param data: The rows in the current page.
:param next_start_index: Index into dataset for the first row in the next page. None if there are no more rows.
"""
# the rows obey the DatasetSchema for the given dataset
rows: List[Dict[str, Any]]
total_count: int
next_page_token: Optional[str] = None
data: List[Dict[str, Any]]
next_start_index: Optional[int] = None
class DatasetStore(Protocol):
@ -37,22 +34,21 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/datasetio/rows", method="GET")
async def get_rows_paginated(
# TODO(xiyan): there's a flakiness here where setting route to "/datasets/" here will not result in proper routing
@webmethod(route="/datasetio/iterrows/{dataset_id:path}", method="GET")
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
"""Get a paginated list of rows from a dataset.
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
"""Get a paginated list of rows from a dataset. Uses cursor-based pagination.
:param dataset_id: The ID of the dataset to get the rows from.
:param rows_in_page: The number of rows to get per page.
:param page_token: The token to get the next page of rows.
:param filter_condition: (Optional) A condition to filter the rows by.
:param start_index: Index into dataset for the first row to get. Get all rows if None.
:param limit: The number of rows to get.
"""
...
@webmethod(route="/datasetio/rows", method="POST")
@webmethod(route="/datasetio/append-rows/{dataset_id:path}", method="POST")
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None: ...

View file

@ -4,19 +4,102 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol
from enum import Enum
from typing import Annotated, Any, Dict, List, Literal, Optional, Protocol, Union
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class DatasetPurpose(str, Enum):
"""
Purpose of the dataset. Each purpose has a required input data schema.
:cvar post-training/messages: The dataset contains messages used for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
:cvar eval/question-answer: The dataset contains a question column and an answer column.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
:cvar eval/messages-answer: The dataset contains a messages column with list of messages and an answer column.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
"""
post_training_messages = "post-training/messages"
eval_question_answer = "eval/question-answer"
eval_messages_answer = "eval/messages-answer"
# TODO: add more schemas here
class DatasetType(Enum):
"""
Type of the dataset source.
:cvar uri: The dataset can be obtained from a URI.
:cvar rows: The dataset is stored in rows.
"""
uri = "uri"
rows = "rows"
@json_schema_type
class URIDataSource(BaseModel):
"""A dataset that can be obtained from a URI.
:param uri: The dataset can be obtained from a URI. E.g.
- "https://mywebsite.com/mydata.jsonl"
- "lsfs://mydata.jsonl"
- "data:csv;base64,{base64_content}"
"""
type: Literal["uri"] = "uri"
uri: str
@json_schema_type
class RowsDataSource(BaseModel):
"""A dataset stored in rows.
:param rows: The dataset is stored in rows. E.g.
- [
{"messages": [{"role": "user", "content": "Hello, world!"}, {"role": "assistant", "content": "Hello, world!"}]}
]
"""
type: Literal["rows"] = "rows"
rows: List[Dict[str, Any]]
DataSource = register_schema(
Annotated[
Union[URIDataSource, RowsDataSource],
Field(discriminator="type"),
],
name="DataSource",
)
class CommonDatasetFields(BaseModel):
dataset_schema: Dict[str, ParamType]
url: URL
"""
Common fields for a dataset.
"""
purpose: DatasetPurpose
source: DataSource
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this dataset",
@ -50,13 +133,69 @@ class Datasets(Protocol):
@webmethod(route="/datasets", method="POST")
async def register_dataset(
self,
dataset_id: str,
dataset_schema: Dict[str, ParamType],
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
dataset_id: Optional[str] = None,
) -> Dataset:
"""
Register a new dataset.
:param purpose: The purpose of the dataset. One of
- "post-training/messages": The dataset contains a messages column with list of messages for post-training.
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
- "eval/question-answer": The dataset contains a question column and an answer column for evaluation.
{
"question": "What is the capital of France?",
"answer": "Paris"
}
- "eval/messages-answer": The dataset contains a messages column with list of messages and an answer column for evaluation.
{
"messages": [
{"role": "user", "content": "Hello, my name is John Doe."},
{"role": "assistant", "content": "Hello, John Doe. How can I help you today?"},
{"role": "user", "content": "What's my name?"},
],
"answer": "John Doe"
}
:param source: The data source of the dataset. Ensure that the data source schema is compatible with the purpose of the dataset. Examples:
- {
"type": "uri",
"uri": "https://mywebsite.com/mydata.jsonl"
}
- {
"type": "uri",
"uri": "lsfs://mydata.jsonl"
}
- {
"type": "uri",
"uri": "data:csv;base64,{base64_content}"
}
- {
"type": "uri",
"uri": "huggingface://llamastack/simpleqa?split=train"
}
- {
"type": "rows",
"rows": [
{
"messages": [
{"role": "user", "content": "Hello, world!"},
{"role": "assistant", "content": "Hello, world!"},
]
}
]
}
:param metadata: The metadata for the dataset.
- E.g. {"description": "My dataset"}
:param dataset_id: The ID of the dataset. If not provided, an ID will be generated.
"""
...
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
async def get_dataset(

View file

@ -38,7 +38,7 @@ from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.resolver import InvalidProviderError
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
@ -213,7 +213,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
config = parse_and_maybe_upgrade_config(config_dict)
run_args = formulate_run_args(args.image_type, args.image_name, config, args.template)
run_args.extend([run_config, str(os.getenv("LLAMA_STACK_PORT", 8321))])
run_with_pty(run_args)
run_command(run_args)
def _generate_run_config(

View file

@ -82,7 +82,7 @@ class StackRun(Subcommand):
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
from llama_stack.distribution.utils.exec import formulate_run_args, run_with_pty
from llama_stack.distribution.utils.exec import formulate_run_args, run_command
config_file = Path(args.config)
has_yaml_suffix = args.config.endswith(".yaml")
@ -136,4 +136,4 @@ class StackRun(Subcommand):
if args.tls_keyfile and args.tls_certfile:
run_args.extend(["--tls-keyfile", args.tls_keyfile, "--tls-certfile", args.tls_certfile])
run_with_pty(run_args)
run_command(run_args)

View file

@ -6,7 +6,6 @@
import importlib.resources
import logging
import sys
from pathlib import Path
from typing import Dict, List
@ -15,7 +14,7 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.exec import run_command
from llama_stack.distribution.utils.image_types import LlamaStackImageType
from llama_stack.providers.datatypes import Api
@ -123,11 +122,7 @@ def build_image(
if special_deps:
args.append("#".join(special_deps))
is_terminal = sys.stdin.isatty()
if is_terminal:
return_code = run_with_pty(args)
else:
return_code = run_command(args)
return_code = run_command(args)
if return_code != 0:
log.error(

View file

@ -43,7 +43,7 @@ RED='\033[0;31m'
NC='\033[0m' # No Color
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
CONTAINER_OPTS=${CONTAINER_OPTS:-}
CONTAINER_OPTS=${CONTAINER_OPTS:---progress=plain}
TEMP_DIR=$(mktemp -d)
@ -253,8 +253,7 @@ $CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"." \
--progress=plain
"."
# clean up tmp/configs
set +x

View file

@ -8,10 +8,13 @@
from pydantic import BaseModel
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
from llama_stack.log import get_logger
from .datatypes import StackRunConfig
from .stack import redact_sensitive_fields
logger = get_logger(name=__name__, category="core")
class ProviderImplConfig(BaseModel):
run_config: StackRunConfig
@ -31,6 +34,10 @@ class ProviderImpl(Providers):
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
logger.debug("ProviderImpl.shutdown")
pass
async def list_providers(self) -> ListProvidersResponse:
run_config = self.config.run_config
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))

View file

@ -12,7 +12,8 @@ from llama_stack.apis.common.content_types import (
InterleavedContent,
InterleavedContentItem,
)
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import DatasetPurpose, DataSource
from llama_stack.apis.eval import (
BenchmarkConfig,
Eval,
@ -160,7 +161,11 @@ class InferenceRouter(Inference):
await self.routing_table.register_model(model_id, provider_model_id, provider_id, metadata, model_type)
def _construct_metrics(
self, prompt_tokens: int, completion_tokens: int, total_tokens: int, model: Model
self,
prompt_tokens: int,
completion_tokens: int,
total_tokens: int,
model: Model,
) -> List[MetricEvent]:
"""Constructs a list of MetricEvent objects containing token usage metrics.
@ -298,7 +303,12 @@ class InferenceRouter(Inference):
completion_text += chunk.event.delta.text
if chunk.event.event_type == ChatCompletionResponseEventType.complete:
completion_tokens = await self._count_tokens(
[CompletionMessage(content=completion_text, stop_reason=StopReason.end_of_turn)],
[
CompletionMessage(
content=completion_text,
stop_reason=StopReason.end_of_turn,
)
],
tool_config.tool_prompt_format,
)
total_tokens = (prompt_tokens or 0) + (completion_tokens or 0)
@ -471,21 +481,36 @@ class DatasetIORouter(DatasetIO):
logger.debug("DatasetIORouter.shutdown")
pass
async def get_rows_paginated(
async def register_dataset(
self,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
dataset_id: Optional[str] = None,
) -> None:
logger.debug(
f"DatasetIORouter.register_dataset: {purpose=} {source=} {metadata=} {dataset_id=}",
)
await self.routing_table.register_dataset(
purpose=purpose,
source=source,
metadata=metadata,
dataset_id=dataset_id,
)
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
logger.debug(
f"DatasetIORouter.get_rows_paginated: {dataset_id}, rows_in_page={rows_in_page}",
f"DatasetIORouter.iterrows: {dataset_id}, {start_index=} {limit=}",
)
return await self.routing_table.get_provider_impl(dataset_id).get_rows_paginated(
return await self.routing_table.get_provider_impl(dataset_id).iterrows(
dataset_id=dataset_id,
rows_in_page=rows_in_page,
page_token=page_token,
filter_condition=filter_condition,
start_index=start_index,
limit=limit,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:

View file

@ -5,6 +5,7 @@
# the root directory of this source tree.
import logging
import uuid
from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter
@ -12,7 +13,14 @@ from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.datasets import (
Dataset,
DatasetPurpose,
Datasets,
DatasetType,
DataSource,
ListDatasetsResponse,
)
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
@ -352,34 +360,42 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def register_dataset(
self,
dataset_id: str,
dataset_schema: Dict[str, ParamType],
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
purpose: DatasetPurpose,
source: DataSource,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
if provider_dataset_id is None:
provider_dataset_id = dataset_id
if provider_id is None:
# If provider_id not specified, use the only provider if it supports this dataset
if len(self.impls_by_provider_id) == 1:
provider_id = list(self.impls_by_provider_id.keys())[0]
dataset_id: Optional[str] = None,
) -> Dataset:
if not dataset_id:
dataset_id = f"dataset-{str(uuid.uuid4())}"
provider_dataset_id = dataset_id
# infer provider from source
if source.type == DatasetType.rows.value:
provider_id = "localfs"
elif source.type == DatasetType.uri.value:
# infer provider from uri
if source.uri.startswith("huggingface"):
provider_id = "huggingface"
else:
raise ValueError(
f"No provider specified and multiple providers available. Please specify a provider_id. Available providers: {self.impls_by_provider_id.keys()}"
)
provider_id = "localfs"
else:
raise ValueError(f"Unknown data source type: {source.type}")
if metadata is None:
metadata = {}
dataset = Dataset(
identifier=dataset_id,
provider_resource_id=provider_dataset_id,
provider_id=provider_id,
dataset_schema=dataset_schema,
url=url,
purpose=purpose,
source=source,
metadata=metadata,
)
await self.register_object(dataset)
return dataset
async def unregister_dataset(self, dataset_id: str) -> None:
dataset = await self.get_dataset(dataset_id)

View file

@ -0,0 +1,11 @@
# More info on playground configuration can be found here:
# https://llama-stack.readthedocs.io/en/latest/playground
FROM python:3.9-slim
WORKDIR /app
COPY . /app/
RUN /usr/local/bin/python -m pip install --upgrade pip && \
/usr/local/bin/pip3 install -r requirements.txt
EXPOSE 8501
ENTRYPOINT ["streamlit", "run", "app.py", "--server.port=8501", "--server.address=0.0.0.0"]

View file

@ -40,3 +40,13 @@ cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
```
## Environment Variables
| Environment Variable | Description | Default Value |
|----------------------------|------------------------------------|---------------------------|
| LLAMA_STACK_ENDPOINT | The endpoint for the Llama Stack | http://localhost:8321 |
| FIREWORKS_API_KEY | API key for Fireworks provider | (empty string) |
| TOGETHER_API_KEY | API key for Together provider | (empty string) |
| SAMBANOVA_API_KEY | API key for SambaNova provider | (empty string) |
| OPENAI_API_KEY | API key for OpenAI provider | (empty string) |

View file

@ -166,11 +166,10 @@ def run_evaluation_3():
eval_candidate = st.session_state["eval_candidate"]
dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasetio.get_rows_paginated(
rows = llama_stack_api.client.datasets.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
)
total_rows = len(rows.rows)
total_rows = len(rows.data)
# Add number of examples control
num_rows = st.number_input(
"Number of Examples to Evaluate",
@ -195,7 +194,7 @@ def run_evaluation_3():
if st.button("Run Evaluation"):
progress_text = "Running evaluation..."
progress_bar = st.progress(0, text=progress_text)
rows = rows.rows
rows = rows.data
if num_rows < total_rows:
rows = rows[:num_rows]

View file

@ -4,13 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import errno
import logging
import os
import select
import signal
import subprocess
import sys
from termcolor import cprint
@ -88,13 +85,6 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
return run_args
def run_with_pty(command):
if sys.platform.startswith("win"):
return _run_with_pty_win(command)
else:
return _run_with_pty_unix(command)
def in_notebook():
try:
from IPython import get_ipython
@ -108,19 +98,19 @@ def in_notebook():
return True
# run a command in a pseudo-terminal, with interrupt handling,
# useful when you want to run interactive things
def _run_with_pty_unix(command):
import pty
import termios
def run_command(command: list[str]) -> int:
"""
Run a command with interrupt handling and output capture.
Uses subprocess.run with direct stream piping for better performance.
master, slave = pty.openpty()
Args:
command (list): The command to run.
old_settings = termios.tcgetattr(sys.stdin)
Returns:
int: The return code of the command.
"""
original_sigint = signal.getsignal(signal.SIGINT)
ctrl_c_pressed = False
process = None
def sigint_handler(signum, frame):
nonlocal ctrl_c_pressed
@ -131,106 +121,19 @@ def _run_with_pty_unix(command):
# Set up the signal handler
signal.signal(signal.SIGINT, sigint_handler)
new_settings = termios.tcgetattr(sys.stdin)
new_settings[3] = new_settings[3] & ~termios.ECHO # Disable echo
new_settings[3] = new_settings[3] & ~termios.ICANON # Disable canonical mode
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, new_settings)
process = subprocess.Popen(
# Run the command with stdout/stderr piped directly to system streams
result = subprocess.run(
command,
stdin=slave,
stdout=slave,
stderr=slave,
universal_newlines=True,
preexec_fn=os.setsid,
text=True,
check=False,
)
# Close the slave file descriptor as it's now owned by the subprocess
os.close(slave)
def handle_io():
while not ctrl_c_pressed:
try:
rlist, _, _ = select.select([sys.stdin, master], [], [], 0.1)
if sys.stdin in rlist:
data = os.read(sys.stdin.fileno(), 1024)
if not data:
break
os.write(master, data)
if master in rlist:
data = os.read(master, 1024)
if not data:
break
sys.stdout.buffer.write(data)
sys.stdout.flush()
except KeyboardInterrupt:
# This will be raised when Ctrl+C is pressed
break
if process.poll() is not None:
break
handle_io()
except (EOFError, KeyboardInterrupt):
pass
except OSError as e:
if e.errno != errno.EIO:
raise
finally:
# Clean up
termios.tcsetattr(sys.stdin, termios.TCSADRAIN, old_settings)
signal.signal(signal.SIGINT, original_sigint)
os.close(master)
if process and process.poll() is None:
process.terminate()
process.wait()
return process.returncode
# run a command in a pseudo-terminal in windows, with interrupt handling,
def _run_with_pty_win(command):
"""
Runs a command with interactive support using subprocess directly.
"""
try:
# For shell scripts on Windows, use appropriate shell
if isinstance(command, (list, tuple)):
if command[0].endswith(".sh"):
if os.path.exists("/usr/bin/bash"): # WSL
command = ["bash"] + command
else:
# Use cmd.exe with bash while preserving all arguments
command = ["cmd.exe", "/c", "bash"] + command
process = subprocess.Popen(
command,
shell=True,
universal_newlines=True,
)
process.wait()
return result.returncode
except subprocess.SubprocessError as e:
log.error(f"Subprocess error: {e}")
return 1
except Exception as e:
print(f"Error: {str(e)}")
log.exception(f"Unexpected error: {e}")
return 1
finally:
if process and process.poll() is None:
process.terminate()
process.wait()
return process.returncode
def run_command(command):
try:
result = subprocess.run(command, capture_output=True, text=True, check=True)
print("Script Output\n", result.stdout)
return result.returncode
except subprocess.CalledProcessError as e:
print("Error running script:", e)
print("Error output:", e.stderr)
return e.returncode
# Restore the original signal handler
signal.signal(signal.SIGINT, original_sigint)

View file

@ -614,118 +614,133 @@ class ChatAgent(ShieldRunnerMixin):
logger.debug(f"completion message with EOM (iter: {n_iter}): {str(message)}")
input_messages = input_messages + [message]
else:
logger.debug(f"completion message (iter: {n_iter}) from the model: {str(message)}")
# 1. Start the tool execution step and progress
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
tool_call = message.tool_calls[0]
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
tool_call=tool_call,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
tool_call=tool_call,
),
)
)
)
input_messages = input_messages + [message]
# If tool is a client tool, yield CompletionMessage and return
if tool_call.tool_name in client_tools:
# NOTE: mark end_of_message to indicate to client that it may
# call the tool and continue the conversation with the tool's response.
message.stop_reason = StopReason.end_of_message
# Process tool calls in the message
client_tool_calls = []
non_client_tool_calls = []
# Separate client and non-client tool calls
for tool_call in message.tool_calls:
if tool_call.tool_name in client_tools:
client_tool_calls.append(tool_call)
else:
non_client_tool_calls.append(tool_call)
# Process non-client tool calls first
for tool_call in non_client_tool_calls:
step_id = str(uuid.uuid4())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepStartPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
)
)
)
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepProgressPayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
delta=ToolCallDelta(
parse_status=ToolCallParseStatus.in_progress,
tool_call=tool_call,
),
)
)
)
# Execute the tool call
async with tracing.span(
"tool_execution",
{
"tool_name": tool_call.tool_name,
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
)
if tool_result.content is None:
raise ValueError(
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
)
result_message = ToolResponseMessage(
call_id=tool_call.call_id,
content=tool_result.content,
)
span.set_attribute("output", result_message.model_dump_json())
# Store tool execution step
tool_execution_step = ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=tool_result.content,
metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
)
# Yield the step completion event
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=tool_execution_step,
)
)
)
# Add the result message to input_messages for the next iteration
input_messages.append(result_message)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(out_attachment)
# If there are client tool calls, yield a message with only those tool calls
if client_tool_calls:
await self.storage.set_in_progress_tool_call_step(
session_id,
turn_id,
ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_calls=client_tool_calls,
tool_responses=[],
started_at=datetime.now(timezone.utc).isoformat(),
),
)
yield message
# Create a copy of the message with only client tool calls
client_message = message.model_copy(deep=True)
client_message.tool_calls = client_tool_calls
# NOTE: mark end_of_message to indicate to client that it may
# call the tool and continue the conversation with the tool's response.
client_message.stop_reason = StopReason.end_of_message
# Yield the message with client tool calls
yield client_message
return
# If tool is a builtin server tool, execute it
tool_name = tool_call.tool_name
if isinstance(tool_name, BuiltinTool):
tool_name = tool_name.value
async with tracing.span(
"tool_execution",
{
"tool_name": tool_name,
"input": message.model_dump_json(),
},
) as span:
tool_execution_start_time = datetime.now(timezone.utc).isoformat()
tool_call = message.tool_calls[0]
tool_result = await self.execute_tool_call_maybe(
session_id,
tool_call,
)
if tool_result.content is None:
raise ValueError(
f"Tool call result (id: {tool_call.call_id}, name: {tool_call.tool_name}) does not have any content"
)
result_messages = [
ToolResponseMessage(
call_id=tool_call.call_id,
content=tool_result.content,
)
]
assert len(result_messages) == 1, "Currently not supporting multiple messages"
result_message = result_messages[0]
span.set_attribute("output", result_message.model_dump_json())
yield AgentTurnResponseStreamChunk(
event=AgentTurnResponseEvent(
payload=AgentTurnResponseStepCompletePayload(
step_type=StepType.tool_execution.value,
step_id=step_id,
step_details=ToolExecutionStep(
step_id=step_id,
turn_id=turn_id,
tool_calls=[tool_call],
tool_responses=[
ToolResponse(
call_id=result_message.call_id,
tool_name=tool_call.tool_name,
content=result_message.content,
metadata=tool_result.metadata,
)
],
started_at=tool_execution_start_time,
completed_at=datetime.now(timezone.utc).isoformat(),
),
)
)
)
# TODO: add tool-input touchpoint and a "start" event for this step also
# but that needs a lot more refactoring of Tool code potentially
if (type(result_message.content) is str) and (
out_attachment := _interpret_content_as_attachment(result_message.content)
):
# NOTE: when we push this message back to the model, the model may ignore the
# attached file path etc. since the model is trained to only provide a user message
# with the summary. We keep all generated attachments and then attach them to final message
output_attachments.append(out_attachment)
input_messages = input_messages + [message, result_message]
async def _initialize_tools(
self,
toolgroups_for_turn: Optional[List[AgentToolGroup]] = None,
@ -891,16 +906,14 @@ class ChatAgent(ShieldRunnerMixin):
if memory_tool and code_interpreter_tool:
# if both memory and code_interpreter are available, we download the URLs
# and attach the data to the last message.
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
await attachment_message(self.tempdir, url_items, input_messages[-1])
# Since memory is present, add all the data to the memory bank
await self.add_to_session_vector_db(session_id, documents)
elif code_interpreter_tool:
# if only code_interpreter is available, we download the URLs to a tempdir
# and attach the path to them as a message to inference with the
# assumption that the model invokes the code_interpreter tool with the path
msg = await attachment_message(self.tempdir, url_items)
input_messages.append(msg)
await attachment_message(self.tempdir, url_items, input_messages[-1])
elif memory_tool:
# if only memory is available, we load the data from the URLs and content items to the memory bank
await self.add_to_session_vector_db(session_id, documents)
@ -967,8 +980,8 @@ async def load_data_from_urls(urls: List[URL]) -> List[str]:
return data
async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessage:
content = []
async def attachment_message(tempdir: str, urls: List[URL], message: UserMessage) -> None:
contents = []
for url in urls:
uri = url.uri
@ -988,16 +1001,19 @@ async def attachment_message(tempdir: str, urls: List[URL]) -> ToolResponseMessa
else:
raise ValueError(f"Unsupported URL {url}")
content.append(
contents.append(
TextContentItem(
text=f'# User provided a file accessible to you at "{filepath}"\nYou can use code_interpreter to load and inspect it.'
)
)
return ToolResponseMessage(
call_id="",
content=content,
)
if isinstance(message.content, list):
message.content.extend(contents)
else:
if isinstance(message.content, str):
message.content = [TextContentItem(text=message.content)] + contents
else:
message.content = [message.content] + contents
def _interpret_content_as_attachment(

View file

@ -3,20 +3,14 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import base64
import os
from abc import ABC, abstractmethod
from dataclasses import dataclass
from typing import Any, Dict, List, Optional
from urllib.parse import urlparse
import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_uri
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import LocalFSDatasetIOConfig
@ -24,30 +18,7 @@ from .config import LocalFSDatasetIOConfig
DATASETS_PREFIX = "localfs_datasets:"
class BaseDataset(ABC):
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
@abstractmethod
def __len__(self) -> int:
raise NotImplementedError()
@abstractmethod
def __getitem__(self, idx):
raise NotImplementedError()
@abstractmethod
def load(self):
raise NotImplementedError()
@dataclass
class DatasetInfo:
dataset_def: Dataset
dataset_impl: BaseDataset
class PandasDataframeDataset(BaseDataset):
class PandasDataframeDataset:
def __init__(self, dataset_def: Dataset, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
self.dataset_def = dataset_def
@ -64,23 +35,19 @@ class PandasDataframeDataset(BaseDataset):
else:
return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None:
if self.df is not None:
return
df = get_dataframe_from_url(self.dataset_def.url)
if df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
if self.dataset_def.source.type == "uri":
self.df = get_dataframe_from_uri(self.dataset_def.source.uri)
elif self.dataset_def.source.type == "rows":
self.df = pandas.DataFrame(self.dataset_def.source.rows)
else:
raise ValueError(f"Unsupported dataset source type: {self.dataset_def.source.type}")
self.df = self._validate_dataset_schema(df)
if self.df is None:
raise ValueError(f"Failed to load dataset from {self.dataset_def.url}")
class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -99,95 +66,55 @@ class LocalFSDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
for dataset in stored_datasets:
dataset = Dataset.model_validate_json(dataset)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
)
self.dataset_infos[dataset.identifier] = dataset
async def shutdown(self) -> None: ...
async def register_dataset(
self,
dataset: Dataset,
dataset_def: Dataset,
) -> None:
# Store in kvstore
key = f"{DATASETS_PREFIX}{dataset.identifier}"
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset.json(),
)
dataset_impl = PandasDataframeDataset(dataset)
self.dataset_infos[dataset.identifier] = DatasetInfo(
dataset_def=dataset,
dataset_impl=dataset_impl,
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
async def unregister_dataset(self, dataset_id: str) -> None:
key = f"{DATASETS_PREFIX}{dataset_id}"
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
start_index = start_index or 0
if page_token is None or len(page_token) == 0:
next_page_token = 0
if limit is None or limit == -1:
end = len(dataset_impl)
else:
next_page_token = int(page_token)
end = min(start_index + limit, len(dataset_impl))
start = next_page_token
if rows_in_page == -1:
end = len(dataset_info.dataset_impl)
else:
end = min(start + rows_in_page, len(dataset_info.dataset_impl))
rows = dataset_impl[start_index:end]
rows = dataset_info.dataset_impl[start:end]
return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
return IterrowsResponse(
data=rows,
next_start_index=end if end < len(dataset_impl) else None,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_info = self.dataset_infos.get(dataset_id)
if dataset_info is None:
raise ValueError(f"Dataset with id {dataset_id} not found")
dataset_impl = dataset_info.dataset_impl
dataset_def = self.dataset_infos[dataset_id]
dataset_impl = PandasDataframeDataset(dataset_def)
dataset_impl.load()
new_rows_df = pandas.DataFrame(rows)
new_rows_df = dataset_impl._validate_dataset_schema(new_rows_df)
dataset_impl.df = pandas.concat([dataset_impl.df, new_rows_df], ignore_index=True)
url = str(dataset_info.dataset_def.url.uri)
parsed_url = urlparse(url)
if parsed_url.scheme == "file" or not parsed_url.scheme:
file_path = parsed_url.path
os.makedirs(os.path.dirname(file_path), exist_ok=True)
dataset_impl.df.to_csv(file_path, index=False)
elif parsed_url.scheme == "data":
# For data URLs, we need to update the base64-encoded content
if not parsed_url.path.startswith("text/csv;base64,"):
raise ValueError("Data URL must be a base64-encoded CSV")
csv_buffer = dataset_impl.df.to_csv(index=False)
base64_content = base64.b64encode(csv_buffer.encode("utf-8")).decode("utf-8")
dataset_info.dataset_def.url = URL(uri=f"data:text/csv;base64,{base64_content}")
else:
raise ValueError(
f"Unsupported URL scheme: {parsed_url.scheme}. Only file:// and data: URLs are supported for writing."
)

View file

@ -14,16 +14,11 @@ from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference, SystemMessage, UserMessage
from llama_stack.apis.scoring import Scoring
from llama_stack.distribution.datatypes import Api
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.inline.agents.meta_reference.agent_instance import (
MEMORY_QUERY_TOOL,
)
from llama_stack.providers.utils.common.data_schema_validator import (
ColumnName,
get_valid_schemas,
validate_dataset_schema,
)
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
from llama_stack.providers.utils.kvstore import kvstore_impl
from .....apis.common.job_types import Job
@ -88,15 +83,17 @@ class MetaReferenceEvalImpl(
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
all_rows = await self.datasetio_api.get_rows_paginated(
# TODO (xiyan): validate dataset schema
# dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
limit=(-1 if benchmark_config.num_examples is None else benchmark_config.num_examples),
)
res = await self.evaluate_rows(
benchmark_id=benchmark_id,
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
benchmark_config=benchmark_config,
)

View file

@ -328,13 +328,13 @@ class LoraFinetuningSingleDevice:
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.get_rows_paginated(
return await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows
rows = all_rows.data
await validate_input_dataset_schema(
datasets_api=self.datasets_api,

View file

@ -227,13 +227,6 @@ class LlamaGuardShield:
if len(messages) >= 2 and (messages[0].role == Role.user.value and messages[1].role == Role.user.value):
messages = messages[1:]
for i in range(1, len(messages)):
if messages[i].role == messages[i - 1].role:
for i, m in enumerate(messages):
print(f"{i}: {m.role}: {m.content}")
raise ValueError(
f"Messages must alternate between user and assistant. Message {i} has the same role as message {i - 1}"
)
return messages
async def run(self, messages: List[Message]) -> RunShieldResponse:

View file

@ -24,7 +24,9 @@ from llama_stack.providers.utils.common.data_schema_validator import (
from .config import BasicScoringConfig
from .scoring_fn.bfcl_scoring_fn import BFCLScoringFn
from .scoring_fn.equality_scoring_fn import EqualityScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import RegexParserMathResponseScoringFn
from .scoring_fn.regex_parser_math_response_scoring_fn import (
RegexParserMathResponseScoringFn,
)
from .scoring_fn.regex_parser_scoring_fn import RegexParserScoringFn
from .scoring_fn.subset_of_scoring_fn import SubsetOfScoringFn
@ -82,12 +84,12 @@ class BasicScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
res = await self.score(
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
)
if save_results_dataset:

View file

@ -167,11 +167,11 @@ class BraintrustScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
res = await self.score(input_rows=all_rows.rows, scoring_functions=scoring_functions)
res = await self.score(input_rows=all_rows.data, scoring_functions=scoring_functions)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()

View file

@ -72,12 +72,12 @@ class LlmAsJudgeScoringImpl(
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.scoring.value))
all_rows = await self.datasetio_api.get_rows_paginated(
all_rows = await self.datasetio_api.iterrows(
dataset_id=dataset_id,
rows_in_page=-1,
limit=-1,
)
res = await self.score(
input_rows=all_rows.rows,
input_rows=all_rows.data,
scoring_functions=scoring_functions,
)
if save_results_dataset:

View file

@ -55,4 +55,13 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.safety.bedrock.BedrockSafetyConfig",
),
),
remote_provider_spec(
api=Api.safety,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=["requests"],
module="llama_stack.providers.remote.safety.nvidia",
config_class="llama_stack.providers.remote.safety.nvidia.NVIDIASafetyConfig",
),
),
]

View file

@ -4,13 +4,13 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Optional
from urllib.parse import parse_qs, urlparse
import datasets as hf_datasets
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.datasetio import DatasetIO, IterrowsResponse
from llama_stack.apis.datasets import Dataset
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.datasetio.url_utils import get_dataframe_from_url
from llama_stack.providers.utils.kvstore import kvstore_impl
from .config import HuggingfaceDatasetIOConfig
@ -18,22 +18,14 @@ from .config import HuggingfaceDatasetIOConfig
DATASETS_PREFIX = "datasets:"
def load_hf_dataset(dataset_def: Dataset):
if dataset_def.metadata.get("path", None):
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
else:
df = get_dataframe_from_url(dataset_def.url)
def parse_hf_params(dataset_def: Dataset):
uri = dataset_def.source.uri
parsed_uri = urlparse(uri)
params = parse_qs(parsed_uri.query)
params = {k: v[0] for k, v in params.items()}
path = parsed_uri.path.lstrip("/")
if df is None:
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
dataset = hf_datasets.Dataset.from_pandas(df)
# drop columns not specified by schema
if dataset_def.dataset_schema:
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
return dataset
return path, params
class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -64,7 +56,7 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
key = f"{DATASETS_PREFIX}{dataset_def.identifier}"
await self.kvstore.set(
key=key,
value=dataset_def.json(),
value=dataset_def.model_dump_json(),
)
self.dataset_infos[dataset_def.identifier] = dataset_def
@ -73,41 +65,34 @@ class HuggingfaceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
await self.kvstore.delete(key=key)
del self.dataset_infos[dataset_id]
async def get_rows_paginated(
async def iterrows(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
start_index: Optional[int] = None,
limit: Optional[int] = None,
) -> IterrowsResponse:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
start_index = start_index or 0
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)
start = next_page_token
if rows_in_page == -1:
if limit is None or limit == -1:
end = len(loaded_dataset)
else:
end = min(start + rows_in_page, len(loaded_dataset))
end = min(start_index + limit, len(loaded_dataset))
rows = [loaded_dataset[i] for i in range(start, end)]
rows = [loaded_dataset[i] for i in range(start_index, end)]
return PaginatedRowsResult(
rows=rows,
total_count=len(rows),
next_page_token=str(end),
return IterrowsResponse(
data=rows,
next_start_index=end if end < len(loaded_dataset) else None,
)
async def append_rows(self, dataset_id: str, rows: List[Dict[str, Any]]) -> None:
dataset_def = self.dataset_infos[dataset_id]
loaded_dataset = load_hf_dataset(dataset_def)
path, params = parse_hf_params(dataset_def)
loaded_dataset = hf_datasets.load_dataset(path, **params)
# Convert rows to HF Dataset format
new_dataset = hf_datasets.Dataset.from_list(rows)

View file

@ -12,6 +12,7 @@ from llama_stack.apis.common.content_types import InterleavedContent
from llama_stack.apis.inference import (
ChatCompletionResponse,
ChatCompletionResponseStreamChunk,
CompletionMessage,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
@ -160,12 +161,14 @@ class PassthroughInferenceAdapter(Inference):
client = self._get_client()
response = await client.inference.chat_completion(**json_params)
response = response.to_dict()
# temporary hack to remove the metrics from the response
response["metrics"] = []
return convert_to_pydantic(ChatCompletionResponse, response)
return ChatCompletionResponse(
completion_message=CompletionMessage(
content=response.completion_message.content.text,
stop_reason=response.completion_message.stop_reason,
tool_calls=response.completion_message.tool_calls,
),
logprobs=response.logprobs,
)
async def _stream_chat_completion(self, json_params: Dict[str, Any]) -> AsyncGenerator:
client = self._get_client()

View file

@ -25,6 +25,10 @@ class VLLMInferenceAdapterConfig(BaseModel):
default="fake",
description="The API token",
)
tls_verify: bool = Field(
default=True,
description="Whether to verify TLS certificates",
)
@classmethod
def sample_run_config(
@ -36,4 +40,5 @@ class VLLMInferenceAdapterConfig(BaseModel):
"url": url,
"max_tokens": "${env.VLLM_MAX_TOKENS:4096}",
"api_token": "${env.VLLM_API_TOKEN:fake}",
"tls_verify": "${env.VLLM_TLS_VERIFY:true}",
}

View file

@ -7,6 +7,7 @@ import json
import logging
from typing import AsyncGenerator, List, Optional, Union
import httpx
from openai import AsyncOpenAI
from openai.types.chat.chat_completion_chunk import (
ChatCompletionChunk as OpenAIChatCompletionChunk,
@ -229,7 +230,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = AsyncOpenAI(base_url=self.config.url, api_key=self.config.api_token)
self.client = AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
async def shutdown(self) -> None:
pass

View file

@ -0,0 +1,18 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any
from .config import NVIDIASafetyConfig
async def get_adapter_impl(config: NVIDIASafetyConfig, _deps) -> Any:
from .nvidia import NVIDIASafetyAdapter
impl = NVIDIASafetyAdapter(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class NVIDIASafetyConfig(BaseModel):
"""
Configuration for the NVIDIA Guardrail microservice endpoint.
Attributes:
guardrails_service_url (str): A base url for accessing the NVIDIA guardrail endpoint, e.g. http://0.0.0.0:7331
config_id (str): The ID of the guardrails configuration to use from the configuration store
(https://developer.nvidia.com/docs/nemo-microservices/guardrails/source/guides/configuration-store-guide.html)
"""
guardrails_service_url: str = Field(
default_factory=lambda: os.getenv("GUARDRAILS_SERVICE_URL", "http://0.0.0.0:7331"),
description="The url for accessing the guardrails service",
)
config_id: Optional[str] = Field(default="self-check", description="Config ID to use from the config store")
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"guardrails_service_url": "${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}",
"config_id": "self-check",
}

View file

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, List, Optional
import requests
from llama_stack.apis.inference import Message
from llama_stack.apis.safety import RunShieldResponse, Safety, SafetyViolation, ViolationLevel
from llama_stack.apis.shields import Shield
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
from .config import NVIDIASafetyConfig
logger = logging.getLogger(__name__)
class NVIDIASafetyAdapter(Safety, ShieldsProtocolPrivate):
def __init__(self, config: NVIDIASafetyConfig) -> None:
"""
Initialize the NVIDIASafetyAdapter with a given safety configuration.
Args:
config (NVIDIASafetyConfig): The configuration containing the guardrails service URL and config ID.
"""
print(f"Initializing NVIDIASafetyAdapter({config.guardrails_service_url})...")
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_shield(self, shield: Shield) -> None:
if not shield.provider_resource_id:
raise ValueError("Shield model not provided.")
async def run_shield(
self, shield_id: str, messages: List[Message], params: Optional[dict[str, Any]] = None
) -> RunShieldResponse:
"""
Run a safety shield check against the provided messages.
Args:
shield_id (str): The unique identifier for the shield to be used.
messages (List[Message]): A list of Message objects representing the conversation history.
params (Optional[dict[str, Any]]): Additional parameters for the shield check.
Returns:
RunShieldResponse: The response containing safety violation details if any.
Raises:
ValueError: If the shield with the provided shield_id is not found.
"""
shield = await self.shield_store.get_shield(shield_id)
if not shield:
raise ValueError(f"Shield {shield_id} not found")
self.shield = NeMoGuardrails(self.config, shield.shield_id)
return await self.shield.run(messages)
class NeMoGuardrails:
"""
A class that encapsulates NVIDIA's guardrails safety logic.
Sends messages to the guardrails service and interprets the response to determine
if a safety violation has occurred.
"""
def __init__(
self,
config: NVIDIASafetyConfig,
model: str,
threshold: float = 0.9,
temperature: float = 1.0,
):
"""
Initialize a NeMoGuardrails instance with the provided parameters.
Args:
config (NVIDIASafetyConfig): The safety configuration containing the config ID and guardrails URL.
model (str): The identifier or name of the model to be used for safety checks.
threshold (float, optional): The threshold for flagging violations. Defaults to 0.9.
temperature (float, optional): The temperature setting for the underlying model. Must be greater than 0. Defaults to 1.0.
Raises:
ValueError: If temperature is less than or equal to 0.
AssertionError: If config_id is not provided in the configuration.
"""
self.config_id = config.config_id
self.model = model
assert self.config_id is not None, "Must provide config id"
if temperature <= 0:
raise ValueError("Temperature must be greater than 0")
self.temperature = temperature
self.threshold = threshold
self.guardrails_service_url = config.guardrails_service_url
async def run(self, messages: List[Message]) -> RunShieldResponse:
"""
Queries the /v1/guardrails/checks endpoint of the NeMo guardrails deployed API.
Args:
messages (List[Message]): A list of Message objects to be checked for safety violations.
Returns:
RunShieldResponse: If the response indicates a violation ("blocked" status), returns a
RunShieldResponse with a SafetyViolation; otherwise, returns a RunShieldResponse with violation set to None.
Raises:
requests.HTTPError: If the POST request fails.
"""
headers = {
"Accept": "application/json",
}
request_data = {
"model": self.model,
"messages": convert_pydantic_to_json_value(messages),
"temperature": self.temperature,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": self.config_id,
},
}
response = requests.post(
url=f"{self.guardrails_service_url}/v1/guardrail/checks", headers=headers, json=request_data
)
response.raise_for_status()
if "Content-Type" in response.headers and response.headers["Content-Type"].startswith("application/json"):
response_json = response.json()
if response_json["status"] == "blocked":
user_message = "Sorry I cannot do this."
metadata = response_json["rails_status"]
return RunShieldResponse(
violation=SafetyViolation(
user_message=user_message,
violation_level=ViolationLevel.ERROR,
metadata=metadata,
)
)
return RunShieldResponse(violation=None)

View file

@ -10,18 +10,17 @@ from urllib.parse import unquote
import pandas
from llama_stack.apis.common.content_types import URL
from llama_stack.providers.utils.memory.vector_store import parse_data_url
def get_dataframe_from_url(url: URL):
def get_dataframe_from_uri(uri: str):
df = None
if url.uri.endswith(".csv"):
df = pandas.read_csv(url.uri)
elif url.uri.endswith(".xlsx"):
df = pandas.read_excel(url.uri)
elif url.uri.startswith("data:"):
parts = parse_data_url(url.uri)
if uri.endswith(".csv"):
df = pandas.read_csv(uri)
elif uri.endswith(".xlsx"):
df = pandas.read_excel(uri)
elif uri.startswith("data:"):
parts = parse_data_url(uri)
data = parts["data"]
if parts["is_base64"]:
data = base64.b64decode(data)
@ -39,6 +38,6 @@ def get_dataframe_from_url(url: URL):
else:
df = pandas.read_excel(data_bytes)
else:
raise ValueError(f"Unsupported file type: {url}")
raise ValueError(f"Unsupported file type: {uri}")
return df

View file

@ -192,7 +192,11 @@ class LiteLLMOpenAIMixin(
if request.tools:
input_dict["tools"] = [convert_tooldef_to_openai_tool(tool) for tool in request.tools]
if request.tool_config.tool_choice:
input_dict["tool_choice"] = request.tool_config.tool_choice.value
input_dict["tool_choice"] = (
request.tool_config.tool_choice.value
if isinstance(request.tool_config.tool_choice, ToolChoice)
else request.tool_config.tool_choice
)
provider_data = self.get_request_provider_data()
key_field = self.provider_data_api_key_field

View file

@ -527,27 +527,31 @@ async def convert_message_to_openai_dict_new(
async def _convert_message_content(
content: InterleavedContent,
) -> Union[str, Iterable[OpenAIChatCompletionContentPartParam]]:
async def impl():
async def impl(
content_: InterleavedContent,
) -> Union[str, OpenAIChatCompletionContentPartParam, List[OpenAIChatCompletionContentPartParam]]:
# Llama Stack and OpenAI spec match for str and text input
if isinstance(content, str):
return content
elif isinstance(content, TextContentItem):
if isinstance(content_, str):
return content_
elif isinstance(content_, TextContentItem):
return OpenAIChatCompletionContentPartTextParam(
type="text",
text=content.text,
text=content_.text,
)
elif isinstance(content, ImageContentItem):
elif isinstance(content_, ImageContentItem):
return OpenAIChatCompletionContentPartImageParam(
type="image_url",
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content)),
image_url=OpenAIImageURL(url=await convert_image_content_to_url(content_)),
)
elif isinstance(content, list):
return [await _convert_message_content(item) for item in content]
elif isinstance(content_, list):
return [await impl(item) for item in content_]
else:
raise ValueError(f"Unsupported content type: {type(content)}")
raise ValueError(f"Unsupported content type: {type(content_)}")
ret = await impl()
if isinstance(ret, str) or isinstance(ret, list) or isinstance(ret, dict):
ret = await impl(content)
# OpenAI*Message expects a str or list
if isinstance(ret, str) or isinstance(ret, list):
return ret
else:
return [ret]
@ -566,13 +570,14 @@ async def convert_message_to_openai_dict_new(
OpenAIChatCompletionMessageToolCall(
id=tool.call_id,
function=OpenAIFunction(
name=tool.tool_name,
name=tool.tool_name if not isinstance(tool.tool_name, BuiltinTool) else tool.tool_name.value,
arguments=json.dumps(tool.arguments),
),
type="function",
)
for tool in message.tool_calls
],
]
or None,
)
elif isinstance(message, ToolResponseMessage):
out = OpenAIChatCompletionToolMessage(
@ -858,7 +863,8 @@ async def convert_openai_chat_completion_stream(
event_type = ChatCompletionResponseEventType.progress
stop_reason = None
toolcall_buffer = {}
tool_call_idx_to_buffer = {}
async for chunk in stream:
choice = chunk.choices[0] # assuming only one choice per chunk
@ -868,7 +874,6 @@ async def convert_openai_chat_completion_stream(
# if there's a tool call, emit an event for each tool in the list
# if tool call and content, emit both separately
if choice.delta.tool_calls:
# the call may have content and a tool call. ChatCompletionResponseEvent
# does not support both, so we emit the content first
@ -889,44 +894,53 @@ async def convert_openai_chat_completion_stream(
)
if not enable_incremental_tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
delta=ToolCallDelta(
tool_call=_convert_openai_tool_calls(choice.delta.tool_calls)[0],
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(logprobs),
for tool_call in choice.delta.tool_calls:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=_convert_openai_tool_calls([tool_call])[0],
parse_status=ToolCallParseStatus.succeeded,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
)
else:
tool_call = choice.delta.tool_calls[0]
if "name" not in toolcall_buffer:
toolcall_buffer["call_id"] = tool_call.id
toolcall_buffer["name"] = None
toolcall_buffer["content"] = ""
if "arguments" not in toolcall_buffer:
toolcall_buffer["arguments"] = ""
for tool_call in choice.delta.tool_calls:
idx = tool_call.index if hasattr(tool_call, "index") else 0
if tool_call.function.name:
toolcall_buffer["name"] = tool_call.function.name
delta = f"{toolcall_buffer['name']}("
if tool_call.function.arguments:
toolcall_buffer["arguments"] += tool_call.function.arguments
delta = toolcall_buffer["arguments"]
if idx not in tool_call_idx_to_buffer:
tool_call_idx_to_buffer[idx] = {
"call_id": tool_call.id,
"name": None,
"arguments": "",
"content": "",
}
toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
else:
buffer = tool_call_idx_to_buffer[idx]
if tool_call.function:
if tool_call.function.name:
buffer["name"] = tool_call.function.name
delta = f"{buffer['name']}("
buffer["content"] += delta
if tool_call.function.arguments:
delta = tool_call.function.arguments
buffer["arguments"] += delta
buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
elif choice.delta.content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
@ -935,47 +949,51 @@ async def convert_openai_chat_completion_stream(
)
)
if toolcall_buffer:
delta = ")"
toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=event_type,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
logprobs=_convert_openai_logprobs(logprobs),
)
)
try:
arguments = json.loads(toolcall_buffer["arguments"])
tool_call = ToolCall(
call_id=toolcall_buffer["call_id"],
tool_name=toolcall_buffer["name"],
arguments=arguments,
)
for idx, buffer in tool_call_idx_to_buffer.items():
logger.debug(f"toolcall_buffer[{idx}]: {buffer}")
if buffer["name"]:
delta = ")"
buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
event_type=event_type,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
),
stop_reason=stop_reason,
logprobs=None,
)
)
except json.JSONDecodeError:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.complete,
delta=ToolCallDelta(
tool_call=toolcall_buffer["content"],
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
try:
arguments = json.loads(buffer["arguments"])
tool_call = ToolCall(
call_id=buffer["call_id"],
tool_name=buffer["name"],
arguments=arguments,
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=tool_call,
parse_status=ToolCallParseStatus.succeeded,
),
stop_reason=stop_reason,
)
)
except json.JSONDecodeError as e:
print(f"Failed to parse arguments: {e}")
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.progress,
delta=ToolCallDelta(
tool_call=buffer["content"],
parse_status=ToolCallParseStatus.failed,
),
stop_reason=stop_reason,
)
)
)
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(

View file

@ -1,5 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -1,158 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import concurrent.futures
import importlib
import json
import subprocess
import sys
from functools import partial
from pathlib import Path
from typing import Iterable
from rich.progress import Progress, SpinnerColumn, TextColumn
from llama_stack.distribution.build import (
SERVER_DEPENDENCIES,
get_provider_dependencies,
)
REPO_ROOT = Path(__file__).parent.parent.parent
class ChangedPathTracker:
"""Track a list of paths we may have changed."""
def __init__(self):
self._changed_paths = []
def add_paths(self, *paths):
for path in paths:
path = str(path)
if path not in self._changed_paths:
self._changed_paths.append(path)
def changed_paths(self):
return self._changed_paths
def find_template_dirs(templates_dir: Path) -> Iterable[Path]:
"""Find immediate subdirectories in the templates folder."""
if not templates_dir.exists():
raise FileNotFoundError(f"Templates directory not found: {templates_dir}")
return sorted(d for d in templates_dir.iterdir() if d.is_dir() and d.name != "__pycache__")
def process_template(template_dir: Path, progress, change_tracker: ChangedPathTracker) -> None:
"""Process a single template directory."""
progress.print(f"Processing {template_dir.name}")
try:
# Import the module directly
module_name = f"llama_stack.templates.{template_dir.name}"
module = importlib.import_module(module_name)
# Get and save the distribution template
if template_func := getattr(module, "get_distribution_template", None):
template = template_func()
yaml_output_dir = REPO_ROOT / "llama_stack" / "templates" / template.name
doc_output_dir = REPO_ROOT / "docs/source/distributions" / f"{template.distro_type}_distro"
change_tracker.add_paths(yaml_output_dir, doc_output_dir)
template.save_distribution(
yaml_output_dir=yaml_output_dir,
doc_output_dir=doc_output_dir,
)
else:
progress.print(f"[yellow]Warning: {template_dir.name} has no get_distribution_template function")
except Exception as e:
progress.print(f"[red]Error processing {template_dir.name}: {str(e)}")
raise e
def check_for_changes(change_tracker: ChangedPathTracker) -> bool:
"""Check if there are any uncommitted changes."""
has_changes = False
for path in change_tracker.changed_paths():
result = subprocess.run(
["git", "diff", "--exit-code", path],
cwd=REPO_ROOT,
capture_output=True,
)
if result.returncode != 0:
print(f"Change detected in '{path}'.", file=sys.stderr)
has_changes = True
return has_changes
def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[str]]:
try:
module_name = f"llama_stack.templates.{template_dir.name}"
module = importlib.import_module(module_name)
if template_func := getattr(module, "get_distribution_template", None):
template = template_func()
normal_deps, special_deps = get_provider_dependencies(template.providers)
# Combine all dependencies in order: normal deps, special deps, server deps
all_deps = sorted(set(normal_deps + SERVER_DEPENDENCIES)) + sorted(set(special_deps))
return template.name, all_deps
except Exception:
return None, []
return None, []
def generate_dependencies_file(change_tracker: ChangedPathTracker):
templates_dir = REPO_ROOT / "llama_stack" / "templates"
distribution_deps = {}
for template_dir in find_template_dirs(templates_dir):
name, deps = collect_template_dependencies(template_dir)
if name:
distribution_deps[name] = deps
deps_file = REPO_ROOT / "distributions" / "dependencies.json"
change_tracker.add_paths(deps_file)
with open(deps_file, "w") as f:
f.write(json.dumps(distribution_deps, indent=2) + "\n")
def main():
templates_dir = REPO_ROOT / "llama_stack" / "templates"
change_tracker = ChangedPathTracker()
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
) as progress:
template_dirs = list(find_template_dirs(templates_dir))
task = progress.add_task("Processing distribution templates...", total=len(template_dirs))
# Create a partial function with the progress bar
process_func = partial(process_template, progress=progress, change_tracker=change_tracker)
# Process templates in parallel
with concurrent.futures.ThreadPoolExecutor() as executor:
# Submit all tasks and wait for completion
list(executor.map(process_func, template_dirs))
progress.update(task, advance=len(template_dirs))
generate_dependencies_file(change_tracker)
if check_for_changes(change_tracker):
print(
"Distribution template changes detected. Please commit the changes.",
file=sys.stderr,
)
sys.exit(1)
sys.exit(0)
if __name__ == "__main__":
main()

View file

@ -1,72 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import importlib
from pathlib import Path
import fire
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.providers.inline.inference.meta_reference.config import MetaReferenceInferenceConfig
from llama_stack.providers.inline.inference.meta_reference.generation import Llama
THIS_DIR = Path(__file__).parent.resolve()
def run_main(
model_id: str,
checkpoint_dir: str,
module_name: str,
output_path: str,
):
module = importlib.import_module(module_name)
assert hasattr(module, "usecases"), f"Module {module_name} missing usecases function"
config = MetaReferenceInferenceConfig(
model=model_id,
max_seq_len=512,
max_batch_size=1,
checkpoint_dir=checkpoint_dir,
)
llama_model = resolve_model(model_id)
if not llama_model:
raise ValueError(f"Model {model_id} not found")
generator = Llama.build(
config=config,
model_id=model_id,
llama_model=llama_model,
)
use_cases = module.usecases()
text = ""
for u in use_cases:
if isinstance(u, str):
use_case_text = f"\n{u}\n"
else:
use_case_text = u.to_text(generator)
text += use_case_text
print(use_case_text)
text += "Thank You!\n"
with open(output_path, "w") as f:
f.write(text)
def main():
fire.Fire(run_main)
if __name__ == "__main__":
main()

View file

@ -1,66 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import os
from pathlib import Path
import pytest
"""
Script for running api on AsyncLlamaStackAsLibraryClient with templates
Assuming directory structure:
- llama-stack
- llama_stack
- scripts
- tests
- api
Example command:
cd llama-stack
EXPORT TOGETHER_API_KEY=<..>
EXPORT FIREWORKS_API_KEY=<..>
python llama_stack/scripts/run_client_sdk_tests.py --templates together fireworks --report
"""
REPO_ROOT = Path(__file__).parent.parent.parent
CLIENT_SDK_TESTS_RELATIVE_PATH = "tests/api/"
def main(parser: argparse.ArgumentParser):
args = parser.parse_args()
templates_dir = REPO_ROOT / "llama_stack" / "templates"
user_specified_templates = [templates_dir / t for t in args.templates] if args.templates else []
for d in templates_dir.iterdir():
if d.is_dir() and d.name != "__pycache__":
template_configs = list(d.rglob("run.yaml"))
if len(template_configs) == 0:
continue
config = template_configs[0]
if user_specified_templates:
if not any(config.parent == t for t in user_specified_templates):
continue
os.environ["LLAMA_STACK_CONFIG"] = str(config)
pytest_args = "--report" if args.report else ""
pytest.main(
[
pytest_args,
"-s",
"-v",
str(REPO_ROOT / CLIENT_SDK_TESTS_RELATIVE_PATH),
]
)
if __name__ == "__main__":
parser = argparse.ArgumentParser(
prog="llama_test",
)
parser.add_argument("--templates", nargs="+")
parser.add_argument("--report", action="store_true")
main(parser)

View file

@ -1,15 +0,0 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
THIS_DIR="$(cd "$(dirname "$(readlink -f "${BASH_SOURCE[0]}")")" && pwd)"
set -euo pipefail
set -x
stack_dir=$(dirname $(dirname $THIS_DIR))
PYTHONPATH=$stack_dir pytest -p no:warnings --asyncio-mode auto --tb=short

View file

@ -1,13 +1,13 @@
version: '2'
distribution_spec:
description: Use NVIDIA NIM for running LLM inference
description: Use NVIDIA NIM for running LLM inference and safety
providers:
inference:
- remote::nvidia
vector_io:
- inline::faiss
safety:
- inline::llama-guard
- remote::nvidia
agents:
- inline::meta-reference
telemetry:
@ -15,16 +15,9 @@ distribution_spec:
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda

View file

@ -6,9 +6,10 @@
from pathlib import Path
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
@ -16,19 +17,13 @@ def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::nvidia"],
"vector_io": ["inline::faiss"],
"safety": ["inline::llama-guard"],
"safety": ["remote::nvidia"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
"datasetio": ["inline::localfs"],
"scoring": ["inline::basic"],
"tool_runtime": ["inline::rag-runtime"],
}
inference_provider = Provider(
@ -36,30 +31,35 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia",
config=NVIDIAConfig.sample_run_config(),
)
safety_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIASafetyConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
)
safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}",
provider_id="nvidia",
)
available_models = {
"nvidia": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
default_models = get_model_registry(available_models)
return DistributionTemplate(
name="nvidia",
distro_type="remote_hosted",
description="Use NVIDIA NIM for running LLM inference",
description="Use NVIDIA NIM for running LLM inference and safety",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
@ -72,15 +72,34 @@ def get_distribution_template() -> DistributionTemplate:
default_models=default_models,
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(
provider_overrides={
"inference": [
inference_provider,
safety_provider,
]
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"NVIDIA_API_KEY": (
"",
"NVIDIA API Key",
),
"GUARDRAILS_SERVICE_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"INFERENCE_MODEL": (
"Llama3.1-8B-Instruct",
"Inference model",
),
"SAFETY_MODEL": (
"meta/llama-3.1-8b-instruct",
"Name of the model to use for safety",
),
},
)

View file

@ -0,0 +1,101 @@
version: '2'
image_name: nvidia
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: nvidia
provider_type: remote::nvidia
config:
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:}
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: nvidia
provider_type: remote::nvidia
config:
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
datasetio:
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
tool_runtime:
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
models:
- metadata: {}
model_id: ${env.INFERENCE_MODEL}
provider_id: nvidia
model_type: llm
- metadata: {}
model_id: ${env.SAFETY_MODEL}
provider_id: nvidia
model_type: llm
shields:
- shield_id: ${env.SAFETY_MODEL}
provider_id: nvidia
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::rag
provider_id: rag-runtime
server:
port: 8321

View file

@ -26,10 +26,11 @@ providers:
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
- provider_id: nvidia
provider_type: remote::nvidia
config:
excluded_categories: []
guardrails_service_url: ${env.GUARDRAILS_SERVICE_URL:http://localhost:7331}
config_id: self-check
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
@ -54,13 +55,6 @@ providers:
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
@ -72,33 +66,10 @@ providers:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/registry.db
@ -227,11 +198,7 @@ datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -6,7 +6,7 @@
from typing import Dict, List, Tuple
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.datasets import DatasetPurpose, URIDataSource
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
BenchmarkInput,
@ -171,76 +171,42 @@ def get_distribution_template() -> DistributionTemplate:
DatasetInput(
dataset_id="simpleqa",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/simpleqa"),
metadata={
"path": "llamastack/simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/simpleqa?split=train",
),
),
DatasetInput(
dataset_id="mmlu_cot",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/mmlu_cot"),
metadata={
"path": "llamastack/mmlu_cot",
"name": "all",
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/mmlu_cot?split=test&name=all",
),
),
DatasetInput(
dataset_id="gpqa_cot",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/gpqa_0shot_cot"),
metadata={
"path": "llamastack/gpqa_0shot_cot",
"name": "gpqa_main",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main",
),
),
DatasetInput(
dataset_id="math_500",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/math_500"),
metadata={
"path": "llamastack/math_500",
"split": "test",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/math_500?split=test",
),
),
DatasetInput(
dataset_id="bfcl",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/llamastack/bfcl_v3"),
metadata={
"path": "llamastack/bfcl_v3",
"split": "train",
},
dataset_schema={
"function": {"type": "string"},
"language": {"type": "string"},
"ground_truth": {"type": "string"},
"id": {"type": "string"},
"chat_completion_input": {"type": "string"},
},
purpose=DatasetPurpose.eval_messages_answer,
source=URIDataSource(
uri="huggingface://datasets/llamastack/bfcl_v3?split=train",
),
),
]

View file

@ -158,80 +158,39 @@ shields:
- shield_id: meta-llama/Llama-Guard-3-8B
vector_dbs: []
datasets:
- dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/simpleqa
metadata:
path: llamastack/simpleqa
split: train
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/simpleqa?split=train
metadata: {}
dataset_id: simpleqa
provider_id: huggingface
- dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/mmlu_cot
metadata:
path: llamastack/mmlu_cot
name: all
split: test
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/mmlu_cot?split=test&name=all
metadata: {}
dataset_id: mmlu_cot
provider_id: huggingface
- dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/gpqa_0shot_cot
metadata:
path: llamastack/gpqa_0shot_cot
name: gpqa_main
split: train
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/gpqa_0shot_cot?split=test&name=gpqa_main
metadata: {}
dataset_id: gpqa_cot
provider_id: huggingface
- dataset_schema:
input_query:
type: string
expected_answer:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/math_500
metadata:
path: llamastack/math_500
split: test
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/math_500?split=test
metadata: {}
dataset_id: math_500
provider_id: huggingface
- dataset_schema:
function:
type: string
language:
type: string
ground_truth:
type: string
id:
type: string
chat_completion_input:
type: string
url:
uri: https://huggingface.co/datasets/llamastack/bfcl_v3
metadata:
path: llamastack/bfcl_v3
split: train
- purpose: eval/messages-answer
source:
type: uri
uri: huggingface://datasets/llamastack/bfcl_v3?split=train
metadata: {}
dataset_id: bfcl
provider_id: huggingface
scoring_fns: []

View file

@ -18,12 +18,14 @@ providers:
url: ${env.VLLM_URL}
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
api_token: ${env.VLLM_API_TOKEN:fake}
tls_verify: ${env.VLLM_TLS_VERIFY:true}
- provider_id: vllm-safety
provider_type: remote::vllm
config:
url: ${env.SAFETY_VLLM_URL}
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
api_token: ${env.VLLM_API_TOKEN:fake}
tls_verify: ${env.VLLM_TLS_VERIFY:true}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}

View file

@ -18,6 +18,7 @@ providers:
url: ${env.VLLM_URL}
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
api_token: ${env.VLLM_API_TOKEN:fake}
tls_verify: ${env.VLLM_TLS_VERIFY:true}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}

View file

@ -11,6 +11,7 @@ import jinja2
import yaml
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetPurpose
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import (
Api,
@ -214,7 +215,9 @@ class DistributionTemplate(BaseModel):
# Register YAML representer for ModelType
yaml.add_representer(ModelType, enum_representer)
yaml.add_representer(DatasetPurpose, enum_representer)
yaml.SafeDumper.add_representer(ModelType, enum_representer)
yaml.SafeDumper.add_representer(DatasetPurpose, enum_representer)
for output_dir in [yaml_output_dir, doc_output_dir]:
output_dir.mkdir(parents=True, exist_ok=True)