Merge branch 'main' into max_infer_iters

This commit is contained in:
Xi Yan 2025-02-28 12:31:31 -08:00
commit 4f94f5a708
62 changed files with 2590 additions and 324 deletions

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -1145,6 +1145,7 @@
} }
], ],
"source": [ "source": [
"# NBVAL_SKIP\n",
"from pydantic import BaseModel\n", "from pydantic import BaseModel\n",
"\n", "\n",
"\n", "\n",
@ -2885,7 +2886,6 @@
} }
], ],
"source": [ "source": [
"# NBVAL_SKIP\n",
"from llama_stack_client.lib.agents.agent import Agent\n", "from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n", "from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import AgentConfig\n", "from llama_stack_client.types.agent_create_params import AgentConfig\n",
@ -4326,7 +4326,7 @@
"provenance": [] "provenance": []
}, },
"kernelspec": { "kernelspec": {
"display_name": "toolchain", "display_name": "master",
"language": "python", "language": "python",
"name": "python3" "name": "python3"
}, },

View file

@ -55,6 +55,7 @@ def main(output_dir: str):
a set of endpoints and their corresponding interfaces that are tailored to a set of endpoints and their corresponding interfaces that are tailored to
best leverage Llama Models.""", best leverage Llama Models.""",
), ),
include_standard_error_responses=True,
), ),
) )

View file

@ -10,6 +10,7 @@ import typing
from dataclasses import make_dataclass from dataclasses import make_dataclass
from typing import Any, Dict, Set, Union from typing import Any, Dict, Set, Union
from llama_stack.apis.datatypes import Error
from llama_stack.strong_typing.core import JsonType from llama_stack.strong_typing.core import JsonType
from llama_stack.strong_typing.docstring import Docstring, parse_type from llama_stack.strong_typing.docstring import Docstring, parse_type
from llama_stack.strong_typing.inspection import ( from llama_stack.strong_typing.inspection import (
@ -435,6 +436,75 @@ class Generator:
self.schema_builder = SchemaBuilder(schema_generator) self.schema_builder = SchemaBuilder(schema_generator)
self.responses = {} self.responses = {}
# Create standard error responses
self._create_standard_error_responses()
def _create_standard_error_responses(self) -> None:
"""
Creates standard error responses that can be reused across operations.
These will be added to the components.responses section of the OpenAPI document.
"""
# Get the Error schema
error_schema = self.schema_builder.classdef_to_ref(Error)
# Create standard error responses
self.responses["BadRequest400"] = Response(
description="The request was invalid or malformed",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 400,
"title": "Bad Request",
"detail": "The request was invalid or malformed",
}
)
}
)
self.responses["TooManyRequests429"] = Response(
description="The client has sent too many requests in a given amount of time",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 429,
"title": "Too Many Requests",
"detail": "You have exceeded the rate limit. Please try again later.",
}
)
}
)
self.responses["InternalServerError500"] = Response(
description="The server encountered an unexpected error",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 500,
"title": "Internal Server Error",
"detail": "An unexpected error occurred. Our team has been notified.",
}
)
}
)
# Add a default error response for any unhandled error cases
self.responses["DefaultError"] = Response(
description="An unexpected error occurred",
content={
"application/json": MediaType(
schema=error_schema,
example={
"status": 0,
"title": "Error",
"detail": "An unexpected error occurred",
}
)
}
)
def _build_type_tag(self, ref: str, schema: Schema) -> Tag: def _build_type_tag(self, ref: str, schema: Schema) -> Tag:
# Don't include schema definition in the tag description because for one, # Don't include schema definition in the tag description because for one,
# it is not very valuable and for another, it causes string formatting # it is not very valuable and for another, it causes string formatting
@ -649,6 +719,18 @@ class Generator:
responses.update(response_builder.build_response(response_options)) responses.update(response_builder.build_response(response_options))
assert len(responses.keys()) > 0, f"No responses found for {op.name}" assert len(responses.keys()) > 0, f"No responses found for {op.name}"
# Add standard error response references
if self.options.include_standard_error_responses:
if "400" not in responses:
responses["400"] = ResponseRef("BadRequest400")
if "429" not in responses:
responses["429"] = ResponseRef("TooManyRequests429")
if "500" not in responses:
responses["500"] = ResponseRef("InternalServerError500")
if "default" not in responses:
responses["default"] = ResponseRef("DefaultError")
if op.event_type is not None: if op.event_type is not None:
builder = ContentBuilder(self.schema_builder) builder = ContentBuilder(self.schema_builder)
callbacks = { callbacks = {

View file

@ -35,6 +35,7 @@ class Options:
:param error_wrapper: True if errors are encapsulated in an error object wrapper. :param error_wrapper: True if errors are encapsulated in an error object wrapper.
:param property_description_fun: Custom transformation function to apply to class property documentation strings. :param property_description_fun: Custom transformation function to apply to class property documentation strings.
:param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types. :param captions: User-defined captions for sections such as "Operations" or "Types", and (if applicable) groups of extra types.
:param include_standard_error_responses: Whether to include standard error responses (400, 429, 500, 503) in all operations.
""" """
server: Server server: Server
@ -52,6 +53,7 @@ class Options:
error_wrapper: bool = False error_wrapper: bool = False
property_description_fun: Optional[Callable[[type, str, str], str]] = None property_description_fun: Optional[Callable[[type, str, str], str]] = None
captions: Optional[Dict[str, str]] = None captions: Optional[Dict[str, str]] = None
include_standard_error_responses: bool = True
default_captions: ClassVar[Dict[str, str]] = { default_captions: ClassVar[Dict[str, str]] = {
"Operations": "Operations", "Operations": "Operations",

View file

@ -106,7 +106,7 @@ It would be best to start with a template and understand the structure of the co
llama stack build llama stack build
> Enter a name for your Llama Stack (e.g. my-local-stack): my-stack > Enter a name for your Llama Stack (e.g. my-local-stack): my-stack
> Enter the image type you want your Llama Stack to be built as (container or conda): conda > Enter the image type you want your Llama Stack to be built as (container or conda or venv): conda
Llama Stack is composed of several APIs working together. Let's select Llama Stack is composed of several APIs working together. Let's select
the provider types (implementations) you want to use for these APIs. the provider types (implementations) you want to use for these APIs.
@ -187,7 +187,7 @@ usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-i
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}] [--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
config config
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution. Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
positional arguments: positional arguments:
config Path to config file to use for the run config Path to config file to use for the run

View file

@ -41,12 +41,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -41,12 +41,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -38,7 +38,7 @@ The API is **exactly identical** for both clients.
:::{dropdown} Starting up the Llama Stack server :::{dropdown} Starting up the Llama Stack server
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc. The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the configurations, please check out [this guide](../references/index.md).
Lets setup some environment variables that we will use in the rest of the guide. Lets setup some environment variables that we will use in the rest of the guide.
```bash ```bash

View file

@ -129,3 +129,35 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
**Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens). **Important:** Set your environment variable `HF_TOKEN` or pass in `--hf-token` to the command to validate your access. You can find your token at [https://huggingface.co/settings/tokens](https://huggingface.co/settings/tokens).
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. > **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```

View file

@ -154,6 +154,38 @@ llama download --source huggingface --model-id Prompt-Guard-86M --ignore-pattern
> **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored. > **Tip:** Default for `llama download` is to run with `--ignore-patterns *.safetensors` since we use the `.pth` files in the `original` folder. For Llama Guard and Prompt Guard, however, we need safetensors. Hence, please run with `--ignore-patterns original` so that safetensors are downloaded and `.pth` files are ignored.
## List the downloaded models
To list the downloaded models with the following command:
```
llama model list --downloaded
```
You should see a table like this:
```
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Understand the models ## Understand the models
The `llama model` command helps you explore the models interface. The `llama model` command helps you explore the models interface.

View file

@ -5,6 +5,9 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import Optional
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type from llama_stack.schema_utils import json_schema_type
@ -33,3 +36,20 @@ class Api(Enum):
# built-in API # built-in API
inspect = "inspect" inspect = "inspect"
@json_schema_type
class Error(BaseModel):
"""
Error response from the API. Roughly follows RFC 7807.
:param status: HTTP status code
:param title: Error title, a short summary of the error which is invariant for an error type
:param detail: Error detail, a longer human-readable description of the error
:param instance: (Optional) A URL which can be used to retrieve more information about the specific occurrence of the error
"""
status: int
title: str
detail: str
instance: Optional[str] = None

View file

@ -9,6 +9,7 @@ import textwrap
from io import StringIO from io import StringIO
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
@ -48,7 +49,26 @@ class ModelPromptFormat(Subcommand):
supported_model_ids = [ supported_model_ids = [
m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2} m for m in CoreModelId if model_family(m) in {ModelFamily.llama3_1, ModelFamily.llama3_2}
] ]
model_str = "\n".join([m.value for m in supported_model_ids])
model_list = [m.value for m in supported_model_ids]
model_str = "\n".join(model_list)
if args.list:
headers = ["Model(s)"]
rows = []
for m in model_list:
rows.append(
[
m,
]
)
print_table(
rows,
headers,
separate_rows=True,
)
return
try: try:
model_id = CoreModelId(args.model_name) model_id = CoreModelId(args.model_name)
except ValueError: except ValueError:

View file

@ -141,7 +141,7 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
completer=WordCompleter(available_providers), completer=WordCompleter(available_providers),
complete_while_typing=True, complete_while_typing=True,
validator=Validator.from_callable( validator=Validator.from_callable(
lambda x: x in available_providers, lambda x: x in available_providers, # noqa: B023 - see https://github.com/astral-sh/ruff/issues/7847
error_message="Invalid provider, use <TAB> to see options", error_message="Invalid provider, use <TAB> to see options",
), ),
) )

View file

@ -112,7 +112,7 @@ def test_parse_and_maybe_upgrade_config_old_format(old_config):
inference_providers = result.providers["inference"] inference_providers = result.providers["inference"]
assert len(inference_providers) == 2 assert len(inference_providers) == 2
assert set(x.provider_id for x in inference_providers) == { assert {x.provider_id for x in inference_providers} == {
"remote::ollama-00", "remote::ollama-00",
"meta-reference-01", "meta-reference-01",
} }

View file

@ -15,7 +15,6 @@ from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.distribution.utils.image_types import ImageType from llama_stack.distribution.utils.image_types import ImageType
from llama_stack.providers.datatypes import Api from llama_stack.providers.datatypes import Api
@ -103,8 +102,6 @@ def build_image(
template_or_config, template_or_config,
image_name, image_name,
container_base, container_base,
str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.container.value),
" ".join(normal_deps), " ".join(normal_deps),
] ]
elif build_config.image_type == ImageType.conda.value: elif build_config.image_type == ImageType.conda.value:

View file

@ -52,7 +52,7 @@ ensure_conda_env_python310() {
local python_version="3.10" local python_version="3.10"
# Check if conda command is available # Check if conda command is available
if ! command -v conda &>/dev/null; then if ! is_command_available conda; then
printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2 printf "${RED}Error: conda command not found. Is Conda installed and in your PATH?${NC}" >&2
exit 1 exit 1
fi fi

View file

@ -1,4 +1,4 @@
#!/bin/bash #!/usr/bin/env bash
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved. # All rights reserved.
@ -20,26 +20,27 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead # mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
if [ "$#" -lt 6 ]; then if [ "$#" -lt 4 ]; then
# This only works for templates # This only works for templates
echo "Usage: $0 <template_or_config> <image_name> <container_base> <build_file_path> <host_build_dir> <pip_dependencies> [<special_pip_deps>]" >&2 echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
exit 1 exit 1
fi fi
set -euo pipefail set -euo pipefail
template_or_config="$1" template_or_config="$1"
image_name="$2" shift
container_base="$3" image_name="$1"
build_file_path="$4" shift
host_build_dir="$5" container_base="$1"
pip_dependencies="$6" shift
special_pip_deps="${7:-}" pip_dependencies="$1"
shift
special_pip_deps="${1:-}"
# Define color codes # Define color codes
RED='\033[0;31m' RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color NC='\033[0m' # No Color
CONTAINER_BINARY=${CONTAINER_BINARY:-docker} CONTAINER_BINARY=${CONTAINER_BINARY:-docker}
@ -47,8 +48,10 @@ CONTAINER_OPTS=${CONTAINER_OPTS:-}
TEMP_DIR=$(mktemp -d) TEMP_DIR=$(mktemp -d)
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
add_to_container() { add_to_container() {
local input
output_file="$TEMP_DIR/Containerfile" output_file="$TEMP_DIR/Containerfile"
if [ -t 0 ]; then if [ -t 0 ]; then
printf '%s\n' "$1" >>"$output_file" printf '%s\n' "$1" >>"$output_file"
@ -58,15 +61,21 @@ add_to_container() {
fi fi
} }
# Check if container command is available
if ! is_command_available $CONTAINER_BINARY; then
printf "${RED}Error: ${CONTAINER_BINARY} command not found. Is ${CONTAINER_BINARY} installed and in your PATH?${NC}" >&2
exit 1
fi
# Update and install UBI9 components if UBI9 base image is used # Update and install UBI9 components if UBI9 base image is used
if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
add_to_container << EOF add_to_container << EOF
FROM $container_base FROM $container_base
WORKDIR /app WORKDIR /app
RUN microdnf -y update && microdnf install -y iputils net-tools wget \ RUN dnf -y update && dnf install -y iputils net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \ vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && microdnf clean all python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
ENV UV_SYSTEM_PYTHON=1 ENV UV_SYSTEM_PYTHON=1
RUN pip install uv RUN pip install uv
@ -165,6 +174,11 @@ EOF
fi fi
fi fi
# remove uv after installation
add_to_container << EOF
RUN pip uninstall -y uv
EOF
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag # if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
if [[ "$template_or_config" != *.yaml ]]; then if [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF add_to_container << EOF
@ -185,26 +199,31 @@ RUN mkdir -p /.llama /.cache
RUN chmod -R g+rw /app /.llama /.cache RUN chmod -R g+rw /app /.llama /.cache
EOF EOF
printf "Containerfile created successfully in $TEMP_DIR/Containerfile\n\n" printf "Containerfile created successfully in %s/Containerfile\n\n" "$TEMP_DIR"
cat $TEMP_DIR/Containerfile cat "$TEMP_DIR"/Containerfile
printf "\n" printf "\n"
mounts="" # Start building the CLI arguments
CLI_ARGS=()
# Read CONTAINER_OPTS and put it in an array
read -ra CLI_ARGS <<< "$CONTAINER_OPTS"
if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then if [ "$USE_COPY_NOT_MOUNT" != "true" ]; then
if [ -n "$LLAMA_STACK_DIR" ]; then if [ -n "$LLAMA_STACK_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_DIR):$stack_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_DIR"):$stack_mount")
fi fi
if [ -n "$LLAMA_MODELS_DIR" ]; then if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_MODELS_DIR"):$models_mount")
fi fi
if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then if [ -n "$LLAMA_STACK_CLIENT_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_STACK_CLIENT_DIR):$client_mount" CLI_ARGS+=("-v" "$(readlink -f "$LLAMA_STACK_CLIENT_DIR"):$client_mount")
fi fi
fi fi
if command -v selinuxenabled &>/dev/null && selinuxenabled; then if is_command_available selinuxenabled && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir # Disable SELinux labels -- we don't want to relabel the llama-stack source dir
CONTAINER_OPTS="$CONTAINER_OPTS --security-opt label=disable" CLI_ARGS+=("--security-opt" "label=disable")
fi fi
# Set version tag based on PyPI version # Set version tag based on PyPI version
@ -225,11 +244,11 @@ image_tag="$image_name:$version_tag"
# Detect platform architecture # Detect platform architecture
ARCH=$(uname -m) ARCH=$(uname -m)
if [ -n "$BUILD_PLATFORM" ]; then if [ -n "$BUILD_PLATFORM" ]; then
PLATFORM="--platform $BUILD_PLATFORM" CLI_ARGS+=("--platform $BUILD_PLATFORM")
elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then elif [ "$ARCH" = "arm64" ] || [ "$ARCH" = "aarch64" ]; then
PLATFORM="--platform linux/arm64" CLI_ARGS+=("--platform" "linux/arm64")
elif [ "$ARCH" = "x86_64" ]; then elif [ "$ARCH" = "x86_64" ]; then
PLATFORM="--platform linux/amd64" CLI_ARGS+=("--platform" "linux/amd64")
else else
echo "Unsupported architecture: $ARCH" echo "Unsupported architecture: $ARCH"
exit 1 exit 1
@ -238,8 +257,13 @@ fi
echo "PWD: $(pwd)" echo "PWD: $(pwd)"
echo "Containerfile: $TEMP_DIR/Containerfile" echo "Containerfile: $TEMP_DIR/Containerfile"
set -x set -x
$CONTAINER_BINARY build $CONTAINER_OPTS $PLATFORM -t $image_tag \
-f "$TEMP_DIR/Containerfile" "." $mounts --progress=plain $CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"." \
--progress=plain
# clean up tmp/configs # clean up tmp/configs
set +x set +x

View file

@ -13,7 +13,7 @@ from llama_stack.providers.datatypes import Api, ProviderSpec
def stack_apis() -> List[Api]: def stack_apis() -> List[Api]:
return [v for v in Api] return list(Api)
class AutoRoutedApiInfo(BaseModel): class AutoRoutedApiInfo(BaseModel):
@ -55,7 +55,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
def providable_apis() -> List[Api]: def providable_apis() -> List[Api]:
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis()) routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
return [api for api in Api if api not in routing_table_apis and api != Api.inspect] return [api for api in Api if api not in routing_table_apis and api != Api.inspect]

View file

@ -115,8 +115,8 @@ async def resolve_impls(
- flatmaps, sorts and resolves the providers in dependency order - flatmaps, sorts and resolves the providers in dependency order
- for each API, produces either a (local, passthrough or router) implementation - for each API, produces either a (local, passthrough or router) implementation
""" """
routing_table_apis = set(x.routing_table_api for x in builtin_automatically_routed_apis()) routing_table_apis = {x.routing_table_api for x in builtin_automatically_routed_apis()}
router_apis = set(x.router_api for x in builtin_automatically_routed_apis()) router_apis = {x.router_api for x in builtin_automatically_routed_apis()}
providers_with_specs = {} providers_with_specs = {}

View file

@ -318,13 +318,6 @@ class VectorDBsRoutingTable(CommonRoutingTableImpl, VectorDBs):
) )
model = await self.get_object_by_identifier("model", embedding_model) model = await self.get_object_by_identifier("model", embedding_model)
if model is None: if model is None:
if embedding_model == "all-MiniLM-L6-v2":
raise ValueError(
"Embeddings are now served via Inference providers. "
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
)
else:
raise ValueError(f"Model {embedding_model} not found") raise ValueError(f"Model {embedding_model} not found")
if model.model_type != ModelType.embedding: if model.model_type != ModelType.embedding:
raise ValueError(f"Model {embedding_model} is not an embedding model") raise ValueError(f"Model {embedding_model} is not an embedding model")

View file

@ -134,7 +134,7 @@ def rag_chat_page():
dict( dict(
name="builtin::rag/knowledge_search", name="builtin::rag/knowledge_search",
args={ args={
"vector_db_ids": [vector_db_id for vector_db_id in selected_vector_dbs], "vector_db_ids": list(selected_vector_dbs),
}, },
) )
], ],

View file

@ -46,7 +46,7 @@ def formulate_run_args(image_type, image_name, config, template_name) -> list:
conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode()) conda_env_info = json.loads(subprocess.check_output(["conda", "info", "--envs", "--json"]).decode())
envs = conda_env_info["envs"] envs = conda_env_info["envs"]
for envpath in envs: for envpath in envs:
if envpath.endswith(env_name): if os.path.basename(envpath) == env_name:
return envpath return envpath
return None return None

View file

@ -226,10 +226,9 @@ class FunctionTagCustomToolGenerator(PromptTemplateGeneratorBase):
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
DEFAULT_PROMPT = textwrap.dedent( DEFAULT_PROMPT = textwrap.dedent(
""" """
You are a helpful assistant. You have access to functions, but you should only use them if they are required.
You are an expert in composing functions. You are given a question and a set of possible functions. You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose. Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
{{ function_description }} {{ function_description }}
""".strip("\n") """.strip("\n")

View file

@ -611,8 +611,17 @@ class ChatAgent(ShieldRunnerMixin):
if event.stop_reason is not None: if event.stop_reason is not None:
stop_reason = event.stop_reason stop_reason = event.stop_reason
span.set_attribute("stop_reason", stop_reason) span.set_attribute("stop_reason", stop_reason)
span.set_attribute("input", [m.model_dump_json() for m in input_messages]) span.set_attribute(
span.set_attribute("output", f"content: {content} tool_calls: {tool_calls}") "input",
json.dumps([json.loads(m.model_dump_json()) for m in input_messages]),
)
output_attr = json.dumps(
{
"content": content,
"tool_calls": [json.loads(t.model_dump_json()) for t in tool_calls],
}
)
span.set_attribute("output", output_attr)
n_iter += 1 n_iter += 1
await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter) await self.storage.set_num_infer_iters_in_turn(session_id, turn_id, n_iter)
@ -796,10 +805,10 @@ class ChatAgent(ShieldRunnerMixin):
self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None self, toolgroups_for_turn: Optional[List[AgentToolGroup]] = None
) -> Tuple[List[ToolDefinition], Dict[str, str]]: ) -> Tuple[List[ToolDefinition], Dict[str, str]]:
# Determine which tools to include # Determine which tools to include
agent_config_toolgroups = set( agent_config_toolgroups = {
(toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup) toolgroup.name if isinstance(toolgroup, AgentToolGroupWithArgs) else toolgroup
for toolgroup in self.agent_config.toolgroups for toolgroup in self.agent_config.toolgroups
) }
toolgroups_for_turn_set = ( toolgroups_for_turn_set = (
agent_config_toolgroups agent_config_toolgroups
if toolgroups_for_turn is None if toolgroups_for_turn is None

View file

@ -3,6 +3,7 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
from tqdm import tqdm from tqdm import tqdm
@ -86,7 +87,6 @@ class MetaReferenceEvalImpl(
) -> Job: ) -> Job:
task_def = self.benchmarks[benchmark_id] task_def = self.benchmarks[benchmark_id]
dataset_id = task_def.dataset_id dataset_id = task_def.dataset_id
candidate = task_config.eval_candidate
scoring_functions = task_def.scoring_functions scoring_functions = task_def.scoring_functions
dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id) dataset_def = await self.datasets_api.get_dataset(dataset_id=dataset_id)
validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value)) validate_dataset_schema(dataset_def.dataset_schema, get_valid_schemas(Api.eval.value))
@ -117,7 +117,7 @@ class MetaReferenceEvalImpl(
generations = [] generations = []
for i, x in tqdm(enumerate(input_rows)): for i, x in tqdm(enumerate(input_rows)):
assert ColumnName.chat_completion_input.value in x, "Invalid input row" assert ColumnName.chat_completion_input.value in x, "Invalid input row"
input_messages = eval(str(x[ColumnName.chat_completion_input.value])) input_messages = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = [UserMessage(**x) for x in input_messages] input_messages = [UserMessage(**x) for x in input_messages]
# NOTE: only single-turn agent generation is supported. Create a new session for each input row # NOTE: only single-turn agent generation is supported. Create a new session for each input row
@ -159,7 +159,7 @@ class MetaReferenceEvalImpl(
generations = [] generations = []
for x in tqdm(input_rows): for x in tqdm(input_rows):
if ColumnName.completion_input.value in x: if ColumnName.completion_input.value in x:
input_content = eval(str(x[ColumnName.completion_input.value])) input_content = json.loads(x[ColumnName.completion_input.value])
response = await self.inference_api.completion( response = await self.inference_api.completion(
model=candidate.model, model=candidate.model,
content=input_content, content=input_content,
@ -167,9 +167,8 @@ class MetaReferenceEvalImpl(
) )
generations.append({ColumnName.generated_answer.value: response.completion_message.content}) generations.append({ColumnName.generated_answer.value: response.completion_message.content})
elif ColumnName.chat_completion_input.value in x: elif ColumnName.chat_completion_input.value in x:
chat_completion_input_str = str(x[ColumnName.chat_completion_input.value]) chat_completion_input_json = json.loads(x[ColumnName.chat_completion_input.value])
input_messages = eval(chat_completion_input_str) input_messages = [UserMessage(**x) for x in chat_completion_input_json]
input_messages = [UserMessage(**x) for x in input_messages]
messages = [] messages = []
if candidate.system_message: if candidate.system_message:
messages.append(candidate.system_message) messages.append(candidate.system_message)

View file

@ -208,7 +208,6 @@ class MetaReferenceInferenceImpl(
logprobs = [] logprobs = []
stop_reason = None stop_reason = None
tokenizer = self.generator.formatter.tokenizer
for token_result in self.generator.completion(request): for token_result in self.generator.completion(request):
tokens.append(token_result.token) tokens.append(token_result.token)
if token_result.text == "<|eot_id|>": if token_result.text == "<|eot_id|>":

View file

@ -207,7 +207,7 @@ def maybe_parse_message(maybe_json: Optional[str]) -> Optional[ProcessingMessage
return parse_message(maybe_json) return parse_message(maybe_json)
except json.JSONDecodeError: except json.JSONDecodeError:
return None return None
except ValueError as e: except ValueError:
return None return None
@ -352,7 +352,7 @@ class ModelParallelProcessGroup:
if isinstance(obj, TaskResponse): if isinstance(obj, TaskResponse):
yield obj.result yield obj.result
except GeneratorExit as e: except GeneratorExit:
self.request_socket.send(encode_msg(CancelSentinel())) self.request_socket.send(encode_msg(CancelSentinel()))
while True: while True:
obj_json = self.request_socket.send() obj_json = self.request_socket.send()

View file

@ -7,6 +7,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates. # Copyright (c) Meta Platforms, Inc. and affiliates.
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement. # This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
# The file gets a special treatment for now?
# ruff: noqa: N803
import unittest import unittest
import torch import torch

View file

@ -10,16 +10,19 @@
# This source code is licensed under the BSD-style license found in the # This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree. # LICENSE file in the root directory of this source tree.
import json
from typing import Any, Mapping from typing import Any, Mapping
from llama_stack.providers.utils.common.data_schema_validator import ColumnName from llama_stack.providers.utils.common.data_schema_validator import ColumnName
def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Mapping[str, Any]: def llama_stack_instruct_to_torchtune_instruct(
sample: Mapping[str, Any],
) -> Mapping[str, Any]:
assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, ( assert ColumnName.chat_completion_input.value in sample and ColumnName.expected_answer.value in sample, (
"Invalid input row" "Invalid input row"
) )
input_messages = eval(str(sample[ColumnName.chat_completion_input.value])) input_messages = json.loads(sample[ColumnName.chat_completion_input.value])
assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message" assert len(input_messages) == 1, "llama stack intruct dataset format only supports 1 user message"
input_message = input_messages[0] input_message = input_messages[0]
@ -37,7 +40,7 @@ def llama_stack_instruct_to_torchtune_instruct(sample: Mapping[str, Any]) -> Map
def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]: def llama_stack_chat_to_torchtune_chat(sample: Mapping[str, Any]) -> Mapping[str, Any]:
assert ColumnName.dialog.value in sample, "Invalid input row" assert ColumnName.dialog.value in sample, "Invalid input row"
role_map = {"user": "human", "assistant": "gpt"} role_map = {"user": "human", "assistant": "gpt"}
dialog = eval(str(sample[ColumnName.dialog.value])) dialog = json.loads(sample[ColumnName.dialog.value])
assert len(dialog) > 1, "dialog must have at least 2 messagse" assert len(dialog) > 1, "dialog must have at least 2 messagse"
roles = [] roles = []

View file

@ -264,7 +264,7 @@ class LoraFinetuningSingleDevice:
) )
self.adapter_params = get_adapter_params(model) self.adapter_params = get_adapter_params(model)
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()]) self._is_dora = any("magnitude" in k for k in self.adapter_params.keys())
set_trainable_params(model, self.adapter_params) set_trainable_params(model, self.adapter_params)

View file

@ -133,7 +133,7 @@ class BraintrustScoringImpl(
async def shutdown(self) -> None: ... async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFn]: async def list_scoring_functions(self) -> List[ScoringFn]:
scoring_fn_defs_list = [x for x in self.supported_fn_defs_registry.values()] scoring_fn_defs_list = list(self.supported_fn_defs_registry.values())
for f in scoring_fn_defs_list: for f in scoring_fn_defs_list:
assert f.identifier.startswith("braintrust"), ( assert f.identifier.startswith("braintrust"), (
"All braintrust scoring fn must have identifier prefixed with 'braintrust'! " "All braintrust scoring fn must have identifier prefixed with 'braintrust'! "

View file

@ -198,7 +198,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]: ) -> Union[ChatCompletionResponse, AsyncIterator[ChatCompletionResponseStreamChunk]]:
if tool_prompt_format: if tool_prompt_format:
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring") warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
await check_health(self._config) # this raises errors await check_health(self._config) # this raises errors

View file

@ -106,7 +106,7 @@ async def convert_chat_completion_request(
payload.update(temperature=strategy.temperature) payload.update(temperature=strategy.temperature)
elif isinstance(strategy, TopKSamplingStrategy): elif isinstance(strategy, TopKSamplingStrategy):
if strategy.top_k != -1 and strategy.top_k < 1: if strategy.top_k != -1 and strategy.top_k < 1:
warnings.warn("top_k must be -1 or >= 1") warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=strategy.top_k) nvext.update(top_k=strategy.top_k)
elif isinstance(strategy, GreedySamplingStrategy): elif isinstance(strategy, GreedySamplingStrategy):
nvext.update(top_k=-1) nvext.update(top_k=-1)
@ -168,7 +168,7 @@ def convert_completion_request(
payload.update(top_p=request.sampling_params.top_p) payload.update(top_p=request.sampling_params.top_p)
elif request.sampling_params.strategy == "top_k": elif request.sampling_params.strategy == "top_k":
if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1: if request.sampling_params.top_k != -1 and request.sampling_params.top_k < 1:
warnings.warn("top_k must be -1 or >= 1") warnings.warn("top_k must be -1 or >= 1", stacklevel=2)
nvext.update(top_k=request.sampling_params.top_k) nvext.update(top_k=request.sampling_params.top_k)
elif request.sampling_params.strategy == "greedy": elif request.sampling_params.strategy == "greedy":
nvext.update(top_k=-1) nvext.update(top_k=-1)

View file

@ -270,6 +270,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
tool_config: Optional[ToolConfig] = None, tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.model_store.get_model(model_id) model = await self.model_store.get_model(model_id)
# This is to be consistent with OpenAI API and support vLLM <= v0.6.3
# References:
# * https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
# * https://github.com/vllm-project/vllm/pull/10000
if not tools and tool_config is not None:
tool_config.tool_choice = ToolChoice.none
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=model.provider_resource_id, model=model.provider_resource_id,
messages=messages, messages=messages,

View file

@ -39,12 +39,11 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model): async def test_eval_evaluate_rows(self, eval_stack, inference_model, judge_model):
eval_impl, benchmarks_impl, datasetio_impl, datasets_impl, models_impl = ( eval_impl, benchmarks_impl, datasetio_impl, datasets_impl = (
eval_stack[Api.eval], eval_stack[Api.eval],
eval_stack[Api.benchmarks], eval_stack[Api.benchmarks],
eval_stack[Api.datasetio], eval_stack[Api.datasetio],
eval_stack[Api.datasets], eval_stack[Api.datasets],
eval_stack[Api.models],
) )
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval") await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
@ -92,11 +91,10 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_eval(self, eval_stack, inference_model, judge_model): async def test_eval_run_eval(self, eval_stack, inference_model, judge_model):
eval_impl, benchmarks_impl, datasets_impl, models_impl = ( eval_impl, benchmarks_impl, datasets_impl = (
eval_stack[Api.eval], eval_stack[Api.eval],
eval_stack[Api.benchmarks], eval_stack[Api.benchmarks],
eval_stack[Api.datasets], eval_stack[Api.datasets],
eval_stack[Api.models],
) )
await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval") await register_dataset(datasets_impl, for_generation=True, dataset_id="test_dataset_for_eval")
@ -131,11 +129,10 @@ class Testeval:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_eval_run_benchmark_eval(self, eval_stack, inference_model): async def test_eval_run_benchmark_eval(self, eval_stack, inference_model):
eval_impl, benchmarks_impl, datasets_impl, models_impl = ( eval_impl, benchmarks_impl, datasets_impl = (
eval_stack[Api.eval], eval_stack[Api.eval],
eval_stack[Api.benchmarks], eval_stack[Api.benchmarks],
eval_stack[Api.datasets], eval_stack[Api.datasets],
eval_stack[Api.models],
) )
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()

View file

@ -18,8 +18,7 @@ from llama_stack.models.llama.sku_list import all_registered_models
INFERENCE_APIS = ["chat_completion"] INFERENCE_APIS = ["chat_completion"]
FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"] FUNCTIONALITIES = ["streaming", "structured_output", "tool_calling"]
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"ollama": set( "ollama": {
[
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
@ -37,10 +36,8 @@ SUPPORTED_MODELS = {
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value, CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value, CoreModelId.llama_guard_3_1b.value,
] },
), "fireworks": {
"fireworks": set(
[
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
@ -51,10 +48,8 @@ SUPPORTED_MODELS = {
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value, CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value, CoreModelId.llama_guard_3_11b_vision.value,
] },
), "together": {
"together": set(
[
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
CoreModelId.llama3_1_405b_instruct.value, CoreModelId.llama3_1_405b_instruct.value,
@ -64,8 +59,7 @@ SUPPORTED_MODELS = {
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value, CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_11b_vision.value, CoreModelId.llama_guard_3_11b_vision.value,
] },
),
} }

View file

@ -45,13 +45,11 @@ class TestScoring:
scoring_functions_impl, scoring_functions_impl,
datasetio_impl, datasetio_impl,
datasets_impl, datasets_impl,
models_impl,
) = ( ) = (
scoring_stack[Api.scoring], scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions], scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio], scoring_stack[Api.datasetio],
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models],
) )
scoring_fns_list = await scoring_functions_impl.list_scoring_functions() scoring_fns_list = await scoring_functions_impl.list_scoring_functions()
provider_id = scoring_fns_list[0].provider_id provider_id = scoring_fns_list[0].provider_id
@ -102,13 +100,11 @@ class TestScoring:
scoring_functions_impl, scoring_functions_impl,
datasetio_impl, datasetio_impl,
datasets_impl, datasets_impl,
models_impl,
) = ( ) = (
scoring_stack[Api.scoring], scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions], scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio], scoring_stack[Api.datasetio],
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models],
) )
await register_dataset(datasets_impl, for_rag=True) await register_dataset(datasets_impl, for_rag=True)
response = await datasets_impl.list_datasets() response = await datasets_impl.list_datasets()
@ -163,13 +159,11 @@ class TestScoring:
scoring_functions_impl, scoring_functions_impl,
datasetio_impl, datasetio_impl,
datasets_impl, datasets_impl,
models_impl,
) = ( ) = (
scoring_stack[Api.scoring], scoring_stack[Api.scoring],
scoring_stack[Api.scoring_functions], scoring_stack[Api.scoring_functions],
scoring_stack[Api.datasetio], scoring_stack[Api.datasetio],
scoring_stack[Api.datasets], scoring_stack[Api.datasets],
scoring_stack[Api.models],
) )
await register_dataset(datasets_impl, for_rag=True) await register_dataset(datasets_impl, for_rag=True)
rows = await datasetio_impl.get_rows_paginated( rows = await datasetio_impl.get_rows_paginated(

View file

@ -6,7 +6,7 @@
import json import json
import logging import logging
import warnings import warnings
from typing import AsyncGenerator, Dict, Generator, Iterable, List, Optional, Union from typing import AsyncGenerator, Dict, Iterable, List, Optional, Union
from openai import AsyncStream from openai import AsyncStream
from openai.types.chat import ( from openai.types.chat import (
@ -605,7 +605,7 @@ def convert_tool_call(
tool_name=tool_call.function.name, tool_name=tool_call.function.name,
arguments=json.loads(tool_call.function.arguments), arguments=json.loads(tool_call.function.arguments),
) )
except Exception as e: except Exception:
return UnparseableToolCall( return UnparseableToolCall(
call_id=tool_call.id or "", call_id=tool_call.id or "",
tool_name=tool_call.function.name or "", tool_name=tool_call.function.name or "",
@ -841,14 +841,13 @@ async def convert_openai_chat_completion_stream(
Convert a stream of OpenAI chat completion chunks into a stream Convert a stream of OpenAI chat completion chunks into a stream
of ChatCompletionResponseStreamChunk. of ChatCompletionResponseStreamChunk.
""" """
yield ChatCompletionResponseStreamChunk(
# generate a stream of ChatCompletionResponseEventType: start -> progress -> progress -> ... event=ChatCompletionResponseEvent(
def _event_type_generator() -> Generator[ChatCompletionResponseEventType, None, None]: event_type=ChatCompletionResponseEventType.start,
yield ChatCompletionResponseEventType.start delta=TextDelta(text=""),
while True: )
yield ChatCompletionResponseEventType.progress )
event_type = ChatCompletionResponseEventType.progress
event_type = _event_type_generator()
stop_reason = None stop_reason = None
toolcall_buffer = {} toolcall_buffer = {}
@ -868,7 +867,7 @@ async def convert_openai_chat_completion_stream(
if choice.delta.content: if choice.delta.content:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=event_type,
delta=TextDelta(text=choice.delta.content), delta=TextDelta(text=choice.delta.content),
logprobs=_convert_openai_logprobs(logprobs), logprobs=_convert_openai_logprobs(logprobs),
) )
@ -877,7 +876,9 @@ async def convert_openai_chat_completion_stream(
# it is possible to have parallel tool calls in stream, but # it is possible to have parallel tool calls in stream, but
# ChatCompletionResponseEvent only supports one per stream # ChatCompletionResponseEvent only supports one per stream
if len(choice.delta.tool_calls) > 1: if len(choice.delta.tool_calls) > 1:
warnings.warn("multiple tool calls found in a single delta, using the first, ignoring the rest") warnings.warn(
"multiple tool calls found in a single delta, using the first, ignoring the rest", stacklevel=2
)
if not enable_incremental_tool_calls: if not enable_incremental_tool_calls:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
@ -909,7 +910,7 @@ async def convert_openai_chat_completion_stream(
toolcall_buffer["content"] += delta toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=event_type,
delta=ToolCallDelta( delta=ToolCallDelta(
tool_call=delta, tool_call=delta,
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,
@ -920,7 +921,7 @@ async def convert_openai_chat_completion_stream(
else: else:
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=event_type,
delta=TextDelta(text=choice.delta.content or ""), delta=TextDelta(text=choice.delta.content or ""),
logprobs=_convert_openai_logprobs(logprobs), logprobs=_convert_openai_logprobs(logprobs),
) )
@ -931,7 +932,7 @@ async def convert_openai_chat_completion_stream(
toolcall_buffer["content"] += delta toolcall_buffer["content"] += delta
yield ChatCompletionResponseStreamChunk( yield ChatCompletionResponseStreamChunk(
event=ChatCompletionResponseEvent( event=ChatCompletionResponseEvent(
event_type=next(event_type), event_type=event_type,
delta=ToolCallDelta( delta=ToolCallDelta(
tool_call=delta, tool_call=delta,
parse_status=ToolCallParseStatus.in_progress, parse_status=ToolCallParseStatus.in_progress,

View file

@ -36,7 +36,7 @@ class RedisKVStoreImpl(KVStore):
value = await self.redis.get(key) value = await self.redis.get(key)
if value is None: if value is None:
return None return None
ttl = await self.redis.ttl(key) await self.redis.ttl(key)
return value return value
async def delete(self, key: str) -> None: async def delete(self, key: str) -> None:

View file

@ -32,7 +32,7 @@ def aggregate_categorical_count(
scoring_results: List[ScoringResultRow], scoring_results: List[ScoringResultRow],
) -> Dict[str, Any]: ) -> Dict[str, Any]:
scores = [str(r["score"]) for r in scoring_results] scores = [str(r["score"]) for r in scoring_results]
unique_scores = sorted(list(set(scores))) unique_scores = sorted(set(scores))
return {"categorical_count": {s: scores.count(s) for s in unique_scores}} return {"categorical_count": {s: scores.count(s) for s in unique_scores}}

View file

@ -66,7 +66,7 @@ class RegisteredBaseScoringFn(BaseScoringFn):
return self.__class__.__name__ return self.__class__.__name__
def get_supported_scoring_fn_defs(self) -> List[ScoringFn]: def get_supported_scoring_fn_defs(self) -> List[ScoringFn]:
return [x for x in self.supported_fn_defs_registry.values()] return list(self.supported_fn_defs_registry.values())
def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None: def register_scoring_fn_def(self, scoring_fn: ScoringFn) -> None:
if scoring_fn.identifier in self.supported_fn_defs_registry: if scoring_fn.identifier in self.supported_fn_defs_registry:

View file

@ -6,6 +6,7 @@
import asyncio import asyncio
import inspect import inspect
import json
from functools import wraps from functools import wraps
from typing import Any, AsyncGenerator, Callable, Type, TypeVar from typing import Any, AsyncGenerator, Callable, Type, TypeVar
@ -17,6 +18,10 @@ T = TypeVar("T")
def serialize_value(value: Any) -> Primitive: def serialize_value(value: Any) -> Primitive:
return str(_prepare_for_json(value))
def _prepare_for_json(value: Any) -> str:
"""Serialize a single value into JSON-compatible format.""" """Serialize a single value into JSON-compatible format."""
if value is None: if value is None:
return "" return ""
@ -25,8 +30,16 @@ def serialize_value(value: Any) -> Primitive:
elif hasattr(value, "_name_"): elif hasattr(value, "_name_"):
return value._name_ return value._name_
elif isinstance(value, BaseModel): elif isinstance(value, BaseModel):
return value.model_dump_json() return json.loads(value.model_dump_json())
elif isinstance(value, (list, tuple, set)):
return [_prepare_for_json(item) for item in value]
elif isinstance(value, dict):
return {str(k): _prepare_for_json(v) for k, v in value.items()}
else: else:
try:
json.dumps(value)
return value
except Exception:
return str(value) return str(value)
@ -104,7 +117,8 @@ def trace_protocol(cls: Type[T]) -> Type[T]:
result = method(self, *args, **kwargs) result = method(self, *args, **kwargs)
span.set_attribute("output", serialize_value(result)) span.set_attribute("output", serialize_value(result))
return result return result
except Exception as _e: except Exception as e:
span.set_attribute("error", str(e))
raise raise
if is_async_gen: if is_async_gen:

View file

@ -99,7 +99,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
template = template_func() template = template_func()
normal_deps, special_deps = get_provider_dependencies(template.providers) normal_deps, special_deps = get_provider_dependencies(template.providers)
# Combine all dependencies in order: normal deps, special deps, server deps # Combine all dependencies in order: normal deps, special deps, server deps
all_deps = sorted(list(set(normal_deps + SERVER_DEPENDENCIES))) + sorted(list(set(special_deps))) all_deps = sorted(set(normal_deps + SERVER_DEPENDENCIES)) + sorted(set(special_deps))
return template.name, all_deps return template.name, all_deps
except Exception: except Exception:

View file

@ -29,12 +29,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -31,12 +31,31 @@ The following environment variables can be configured:
## Prerequisite: Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ llama model list --downloaded
Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3.2-90B-Vision-Instruct Llama-Guard-3-8B ┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M ┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
``` ```
## Running the Distribution ## Running the Distribution

View file

@ -93,7 +93,7 @@ def get_distribution_template() -> DistributionTemplate:
"inference": [inference_provider], "inference": [inference_provider],
"vector_io": [vector_io_provider_sqlite], "vector_io": [vector_io_provider_sqlite],
}, },
default_models=[inference_model], default_models=[inference_model, embedding_model],
default_tool_groups=default_tool_groups, default_tool_groups=default_tool_groups,
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(

View file

@ -90,6 +90,12 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: ollama provider_id: ollama
model_type: llm model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: ollama
provider_model_id: all-minilm:latest
model_type: embedding
shields: [] shields: []
vector_dbs: [] vector_dbs: []
datasets: [] datasets: []

View file

@ -123,39 +123,16 @@ select = [
"I", # isort "I", # isort
] ]
ignore = [ ignore = [
"E203", # The following ignores are desired by the project maintainers.
"E305", "E402", # Module level import not at top of file
"E402", "E501", # Line too long
"E501", # line too long "F405", # Maybe undefined or defined from star import
"E721", "C408", # Ignored because we like the dict keyword argument syntax
"E741", "N812", # Ignored because import torch.nn.functional as F is PyTorch convention
"F405",
"F841",
"C408", # ignored because we like the dict keyword argument syntax
"E302",
"W291",
"E303",
"N812", # ignored because import torch.nn.functional as F is PyTorch convention
"N817", # ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
"E731", # allow usage of assigning lambda expressions
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later. # These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
"C901", "C901", # Complexity of the function is too high
"C405",
"C414",
"N803",
"N999",
"C403",
"C416",
"B028",
"C419",
"C401",
"B023",
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
"EXE001",
"N802", # random naming hints don't need
# these ignores are from flake8-bugbear; please fix! # these ignores are from flake8-bugbear; please fix!
"B007",
"B008", "B008",
] ]

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999

View file

@ -4,20 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import json
from typing import Dict, List
from uuid import uuid4 from uuid import uuid4
import pytest import pytest
from llama_stack_client.lib.agents.agent import Agent from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.client_tool import ClientTool from llama_stack_client.lib.agents.client_tool import client_tool
from llama_stack_client.lib.agents.event_logger import EventLogger from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types import ToolResponseMessage
from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument from llama_stack_client.types.agents.turn_create_params import Document as AgentDocument
from llama_stack_client.types.memory_insert_params import Document from llama_stack_client.types.memory_insert_params import Document
from llama_stack_client.types.shared.completion_message import CompletionMessage
from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig from llama_stack_client.types.shared_params.agent_config import AgentConfig, ToolConfig
from llama_stack_client.types.tool_def_param import Parameter
from llama_stack.apis.agents.agents import ( from llama_stack.apis.agents.agents import (
AgentConfig as Server__AgentConfig, AgentConfig as Server__AgentConfig,
@ -27,56 +22,15 @@ from llama_stack.apis.agents.agents import (
) )
class TestClientTool(ClientTool): @client_tool
"""Tool to give boiling point of a liquid def get_boiling_point(liquid_name: str, celcius: bool = True) -> int:
Returns the correct value for polyjuice in Celcius and Fahrenheit
and returns -1 for other liquids
""" """
Returns the boiling point of a liquid in Celcius or Fahrenheit
def run(self, messages: List[CompletionMessage]) -> List[ToolResponseMessage]: :param liquid_name: The name of the liquid
assert len(messages) == 1, "Expected single message" :param celcius: Whether to return the boiling point in Celcius
:return: The boiling point of the liquid in Celcius or Fahrenheit
message = messages[0] """
tool_call = message.tool_calls[0]
try:
response = self.run_impl(**tool_call.arguments)
response_str = json.dumps(response, ensure_ascii=False)
except Exception as e:
response_str = f"Error when running tool: {e}"
message = ToolResponseMessage(
role="tool",
call_id=tool_call.call_id,
tool_name=tool_call.tool_name,
content=response_str,
)
return message
def get_name(self) -> str:
return "get_boiling_point"
def get_description(self) -> str:
return "Get the boiling point of imaginary liquids (eg. polyjuice)"
def get_params_definition(self) -> Dict[str, Parameter]:
return {
"liquid_name": Parameter(
name="liquid_name",
parameter_type="string",
description="The name of the liquid",
required=True,
),
"celcius": Parameter(
name="celcius",
parameter_type="boolean",
description="Whether to return the boiling point in Celcius",
required=False,
),
}
def run_impl(self, liquid_name: str, celcius: bool = True) -> int:
if liquid_name.lower() == "polyjuice": if liquid_name.lower() == "polyjuice":
if celcius: if celcius:
return -100 return -100
@ -298,7 +252,7 @@ def test_code_interpreter_for_attachments(llama_stack_client, agent_config):
def test_custom_tool(llama_stack_client, agent_config): def test_custom_tool(llama_stack_client, agent_config):
client_tool = TestClientTool() client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
"toolgroups": ["builtin::websearch"], "toolgroups": ["builtin::websearch"],
@ -326,7 +280,7 @@ def test_custom_tool(llama_stack_client, agent_config):
def test_tool_choice(llama_stack_client, agent_config): def test_tool_choice(llama_stack_client, agent_config):
def run_agent(tool_choice): def run_agent(tool_choice):
client_tool = TestClientTool() client_tool = get_boiling_point
test_agent_config = { test_agent_config = {
**agent_config, **agent_config,
@ -362,7 +316,7 @@ def test_tool_choice(llama_stack_client, agent_config):
# TODO: fix this flaky test # TODO: fix this flaky test
def xtest_override_system_message_behavior(llama_stack_client, agent_config): def xtest_override_system_message_behavior(llama_stack_client, agent_config):
client_tool = TestClientTool() client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
"instructions": "You are a pirate", "instructions": "You are a pirate",
@ -458,7 +412,6 @@ def test_rag_agent(llama_stack_client, agent_config, rag_tool_name):
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
provider_id="faiss",
) )
llama_stack_client.tool_runtime.rag_tool.insert( llama_stack_client.tool_runtime.rag_tool.insert(
documents=documents, documents=documents,
@ -587,7 +540,7 @@ def test_rag_and_code_agent(llama_stack_client, agent_config):
def test_create_turn_response(llama_stack_client, agent_config): def test_create_turn_response(llama_stack_client, agent_config):
client_tool = TestClientTool() client_tool = get_boiling_point
agent_config = { agent_config = {
**agent_config, **agent_config,
"input_shields": [], "input_shields": [],

View file

@ -117,7 +117,7 @@ def client_with_models(llama_stack_client, text_model_id, vision_model_id, embed
assert len(providers) > 0, "No inference providers found" assert len(providers) > 0, "No inference providers found"
inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"] inference_providers = [p.provider_id for p in providers if p.provider_type != "inline::sentence-transformers"]
model_ids = set(m.identifier for m in client.models.list()) model_ids = {m.identifier for m in client.models.list()}
model_ids.update(m.provider_resource_id for m in client.models.list()) model_ids.update(m.provider_resource_id for m in client.models.list())
if text_model_id and text_model_id not in model_ids: if text_model_id and text_model_id not in model_ids:

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999

View file

@ -75,6 +75,26 @@ DUMMY_IMAGE_URL = ImageContentItem(
image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image" image=ImageContentItemImage(url=ImageContentItemImageURL(uri="https://example.com/image.jpg")), type="image"
) )
DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image") DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64string"), type="image")
SUPPORTED_PROVIDERS = {"remote::nvidia"}
MODELS_SUPPORTING_MEDIA = {}
MODELS_SUPPORTING_OUTPUT_DIMENSION = {"nvidia/llama-3.2-nv-embedqa-1b-v2"}
MODELS_REQUIRING_TASK_TYPE = {
"nvidia/llama-3.2-nv-embedqa-1b-v2",
"nvidia/nv-embedqa-e5-v5",
"nvidia/nv-embedqa-mistral-7b-v2",
"snowflake/arctic-embed-l",
}
MODELS_SUPPORTING_TASK_TYPE = MODELS_REQUIRING_TASK_TYPE
def default_task_type(model_id):
"""
Some models require a task type parameter. This provides a default value for
testing those models.
"""
if model_id in MODELS_REQUIRING_TASK_TYPE:
return {"task_type": "query"}
return {}
@pytest.mark.parametrize( @pytest.mark.parametrize(
@ -88,8 +108,12 @@ DUMMY_IMAGE_BASE64 = ImageContentItem(image=ImageContentItemImage(data="base64st
"list[text]", "list[text]",
], ],
) )
def test_embedding_text(llama_stack_client, embedding_model_id, contents): def test_embedding_text(llama_stack_client, embedding_model_id, contents, inference_provider_type):
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list) assert isinstance(response.embeddings[0], list)
@ -107,9 +131,14 @@ def test_embedding_text(llama_stack_client, embedding_model_id, contents):
"list[url,string,base64,text]", "list[url,string,base64,text]",
], ],
) )
@pytest.mark.xfail(reason="Media is not supported") def test_embedding_image(llama_stack_client, embedding_model_id, contents, inference_provider_type):
def test_embedding_image(llama_stack_client, embedding_model_id, contents): if inference_provider_type not in SUPPORTED_PROVIDERS:
response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=contents) pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_MEDIA:
pytest.xfail(f"{embedding_model_id} doesn't support media")
response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, **default_task_type(embedding_model_id)
)
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents) assert len(response.embeddings) == sum(len(content) if isinstance(content, list) else 1 for content in contents)
assert isinstance(response.embeddings[0], list) assert isinstance(response.embeddings[0], list)
@ -134,9 +163,16 @@ def test_embedding_image(llama_stack_client, embedding_model_id, contents):
"short", "short",
], ],
) )
def test_embedding_truncation(llama_stack_client, embedding_model_id, text_truncation, contents): def test_embedding_truncation(
llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings( response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=contents, text_truncation=text_truncation model_id=embedding_model_id,
contents=contents,
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1 assert len(response.embeddings) == 1
@ -162,25 +198,43 @@ def test_embedding_truncation(llama_stack_client, embedding_model_id, text_trunc
"long-str", "long-str",
], ],
) )
def test_embedding_truncation_error(llama_stack_client, embedding_model_id, text_truncation, contents): def test_embedding_truncation_error(
with pytest.raises(BadRequestError) as excinfo: llama_stack_client, embedding_model_id, text_truncation, contents, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_LONG_TEXT], text_truncation=text_truncation model_id=embedding_model_id,
contents=[DUMMY_LONG_TEXT],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )
@pytest.mark.xfail(reason="Only valid for model supporting dimension reduction") def test_embedding_output_dimension(llama_stack_client, embedding_model_id, inference_provider_type):
def test_embedding_output_dimension(llama_stack_client, embedding_model_id): if inference_provider_type not in SUPPORTED_PROVIDERS:
base_response = llama_stack_client.inference.embeddings(model_id=embedding_model_id, contents=[DUMMY_STRING]) pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_OUTPUT_DIMENSION:
pytest.xfail(f"{embedding_model_id} doesn't support output_dimension")
base_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], **default_task_type(embedding_model_id)
)
test_response = llama_stack_client.inference.embeddings( test_response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], output_dimension=32 model_id=embedding_model_id,
contents=[DUMMY_STRING],
**default_task_type(embedding_model_id),
output_dimension=32,
) )
assert len(base_response.embeddings[0]) != len(test_response.embeddings[0]) assert len(base_response.embeddings[0]) != len(test_response.embeddings[0])
assert len(test_response.embeddings[0]) == 32 assert len(test_response.embeddings[0]) == 32
@pytest.mark.xfail(reason="Only valid for model supporting task type") def test_embedding_task_type(llama_stack_client, embedding_model_id, inference_provider_type):
def test_embedding_task_type(llama_stack_client, embedding_model_id): if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
if embedding_model_id not in MODELS_SUPPORTING_TASK_TYPE:
pytest.xfail(f"{embedding_model_id} doesn't support task_type")
query_embedding = llama_stack_client.inference.embeddings( query_embedding = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query" model_id=embedding_model_id, contents=[DUMMY_STRING], task_type="query"
) )
@ -199,9 +253,14 @@ def test_embedding_task_type(llama_stack_client, embedding_model_id):
"start", "start",
], ],
) )
def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation): def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_truncation, inference_provider_type):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
response = llama_stack_client.inference.embeddings( response = llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )
assert isinstance(response, EmbeddingsResponse) assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == 1 assert len(response.embeddings) == 1
@ -219,8 +278,15 @@ def test_embedding_text_truncation(llama_stack_client, embedding_model_id, text_
"right", "right",
], ],
) )
def test_embedding_text_truncation_error(llama_stack_client, embedding_model_id, text_truncation): def test_embedding_text_truncation_error(
with pytest.raises(BadRequestError) as excinfo: llama_stack_client, embedding_model_id, text_truncation, inference_provider_type
):
if inference_provider_type not in SUPPORTED_PROVIDERS:
pytest.xfail(f"{inference_provider_type} doesn't support embedding model yet")
with pytest.raises(BadRequestError):
llama_stack_client.inference.embeddings( llama_stack_client.inference.embeddings(
model_id=embedding_model_id, contents=[DUMMY_STRING], text_truncation=text_truncation model_id=embedding_model_id,
contents=[DUMMY_STRING],
text_truncation=text_truncation,
**default_task_type(embedding_model_id),
) )

View file

@ -139,7 +139,7 @@ def test_text_completion_log_probs_streaming(client_with_models, text_model_id,
"top_k": 1, "top_k": 1,
}, },
) )
streamed_content = [chunk for chunk in response] streamed_content = list(response)
for chunk in streamed_content: for chunk in streamed_content:
if chunk.delta: # if there's a token, we expect logprobs if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty" assert chunk.logprobs, "Logprobs should not be empty"
@ -405,7 +405,7 @@ def test_text_chat_completion_tool_calling_tools_not_in_request(
assert delta.tool_call.tool_name == "get_object_namespace_list" assert delta.tool_call.tool_name == "get_object_namespace_list"
if delta.type == "tool_call" and delta.parse_status == "failed": if delta.type == "tool_call" and delta.parse_status == "failed":
# expect raw message that failed to parse in tool_call # expect raw message that failed to parse in tool_call
assert type(delta.tool_call) == str assert isinstance(delta.tool_call, str)
assert len(delta.tool_call) > 0 assert len(delta.tool_call) > 0
else: else:
for tc in response.completion_message.tool_calls: for tc in response.completion_message.tool_calls:

View file

@ -42,8 +42,7 @@ def featured_models():
SUPPORTED_MODELS = { SUPPORTED_MODELS = {
"ollama": set( "ollama": {
[
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_8b_instruct.value, CoreModelId.llama3_1_8b_instruct.value,
CoreModelId.llama3_1_70b_instruct.value, CoreModelId.llama3_1_70b_instruct.value,
@ -61,10 +60,9 @@ SUPPORTED_MODELS = {
CoreModelId.llama3_3_70b_instruct.value, CoreModelId.llama3_3_70b_instruct.value,
CoreModelId.llama_guard_3_8b.value, CoreModelId.llama_guard_3_8b.value,
CoreModelId.llama_guard_3_1b.value, CoreModelId.llama_guard_3_1b.value,
] },
), "tgi": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
"tgi": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]), "vllm": {model.core_model_id.value for model in all_registered_models() if model.huggingface_repo},
"vllm": set([model.core_model_id.value for model in all_registered_models() if model.huggingface_repo]),
} }

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999

View file

@ -42,7 +42,7 @@ def code_scanner_shield_id(available_shields):
@pytest.fixture(scope="session") @pytest.fixture(scope="session")
def model_providers(llama_stack_client): def model_providers(llama_stack_client):
return set([x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"]) return {x.provider_id for x in llama_stack_client.providers.list() if x.api == "inference"}
def test_unsafe_examples(llama_stack_client, llama_guard_text_shield_id): def test_unsafe_examples(llama_stack_client, llama_guard_text_shield_id):

View file

@ -24,7 +24,6 @@ def single_entry_vector_db_registry(llama_stack_client, empty_vector_db_registry
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
provider_id="faiss",
) )
vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()] vector_dbs = [vector_db.identifier for vector_db in llama_stack_client.vector_dbs.list()]
return vector_dbs return vector_dbs
@ -121,7 +120,6 @@ def test_vector_db_insert_from_url_and_query(llama_stack_client, empty_vector_db
vector_db_id=vector_db_id, vector_db_id=vector_db_id,
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
embedding_dimension=384, embedding_dimension=384,
provider_id="faiss",
) )
# list to check memory bank is successfully registered # list to check memory bank is successfully registered

View file

@ -3,3 +3,4 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
# ruff: noqa: N999