Merge branch 'main' into max_infer_iters

This commit is contained in:
Xi Yan 2025-02-28 12:31:31 -08:00
commit 4f94f5a708
62 changed files with 2590 additions and 324 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1145,6 +1145,7 @@
}
],
"source": [
"# NBVAL_SKIP\n",
"from pydantic import BaseModel\n",
"\n",
"\n",
@ -2885,7 +2886,6 @@
}
],
"source": [
"# NBVAL_SKIP\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n",
@ -4326,7 +4326,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "toolchain",
"display_name": "master",
"language": "python",
"name": "python3"
},

View file

@ -55,6 +55,7 @@ def main(output_dir: str):
a set of endpoints and their corresponding interfaces that are tailored to
best leverage Llama Models.""",
),
include_standard_error_responses=True,
),
)

View file

@ -10,6 +10,7 @@ import typing
from dataclasses import make_dataclass
from typing import Any, Dict, Set, Union
from llama_stack.apis.datatypes import Error
from llama_stack.strong_typing.core import JsonType
from llama_stack.strong_typing.docstring import Docstring, parse_type
from llama_stack.strong_typing.inspection import (
@ -434,6 +435,75 @@ class Generator:
)
self.schema_builder = SchemaBuilder(schema_generator)
self.responses = {}
# Create standard error responses
self._create_standard_error_responses()
def _create_standard_error_responses(self) -> None:
"""
Creates standard error responses that can be reused across operations.
These will be added to the components.responses section of the OpenAPI document.
"""
# Get the Error schema
error_schema = self.schema_builder.classdef_to_ref(Error)
# Create standard error responses
self.responses["BadRequest400"] = Response(
description="The request was invalid or malformed",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 400,
"title": "Bad Request",
"detail": "The request was invalid or malformed",
}
)
}
)
self.responses["TooManyRequests429"] = Response(
description="The client has sent too many requests in a given amount of time",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 429,
"title": "Too Many Requests",
"detail": "You have exceeded the rate limit. Please try again later.",
}
)
}
)
self.responses["InternalServerError500"] = Response(
description="The server encountered an unexpected error",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 500,
"title": "Internal Server Error",
"detail": "An unexpected error occurred. Our team has been notified.",
}
)
}
)
# Add a default error response for any unhandled error cases
self.responses["DefaultError"] = Response(
description="An unexpected error occurred",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 0,
"title": "Error",
"detail": "An unexpected error occurred",
}
)
}
)
def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
# Don't include schema definition in the tag description because for one,
@ -649,6 +719,18 @@ class Generator:
responses.update(response_builder.build_response(response_options))
assert len(responses.keys()) > 0, f"No responses found for {op.name}"
# Add standard error response references
if self.options.include_standard_error_responses:
if "400" not in responses:
responses["400"] = ResponseRef("BadRequest400")
if "429" not in responses:
responses["429"] = ResponseRef("TooManyRequests429")
if "500" not in responses:
responses["500"] = ResponseRef("InternalServerError500")
if "default" not in responses:
responses["default"] = ResponseRef("DefaultError")
if op.event_type is not None:
builder = ContentBuilder(self.schema_builder)
callbacks = {

View file

@ -35,6 +35,7 @@ class Options:
:param error_wrapper: True if errors are encapsulated in an error object wrapper.
:param property_description_fun: Custom transformation function to apply to class property documentation strings.
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
:param include_standard_error_responses: Whether to include standard error responses (400, 429, 500, 503) in all operations.
"""
server: Server
@ -52,6 +53,7 @@ class Options:
error_wrapper: bool = False
property_description_fun: Optional[Callable[[type, str, str], str]] = None
captions: Optional[Dict[str, str]] = None
include_standard_error_responses: bool = True
default_captions: ClassVar[Dict[str, str]] = {
"Operations": "Operations",

View file

@ -106,7 +106,7 @@ It would be best to start with a template and understand the structure of the co
llama stack build
> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack
> Enter the image type you want your Llama Stack to be built as (container or conda): conda
> Enter the image type you want your Llama Stack to be built as (container or conda or venv): conda
Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs.
@ -187,7 +187,7 @@ usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-i
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
config
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
positional arguments:
config Path to config file to use for the run

View file

@ -41,12 +41,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ ls ~/.llama/checkpoints
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M
$ llama model list --downloaded
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Running the Distribution

View file

@ -41,12 +41,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ ls ~/.llama/checkpoints
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M
$ llama model list --downloaded
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Running the Distribution

View file

@ -38,7 +38,7 @@ The API is **exactly identical** for both clients.
:::{dropdown} Starting up the Llama Stack server
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image.
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the configurations, please check out [this guide](../references/index.md).
Lets setup some environment variables that we will use in the rest of the guide.
```bash

View file

@ -129,3 +129,35 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```

View file

@ -154,6 +154,38 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Understand the models
The `llama model` command helps you explore the models interface.

View file

@ -5,6 +5,9 @@
# the root directory of this source tree.
from enum import Enum
from typing import Optional
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@ -33,3 +36,20 @@ class Api(Enum):
# built-in API
inspect = "inspect"
@json_schema_type
class Error(BaseModel):
"""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code
:param title: Error title, a short summary of the error which is invariant for an error type
:param detail: Error detail, a longer human-readable description of the error
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
"""
status: int
title: str
detail: str
instance: Optional[str] = None

View file

@ -9,6 +9,7 @@ import textwrap
from io import StringIO
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
@ -48,7 +49,26 @@ class ModelPromptFormat(Subcommand):
supported_model_ids = [
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
]
model_str = "\n".join([m.value for m in supported_model_ids])
model_list = [m.value for m in supported_model_ids]
model_str = "\n".join(model_list)
if args.list:
headers = ["Model(s)"]
rows = []
for m in model_list:
rows.append(
[
m,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
return
try:
model_id = CoreModelId(args.model_name)
except ValueError:

View file

@ -141,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
completer=WordCompleter(available_providers),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in available_providers,
lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
error_message="Invalid provider, use <TAB> to see options",
),
)

View file

@ -112,7 +112,7 @@ def test_parse_and_maybe_upgrade_config_old_format(old_config):
inference_providers = result.providers["inference"]
assert len(inference_providers) == 2
assert set(x.provider_id for x in inference_providers) == {
assert {x.provider_id for x in inference_providers} == {
"remote::ollama-00",
"meta-reference-01",
}

View file

@ -15,7 +15,6 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api
@ -103,8 +102,6 @@ def build_image(
template_or_config,
image_name,
container_base,
str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.container.value),
" ".join(normal_deps),
]
elif build_config.image_type == ImageType.conda.value:

View file

@ -52,7 +52,7 @@ ensure_conda_env_python310() {
local python_version="3.10"
# Check if conda command is available
if ! command -v conda &>/dev/null; then
if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1
fi

View file

@ -1,4 +1,4 @@
#!/bin/bash
#!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
@ -20,26 +20,27 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
if [ "$#" -lt 6 ]; then
if [ "$#" -lt 4 ]; then
# This only works for templates
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
exit 1
fi
set -euo pipefail
template_or_config="$1"
image_name="$2"
container_base="$3"
build_file_path="$4"
host_build_dir="$5"
pip_dependencies="$6"
special_pip_deps="${7:-}"
shift
image_name="$1"
shift
container_base="$1"
shift
pip_dependencies="$1"
shift
special_pip_deps="${1:-}"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
@ -47,8 +48,10 @@ CONTAINER_OPTS=${CONTAINER_OPTS:-}
TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
add_to_container() {
local input
output_file="$TEMP_DIR/Containerfile"
if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file"
@ -58,15 +61,21 @@ add_to_container() {
fi
}
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
# Update and install UBI9 components if UBI9 base image is used
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF
FROM $container_base
WORKDIR /app
RUN microdnf -y update && microdnf install -y iputils net-tools wget \
RUN dnf -y update && dnf install -y iputils net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1
RUN pip install uv
@ -165,6 +174,11 @@ EOF
fi
fi
# remove uv after installation
add_to_container << EOF
RUN pip uninstall -y uv
EOF
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
if [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF
@ -185,26 +199,31 @@ RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache
EOF
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n"
cat $TEMP_DIR/Containerfile
printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
cat "$TEMP_DIR"/Containerfile
printf "\n"
mounts=""
# Start building the CLI arguments
CLI_ARGS=()
# Read CONTAINER_OPTS and put it in an array
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount"
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
fi
if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_MODELS_DIR"):$models_mount")
fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount"
CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
fi
fi
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
if is_command_available selinuxenabled && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable"
CLI_ARGS+=("--security-opt" "label=disable")
fi
# Set version tag based on PyPI version
@ -225,11 +244,11 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture
ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then
PLATFORM="--platform $BUILD_PLATFORM"
CLI_ARGS+=("--platform $BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
PLATFORM="--platform linux/arm64"
CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then
PLATFORM="--platform linux/amd64"
CLI_ARGS+=("--platform" "linux/amd64")
else
echo "Unsupported architecture: $ARCH"
exit 1
@ -238,8 +257,13 @@ fi
echo "PWD: $(pwd)"
echo "Containerfile: $TEMP_DIR/Containerfile"
set -x
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain
$CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"." \
--progress=plain
# clean up tmp/configs
set +x

View file

@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]:
return [v for v in Api]
return list(Api)
class AutoRoutedApiInfo(BaseModel):
@ -55,7 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]:
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect]

View file

@ -115,8 +115,8 @@ async def resolve_impls(
- flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation
"""
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis())
router_apis = set(x.router_api for x in builtin_automatically_routed_apis())
routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = {}

View file

@ -318,14 +318,7 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
)
model = await self.get_object_by_identifier("model", embedding_model)
if model is None:
if embedding_model == "all-MiniLM-L6-v2":
raise ValueError(
"Embeddings are now served via Inference providers. "
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
)
else:
raise ValueError(f"Model {embedding_model} not found")
raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model")
if "embedding_dimension" not in model.metadata:

View file

@ -134,7 +134,7 @@ def rag_chat_page():
dict(
name="builtin::rag/knowledge_search",
args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs],
"vector_db_ids": list(selected_vector_dbs),
},
)
],

View file

@ -46,7 +46,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"]
for envpath in envs:
if envpath.endswith(env_name):
if os.path.basename(envpath) == env_name:
return envpath
return None

View file

@ -226,10 +226,9 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
DEFAULT_PROMPT = textwrap.dedent(
"""
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
{{ function_description }}
""".strip("\n")

View file

@ -611,8 +611,17 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None:
stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason)
span.set_attribute("input", [m.model_dump_json() for m in input_messages])
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}")
span.set_attribute(
"input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
n_iter += 1
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
@ -796,10 +805,10 @@ class ChatAgent(ShieldRunnerMixin):
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include
agent_config_toolgroups = set(
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup)
agent_config_toolgroups = {
toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
for toolgroup in self.agent_config.toolgroups
)
}
toolgroups_for_turn_set = (
agent_config_toolgroups
if toolgroups_for_turn is None

View file

@ -3,6 +3,7 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional
from tqdm import tqdm
@ -86,7 +87,6 @@ class MetaReferenceEvalImpl(
) -> Job:
task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
@ -117,7 +117,7 @@ class MetaReferenceEvalImpl(
generations = []
for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = eval(str(x[ColumnName.chat_completion_input.value]))
input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row
@ -159,7 +159,7 @@ class MetaReferenceEvalImpl(
generations = []
for x in tqdm(input_rows):
if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value]))
input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.completion(
model=candidate.model,
content=input_content,
@ -167,9 +167,8 @@ class MetaReferenceEvalImpl(
)
generations.append({ColumnName.generated_answer.value: response.completion_message.content})
elif ColumnName.chat_completion_input.value in x:
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value])
input_messages = eval(chat_completion_input_str)
input_messages = [UserMessage(**x) for x in input_messages]
chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in chat_completion_input_json]
messages = []
if candidate.system_message:
messages.append(candidate.system_message)

View file

@ -208,7 +208,6 @@ class MetaReferenceInferenceImpl(
logprobs = []
stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request):
tokens.append(token_result.token)
if token_result.text == "<|eot_id|>":

View file

@ -207,7 +207,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
return parse_message(maybe_json)
except json.JSONDecodeError:
return None
except ValueError as e:
except ValueError:
return None
@ -352,7 +352,7 @@ class ModelParallelProcessGroup:
if isinstance(obj, TaskResponse):
yield obj.result
except GeneratorExit as e:
except GeneratorExit:
self.request_socket.send(encode_msg(CancelSentinel()))
while True:
obj_json = self.request_socket.send()

View file

@ -7,6 +7,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# The file gets a special treatment for now?
# ruff: noqa: N803
import unittest
import torch

View file

@ -10,16 +10,19 @@
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
import json
from typing import Any, Mapping
from llama_stack.providers.utils.common.data_schema_validator import ColumnName
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]:
def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any],
) -> Mapping[str, Any]:
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
"Invalid input row"
)
input_messages = eval(str(sample[ColumnName.chat_completion_input.value]))
input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0]
@ -37,7 +40,7 @@ def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Map
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.dialog.value in sample, "Invalid input row"
role_map = {"user": "human", "assistant": "gpt"}
dialog = eval(str(sample[ColumnName.dialog.value]))
dialog = json.loads(sample[ColumnName.dialog.value])
assert len(dialog) > 1, "dialog must have at least 2 messagse"
roles = []

View file

@ -264,7 +264,7 @@ class LoraFinetuningSingleDevice:
)
self.adapter_params = get_adapter_params(model)
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
set_trainable_params(model, self.adapter_params)

View file

@ -133,7 +133,7 @@ class BraintrustScoringImpl(
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()]
scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
for f in scoring_fn_defs_list:
assert f.identifier.startswith("braintrust"), (
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! "

View file

@ -198,7 +198,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring")
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
await check_health(self._config) # this raises errors

View file

@ -106,7 +106,7 @@ async def convert_chat_completion_request(
payload.update(temperature=strategy.temperature)
elif isinstance(strategy, TopKSamplingStrategy):
if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1")
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=strategy.top_k)
elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1)
@ -168,7 +168,7 @@ def convert_completion_request(
payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k":
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
warnings.warn("top_k must be -1 or >= 1")
warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1)

View file

@ -270,6 +270,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
model = await self.model_store.get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_config is not None:
tool_config.tool_choice = ToolChoice.none
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,

View file

@ -39,12 +39,11 @@ class Testeval:
@pytest.mark.asyncio
async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model):
eval_impl, benchmarks_impl, datasetio_impl, datasets_impl, models_impl = (
eval_impl, benchmarks_impl, datasetio_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.benchmarks],
eval_stack[Api.datasetio],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
@ -92,11 +91,10 @@ class Testeval:
@pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack, inference_model, judge_model):
eval_impl, benchmarks_impl, datasets_impl, models_impl = (
eval_impl, benchmarks_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.benchmarks],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
@ -131,11 +129,10 @@ class Testeval:
@pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack, inference_model):
eval_impl, benchmarks_impl, datasets_impl, models_impl = (
eval_impl, benchmarks_impl, datasets_impl = (
eval_stack[Api.eval],
eval_stack[Api.benchmarks],
eval_stack[Api.datasets],
eval_stack[Api.models],
)
response = await datasets_impl.list_datasets()

View file

@ -18,54 +18,48 @@ from llama_stack.models.llama.sku_list import all_registered_models
INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
SUPPORTED_MODELS = {
"ollama": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
]
),
"fireworks": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
"together": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
]
),
"ollama": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
},
"fireworks": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
},
"together": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value,
},
}

View file

@ -45,13 +45,11 @@ class TestScoring:
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id
@ -102,13 +100,11 @@ class TestScoring:
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets()
@ -163,13 +159,11 @@ class TestScoring:
scoring_functions_impl,
datasetio_impl,
datasets_impl,
models_impl,
) = (
scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio],
scoring_stack[Api.datasets],
scoring_stack[Api.models],
)
await register_dataset(datasets_impl, for_rag=True)
rows = await datasetio_impl.get_rows_paginated(

View file

@ -6,7 +6,7 @@
import json
import logging
import warnings
from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union
from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union
from openai import AsyncStream
from openai.types.chat import (
@ -605,7 +605,7 @@ def convert_tool_call(
tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments),
)
except Exception as e:
except Exception:
return UnparseableToolCall(
call_id=tool_call.id or "",
tool_name=tool_call.function.name or "",
@ -841,14 +841,13 @@ async def convert_openai_chat_completion_stream(
Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk.
"""
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ...
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]:
yield ChatCompletionResponseEventType.start
while True:
yield ChatCompletionResponseEventType.progress
event_type = _event_type_generator()
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=ChatCompletionResponseEventType.start,
delta=TextDelta(text=""),
)
)
event_type = ChatCompletionResponseEventType.progress
stop_reason = None
toolcall_buffer = {}
@ -868,7 +867,7 @@ async def convert_openai_chat_completion_stream(
if choice.delta.content:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
event_type=event_type,
delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(logprobs),
)
@ -877,7 +876,9 @@ async def convert_openai_chat_completion_stream(
# it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1:
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest")
warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
)
if not enable_incremental_tool_calls:
yield ChatCompletionResponseStreamChunk(
@ -909,7 +910,7 @@ async def convert_openai_chat_completion_stream(
toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
event_type=event_type,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,
@ -920,7 +921,7 @@ async def convert_openai_chat_completion_stream(
else:
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(logprobs),
)
@ -931,7 +932,7 @@ async def convert_openai_chat_completion_stream(
toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent(
event_type=next(event_type),
event_type=event_type,
delta=ToolCallDelta(
tool_call=delta,
parse_status=ToolCallParseStatus.in_progress,

View file

@ -36,7 +36,7 @@ class RedisKVStoreImpl(KVStore):
value = await self.redis.get(key)
if value is None:
return None
ttl = await self.redis.ttl(key)
await self.redis.ttl(key)
return value
async def delete(self, key: str) -> None:

View file

@ -32,7 +32,7 @@ def aggregate_categorical_count(
scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(list(set(scores)))
unique_scores = sorted(set(scores))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}}

View file

@ -66,7 +66,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
return [x for x in self.supported_fn_defs_registry.values()]
return list(self.supported_fn_defs_registry.values())
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
if scoring_fn.identifier in self.supported_fn_defs_registry:

View file

@ -6,6 +6,7 @@
import asyncio
import inspect
import json
from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar
@ -17,6 +18,10 @@ T = TypeVar("T")
def serialize_value(value: Any) -> Primitive:
return str(_prepare_for_json(value))
def _prepare_for_json(value: Any) -> str:
"""Serialize a single value into JSON-compatible format."""
if value is None:
return ""
@ -25,9 +30,17 @@ def serialize_value(value: Any) -> Primitive:
elif hasattr(value, "_name_"):
return value._name_
elif isinstance(value, BaseModel):
return value.model_dump_json()
return json.loads(value.model_dump_json())
elif isinstance(value, (list, tuple, set)):
return [_prepare_for_json(item) for item in value]
elif isinstance(value, dict):
return {str(k): _prepare_for_json(v) for k, v in value.items()}
else:
return str(value)
try:
json.dumps(value)
return value
except Exception:
return str(value)
def trace_protocol(cls: Type[T]) -> Type[T]:
@ -104,7 +117,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result))
return result
except Exception as _e:
except Exception as e:
span.set_attribute("error", str(e))
raise
if is_async_gen:

View file

@ -99,7 +99,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
template = template_func()
normal_deps, special_deps = get_provider_dependencies(template.providers)
# Combine all dependencies in order: normal deps, special deps, server deps
all_deps = sorted(list(set(normal_deps + SERVER_DEPENDENCIES))) + sorted(list(set(special_deps)))
all_deps = sorted(set(normal_deps + SERVER_DEPENDENCIES)) + sorted(set(special_deps))
return template.name, all_deps
except Exception:

View file

@ -29,12 +29,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ ls ~/.llama/checkpoints
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M
$ llama model list --downloaded
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Running the Distribution

View file

@ -31,12 +31,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ ls ~/.llama/checkpoints
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M
$ llama model list --downloaded
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Running the Distribution

View file

@ -93,7 +93,7 @@ def get_distribution_template() -> DistributionTemplate:
"inference": [inference_provider],
"vector_io": [vector_io_provider_sqlite],
},
default_models=[inference_model],
default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups,
),
"run-with-safety.yaml": RunConfigSettings(

View file

@ -90,6 +90,12 @@ models:
model_id: ${env.INFERENCE_MODEL}
provider_id: ollama
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding
shields: []
vector_dbs: []
datasets: []

View file

@ -123,39 +123,16 @@ select = [
"I", # isort
]
ignore = [
"E203",
"E305",
"E402",
"E501", # line too long
"E721",
"E741",
"F405",
"F841",
"C408", # ignored because we like the dict keyword argument syntax
"E302",
"W291",
"E303",
"N812", # ignored because import torch.nn.functional as F is PyTorch convention
"N817", # ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
"E731", # allow usage of assigning lambda expressions
# The following ignores are desired by the project maintainers.
"E402", # Module level import not at top of file
"E501", # Line too long
"F405", # Maybe undefined or defined from star import
"C408", # Ignored because we like the dict keyword argument syntax
"N812", # Ignored because import torch.nn.functional as F is PyTorch convention
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
"C901",
"C405",
"C414",
"N803",
"N999",
"C403",
"C416",
"B028",
"C419",
"C401",
"B023",
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
"EXE001",
"N802", # random naming hints don't need
"C901", # Complexity of the function is too high
# these ignores are from flake8-bugbear; please fix!
"B007",
"B008",
]

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -4,20 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import json
from typing import Dict, List
from uuid import uuid4
import pytest
from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.client_tool import ClientTool
from llama_stack_client.lib.agents.client_tool import client_tool
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack_client.types.tool_def_param import Parameter
from llama_stack.apis.agents.agents import (
AgentConfig as Server__AgentConfig,
@ -27,63 +22,22 @@ from llama_stack.apis.agents.agents import (
)
class TestClientTool(ClientTool):
"""Tool to give boiling point of a liquid
Returns the correct value for polyjuice in Celcius and Fahrenheit
and returns -1 for other liquids
@client_tool
def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
"""
Returns the boiling point of a liquid in Celcius or Fahrenheit
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]:
assert len(messages) == 1, "Expected single message"
message = messages[0]
tool_call = message.tool_calls[0]
try:
response = self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"
message = ToolResponseMessage(
role="tool",
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response_str,
)
return message
def get_name(self) -> str:
return "get_boiling_point"
def get_description(self) -> str:
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, Parameter]:
return {
"liquid_name": Parameter(
name="liquid_name",
parameter_type="string",
description="The name of the liquid",
required=True,
),
"celcius": Parameter(
name="celcius",
parameter_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
}
def run_impl(self, liquid_name: str, celcius: bool = True) -> int:
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
:param liquid_name: The name of the liquid
:param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
"""
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -1
return -212
else:
return -1
@pytest.fixture(scope="session")
@ -298,7 +252,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
def test_custom_tool(llama_stack_client, agent_config):
client_tool = TestClientTool()
client_tool = get_boiling_point
agent_config = {
**agent_config,
"toolgroups": ["builtin::websearch"],
@ -326,7 +280,7 @@ def test_custom_tool(llama_stack_client, agent_config):
def test_tool_choice(llama_stack_client, agent_config):
def run_agent(tool_choice):
client_tool = TestClientTool()
client_tool = get_boiling_point
test_agent_config = {
**agent_config,
@ -362,7 +316,7 @@ def test_tool_choice(llama_stack_client, agent_config):
# TODO: fix this flaky test
def xtest_override_system_message_behavior(llama_stack_client, agent_config):
client_tool = TestClientTool()
client_tool = get_boiling_point
agent_config = {
**agent_config,
"instructions": "You are a pirate",
@ -458,7 +412,6 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_id="faiss",
)
llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents,
@ -587,7 +540,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
def test_create_turn_response(llama_stack_client, agent_config):
client_tool = TestClientTool()
client_tool = get_boiling_point
agent_config = {
**agent_config,
"input_shields": [],

View file

@ -117,7 +117,7 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
model_ids = set(m.identifier for m in client.models.list())
model_ids = {m.identifier for m in client.models.list()}
model_ids.update(m.provider_resource_id for m in client.models.list())
if text_model_id and text_model_id not in model_ids:

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -75,6 +75,26 @@ DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
)
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
SUPPORTED_PROVIDERS = {"remote::nvidia"}
MODELS_SUPPORTING_MEDIA = {}
MODELS_SUPPORTING_OUTPUT_DIMENSION = {"nvidia/llama-3.2-nv-embedqa-1b-v2"}
MODELS_REQUIRING_TASK_TYPE = {
"nvidia/llama-3.2-nv-embedqa-1b-v2",
"nvidia/nv-embedqa-e5-v5",
"nvidia/nv-embedqa-mistral-7b-v2",
"snowflake/arctic-embed-l",
}
MODELS_SUPPORTING_TASK_TYPE = MODELS_REQUIRING_TASK_TYPE
def default_task_type(model_id):
"""
Some models require a task type parameter. This provides a default value for
testing those models.
"""
if model_id in MODELS_REQUIRING_TASK_TYPE:
return {"task_type": "query"}
return {}
@pytest.mark.parametrize(
@ -88,8 +108,12 @@ DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64st
"list[text]",
],
)
def test_embedding_text(llama_stack_client, embedding_model_id, contents):
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents)
def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list)
@ -107,9 +131,14 @@ def test_embedding_text(llama_stack_client, embedding_model_id, contents):
"list[url,string,base64,text]",
],
)
@pytest.mark.xfail(reason="Media is not supported")
def test_embedding_image(llama_stack_client, embedding_model_id, contents):
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents)
def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_MEDIA:
pytest.xfail(f"{embedding_model_id} doesn't support media")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list)
@ -134,9 +163,16 @@ def test_embedding_image(llama_stack_client, embedding_model_id, contents):
"short",
],
)
def test_embedding_truncation(llama_stack_client, embedding_model_id, text_truncation, contents):
def test_embedding_truncation(
llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, text_truncation=text_truncation
model_id=embedding_model_id,
contents=contents,
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1
@ -162,25 +198,43 @@ def test_embedding_truncation(llama_stack_client, embedding_model_id, text_trunc
"long-str",
],
)
def test_embedding_truncation_error(llama_stack_client, embedding_model_id, text_truncation, contents):
with pytest.raises(BadRequestError) as excinfo:
def test_embedding_truncation_error(
llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation
model_id=embedding_model_id,
contents=[DUMMY_LONG_TEXT],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
)
@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction")
def test_embedding_output_dimension(llama_stack_client, embedding_model_id):
base_response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=[DUMMY_STRING])
def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_OUTPUT_DIMENSION:
pytest.xfail(f"{embedding_model_id} doesn't support output_dimension")
base_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], **default_task_type(embedding_model_id)
)
test_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], output_dimension=32
model_id=embedding_model_id,
contents=[DUMMY_STRING],
**default_task_type(embedding_model_id),
output_dimension=32,
)
assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
assert len(test_response.embeddings[0]) == 32
@pytest.mark.xfail(reason="Only valid for model supporting task type")
def test_embedding_task_type(llama_stack_client, embedding_model_id):
def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
pytest.xfail(f"{embedding_model_id} doesn't support task_type")
query_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
)
@ -199,9 +253,14 @@ def test_embedding_task_type(llama_stack_client, embedding_model_id):
"start",
],
)
def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation):
def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1
@ -219,8 +278,15 @@ def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_
"right",
],
)
def test_embedding_text_truncation_error(llama_stack_client, embedding_model_id, text_truncation):
with pytest.raises(BadRequestError) as excinfo:
def test_embedding_text_truncation_error(
llama_stack_client, embedding_model_id, text_truncation, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation
model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
)

View file

@ -139,7 +139,7 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
"top_k": 1,
},
)
streamed_content = [chunk for chunk in response]
streamed_content = list(response)
for chunk in streamed_content:
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
@ -405,7 +405,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
assert delta.tool_call.tool_name == "get_object_namespace_list"
if delta.type == "tool_call" and delta.parse_status == "failed":
# expect raw message that failed to parse in tool_call
assert type(delta.tool_call) == str
assert isinstance(delta.tool_call, str)
assert len(delta.tool_call) > 0
else:
for tc in response.completion_message.tool_calls:

View file

@ -42,29 +42,27 @@ def featured_models():
SUPPORTED_MODELS = {
"ollama": set(
[
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
]
),
"tgi": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]),
"vllm": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]),
"ollama": {
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_1b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_3b_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_11b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_2_90b_vision_instruct.value,
CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value,
},
"tgi": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
"vllm": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
}

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999

View file

@ -42,7 +42,7 @@ def code_scanner_shield_id(available_shields):
@pytest.fixture(scope="session")
def model_providers(llama_stack_client):
return set([x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"])
return {x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"}
def test_unsafe_examples(llama_stack_client, llama_guard_text_shield_id):

View file

@ -24,7 +24,6 @@ def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_id="faiss",
)
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
return vector_dbs
@ -121,7 +120,6 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384,
provider_id="faiss",
)
# list to check memory bank is successfully registered

View file

@ -3,3 +3,4 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# ruff: noqa: N999