Merge branch 'main' into pr1334

This commit is contained in:
ehhuang 2025-02-28 16:28:51 -08:00 committed by GitHub
commit e111a7cb4e
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 2373 additions and 135 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -55,6 +55,7 @@ def main(output_dir: str):
a set of endpoints and their corresponding interfaces that are tailored to a set of endpoints and their corresponding interfaces that are tailored to
best leverage Llama Models.""", best leverage Llama Models.""",
), ),
include_standard_error_responses=True,
), ),
) )

View file

@ -10,6 +10,7 @@ import typing
from dataclasses import make_dataclass from dataclasses import make_dataclass
from typing import Any, Dict, Set, Union from typing import Any, Dict, Set, Union
from llama_stack.apis.datatypes import Error
from llama_stack.strong_typing.core import JsonType from llama_stack.strong_typing.core import JsonType
from llama_stack.strong_typing.docstring import Docstring, parse_type from llama_stack.strong_typing.docstring import Docstring, parse_type
from llama_stack.strong_typing.inspection import ( from llama_stack.strong_typing.inspection import (
@ -434,6 +435,75 @@ class Generator:
) )
self.schema_builder = SchemaBuilder(schema_generator) self.schema_builder = SchemaBuilder(schema_generator)
self.responses = {} self.responses = {}
# Create standard error responses
self._create_standard_error_responses()
def _create_standard_error_responses(self) -> None:
"""
Creates standard error responses that can be reused across operations.
These will be added to the components.responses section of the OpenAPI document.
"""
# Get the Error schema
error_schema = self.schema_builder.classdef_to_ref(Error)
# Create standard error responses
self.responses["BadRequest400"] = Response(
description="The request was invalid or malformed",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 400,
"title": "Bad Request",
"detail": "The request was invalid or malformed",
}
)
}
)
self.responses["TooManyRequests429"] = Response(
description="The client has sent too many requests in a given amount of time",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 429,
"title": "Too Many Requests",
"detail": "You have exceeded the rate limit. Please try again later.",
}
)
}
)
self.responses["InternalServerError500"] = Response(
description="The server encountered an unexpected error",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 500,
"title": "Internal Server Error",
"detail": "An unexpected error occurred. Our team has been notified.",
}
)
}
)
# Add a default error response for any unhandled error cases
self.responses["DefaultError"] = Response(
description="An unexpected error occurred",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 0,
"title": "Error",
"detail": "An unexpected error occurred",
}
)
}
)
def _build_type_tag(self, ref: str, schema: Schema) -> Tag: def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
# Don't include schema definition in the tag description because for one, # Don't include schema definition in the tag description because for one,
@ -649,6 +719,18 @@ class Generator:
responses.update(response_builder.build_response(response_options)) responses.update(response_builder.build_response(response_options))
assert len(responses.keys()) > 0, f"No responses found for {op.name}" assert len(responses.keys()) > 0, f"No responses found for {op.name}"
# Add standard error response references
if self.options.include_standard_error_responses:
if "400" not in responses:
responses["400"] = ResponseRef("BadRequest400")
if "429" not in responses:
responses["429"] = ResponseRef("TooManyRequests429")
if "500" not in responses:
responses["500"] = ResponseRef("InternalServerError500")
if "default" not in responses:
responses["default"] = ResponseRef("DefaultError")
if op.event_type is not None: if op.event_type is not None:
builder = ContentBuilder(self.schema_builder) builder = ContentBuilder(self.schema_builder)
callbacks = { callbacks = {

View file

@ -35,6 +35,7 @@ class Options:
:param error_wrapper: True if errors are encapsulated in an error object wrapper. :param error_wrapper: True if errors are encapsulated in an error object wrapper.
:param property_description_fun: Custom transformation function to apply to class property documentation strings. :param property_description_fun: Custom transformation function to apply to class property documentation strings.
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types. :param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
:param include_standard_error_responses: Whether to include standard error responses (400, 429, 500, 503) in all operations.
""" """
server: Server server: Server
@ -52,6 +53,7 @@ class Options:
error_wrapper: bool = False error_wrapper: bool = False
property_description_fun: Optional[Callable[[type, str, str], str]] = None property_description_fun: Optional[Callable[[type, str, str], str]] = None
captions: Optional[Dict[str, str]] = None captions: Optional[Dict[str, str]] = None
include_standard_error_responses: bool = True
default_captions: ClassVar[Dict[str, str]] = { default_captions: ClassVar[Dict[str, str]] = {
"Operations": "Operations", "Operations": "Operations",

View file

@ -106,7 +106,7 @@ It would be best to start with a template and understand the structure of the co
llama stack build llama stack build
> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack > Enter a name for your Llama Stack (e.g. my-local-stack): my-stack
> Enter the image type you want your Llama Stack to be built as (container or conda): conda > Enter the image type you want your Llama Stack to be built as (container or conda or venv): conda
Llama Stack is composed of several APIs working together. Let's select Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs. the provider types (implementations) you want to use for these APIs.
@ -187,7 +187,7 @@ usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-i
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}] [--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
config config
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution. Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
positional arguments: positional arguments:
config Path to config file to use for the run config Path to config file to use for the run

View file

@ -41,12 +41,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -41,12 +41,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -129,3 +129,35 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. > **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```

View file

@ -154,6 +154,38 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. > **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Understand the models ## Understand the models
The `llama model` command helps you explore the models interface. The `llama model` command helps you explore the models interface.

View file

@ -73,7 +73,7 @@ If you're looking for more specific topics, we have a [Zero to Hero Guide](#next
Open a new terminal and install `llama-stack`: Open a new terminal and install `llama-stack`:
```bash ```bash
conda activate ollama conda activate ollama
pip install llama-stack==0.1.0 pip install -U llama-stack
``` ```
--- ---

View file

@ -5,6 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -33,3 +36,20 @@ class Api(Enum):
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"
@json_schema_type
class Error(BaseModel):
"""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code
:param title: Error title, a short summary of the error which is invariant for an error type
:param detail: Error detail, a longer human-readable description of the error
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
"""
status: int
title: str
detail: str
instance: Optional[str] = None

View file

@ -15,7 +15,6 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -103,8 +102,6 @@ def build_image(
template_or_config, template_or_config,
image_name, image_name,
container_base, container_base,
str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.container.value),
" ".join(normal_deps), " ".join(normal_deps),
] ]
elif build_config.image_type == ImageType.conda.value: elif build_config.image_type == ImageType.conda.value:

View file

@ -52,7 +52,7 @@ ensure_conda_env_python310() {
local python_version="3.10" local python_version="3.10"
# Check if conda command is available # Check if conda command is available
if ! command -v conda &>/dev/null; then if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2 printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1 exit 1
fi fi

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
@ -20,26 +20,27 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead # mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
if [ "$#" -lt 6 ]; then if [ "$#" -lt 4 ]; then
# This only works for templates # This only works for templates
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2 echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
exit 1 exit 1
fi fi
set -euo pipefail set -euo pipefail
template_or_config="$1" template_or_config="$1"
image_name="$2" shift
container_base="$3" image_name="$1"
build_file_path="$4" shift
host_build_dir="$5" container_base="$1"
pip_dependencies="$6" shift
special_pip_deps="${7:-}" pip_dependencies="$1"
shift
special_pip_deps="${1:-}"
# Define color codes # Define color codes
RED='\033[0;31m' RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
CONTAINER_BINARY=${CONTAINER_BINARY:-docker} CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
@ -47,8 +48,10 @@ CONTAINER_OPTS=${CONTAINER_OPTS:-}
TEMP_DIR=$(mktemp -d) TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
add_to_container() { add_to_container() {
local input
output_file="$TEMP_DIR/Containerfile" output_file="$TEMP_DIR/Containerfile"
if [ -t 0 ]; then if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file" printf '%s\n' "$1" >>"$output_file"
@ -58,15 +61,21 @@ add_to_container() {
fi fi
} }
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
# Update and install UBI9 components if UBI9 base image is used # Update and install UBI9 components if UBI9 base image is used
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF add_to_container << EOF
FROM $container_base FROM $container_base
WORKDIR /app WORKDIR /app
RUN microdnf -y update && microdnf install -y iputils net-tools wget \ RUN dnf -y update && dnf install -y iputils net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \ vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1
RUN pip install uv RUN pip install uv
@ -150,12 +159,12 @@ EOF
add_to_container << EOF add_to_container << EOF
RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \ RUN uv pip install --no-cache --extra-index-url https://test.pypi.org/simple/ \
--index-strategy unsafe-best-match \ --index-strategy unsafe-best-match \
llama-models==$TEST_PYPI_VERSION llama-stack-client==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION
EOF EOF
else else
if [ -n "$PYPI_VERSION" ]; then if [ -n "$PYPI_VERSION" ]; then
SPEC_VERSION="llama-stack==${PYPI_VERSION} llama-models==${PYPI_VERSION} llama-stack-client==${PYPI_VERSION}" SPEC_VERSION="llama-stack==${PYPI_VERSION}"
else else
SPEC_VERSION="llama-stack" SPEC_VERSION="llama-stack"
fi fi
@ -165,6 +174,11 @@ EOF
fi fi
fi fi
# remove uv after installation
add_to_container << EOF
RUN pip uninstall -y uv
EOF
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag # if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
if [[ "$template_or_config" != *.yaml ]]; then if [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF add_to_container << EOF
@ -185,26 +199,31 @@ RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache RUN chmod -R g+rw /app /.llama /.cache
EOF EOF
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n" printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
cat $TEMP_DIR/Containerfile cat "$TEMP_DIR"/Containerfile
printf "\n" printf "\n"
mounts="" # Start building the CLI arguments
CLI_ARGS=()
# Read CONTAINER_OPTS and put it in an array
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_MODELS_DIR"):$models_mount")
fi fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
fi fi
fi fi
if command -v selinuxenabled &>/dev/null && selinuxenabled; then if is_command_available selinuxenabled && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir # Disable SELinux labels -- we don't want to relabel the llama-stack source dir
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable" CLI_ARGS+=("--security-opt" "label=disable")
fi fi
# Set version tag based on PyPI version # Set version tag based on PyPI version
@ -225,11 +244,11 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture # Detect platform architecture
ARCH=$(uname -m) ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then if [ -n "$BUILD_PLATFORM" ]; then
PLATFORM="--platform $BUILD_PLATFORM" CLI_ARGS+=("--platform $BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
PLATFORM="--platform linux/arm64" CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then elif [ "$ARCH" = "x86_64" ]; then
PLATFORM="--platform linux/amd64" CLI_ARGS+=("--platform" "linux/amd64")
else else
echo "Unsupported architecture: $ARCH" echo "Unsupported architecture: $ARCH"
exit 1 exit 1
@ -238,8 +257,13 @@ fi
echo "PWD: $(pwd)" echo "PWD: $(pwd)"
echo "Containerfile: $TEMP_DIR/Containerfile" echo "Containerfile: $TEMP_DIR/Containerfile"
set -x set -x
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain $CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"." \
--progress=plain
# clean up tmp/configs # clean up tmp/configs
set +x set +x

View file

@ -3,6 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from tqdm import tqdm from tqdm import tqdm
@ -116,7 +117,7 @@ class MetaReferenceEvalImpl(
generations = [] generations = []
for i, x in tqdm(enumerate(input_rows)): for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row" assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = eval(str(x[ColumnName.chat_completion_input.value])) input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages] input_messages = [UserMessage(**x) for x in input_messages]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row # NOTE: only single-turn agent generation is supported. Create a new session for each input row
@ -158,7 +159,7 @@ class MetaReferenceEvalImpl(
generations = [] generations = []
for x in tqdm(input_rows): for x in tqdm(input_rows):
if ColumnName.completion_input.value in x: if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value])) input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.completion( response = await self.inference_api.completion(
model=candidate.model, model=candidate.model,
content=input_content, content=input_content,
@ -166,9 +167,8 @@ class MetaReferenceEvalImpl(
) )
generations.append({ColumnName.generated_answer.value: response.completion_message.content}) generations.append({ColumnName.generated_answer.value: response.completion_message.content})
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value]) chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = eval(chat_completion_input_str) input_messages = [UserMessage(**x) for x in chat_completion_input_json]
input_messages = [UserMessage(**x) for x in input_messages]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
messages.append(candidate.system_message) messages.append(candidate.system_message)

View file

@ -10,16 +10,19 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import json
from typing import Any, Mapping from typing import Any, Mapping
from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.common.data_schema_validator import ColumnName
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]: def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any],
) -> Mapping[str, Any]:
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, ( assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
"Invalid input row" "Invalid input row"
) )
input_messages = eval(str(sample[ColumnName.chat_completion_input.value])) input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message" assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0] input_message = input_messages[0]
@ -37,7 +40,7 @@ def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Map
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]: def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.dialog.value in sample, "Invalid input row" assert ColumnName.dialog.value in sample, "Invalid input row"
role_map = {"user": "human", "assistant": "gpt"} role_map = {"user": "human", "assistant": "gpt"}
dialog = eval(str(sample[ColumnName.dialog.value])) dialog = json.loads(sample[ColumnName.dialog.value])
assert len(dialog) > 1, "dialog must have at least 2 messagse" assert len(dialog) > 1, "dialog must have at least 2 messagse"
roles = [] roles = []

View file

@ -270,6 +270,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_config is not None:
tool_config.tool_choice = ToolChoice.none
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
messages=messages, messages=messages,

View file

@ -29,12 +29,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -31,12 +31,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -4,20 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from typing import Dict, List
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.client_tool import client_tool
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack_client.types.tool_def_param import Parameter
from llama_stack.apis.agents.agents import ( from llama_stack.apis.agents.agents import (
AgentConfig as Server__AgentConfig, AgentConfig as Server__AgentConfig,
@ -27,63 +22,22 @@ from llama_stack.apis.agents.agents import (
) )
class TestClientTool(ClientTool): @client_tool
"""Tool to give boiling point of a liquid def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
Returns the correct value for polyjuice in Celcius and Fahrenheit
and returns -1 for other liquids
""" """
Returns the boiling point of a liquid in Celcius or Fahrenheit
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: :param liquid_name: The name of the liquid
assert len(messages) == 1, "Expected single message" :param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
message = messages[0] """
if liquid_name.lower() == "polyjuice":
tool_call = message.tool_calls[0] if celcius:
return -100
try:
response = self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"
message = ToolResponseMessage(
role="tool",
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response_str,
)
return message
def get_name(self) -> str:
return "get_boiling_point"
def get_description(self) -> str:
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, Parameter]:
return {
"liquid_name": Parameter(
name="liquid_name",
parameter_type="string",
description="The name of the liquid",
required=True,
),
"celcius": Parameter(
name="celcius",
parameter_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
}
def run_impl(self, liquid_name: str, celcius: bool = True) -> int:
if liquid_name.lower() == "polyjuice":
if celcius:
return -100
else:
return -212
else: else:
return -1 return -212
else:
return -1
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
@ -298,7 +252,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
def test_custom_tool(llama_stack_client, agent_config): def test_custom_tool(llama_stack_client, agent_config):
client_tool = TestClientTool() client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
"toolgroups": ["builtin::websearch"], "toolgroups": ["builtin::websearch"],
@ -326,7 +280,7 @@ def test_custom_tool(llama_stack_client, agent_config):
def test_tool_choice(llama_stack_client, agent_config): def test_tool_choice(llama_stack_client, agent_config):
def run_agent(tool_choice): def run_agent(tool_choice):
client_tool = TestClientTool() client_tool = get_boiling_point
test_agent_config = { test_agent_config = {
**agent_config, **agent_config,
@ -505,7 +459,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
def test_create_turn_response(llama_stack_client, agent_config): def test_create_turn_response(llama_stack_client, agent_config):
client_tool = TestClientTool() client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
"input_shields": [], "input_shields": [],

View file

@ -76,6 +76,25 @@ DUMMY_IMAGE_URL = ImageContentItem(
) )
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
SUPPORTED_PROVIDERS = {"remote::nvidia"} SUPPORTED_PROVIDERS = {"remote::nvidia"}
MODELS_SUPPORTING_MEDIA = {}
MODELS_SUPPORTING_OUTPUT_DIMENSION = {"nvidia/llama-3.2-nv-embedqa-1b-v2"}
MODELS_REQUIRING_TASK_TYPE = {
"nvidia/llama-3.2-nv-embedqa-1b-v2",
"nvidia/nv-embedqa-e5-v5",
"nvidia/nv-embedqa-mistral-7b-v2",
"snowflake/arctic-embed-l",
}
MODELS_SUPPORTING_TASK_TYPE = MODELS_REQUIRING_TASK_TYPE
def default_task_type(model_id):
"""
Some models require a task type parameter. This provides a default value for
testing those models.
"""
if model_id in MODELS_REQUIRING_TASK_TYPE:
return {"task_type": "query"}
return {}
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -92,7 +111,9 @@ SUPPORTED_PROVIDERS = {"remote::nvidia"}
def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type): def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list) assert isinstance(response.embeddings[0], list)
@ -110,11 +131,14 @@ def test_embedding_text(llama_stack_client, embedding_model_id, contents, infere
"list[url,string,base64,text]", "list[url,string,base64,text]",
], ],
) )
@pytest.mark.xfail(reason="Media is not supported")
def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type): def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) if embedding_model_id not in MODELS_SUPPORTING_MEDIA:
pytest.xfail(f"{embedding_model_id} doesn't support media")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list) assert isinstance(response.embeddings[0], list)
@ -145,7 +169,10 @@ def test_embedding_truncation(
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings( response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, text_truncation=text_truncation model_id=embedding_model_id,
contents=contents,
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1 assert len(response.embeddings) == 1
@ -178,26 +205,36 @@ def test_embedding_truncation_error(
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError): with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation model_id=embedding_model_id,
contents=[DUMMY_LONG_TEXT],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )
@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction")
def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type): def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
base_response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=[DUMMY_STRING]) if embedding_model_id not in MODELS_SUPPORTING_OUTPUT_DIMENSION:
pytest.xfail(f"{embedding_model_id} doesn't support output_dimension")
base_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], **default_task_type(embedding_model_id)
)
test_response = llama_stack_client.inference.embeddings( test_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], output_dimension=32 model_id=embedding_model_id,
contents=[DUMMY_STRING],
**default_task_type(embedding_model_id),
output_dimension=32,
) )
assert len(base_response.embeddings[0]) != len(test_response.embeddings[0]) assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
assert len(test_response.embeddings[0]) == 32 assert len(test_response.embeddings[0]) == 32
@pytest.mark.xfail(reason="Only valid for model supporting task type")
def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type): def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
pytest.xfail(f"{embedding_model_id} doesn't support task_type")
query_embedding = llama_stack_client.inference.embeddings( query_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query" model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
) )
@ -220,7 +257,10 @@ def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_
if inference_provider_type not in SUPPORTED_PROVIDERS: if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings( response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1 assert len(response.embeddings) == 1
@ -245,5 +285,8 @@ def test_embedding_text_truncation_error(
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet") pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError): with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )