Merge branch 'main' into fiddlecube-guard

This commit is contained in:
Kaushik Srinivasan 2025-02-16 21:32:47 -08:00 committed by GitHub
commit 7fdc50a837
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
324 changed files with 12802 additions and 3145 deletions

View file

@ -8,4 +8,3 @@
[Describe the tests you ran to verify your changes with result summaries. *Provide clear instructions so the plan can be easily re-executed.*]
[//]: # (## Documentation)
[//]: # (- [ ] Added a Changelog entry if the change is significant)

View file

@ -30,10 +30,7 @@ repos:
rev: v0.9.4
hooks:
- id: ruff
args: [
--fix,
--exit-non-zero-on-fix
]
exclude: ^llama_stack/strong_typing/.*$
- id: ruff-format
- repo: https://github.com/adamchainz/blacken-docs
@ -47,7 +44,13 @@ repos:
rev: 0.5.26
hooks:
- id: uv-export
args: ["--frozen", "--no-hashes", "--no-emit-project"]
args: [
"--frozen",
"--no-hashes",
"--no-emit-project",
"--output-file=requirements.txt"
]
files: ^pyproject\.toml$
- id: uv-sync
# - repo: https://github.com/pre-commit/mirrors-mypy

View file

@ -1,37 +0,0 @@
# Suggested config from pytorch that we can adapt
lint.select = ["B", "C", "E" , "F" , "N", "W", "B9"]
line-length = 120
# C408 ignored because we like the dict keyword argument syntax
# E501 is not flexible enough, we're using B950 instead
# N812 ignored because import torch.nn.functional as F is PyTorch convention
# N817 ignored because importing using acronyms is convention (DistributedDataParallel as DDP)
# E731 allow usage of assigning lambda expressions
# E701 let black auto-format statements on one line
# E704 let black auto-format statements on one line
lint.ignore = [
"E203", "E305", "E402", "E501", "E721", "E741", "F405", "F821", "F841",
"C408", "E302", "W291", "E303", "N812", "N817", "E731", "E701",
# These are the additional ones we started ignoring after moving to ruff. We should look into each one of them later.
"C901", "C405", "C414", "N803", "N999", "C403", "C416", "B028", "C419", "C401", "B023",
# shebang has extra meaning in fbcode lints, so I think it's not worth trying
# to line this up with executable bit
"EXE001",
# random naming hints don't need
"N802",
# these ignores are from flake8-bugbear; please fix!
"B007", "B008"
]
exclude = [
"./.git",
"./docs/*",
"./build",
"./scripts",
"./venv",
"*.pyi",
".pre-commit-config.yaml",
"*.md",
".flake8"
]

View file

@ -40,6 +40,7 @@ If you need help or guidance, comment on the issue. Issues that are extra friend
3. Ensure the test suite passes.
4. Make sure your code lints using `pre-commit`.
5. If you haven't already, complete the Contributor License Agreement ("CLA").
6. Ensure your pull request follows the [conventional commits format](https://www.conventionalcommits.org/en/v1.0.0/).
## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need
@ -98,7 +99,8 @@ $ uv sync
```
## Coding Style
* 2 spaces for indentation rather than tabs
* 4 spaces for indentation rather than tabs
* 80 character line length
* ...

View file

@ -7,13 +7,13 @@
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
Llama Stack defines and standardizes the core building blocks that simplify AI application development. It codified best practices across the Llama ecosystem. More specifically, it provides
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides
- **Unified API layer** for Inference, RAG, Agents, Tools, Safety, Evals, and Telemetry.
- **Plugin architecture** to support the rich ecosystem of implementations of the different APIs in different environments like local development, on-premises, cloud, and mobile.
- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment
- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack
- **Plugin architecture** to support the rich ecosystem of different API implementations in various environments, including local development, on-premises, cloud, and mobile.
- **Prepackaged verified distributions** which offer a one-stop solution for developers to get started quickly and reliably in any environment.
- **Multiple developer interfaces** like CLI and SDKs for Python, Typescript, iOS, and Android.
- **Standalone applications** as examples for how to build production-grade AI applications with Llama Stack.
<div style="text-align: center;">
<img
@ -25,14 +25,14 @@ Llama Stack defines and standardizes the core building blocks that simplify AI a
</div>
### Llama Stack Benefits
- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choice.
- **Consistent Experience**: With its unified APIs Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior.
- **Flexible Options**: Developers can choose their preferred infrastructure without changing APIs and enjoy flexible deployment choices.
- **Consistent Experience**: With its unified APIs, Llama Stack makes it easier to build, test, and deploy AI applications with consistent application behavior.
- **Robust Ecosystem**: Llama Stack is already integrated with distribution partners (cloud providers, hardware vendors, and AI-focused companies) that offer tailored infrastructure, software, and services for deploying Llama models.
By reducing friction and complexity, Llama Stack empowers developers to focus on what they do best: building transformative generative AI applications.
### API Providers
Here is a list of the various API providers and available distributions to developers started easily,
Here is a list of the various API providers and available distributions that can help developers get started easily with Llama Stack.
| **API Provider Builder** | **Environments** | **Agents** | **Inference** | **Memory** | **Safety** | **Telemetry** |
|:------------------------:|:----------------------:|:----------:|:-------------:|:----------:|:----------:|:-------------:|
@ -71,15 +71,15 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
You have two ways to install this repository:
1. **Install as a package**:
* **Install as a package**:
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
```bash
pip install llama-stack
```
2. **Install from source**:
* **Install from source**:
If you prefer to install from the source code, make sure you have [conda installed](https://docs.conda.io/projects/conda/en/stable).
Then, follow these steps:
Then, run the following commands:
```bash
mkdir -p ~/local
cd ~/local
@ -96,10 +96,11 @@ You have two ways to install this repository:
Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest/index.html) page for more details.
* [CLI reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html)
* Guide using `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
* Quick guide to start a Llama Stack server.
* CLI references
* [llama (server-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/index.html): Guide for using the `llama` CLI to work with Llama models (download, study prompts), and building/starting a Llama Stack distribution.
* [llama (client-side) CLI Reference](https://llama-stack.readthedocs.io/en/latest/references/llama_stack_client_cli_reference.html): Guide for using the `llama-stack-client` CLI, which allows you to query information about the distribution.
* Getting Started
* [Quick guide to start a Llama Stack server](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html).
* [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs
* The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
* A [Zero-to-Hero Guide](https://github.com/meta-llama/llama-stack/tree/main/docs/zero_to_hero_guide) that guide you through all the key components of llama stack with code samples.
@ -115,6 +116,6 @@ Please checkout our [Documentation](https://llama-stack.readthedocs.io/en/latest
| Typescript | [llama-stack-client-typescript](https://github.com/meta-llama/llama-stack-client-typescript) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
Check out our client SDKs for connecting to a Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [typescript](https://github.com/meta-llama/llama-stack-client-typescript), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.
You can find more example scripts with client SDKs to talk with the Llama Stack server in our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repo.

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -324,7 +324,7 @@
"- vector_io\n",
"container_image: null\n",
"datasets: <span style=\"font-weight: bold\">[]</span>\n",
"eval_tasks: <span style=\"font-weight: bold\">[]</span>\n",
"benchmarks: <span style=\"font-weight: bold\">[]</span>\n",
"image_name: together\n",
"metadata_store:\n",
" db_path: <span style=\"color: #800080; text-decoration-color: #800080\">/Users/ashwin/.llama/distributions/together/</span><span style=\"color: #ff00ff; text-decoration-color: #ff00ff\">registry.db</span>\n",
@ -508,7 +508,7 @@
"- vector_io\n",
"container_image: null\n",
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"eval_tasks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"image_name: together\n",
"metadata_store:\n",
" db_path: \u001b[35m/Users/ashwin/.llama/distributions/together/\u001b[0m\u001b[95mregistry.db\u001b[0m\n",
@ -3419,22 +3419,22 @@
},
{
"cell_type": "code",
"execution_count": null,
"id": "865fc5a8",
"metadata": {},
"outputs": [],
"source": [
"!pip install llama-stack-client==0.1.0"
]
},
{
"cell_type": "code",
"execution_count": null,
"execution_count": 3,
"id": "44e05e16",
"metadata": {},
"outputs": [],
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
" % Total % Received % Xferd Average Speed Time Time Time Current\n",
" Dload Upload Total Spent Left Speed\n",
"100 275k 100 275k 0 0 780k 0 --:--:-- --:--:-- --:--:-- 780k\n"
]
}
],
"source": [
"!wget https://raw.githubusercontent.com/meta-llama/llama-models/refs/heads/main/Llama_Repo.jpeg"
"!curl -O https://raw.githubusercontent.com/meta-llama/llama-models/refs/heads/main/Llama_Repo.jpeg"
]
},
{
@ -3444,6 +3444,7 @@
"metadata": {},
"outputs": [],
"source": [
"# NBVAL_SKIP\n",
"from PIL import Image\n",
"import matplotlib.pyplot as plt\n",
"\n",
@ -3580,6 +3581,7 @@
" model=LLAMA32_11B_INSTRUCT,\n",
" instructions=\"You are a helpful assistant\",\n",
" enable_session_persistence=False,\n",
" toolgroups=[],\n",
" )\n",
"\n",
" agent = Agent(client, agent_config)\n",
@ -3630,7 +3632,7 @@
"provenance": []
},
"kernelspec": {
"display_name": "toolchain",
"display_name": "master",
"language": "python",
"name": "python3"
},
@ -3644,7 +3646,7 @@
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.10.15"
"version": "3.10.16"
},
"widgets": {
"application/vnd.jupyter.widget-state+json": {

View file

@ -370,7 +370,7 @@
"- tool_runtime\n",
"datasets: <span style=\"font-weight: bold\">[]</span>\n",
"container_image: null\n",
"eval_tasks: <span style=\"font-weight: bold\">[]</span>\n",
"benchmarks: <span style=\"font-weight: bold\">[]</span>\n",
"image_name: together\n",
"memory_banks: <span style=\"font-weight: bold\">[]</span>\n",
"metadata_store:\n",
@ -551,7 +551,7 @@
"- tool_runtime\n",
"datasets: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"container_image: null\n",
"eval_tasks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"benchmarks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"image_name: together\n",
"memory_banks: \u001b[1m[\u001b[0m\u001b[1m]\u001b[0m\n",
"metadata_store:\n",

View file

@ -1,4 +1,4 @@
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/[<subdir>]/api/endpoints.py` using the `generate.py` utility.
The RFC Specification (OpenAPI format) is generated from the set of API endpoints located in `llama_stack/distribution/server/endpoints.py` using the `generate.py` utility.
Please install the following packages before running the script:
@ -6,4 +6,4 @@ Please install the following packages before running the script:
pip install python-openapi json-strong-typing fire PyYAML llama-models
```
Then simply run `sh run_openapi_generator.sh <OUTPUT_DIR>`
Then simply run `sh run_openapi_generator.sh`

View file

@ -16,18 +16,6 @@ from pathlib import Path
import fire
import ruamel.yaml as yaml
from llama_models import schema_utils
# We do some monkey-patching to ensure our definitions only use the minimal
# (json_schema_type, webmethod) definitions from the llama_models package. For
# generation though, we need the full definitions and implementations from the
# (json-strong-typing) package.
from .strong_typing.schema import json_schema_type, register_schema
schema_utils.json_schema_type = json_schema_type
schema_utils.register_schema = register_schema
from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402
from llama_stack.distribution.stack import LlamaStack # noqa: E402

View file

@ -10,9 +10,9 @@ import typing
from dataclasses import make_dataclass
from typing import Any, Dict, Set, Union
from ..strong_typing.core import JsonType
from ..strong_typing.docstring import Docstring, parse_type
from ..strong_typing.inspection import (
from llama_stack.strong_typing.core import JsonType
from llama_stack.strong_typing.docstring import Docstring, parse_type
from llama_stack.strong_typing.inspection import (
is_generic_list,
is_type_optional,
is_type_union,
@ -20,15 +20,15 @@ from ..strong_typing.inspection import (
unwrap_optional_type,
unwrap_union_types,
)
from ..strong_typing.name import python_type_to_name
from ..strong_typing.schema import (
from llama_stack.strong_typing.name import python_type_to_name
from llama_stack.strong_typing.schema import (
get_schema_identifier,
JsonSchemaGenerator,
register_schema,
Schema,
SchemaOptions,
)
from ..strong_typing.serialization import json_dump_string, object_to_json
from llama_stack.strong_typing.serialization import json_dump_string, object_to_json
from .operations import (
EndpointOperation,
@ -644,7 +644,10 @@ class Generator:
else:
callbacks = None
description = "\n".join(filter(None, [doc_string.short_description, doc_string.long_description]))
description = "\n".join(
filter(None, [doc_string.short_description, doc_string.long_description])
)
return Operation(
tags=[op.defining_class.__name__],
summary=None,
@ -654,6 +657,7 @@ class Generator:
requestBody=requestBody,
responses=responses,
callbacks=callbacks,
deprecated=True if "DEPRECATED" in op.func_name else None,
security=[] if op.public else None,
)
@ -681,6 +685,7 @@ class Generator:
raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
route = op.get_route()
route = route.replace(":path", "")
print(f"route: {route}")
if route in paths:
paths[route].update(pathItem)

View file

@ -15,7 +15,7 @@ from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from termcolor import colored
from ..strong_typing.inspection import get_signature
from llama_stack.strong_typing.inspection import get_signature
def split_prefix(
@ -130,6 +130,8 @@ class _FormatParameterExtractor:
def _get_route_parameters(route: str) -> List[str]:
extractor = _FormatParameterExtractor()
# Replace all occurrences of ":path" with empty string
route = route.replace(":path", "")
route.format_map(extractor)
return extractor.keys

View file

@ -9,7 +9,7 @@ import enum
from dataclasses import dataclass
from typing import Any, ClassVar, Dict, List, Optional, Union
from ..strong_typing.schema import JsonType, Schema, StrictJsonType
from llama_stack.strong_typing.schema import JsonType, Schema, StrictJsonType
URL = str
@ -117,6 +117,7 @@ class Operation:
requestBody: Optional[RequestBody] = None
callbacks: Optional[Dict[str, "Callback"]] = None
security: Optional[List["SecurityRequirement"]] = None
deprecated: Optional[bool] = None
@dataclass

View file

@ -9,7 +9,7 @@ import typing
from pathlib import Path
from typing import TextIO
from ..strong_typing.schema import object_to_json, StrictJsonType
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
from .generator import Generator
from .options import Options

View file

@ -41,14 +41,14 @@ system_message = {
"content": SYSTEM_PROMPT_TEMPLATE,
}
client.eval_tasks.register(
eval_task_id="meta-reference::mmmu",
client.benchmarks.register(
benchmark_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
)
response = client.eval.evaluate_rows(
task_id="meta-reference::mmmu",
benchmark_id="meta-reference::mmmu",
input_rows=eval_rows,
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
task_config={
@ -99,14 +99,14 @@ eval_rows = client.datasetio.get_rows_paginated(
```
```python
client.eval_tasks.register(
eval_task_id="meta-reference::simpleqa",
client.benchmarks.register(
benchmark_id="meta-reference::simpleqa",
dataset_id=simpleqa_dataset_id,
scoring_functions=["llm-as-judge::405b-simpleqa"],
)
response = client.eval.evaluate_rows(
task_id="meta-reference::simpleqa",
benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={
@ -156,7 +156,7 @@ agent_config = {
}
response = client.eval.evaluate_rows(
task_id="meta-reference::simpleqa",
benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={

View file

@ -10,15 +10,15 @@ Here's how to set up basic evaluation:
```python
# Create an evaluation task
response = client.eval_tasks.register(
eval_task_id="my_eval",
response = client.benchmarks.register(
benchmark_id="my_eval",
dataset_id="my_dataset",
scoring_functions=["accuracy", "relevance"],
)
# Run evaluation
job = client.eval.run_eval(
task_id="my_eval",
benchmark_id="my_eval",
task_config={
"type": "app",
"eval_candidate": {"type": "agent", "config": agent_config},
@ -26,5 +26,5 @@ job = client.eval.run_eval(
)
# Get results
result = client.eval.job_result(task_id="my_eval", job_id=job.job_id)
result = client.eval.job_result(benchmark_id="my_eval", job_id=job.job_id)
```

View file

@ -5,7 +5,7 @@ The Llama Stack Evaluation flow allows you to run evaluations on your GenAI appl
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
- `/datasetio` + `/datasets` API
- `/scoring` + `/scoring_functions` API
- `/eval` + `/eval_tasks` API
- `/eval` + `/benchmarks` API
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
@ -21,7 +21,7 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
- **Scoring**: evaluate outputs of the system.
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
- Associated with `EvalTask` resource.
- Associated with `Benchmark` resource.
Use the following decision tree to decide how to use LlamaStack Evaluation flow.

View file

@ -42,7 +42,7 @@ Some of these APIs are associated with a set of **Resources**. Here is the mappi
- **Tool Runtime** is associated with `ToolGroup` resources.
- **DatasetIO** is associated with `Dataset` resources.
- **Scoring** is associated with `ScoringFunction` resources.
- **Eval** is associated with `Model` and `EvalTask` resources.
- **Eval** is associated with `Model` and `Benchmark` resources.
Furthermore, we allow these resources to be **federated** across multiple providers. For example, you may have some Llama models served by Fireworks while others are served by AWS Bedrock. Regardless, they will all work seamlessly with the same uniform Inference API provided by Llama Stack.

View file

@ -180,12 +180,45 @@ After this step is successful, you should be able to find the built container im
### Running your Stack server
Now, let's start the Llama Stack Distribution Server. You will need the YAML configuration file which was written out at the end by the `llama stack build` step.
```
llama stack run -h
usage: llama stack run [-h] [--port PORT] [--image-name IMAGE_NAME] [--disable-ipv6] [--env KEY=VALUE] [--tls-keyfile TLS_KEYFILE]
[--tls-certfile TLS_CERTFILE] [--image-type {conda,container,venv}]
config
start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.
positional arguments:
config Path to config file to use for the run
options:
-h, --help show this help message and exit
--port PORT Port to run the server on. Defaults to 8321
--image-name IMAGE_NAME
Name of the image to run. Defaults to the current conda environment
--disable-ipv6 Disable IPv6 support
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.
--tls-keyfile TLS_KEYFILE
Path to TLS key file for HTTPS
--tls-certfile TLS_CERTFILE
Path to TLS certificate file for HTTPS
--image-type {conda,container,venv}
Image Type used during the build. This can be either conda or container or venv.
```
```
# Start using template name
llama stack run tgi
# Start using config file
llama stack run ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
# Start using a venv
llama stack run --image-type venv ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
# Start using a conda environment
llama stack run --image-type conda ~/.llama/distributions/llamastack-my-local-stack/my-local-stack-run.yaml
```
```

View file

@ -2,7 +2,7 @@
```{admonition} News
:class: tip
Llama Stack 0.1.2 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.2) for more details.
Llama Stack 0.1.3 is now available! See the [release notes](https://github.com/meta-llama/llama-stack/releases/tag/v0.1.3) for more details.
```
# Llama Stack

View file

@ -64,7 +64,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
```
```bash
$ llama-stack-client eval_tasks register \
$ llama-stack-client benchmarks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \
@ -86,7 +86,7 @@ Interactive pages for users to play with and explore Llama Stack API capabilitie
- Under the hood, it uses Llama Stack's `/providers` API to get information about the providers.
- **API Resources**: Inspect Llama Stack API resources
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `eval_tasks`, `shields`).
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `benchmarks`, `shields`).
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.

View file

@ -5,7 +5,7 @@ The Llama Stack Evaluation flow allows you to run evaluations on your GenAI appl
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
- `/datasetio` + `/datasets` API
- `/scoring` + `/scoring_functions` API
- `/eval` + `/eval_tasks` API
- `/eval` + `/benchmarks` API
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
@ -21,7 +21,7 @@ The Evaluation APIs are associated with a set of Resources as shown in the follo
- **Scoring**: evaluate outputs of the system.
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
- Associated with `EvalTask` resource.
- Associated with `Benchmark` resource.
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
@ -77,14 +77,14 @@ system_message = {
"content": SYSTEM_PROMPT_TEMPLATE,
}
client.eval_tasks.register(
eval_task_id="meta-reference::mmmu",
client.benchmarks.register(
benchmark_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}",
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
)
response = client.eval.evaluate_rows(
task_id="meta-reference::mmmu",
benchmark_id="meta-reference::mmmu",
input_rows=eval_rows,
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
task_config={
@ -135,14 +135,14 @@ eval_rows = client.datasetio.get_rows_paginated(
```
```python
client.eval_tasks.register(
eval_task_id="meta-reference::simpleqa",
client.benchmarks.register(
benchmark_id="meta-reference::simpleqa",
dataset_id=simpleqa_dataset_id,
scoring_functions=["llm-as-judge::405b-simpleqa"],
)
response = client.eval.evaluate_rows(
task_id="meta-reference::simpleqa",
benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={
@ -192,7 +192,7 @@ agent_config = {
}
response = client.eval.evaluate_rows(
task_id="meta-reference::simpleqa",
benchmark_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={
@ -281,7 +281,7 @@ The following examples give the quick steps to start running evaluations using t
#### Benchmark Evaluation CLI
Usage: There are 2 inputs necessary for running a benchmark eval
- `eval-task-id`: the identifier associated with the eval task. Each `EvalTask` is parametrized by
- `eval-task-id`: the identifier associated with the eval task. Each `Benchmark` is parametrized by
- `dataset_id`: the identifier associated with the dataset.
- `List[scoring_function_id]`: list of scoring function identifiers.
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
@ -289,7 +289,7 @@ Usage: There are 2 inputs necessary for running a benchmark eval
```
llama-stack-client eval run_benchmark <eval-task-id> \
--eval-task-config ~/eval_task_config.json \
--eval-task-config ~/benchmark_config.json \
--visualize
```
@ -309,15 +309,15 @@ llama-stack-client eval run_scoring <scoring_fn_id_1> <scoring_fn_id_2> ... <sco
--output-dir ./
```
#### Defining EvalTaskConfig
The `EvalTaskConfig` are user specified config to define:
#### Defining BenchmarkConfig
The `BenchmarkConfig` are user specified config to define:
1. `EvalCandidate` to run generation on:
- `ModelCandidate`: The model will be used for generation through LlamaStack /inference API.
- `AgentCandidate`: The agentic system specified by AgentConfig will be used for generation through LlamaStack /agents API.
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
**Example Benchmark EvalTaskConfig**
**Example Benchmark BenchmarkConfig**
```json
{
"type": "benchmark",
@ -335,7 +335,7 @@ The `EvalTaskConfig` are user specified config to define:
}
```
**Example Application EvalTaskConfig**
**Example Application BenchmarkConfig**
```json
{
"type": "app",

View file

@ -161,14 +161,14 @@ Options:
## Eval Task Management
### `llama-stack-client eval_tasks list`
### `llama-stack-client benchmarks list`
```bash
$ llama-stack-client eval_tasks list
$ llama-stack-client benchmarks list
```
### `llama-stack-client eval_tasks register`
### `llama-stack-client benchmarks register`
```bash
$ llama-stack-client eval_tasks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
$ llama-stack-client benchmarks register --eval-task-id <eval-task-id> --dataset-id <dataset-id> --scoring-functions <function1> [<function2> ...] [--provider-id <provider-id>] [--provider-eval-task-id <provider-eval-task-id>] [--metadata <metadata>]
```
Options:
@ -191,7 +191,7 @@ Options:
- `--num-examples`: Optional. Number of examples to evaluate (useful for debugging)
- `--visualize`: Optional flag. If set, visualizes evaluation results after completion
Example eval_task_config.json:
Example benchmark_config.json:
```json
{
"type": "benchmark",

View file

@ -181,8 +181,8 @@ from llama_stack_client.types import EvaluateResponse, Job
Methods:
- <code title="post /v1/eval/tasks/{task_id}/evaluations">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(task_id, \*\*<a href="src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
- <code title="post /v1/eval/tasks/{task_id}/jobs">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">run_eval</a>(task_id, \*\*<a href="src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="./src/llama_stack_client/types/job.py">Job</a></code>
- <code title="post /v1/eval/tasks/{benchmark_id}/evaluations">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">evaluate_rows</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_evaluate_rows_params.py">params</a>) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
- <code title="post /v1/eval/tasks/{benchmark_id}/jobs">client.eval.<a href="./src/llama_stack_client/resources/eval/eval.py">run_eval</a>(benchmark_id, \*\*<a href="src/llama_stack_client/types/eval_run_eval_params.py">params</a>) -> <a href="./src/llama_stack_client/types/job.py">Job</a></code>
### Jobs
@ -194,9 +194,9 @@ from llama_stack_client.types.eval import JobStatusResponse
Methods:
- <code title="get /v1/eval/tasks/{task_id}/jobs/{job_id}/result">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, task_id) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
- <code title="delete /v1/eval/tasks/{task_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, task_id) -> None</code>
- <code title="get /v1/eval/tasks/{task_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, task_id) -> Optional[JobStatusResponse]</code>
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}/result">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">retrieve</a>(job_id, \*, benchmark_id) -> <a href="./src/llama_stack_client/types/evaluate_response.py">EvaluateResponse</a></code>
- <code title="delete /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">cancel</a>(job_id, \*, benchmark_id) -> None</code>
- <code title="get /v1/eval/tasks/{benchmark_id}/jobs/{job_id}">client.eval.jobs.<a href="./src/llama_stack_client/resources/eval/jobs.py">status</a>(job_id, \*, benchmark_id) -> Optional[JobStatusResponse]</code>
## Inspect
@ -443,20 +443,20 @@ Methods:
- <code title="get /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">list</a>() -> <a href="./src/llama_stack_client/types/scoring_function_list_response.py">ScoringFunctionListResponse</a></code>
- <code title="post /v1/scoring-functions">client.scoring_functions.<a href="./src/llama_stack_client/resources/scoring_functions.py">register</a>(\*\*<a href="src/llama_stack_client/types/scoring_function_register_params.py">params</a>) -> None</code>
## EvalTasks
## Benchmarks
Types:
```python
from llama_stack_client.types import (
EvalTask,
ListEvalTasksResponse,
EvalTaskListResponse,
Benchmark,
ListBenchmarksResponse,
BenchmarkListResponse,
)
```
Methods:
- <code title="get /v1/eval-tasks/{eval_task_id}">client.eval_tasks.<a href="./src/llama_stack_client/resources/eval_tasks.py">retrieve</a>(eval_task_id) -> <a href="./src/llama_stack_client/types/eval_task.py">Optional[EvalTask]</a></code>
- <code title="get /v1/eval-tasks">client.eval_tasks.<a href="./src/llama_stack_client/resources/eval_tasks.py">list</a>() -> <a href="./src/llama_stack_client/types/eval_task_list_response.py">EvalTaskListResponse</a></code>
- <code title="post /v1/eval-tasks">client.eval_tasks.<a href="./src/llama_stack_client/resources/eval_tasks.py">register</a>(\*\*<a href="src/llama_stack_client/types/eval_task_register_params.py">params</a>) -> None</code>
- <code title="get /v1/eval-tasks/{benchmark_id}">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">retrieve</a>(benchmark_id) -> <a href="./src/llama_stack_client/types/benchmark.py">Optional[Benchmark]</a></code>
- <code title="get /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">list</a>() -> <a href="./src/llama_stack_client/types/benchmark_list_response.py">BenchmarkListResponse</a></code>
- <code title="post /v1/eval-tasks">client.benchmarks.<a href="./src/llama_stack_client/resources/benchmarks.py">register</a>(\*\*<a href="src/llama_stack_client/types/benchmark_register_params.py">params</a>) -> None</code>

View file

@ -15,29 +15,29 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, URL
from llama_stack.apis.common.content_types import URL, ContentDelta, InterleavedContent
from llama_stack.apis.inference import (
CompletionMessage,
ResponseFormat,
SamplingParams,
ToolCall,
ToolChoice,
ToolConfig,
ToolPromptFormat,
ToolResponse,
ToolResponseMessage,
UserMessage,
ToolConfig,
)
from llama_stack.apis.safety import SafetyViolation
from llama_stack.apis.tools import ToolDef
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class Attachment(BaseModel):
@ -154,7 +154,7 @@ class AgentConfigCommon(BaseModel):
output_shields: Optional[List[str]] = Field(default_factory=list)
toolgroups: Optional[List[AgentToolGroup]] = Field(default_factory=list)
client_tools: Optional[List[ToolDef]] = Field(default_factory=list)
tool_choice: Optional[ToolChoice] = Field(default=ToolChoice.auto, deprecated="use tool_config instead")
tool_choice: Optional[ToolChoice] = Field(default=None, deprecated="use tool_config instead")
tool_prompt_format: Optional[ToolPromptFormat] = Field(default=None, deprecated="use tool_config instead")
tool_config: Optional[ToolConfig] = Field(default=None)
@ -166,18 +166,20 @@ class AgentConfigCommon(BaseModel):
raise ValueError("tool_choice is deprecated. Use tool_choice in tool_config instead.")
if self.tool_prompt_format and self.tool_config.tool_prompt_format != self.tool_prompt_format:
raise ValueError("tool_prompt_format is deprecated. Use tool_prompt_format in tool_config instead.")
if self.tool_config is None:
self.tool_config = ToolConfig(
tool_choice=self.tool_choice,
tool_prompt_format=self.tool_prompt_format,
)
else:
params = {}
if self.tool_choice:
params["tool_choice"] = self.tool_choice
if self.tool_prompt_format:
params["tool_prompt_format"] = self.tool_prompt_format
self.tool_config = ToolConfig(**params)
@json_schema_type
class AgentConfig(AgentConfigCommon):
model: str
instructions: str
enable_session_persistence: bool
enable_session_persistence: Optional[bool] = False
response_format: Optional[ResponseFormat] = None
@ -333,7 +335,10 @@ class Agents(Protocol):
tool_config: Optional[ToolConfig] = None,
) -> Union[Turn, AsyncIterator[AgentTurnResponseStreamChunk]]: ...
@webmethod(route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}", method="GET")
@webmethod(
route="/agents/{agent_id}/session/{session_id}/turn/{turn_id}",
method="GET",
)
async def get_agents_turn(
self,
agent_id: str,

View file

@ -1,207 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from llama_models.llama3.api.datatypes import ToolPromptFormat
from llama_models.llama3.api.tool_utils import ToolUtils
from termcolor import cprint
from llama_stack.apis.agents import AgentTurnResponseEventType, StepType
from llama_stack.apis.common.content_types import ToolCallParseStatus
from llama_stack.apis.inference import ToolResponseMessage
from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
)
class LogEvent:
def __init__(
self,
role: Optional[str] = None,
content: str = "",
end: str = "\n",
color="white",
):
self.role = role
self.content = content
self.color = color
self.end = "\n" if end is None else end
def __str__(self):
if self.role is not None:
return f"{self.role}> {self.content}"
else:
return f"{self.content}"
def print(self, flush=True):
cprint(f"{str(self)}", color=self.color, end=self.end, flush=flush)
EventType = AgentTurnResponseEventType
class EventLogger:
async def log(
self,
event_generator,
stream=True,
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
):
previous_event_type = None
previous_step_type = None
async for chunk in event_generator:
if not hasattr(chunk, "event"):
# Need to check for custom tool first
# since it does not produce event but instead
# a Message
if isinstance(chunk, ToolResponseMessage):
yield (
chunk,
LogEvent(role="CustomTool", content=chunk.content, color="grey"),
)
continue
event = chunk.event
event_type = event.payload.event_type
if event_type in {
EventType.turn_start.value,
EventType.turn_complete.value,
}:
# Currently not logging any turn realted info
yield event, None
continue
step_type = event.payload.step_type
# handle safety
if step_type == StepType.shield_call and event_type == EventType.step_complete.value:
violation = event.payload.step_details.violation
if not violation:
yield (
event,
LogEvent(role=step_type, content="No Violation", color="magenta"),
)
else:
yield (
event,
LogEvent(
role=step_type,
content=f"{violation.metadata} {violation.user_message}",
color="red",
),
)
# handle inference
if step_type == StepType.inference:
if stream:
if event_type == EventType.step_start.value:
# TODO: Currently this event is never received
yield (
event,
LogEvent(role=step_type, content="", end="", color="yellow"),
)
elif event_type == EventType.step_progress.value:
# HACK: if previous was not step/event was not inference's step_progress
# this is the first time we are getting model inference response
# aka equivalent to step_start for inference. Hence,
# start with "Model>".
if (
previous_event_type != EventType.step_progress.value
and previous_step_type != StepType.inference
):
yield (
event,
LogEvent(role=step_type, content="", end="", color="yellow"),
)
delta = event.payload.delta
if delta.type == "tool_call":
if delta.parse_status == ToolCallParseStatus.succeeded:
yield (
event,
LogEvent(
role=None,
content=delta.tool_call,
end="",
color="cyan",
),
)
else:
yield (
event,
LogEvent(
role=None,
content=delta.text,
end="",
color="yellow",
),
)
else:
# step_complete
yield event, LogEvent(role=None, content="")
else:
# Not streaming
if event_type == EventType.step_complete.value:
response = event.payload.step_details.model_response
if response.tool_calls:
content = ToolUtils.encode_tool_call(response.tool_calls[0], tool_prompt_format)
else:
content = response.content
yield (
event,
LogEvent(
role=step_type,
content=content,
color="yellow",
),
)
# handle tool_execution
if (
step_type == StepType.tool_execution
and
# Only print tool calls and responses at the step_complete event
event_type == EventType.step_complete.value
):
details = event.payload.step_details
for t in details.tool_calls:
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{t.tool_name} Args:{t.arguments}",
color="green",
),
)
for r in details.tool_responses:
yield (
event,
LogEvent(
role=step_type,
content=f"Tool:{r.tool_name} Response:{r.content}",
color="green",
),
)
if step_type == StepType.memory_retrieval and event_type == EventType.step_complete.value:
details = event.payload.step_details
inserted_context = interleaved_content_as_str(details.inserted_context)
content = f"fetched {len(inserted_context)} bytes from {details.vector_db_ids}"
yield (
event,
LogEvent(
role=step_type,
content=content,
color="cyan",
),
)
previous_event_type = event_type
previous_step_type = step_type

View file

@ -6,7 +6,6 @@
from typing import List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.inference import (
@ -21,6 +20,7 @@ from llama_stack.apis.inference import (
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type

View file

@ -4,4 +4,4 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval_tasks import * # noqa: F401 F403
from .benchmarks import * # noqa: F401 F403

View file

@ -0,0 +1,86 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonBenchmarkFields(BaseModel):
dataset_id: str
scoring_functions: List[str]
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@json_schema_type
class Benchmark(CommonBenchmarkFields, Resource):
type: Literal[ResourceType.benchmark.value] = ResourceType.benchmark.value
@property
def benchmark_id(self) -> str:
return self.identifier
@property
def provider_benchmark_id(self) -> str:
return self.provider_resource_id
class BenchmarkInput(CommonBenchmarkFields, BaseModel):
benchmark_id: str
provider_id: Optional[str] = None
provider_benchmark_id: Optional[str] = None
class ListBenchmarksResponse(BaseModel):
data: List[Benchmark]
@runtime_checkable
class Benchmarks(Protocol):
@webmethod(route="/eval/benchmarks", method="GET")
async def list_benchmarks(self) -> ListBenchmarksResponse: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}", method="GET")
async def get_benchmark(
self,
benchmark_id: str,
) -> Optional[Benchmark]: ...
@webmethod(route="/eval/benchmarks", method="POST")
async def register_benchmark(
self,
benchmark_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
@webmethod(route="/eval-tasks", method="GET")
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse: ...
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
async def DEPRECATED_get_eval_task(
self,
eval_task_id: str,
) -> Optional[Benchmark]: ...
@webmethod(route="/eval-tasks", method="POST")
async def DEPRECATED_register_eval_task(
self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...

View file

@ -7,11 +7,11 @@
from enum import Enum
from typing import Annotated, List, Literal, Optional, Union
from llama_models.llama3.api.datatypes import ToolCall
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field, model_validator
from llama_stack.models.llama.datatypes import ToolCall
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class URL(BaseModel):

View file

@ -7,11 +7,10 @@
from enum import Enum
from typing import Any, Dict, Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.apis.common.content_types import URL
from llama_stack.schema_utils import json_schema_type
@json_schema_type

View file

@ -5,9 +5,10 @@
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class Job(BaseModel):

View file

@ -7,9 +7,10 @@
from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type
@json_schema_type
class PostTrainingMetric(BaseModel):

View file

@ -6,10 +6,11 @@
from typing import Literal, Union
from llama_models.schema_utils import json_schema_type, register_schema
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
@json_schema_type
class StringType(BaseModel):

View file

@ -6,10 +6,10 @@
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.datasets import Dataset
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type

View file

@ -6,12 +6,12 @@
from typing import Any, Dict, List, Literal, Optional, Protocol
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonDatasetFields(BaseModel):
@ -58,7 +58,7 @@ class Datasets(Protocol):
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...
@webmethod(route="/datasets/{dataset_id}", method="GET")
@webmethod(route="/datasets/{dataset_id:path}", method="GET")
async def get_dataset(
self,
dataset_id: str,
@ -67,7 +67,7 @@ class Datasets(Protocol):
@webmethod(route="/datasets", method="GET")
async def list_datasets(self) -> ListDatasetsResponse: ...
@webmethod(route="/datasets/{dataset_id}", method="DELETE")
@webmethod(route="/datasets/{dataset_id:path}", method="DELETE")
async def unregister_dataset(
self,
dataset_id: str,

View file

@ -6,7 +6,7 @@
from enum import Enum
from llama_models.schema_utils import json_schema_type
from llama_stack.schema_utils import json_schema_type
@json_schema_type
@ -28,7 +28,7 @@ class Api(Enum):
vector_dbs = "vector_dbs"
datasets = "datasets"
scoring_functions = "scoring_functions"
eval_tasks = "eval_tasks"
benchmarks = "benchmarks"
tool_groups = "tool_groups"
# built-in API

View file

@ -6,7 +6,6 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
@ -15,6 +14,7 @@ from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.inference import SamplingParams, SystemMessage
from llama_stack.apis.scoring import ScoringResult
from llama_stack.apis.scoring_functions import ScoringFnParams
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type
@ -38,19 +38,9 @@ EvalCandidate = register_schema(
@json_schema_type
class BenchmarkEvalTaskConfig(BaseModel):
class BenchmarkConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
@json_schema_type
class AppEvalTaskConfig(BaseModel):
type: Literal["app"] = "app"
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
@ -62,12 +52,6 @@ class AppEvalTaskConfig(BaseModel):
# we could optinally add any specific dataset config here
EvalTaskConfig = register_schema(
Annotated[Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")],
name="EvalTaskConfig",
)
@json_schema_type
class EvaluateResponse(BaseModel):
generations: List[Dict[str, Any]]
@ -76,27 +60,52 @@ class EvaluateResponse(BaseModel):
class Eval(Protocol):
@webmethod(route="/eval/tasks/{task_id}/jobs", method="POST")
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs", method="POST")
async def run_eval(
self,
benchmark_id: str,
task_config: BenchmarkConfig,
) -> Job: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}/evaluations", method="POST")
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
) -> EvaluateResponse: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="GET")
async def job_status(self, benchmark_id: str, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, benchmark_id: str, job_id: str) -> None: ...
@webmethod(route="/eval/benchmarks/{benchmark_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: ...
@webmethod(route="/eval/tasks/{task_id}/jobs", method="POST")
async def DEPRECATED_run_eval(
self,
task_id: str,
task_config: EvalTaskConfig,
task_config: BenchmarkConfig,
) -> Job: ...
@webmethod(route="/eval/tasks/{task_id}/evaluations", method="POST")
async def evaluate_rows(
async def DEPRECATED_evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
task_config: BenchmarkConfig,
) -> EvaluateResponse: ...
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="GET")
async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
async def DEPRECATED_job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}", method="DELETE")
async def job_cancel(self, task_id: str, job_id: str) -> None: ...
async def DEPRECATED_job_cancel(self, task_id: str, job_id: str) -> None: ...
@webmethod(route="/eval/tasks/{task_id}/jobs/{job_id}/result", method="GET")
async def job_result(self, job_id: str, task_id: str) -> EvaluateResponse: ...
async def DEPRECATED_job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...

View file

@ -1,66 +0,0 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
class CommonEvalTaskFields(BaseModel):
dataset_id: str
scoring_functions: List[str]
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@json_schema_type
class EvalTask(CommonEvalTaskFields, Resource):
type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
@property
def eval_task_id(self) -> str:
return self.identifier
@property
def provider_eval_task_id(self) -> str:
return self.provider_resource_id
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
eval_task_id: str
provider_id: Optional[str] = None
provider_eval_task_id: Optional[str] = None
class ListEvalTasksResponse(BaseModel):
data: List[EvalTask]
@runtime_checkable
class EvalTasks(Protocol):
@webmethod(route="/eval-tasks", method="GET")
async def list_eval_tasks(self) -> ListEvalTasksResponse: ...
@webmethod(route="/eval-tasks/{eval_task_id}", method="GET")
async def get_eval_task(
self,
eval_task_id: str,
) -> Optional[EvalTask]: ...
@webmethod(route="/eval-tasks", method="POST")
async def register_eval_task(
self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_eval_task_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...

View file

@ -13,11 +13,17 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.llama3.api.datatypes import (
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
from llama_stack.models.llama.datatypes import (
BuiltinTool,
SamplingParams,
StopReason,
@ -25,13 +31,8 @@ from llama_models.llama3.api.datatypes import (
ToolDefinition,
ToolPromptFormat,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field, field_validator
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent
from llama_stack.apis.models import Model
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
class LogProbConfig(BaseModel):
@ -357,7 +358,7 @@ class ChatCompletionRequest(BaseModel):
@json_schema_type
class ChatCompletionResponseStreamChunk(BaseModel):
class ChatCompletionResponseStreamChunk(MetricResponseMixin, BaseModel):
"""A chunk of a streamed chat completion response.
:param event: The event containing the new content
@ -367,7 +368,7 @@ class ChatCompletionResponseStreamChunk(BaseModel):
@json_schema_type
class ChatCompletionResponse(BaseModel):
class ChatCompletionResponse(MetricResponseMixin, BaseModel):
"""Response from a chat completion request.
:param completion_message: The complete response message

View file

@ -6,9 +6,10 @@
from typing import List, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
class ProviderInfo(BaseModel):

View file

@ -7,11 +7,11 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonModelFields(BaseModel):
@ -62,7 +62,7 @@ class Models(Protocol):
@webmethod(route="/models", method="GET")
async def list_models(self) -> ListModelsResponse: ...
@webmethod(route="/models/{model_id}", method="GET")
@webmethod(route="/models/{model_id:path}", method="GET")
async def get_model(
self,
model_id: str,
@ -78,7 +78,7 @@ class Models(Protocol):
model_type: Optional[ModelType] = None,
) -> Model: ...
@webmethod(route="/models/{model_id}", method="DELETE")
@webmethod(route="/models/{model_id:path}", method="DELETE")
async def unregister_model(
self,
model_id: str,

View file

@ -8,13 +8,13 @@ from datetime import datetime
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.common.training_types import Checkpoint
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type

View file

@ -15,7 +15,7 @@ class ResourceType(Enum):
vector_db = "vector_db"
dataset = "dataset"
scoring_function = "scoring_function"
eval_task = "eval_task"
benchmark = "benchmark"
tool = "tool"
tool_group = "tool_group"

View file

@ -7,12 +7,12 @@
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.inference import Message
from llama_stack.apis.shields import Shield
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type

View file

@ -6,10 +6,10 @@
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.scoring_functions import ScoringFn, ScoringFnParams
from llama_stack.schema_utils import json_schema_type, webmethod
# mapping of metric to value
ScoringResultRow = Dict[str, Any]

View file

@ -12,16 +12,16 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Perhaps more structure can be imposed on these functions. Maybe they could be associated
@ -134,7 +134,7 @@ class ScoringFunctions(Protocol):
@webmethod(route="/scoring-functions", method="GET")
async def list_scoring_functions(self) -> ListScoringFunctionsResponse: ...
@webmethod(route="/scoring-functions/{scoring_fn_id}", method="GET")
@webmethod(route="/scoring-functions/{scoring_fn_id:path}", method="GET")
async def get_scoring_function(self, scoring_fn_id: str, /) -> Optional[ScoringFn]: ...
@webmethod(route="/scoring-functions", method="POST")

View file

@ -6,11 +6,11 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class CommonShieldFields(BaseModel):
@ -48,7 +48,7 @@ class Shields(Protocol):
@webmethod(route="/shields", method="GET")
async def list_shields(self) -> ListShieldsResponse: ...
@webmethod(route="/shields/{identifier}", method="GET")
@webmethod(route="/shields/{identifier:path}", method="GET")
async def get_shield(self, identifier: str) -> Optional[Shield]: ...
@webmethod(route="/shields", method="POST")

View file

@ -5,14 +5,12 @@
# the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.inference import Message
from llama_stack.schema_utils import json_schema_type, webmethod
class FilteringFunction(Enum):

View file

@ -13,14 +13,16 @@ from typing import (
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
runtime_checkable,
)
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.models.llama.datatypes import Primitive
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
# Add this constant near the top of the file, after the imports
DEFAULT_TTL_DAYS = 7
@ -76,7 +78,7 @@ class EventCommon(BaseModel):
trace_id: str
span_id: str
timestamp: datetime
attributes: Optional[Dict[str, Any]] = Field(default_factory=dict)
attributes: Optional[Dict[str, Primitive]] = Field(default_factory=dict)
@json_schema_type
@ -94,6 +96,30 @@ class MetricEvent(EventCommon):
unit: str
# This is a short term solution to allow inference API to return metrics
# The ideal way to do this is to have a way for all response types to include metrics
# and all metric events logged to the telemetry API to be inlcuded with the response
# To do this, we will need to augment all response types with a metrics field.
# We have hit a blocker from stainless SDK that prevents us from doing this.
# The blocker is that if we were to augment the response types that have a data field
# in them like so
# class ListModelsResponse(BaseModel):
# metrics: Optional[List[MetricEvent]] = None
# data: List[Models]
# ...
# The client SDK will need to access the data by using a .data field, which is not
# ergonomic. Stainless SDK does support unwrapping the response type, but it
# requires that the response type to only have a single field.
# We will need a way in the client SDK to signal that the metrics are needed
# and if they are needed, the client SDK has to return the full response type
# without unwrapping it.
class MetricResponseMixin(BaseModel):
metrics: Optional[List[MetricEvent]] = None
@json_schema_type
class StructuredLogType(Enum):
SPAN_START = "span_start"
@ -199,13 +225,13 @@ class Telemetry(Protocol):
order_by: Optional[List[str]] = None,
) -> QueryTracesResponse: ...
@webmethod(route="/telemetry/traces/{trace_id}", method="GET")
@webmethod(route="/telemetry/traces/{trace_id:path}", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ...
@webmethod(route="/telemetry/traces/{trace_id}/spans/{span_id}", method="GET")
@webmethod(route="/telemetry/traces/{trace_id:path}/spans/{span_id:path}", method="GET")
async def get_span(self, trace_id: str, span_id: str) -> Span: ...
@webmethod(route="/telemetry/spans/{span_id}/tree", method="GET")
@webmethod(route="/telemetry/spans/{span_id:path}/tree", method="GET")
async def get_span_tree(
self,
span_id: str,

View file

@ -4,5 +4,5 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .tools import * # noqa: F401 F403
from .rag_tool import * # noqa: F401 F403
from .tools import * # noqa: F401 F403

View file

@ -7,12 +7,12 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Union
from llama_models.schema_utils import json_schema_type, register_schema, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated, Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
@json_schema_type

View file

@ -7,13 +7,13 @@
from enum import Enum
from typing import Any, Dict, List, Literal, Optional
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Protocol, runtime_checkable
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
from .rag_tool import RAGToolRuntime
@ -101,7 +101,7 @@ class ToolGroups(Protocol):
"""Register a tool group"""
...
@webmethod(route="/toolgroups/{toolgroup_id}", method="GET")
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="GET")
async def get_tool_group(
self,
toolgroup_id: str,
@ -117,13 +117,13 @@ class ToolGroups(Protocol):
"""List tools with optional tool group"""
...
@webmethod(route="/tools/{tool_name}", method="GET")
@webmethod(route="/tools/{tool_name:path}", method="GET")
async def get_tool(
self,
tool_name: str,
) -> Tool: ...
@webmethod(route="/toolgroups/{toolgroup_id}", method="DELETE")
@webmethod(route="/toolgroups/{toolgroup_id:path}", method="DELETE")
async def unregister_toolgroup(
self,
toolgroup_id: str,

View file

@ -6,11 +6,11 @@
from typing import List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
@json_schema_type
@ -46,7 +46,7 @@ class VectorDBs(Protocol):
@webmethod(route="/vector-dbs", method="GET")
async def list_vector_dbs(self) -> ListVectorDBsResponse: ...
@webmethod(route="/vector-dbs/{vector_db_id}", method="GET")
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="GET")
async def get_vector_db(
self,
vector_db_id: str,
@ -62,5 +62,5 @@ class VectorDBs(Protocol):
provider_vector_db_id: Optional[str] = None,
) -> VectorDB: ...
@webmethod(route="/vector-dbs/{vector_db_id}", method="DELETE")
@webmethod(route="/vector-dbs/{vector_db_id:path}", method="DELETE")
async def unregister_vector_db(self, vector_db_id: str) -> None: ...

View file

@ -10,12 +10,12 @@
# the root directory of this source tree.
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.inference import InterleavedContent
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
from llama_stack.schema_utils import json_schema_type, webmethod
class Chunk(BaseModel):

View file

@ -16,11 +16,7 @@ from pathlib import Path
from typing import Dict, List, Optional
import httpx
from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict
from rich.console import Console
from rich.progress import (
BarColumn,
@ -33,6 +29,8 @@ from rich.progress import (
from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import Model
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
class Download(Subcommand):
@ -85,8 +83,7 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
type=str,
required=False,
default="*.safetensors",
help="""
For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
help="""For source=huggingface, files matching any of the patterns are not downloaded. Defaults to ignoring
safetensors files to avoid downloading duplicate weights.
""",
)
@ -456,7 +453,7 @@ def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
# Handle comma-separated model IDs
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
from llama_models.sku_list import llama_meta_net_info, resolve_model
from llama_stack.models.llama.sku_list import llama_meta_net_info, resolve_model
from .model.safety_models import (
prompt_guard_download_info,

View file

@ -7,12 +7,11 @@
import argparse
import json
from llama_models.sku_list import resolve_model
from termcolor import colored
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_list import resolve_model
class ModelDescribe(Subcommand):
@ -35,6 +34,7 @@ class ModelDescribe(Subcommand):
"--model-id",
type=str,
required=True,
help="See `llama model list` or `llama model list --show-all` for the list of available models",
)
def _run_model_describe_cmd(self, args: argparse.Namespace) -> None:

View file

@ -6,10 +6,9 @@
import argparse
from llama_models.sku_list import all_registered_models
from llama_stack.cli.subcommand import Subcommand
from llama_stack.cli.table import print_table
from llama_stack.models.llama.sku_list import all_registered_models
class ModelList(Subcommand):
@ -38,7 +37,7 @@ class ModelList(Subcommand):
headers = [
"Model Descriptor",
"Hugging Face Repo",
"Model ID",
"Context Length",
]

View file

@ -11,7 +11,6 @@ from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.subcommand import Subcommand
@ -26,6 +25,8 @@ class ModelParser(Subcommand):
description="Work with llama models",
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="model_subcommands")
# Add sub-commands

View file

@ -8,9 +8,8 @@ import argparse
import textwrap
from io import StringIO
from llama_models.datatypes import CoreModelId, is_multimodal, model_family, ModelFamily
from llama_stack.cli.subcommand import Subcommand
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
class ModelPromptFormat(Subcommand):

View file

@ -6,12 +6,11 @@
from typing import Any, Dict, Optional
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.datatypes import SamplingParams
from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict, Field
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
class PromptGuardModel(BaseModel):
"""Make a 'fake' Model-like object for Prompt Guard. Eventually this will be removed."""

View file

@ -21,12 +21,11 @@ from prompt_toolkit.validation import Validator
from termcolor import cprint
from llama_stack.cli.table import print_table
from llama_stack.distribution.build import (
SERVER_DEPENDENCIES,
ImageType,
build_image,
get_provider_dependencies,
ImageType,
SERVER_DEPENDENCIES,
)
from llama_stack.distribution.datatypes import (
BuildConfig,

View file

@ -56,9 +56,8 @@ class StackBuild(Subcommand):
"--image-name",
type=str,
help=textwrap.dedent(
"""[for image-type=conda] Name of the conda environment to use for the build. If
not specified, currently active Conda environment will be used. If no Conda
environment is active, you must specify a name.
"""[for image-type=conda|venv] Name of the conda or virtual environment to use for
the build. If not specified, currently active Conda environment will be used if found.
"""
),
default=None,

View file

@ -17,7 +17,7 @@ class StackConfigure(Subcommand):
self.parser = subparsers.add_parser(
"configure",
prog="llama stack configure",
description="configure a llama stack distribution",
description="Configure a llama stack distribution",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()

View file

@ -21,15 +21,19 @@ class StackListProviders(Subcommand):
self._add_arguments()
self.parser.set_defaults(func=self._run_providers_list_cmd)
def _add_arguments(self):
@property
def providable_apis(self):
from llama_stack.distribution.distribution import providable_apis
api_values = [api.value for api in providable_apis()]
return [api.value for api in providable_apis()]
def _add_arguments(self):
self.parser.add_argument(
"api",
type=str,
choices=api_values,
help="API to list providers for (one of: {})".format(api_values),
choices=self.providable_apis,
nargs="?",
help="API to list providers for. List all if not specified.",
)
def _run_providers_list_cmd(self, args: argparse.Namespace) -> None:
@ -37,20 +41,29 @@ class StackListProviders(Subcommand):
from llama_stack.distribution.distribution import Api, get_provider_registry
all_providers = get_provider_registry()
providers_for_api = all_providers[Api(args.api)]
if args.api:
providers = [(args.api, all_providers[Api(args.api)])]
else:
providers = [(k.value, prov) for k, prov in all_providers.items()]
providers = [p for api, p in providers if api in self.providable_apis]
# eventually, this should query a registry at llama.meta.com/llamastack/distributions
headers = [
"API Type",
"Provider Type",
"PIP Package Dependencies",
]
rows = []
for spec in providers_for_api.values():
if spec.provider_type == "sample":
specs = [spec for p in providers for spec in p.values()]
for spec in specs:
if spec.is_sample:
continue
rows.append(
[
spec.api.value,
spec.provider_type,
",".join(spec.pip_packages),
]
@ -59,4 +72,5 @@ class StackListProviders(Subcommand):
rows,
headers,
separate_rows=True,
sort_by=(0, 1),
)

View file

@ -19,7 +19,7 @@ class StackRun(Subcommand):
self.parser = subparsers.add_parser(
"run",
prog="llama stack run",
description="""start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
description="""Start the server for a Llama Stack Distribution. You should have already built (or downloaded) and configured the distribution.""",
formatter_class=argparse.RawTextHelpFormatter,
)
self._add_arguments()
@ -65,6 +65,13 @@ class StackRun(Subcommand):
type=str,
help="Path to TLS certificate file for HTTPS",
)
self.parser.add_argument(
"--image-type",
type=str,
help="Image Type used during the build. This can be either conda or container or venv.",
choices=["conda", "container", "venv"],
default="conda",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
import importlib.resources
@ -118,11 +125,11 @@ class StackRun(Subcommand):
config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict)
if config.container_image:
if args.image_type == ImageType.container.value or config.container_image:
script = importlib.resources.files("llama_stack") / "distribution/start_container.sh"
image_name = f"distribution-{template_name}" if template_name else config.container_image
run_args = [script, image_name]
else:
elif args.image_type == ImageType.conda.value:
current_conda_env = os.environ.get("CONDA_DEFAULT_ENV")
image_name = args.image_name or current_conda_env
if not image_name:
@ -167,6 +174,15 @@ class StackRun(Subcommand):
script,
image_name,
]
else:
# else must be venv since that is the only valid option left.
current_venv = os.environ.get("VIRTUAL_ENV")
venv = args.image_name or current_venv
script = importlib.resources.files("llama_stack") / "distribution/start_venv.sh"
run_args = [
script,
venv,
]
run_args.extend([str(config_file), str(args.port)])
if args.disable_ipv6:

View file

@ -31,6 +31,8 @@ class StackParser(Subcommand):
version=f"{version('llama-stack')}",
)
self.parser.set_defaults(func=lambda args: self.parser.print_help())
subparsers = self.parser.add_subparsers(title="stack_subcommands")
# Add sub-commands

View file

@ -6,6 +6,7 @@
import re
import textwrap
from typing import Iterable
from termcolor import cprint
@ -25,13 +26,13 @@ def format_row(row, col_widths):
lines.extend(textwrap.wrap(line, width, break_long_words=False, replace_whitespace=False))
return lines
wrapped = [wrap(item, width) for item, width in zip(row, col_widths)]
wrapped = [wrap(item, width) for item, width in zip(row, col_widths, strict=False)]
max_lines = max(len(subrow) for subrow in wrapped)
lines = []
for i in range(max_lines):
line = []
for cell_lines, width in zip(wrapped, col_widths):
for cell_lines, width in zip(wrapped, col_widths, strict=False):
value = cell_lines[i] if i < len(cell_lines) else ""
line.append(value + " " * (width - len(strip_ansi_colors(value))))
lines.append("| " + (" | ".join(line)) + " |")
@ -39,20 +40,24 @@ def format_row(row, col_widths):
return "\n".join(lines)
def print_table(rows, headers=None, separate_rows: bool = False):
def print_table(rows, headers=None, separate_rows: bool = False, sort_by: Iterable[int] = tuple()):
def itemlen(item):
return max([len(line) for line in strip_ansi_colors(item).split("\n")])
rows = [[x or "" for x in row] for row in rows]
if sort_by:
rows.sort(key=lambda x: tuple(x[i] for i in sort_by))
if not headers:
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows)]
col_widths = [max(itemlen(item) for item in col) for col in zip(*rows, strict=False)]
else:
col_widths = [
max(
itemlen(header),
max(itemlen(item) for item in col),
)
for header, col in zip(headers, zip(*rows))
for header, col in zip(headers, zip(*rows, strict=False), strict=False)
]
col_widths = [min(w, 80) for w in col_widths]

View file

@ -8,6 +8,7 @@ from datetime import datetime
import pytest
import yaml
from llama_stack.distribution.configure import (
LLAMA_STACK_RUN_CONFIG_VERSION,
parse_and_maybe_upgrade_config,

View file

@ -8,7 +8,6 @@ import importlib.resources
import logging
import sys
from enum import Enum
from pathlib import Path
from typing import Dict, List
@ -16,11 +15,8 @@ from pydantic import BaseModel
from termcolor import cprint
from llama_stack.distribution.datatypes import BuildConfig, Provider
from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR
from llama_stack.distribution.utils.exec import run_command, run_with_pty
from llama_stack.providers.datatypes import Api
@ -130,7 +126,6 @@ def build_image(
args = [
script,
str(image_name),
str(build_file_path),
" ".join(normal_deps),
]

View file

@ -24,23 +24,21 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
fi
if [ "$#" -lt 3 ]; then
echo "Usage: $0 <distribution_type> <build_name> <build_file_path> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Usage: $0 <distribution_type> <build_name> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Example: $0 <distribution_type> mybuild ./my-stack-build.yaml 'numpy pandas scipy'" >&2
exit 1
fi
special_pip_deps="$4"
special_pip_deps="$3"
set -euo pipefail
build_name="$1"
env_name="llamastack-$build_name"
build_file_path="$2"
pip_dependencies="$3"
pip_dependencies="$2"
# Define color codes
RED='\033[0;31m'
GREEN='\033[0;32m'
NC='\033[0m' # No Color
# this is set if we actually create a new conda in which case we need to clean up
@ -49,34 +47,63 @@ ENVNAME=""
SCRIPT_DIR=$(dirname "$(readlink -f "$0")")
source "$SCRIPT_DIR/common.sh"
# pre-run checks to make sure we can proceed with the installation
pre_run_checks() {
local env_name="$1"
if ! is_command_available uv; then
echo "uv is not installed, trying to install it."
if ! is_command_available pip; then
echo "pip is not installed, cannot automatically install 'uv'."
echo "Follow this link to install it:"
echo "https://docs.astral.sh/uv/getting-started/installation/"
exit 1
else
pip install uv
fi
fi
# checking if an environment with the same name already exists
if [ -d "$env_name" ]; then
echo "Environment '$env_name' already exists, re-using it."
fi
}
run() {
local env_name="$1"
local pip_dependencies="$2"
local special_pip_deps="$3"
pip install uv
echo "Using virtual environment $env_name"
uv venv "$env_name"
# shellcheck source=/dev/null
source "$env_name/bin/activate"
if [ -n "$TEST_PYPI_VERSION" ]; then
# these packages are damaged in test-pypi, so install them first
uv pip install fastapi libcst
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install --extra-index-url https://test.pypi.org/simple/ \
llama-models==$TEST_PYPI_VERSION llama-stack==$TEST_PYPI_VERSION \
llama-models=="$TEST_PYPI_VERSION" llama-stack=="$TEST_PYPI_VERSION" \
$pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do
echo "$part"
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install $part
done
fi
else
# Re-installing llama-stack in the new conda environment
# Re-installing llama-stack in the new virtual environment
if [ -n "$LLAMA_STACK_DIR" ]; then
if [ ! -d "$LLAMA_STACK_DIR" ]; then
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: $LLAMA_STACK_DIR${NC}\n" >&2
printf "${RED}Warning: LLAMA_STACK_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_STACK_DIR" >&2
exit 1
fi
printf "Installing from LLAMA_STACK_DIR: $LLAMA_STACK_DIR\n"
printf "Installing from LLAMA_STACK_DIR: %s\n" "$LLAMA_STACK_DIR"
uv pip install --no-cache-dir -e "$LLAMA_STACK_DIR"
else
uv pip install --no-cache-dir llama-stack
@ -84,26 +111,31 @@ run() {
if [ -n "$LLAMA_MODELS_DIR" ]; then
if [ ! -d "$LLAMA_MODELS_DIR" ]; then
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: $LLAMA_MODELS_DIR${NC}\n" >&2
printf "${RED}Warning: LLAMA_MODELS_DIR is set but directory does not exist: %s${NC}\n" "$LLAMA_MODELS_DIR" >&2
exit 1
fi
printf "Installing from LLAMA_MODELS_DIR: $LLAMA_MODELS_DIR\n"
printf "Installing from LLAMA_MODELS_DIR: %s\n" "$LLAMA_MODELS_DIR"
uv pip uninstall llama-models
uv pip install --no-cache-dir -e "$LLAMA_MODELS_DIR"
fi
# Install pip dependencies
printf "Installing pip dependencies\n"
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install $pip_dependencies
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do
echo "$part"
# shellcheck disable=SC2086
# we are building a command line so word splitting is expected
uv pip install $part
done
fi
fi
}
pre_run_checks "$env_name"
run "$env_name" "$pip_dependencies" "$special_pip_deps"

View file

@ -5,18 +5,16 @@
# the root directory of this source tree.
import inspect
import json
from collections.abc import AsyncIterator
from enum import Enum
from typing import Any, get_args, get_origin, Type, Union
from typing import Any, Type, Union, get_args, get_origin
import httpx
from pydantic import BaseModel, parse_obj_as
from termcolor import cprint
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.providers.datatypes import RemoteProviderConfig
_CLIENT_CLASSES = {}
@ -188,33 +186,3 @@ def extract_async_iterator_type(type_hint):
inner_args = get_args(arg)
return inner_args[0]
return None
async def example(model: str = None):
from llama_stack.apis.inference import Inference, UserMessage # noqa: F403
from llama_stack.apis.inference.event_logger import EventLogger
client_class = create_api_client_class(Inference)
client = client_class("http://localhost:5003")
if not model:
model = "Llama3.2-3B-Instruct"
message = UserMessage(content="hello world, write me a 2 sentence poem about the moon")
cprint(f"User>{message.content}", "green")
stream = True
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,
)
async for log in EventLogger().log(iterator):
log.print()
if __name__ == "__main__":
import asyncio
asyncio.run(example())

View file

@ -38,3 +38,8 @@ setup_cleanup_handlers() {
conda deactivate
}
# check if a command is present
is_command_available() {
command -v "$1" &>/dev/null
}

View file

@ -5,12 +5,11 @@
# the root directory of this source tree.
import logging
import textwrap
from typing import Any, Dict
from llama_stack.distribution.datatypes import (
DistributionSpec,
LLAMA_STACK_RUN_CONFIG_VERSION,
DistributionSpec,
Provider,
StackRunConfig,
)
@ -20,7 +19,6 @@ from llama_stack.distribution.distribution import (
)
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.distribution.utils.prompt_for_config import prompt_for_config
from llama_stack.providers.datatypes import Api, ProviderSpec
logger = logging.getLogger(__name__)

View file

@ -8,10 +8,10 @@ from typing import Annotated, Any, Dict, List, Optional, Union
from pydantic import BaseModel, Field
from llama_stack.apis.benchmarks import Benchmark, BenchmarkInput
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Dataset, DatasetInput
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTask, EvalTaskInput
from llama_stack.apis.inference import Inference
from llama_stack.apis.models import Model, ModelInput
from llama_stack.apis.safety import Safety
@ -37,7 +37,7 @@ RoutableObject = Union[
VectorDB,
Dataset,
ScoringFn,
EvalTask,
Benchmark,
Tool,
ToolGroup,
]
@ -50,7 +50,7 @@ RoutableObjectWithProvider = Annotated[
VectorDB,
Dataset,
ScoringFn,
EvalTask,
Benchmark,
Tool,
ToolGroup,
],
@ -173,7 +173,7 @@ a default SQLite store will be used.""",
vector_dbs: List[VectorDBInput] = Field(default_factory=list)
datasets: List[DatasetInput] = Field(default_factory=list)
scoring_fns: List[ScoringFnInput] = Field(default_factory=list)
eval_tasks: List[EvalTaskInput] = Field(default_factory=list)
benchmarks: List[BenchmarkInput] = Field(default_factory=list)
tool_groups: List[ToolGroupInput] = Field(default_factory=list)
server: ServerConfig = Field(

View file

@ -44,7 +44,7 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
router_api=Api.scoring,
),
AutoRoutedApiInfo(
routing_table_api=Api.eval_tasks,
routing_table_api=Api.benchmarks,
router_api=Api.eval,
),
AutoRoutedApiInfo(

View file

@ -82,3 +82,6 @@ class DistributionInspectImpl(Inspect):
async def version(self) -> VersionInfo:
return VersionInfo(version=version("llama-stack"))
async def shutdown(self) -> None:
pass

View file

@ -13,10 +13,21 @@ import re
from concurrent.futures import ThreadPoolExecutor
from enum import Enum
from pathlib import Path
from typing import Any, get_args, get_origin, Optional, TypeVar
from typing import Any, Optional, TypeVar, get_args, get_origin
import httpx
import yaml
from llama_stack_client import (
NOT_GIVEN,
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
from llama_stack.distribution.build import print_pip_install_help
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
@ -35,17 +46,6 @@ from llama_stack.providers.utils.telemetry.tracing import (
setup_logger,
start_trace,
)
from llama_stack_client import (
APIResponse,
AsyncAPIResponse,
AsyncLlamaStackClient,
AsyncStream,
LlamaStackClient,
NOT_GIVEN,
)
from pydantic import BaseModel, TypeAdapter
from rich.console import Console
from termcolor import cprint
T = TypeVar("T")
@ -231,7 +231,13 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient):
def _convert_path_to_regex(path: str) -> str:
# Convert {param} to named capture groups
pattern = re.sub(r"{(\w+)}", r"(?P<\1>[^/]+)", path)
# handle {param:path} as well which allows for forward slashes in the param value
pattern = re.sub(
r"{(\w+)(?::path)?}",
lambda m: f"(?P<{m.group(1)}>{'[^/]+' if not m.group(0).endswith(':path') else '.+'})",
path,
)
return f"^{pattern}$"
for api, api_endpoints in endpoints.items():

View file

@ -9,10 +9,10 @@ import logging
from typing import Any, Dict, List, Set
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
@ -37,8 +37,8 @@ from llama_stack.distribution.store import DistributionRegistry
from llama_stack.distribution.utils.dynamic import instantiate_class_type
from llama_stack.providers.datatypes import (
Api,
BenchmarksProtocolPrivate,
DatasetsProtocolPrivate,
EvalTasksProtocolPrivate,
InlineProviderSpec,
ModelsProtocolPrivate,
ProviderSpec,
@ -73,7 +73,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.scoring: Scoring,
Api.scoring_functions: ScoringFunctions,
Api.eval: Eval,
Api.eval_tasks: EvalTasks,
Api.benchmarks: Benchmarks,
Api.post_training: PostTraining,
Api.tool_groups: ToolGroups,
Api.tool_runtime: ToolRuntime,
@ -92,7 +92,7 @@ def additional_protocols_map() -> Dict[Api, Any]:
ScoringFunctions,
Api.scoring_functions,
),
Api.eval: (EvalTasksProtocolPrivate, EvalTasks, Api.eval_tasks),
Api.eval: (BenchmarksProtocolPrivate, Benchmarks, Api.benchmarks),
}

View file

@ -7,13 +7,12 @@
from typing import Any, Dict
from llama_stack.distribution.datatypes import RoutedProtocol
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
from .routing_tables import (
BenchmarksRoutingTable,
DatasetsRoutingTable,
EvalTasksRoutingTable,
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
ShieldsRoutingTable,
@ -34,7 +33,7 @@ async def get_routing_table_impl(
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
"eval_tasks": EvalTasksRoutingTable,
"benchmarks": BenchmarksRoutingTable,
"tool_groups": ToolGroupsRoutingTable,
}

View file

@ -6,12 +6,11 @@
from typing import Any, AsyncGenerator, Dict, List, Optional
from llama_stack.apis.common.content_types import InterleavedContent, URL
from llama_stack.apis.common.content_types import URL, InterleavedContent
from llama_stack.apis.datasetio import DatasetIO, PaginatedRowsResult
from llama_stack.apis.eval import (
AppEvalTaskConfig,
BenchmarkConfig,
Eval,
EvalTaskConfig,
EvaluateResponse,
Job,
JobStatus,
@ -347,23 +346,23 @@ class EvalRouter(Eval):
async def run_eval(
self,
task_id: str,
task_config: AppEvalTaskConfig,
benchmark_id: str,
task_config: BenchmarkConfig,
) -> Job:
return await self.routing_table.get_provider_impl(task_id).run_eval(
task_id=task_id,
return await self.routing_table.get_provider_impl(benchmark_id).run_eval(
benchmark_id=benchmark_id,
task_config=task_config,
)
async def evaluate_rows(
self,
task_id: str,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: EvalTaskConfig,
task_config: BenchmarkConfig,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).evaluate_rows(
task_id=task_id,
return await self.routing_table.get_provider_impl(benchmark_id).evaluate_rows(
benchmark_id=benchmark_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
@ -371,30 +370,72 @@ class EvalRouter(Eval):
async def job_status(
self,
task_id: str,
benchmark_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.routing_table.get_provider_impl(task_id).job_status(task_id, job_id)
return await self.routing_table.get_provider_impl(benchmark_id).job_status(benchmark_id, job_id)
async def job_cancel(
self,
task_id: str,
benchmark_id: str,
job_id: str,
) -> None:
await self.routing_table.get_provider_impl(task_id).job_cancel(
task_id,
await self.routing_table.get_provider_impl(benchmark_id).job_cancel(
benchmark_id,
job_id,
)
async def job_result(
self,
benchmark_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(benchmark_id).job_result(
benchmark_id,
job_id,
)
async def DEPRECATED_run_eval(
self,
task_id: str,
task_config: BenchmarkConfig,
) -> Job:
return await self.run_eval(benchmark_id=task_id, task_config=task_config)
async def DEPRECATED_evaluate_rows(
self,
task_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
task_config: BenchmarkConfig,
) -> EvaluateResponse:
return await self.evaluate_rows(
benchmark_id=task_id,
input_rows=input_rows,
scoring_functions=scoring_functions,
task_config=task_config,
)
async def DEPRECATED_job_status(
self,
task_id: str,
job_id: str,
) -> Optional[JobStatus]:
return await self.job_status(benchmark_id=task_id, job_id=job_id)
async def DEPRECATED_job_cancel(
self,
task_id: str,
job_id: str,
) -> None:
return await self.job_cancel(benchmark_id=task_id, job_id=job_id)
async def DEPRECATED_job_result(
self,
task_id: str,
job_id: str,
) -> EvaluateResponse:
return await self.routing_table.get_provider_impl(task_id).job_result(
task_id,
job_id,
)
return await self.job_result(benchmark_id=task_id, job_id=job_id)
class ToolRuntimeRouter(ToolRuntime):

View file

@ -4,14 +4,15 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import Any, Dict, List, Optional
from pydantic import TypeAdapter
from llama_stack.apis.benchmarks import Benchmark, Benchmarks, ListBenchmarksResponse
from llama_stack.apis.common.content_types import URL
from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.datasets import Dataset, Datasets, ListDatasetsResponse
from llama_stack.apis.eval_tasks import EvalTask, EvalTasks, ListEvalTasksResponse
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
from llama_stack.apis.resource import ResourceType
from llama_stack.apis.scoring_functions import (
@ -38,6 +39,8 @@ from llama_stack.distribution.datatypes import (
from llama_stack.distribution.store import DistributionRegistry
from llama_stack.providers.datatypes import Api, RoutingTable
logger = logging.getLogger(__name__)
def get_impl_api(p: Any) -> Api:
return p.__provider_spec__.api
@ -60,7 +63,7 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> Routable
elif api == Api.scoring:
return await p.register_scoring_function(obj)
elif api == Api.eval:
return await p.register_eval_task(obj)
return await p.register_benchmark(obj)
elif api == Api.tool_runtime:
return await p.register_tool(obj)
else:
@ -121,7 +124,7 @@ class CommonRoutingTableImpl(RoutingTable):
scoring_functions = await p.list_scoring_functions()
await add_objects(scoring_functions, pid, ScoringFn)
elif api == Api.eval:
p.eval_task_store = self
p.benchmark_store = self
elif api == Api.tool_runtime:
p.tool_store = self
@ -141,8 +144,8 @@ class CommonRoutingTableImpl(RoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
elif isinstance(self, EvalTasksRoutingTable):
return ("Eval", "eval_task")
elif isinstance(self, BenchmarksRoutingTable):
return ("Eval", "benchmark")
elif isinstance(self, ToolGroupsRoutingTable):
return ("Tools", "tool")
else:
@ -428,20 +431,20 @@ class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, ScoringFunctions):
await self.register_object(scoring_fn)
class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
async def list_eval_tasks(self) -> ListEvalTasksResponse:
return ListEvalTasksResponse(data=await self.get_all_with_type("eval_task"))
class BenchmarksRoutingTable(CommonRoutingTableImpl, Benchmarks):
async def list_benchmarks(self) -> ListBenchmarksResponse:
return ListBenchmarksResponse(data=await self.get_all_with_type("benchmark"))
async def get_eval_task(self, eval_task_id: str) -> Optional[EvalTask]:
return await self.get_object_by_identifier("eval_task", eval_task_id)
async def get_benchmark(self, benchmark_id: str) -> Optional[Benchmark]:
return await self.get_object_by_identifier("benchmark", benchmark_id)
async def register_eval_task(
async def register_benchmark(
self,
eval_task_id: str,
benchmark_id: str,
dataset_id: str,
scoring_functions: List[str],
metadata: Optional[Dict[str, Any]] = None,
provider_eval_task_id: Optional[str] = None,
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
) -> None:
if metadata is None:
@ -453,17 +456,46 @@ class EvalTasksRoutingTable(CommonRoutingTableImpl, EvalTasks):
raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id."
)
if provider_eval_task_id is None:
provider_eval_task_id = eval_task_id
eval_task = EvalTask(
identifier=eval_task_id,
if provider_benchmark_id is None:
provider_benchmark_id = benchmark_id
benchmark = Benchmark(
identifier=benchmark_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_id=provider_id,
provider_resource_id=provider_eval_task_id,
provider_resource_id=provider_benchmark_id,
)
await self.register_object(benchmark)
async def DEPRECATED_list_eval_tasks(self) -> ListBenchmarksResponse:
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.list_benchmarks()
async def DEPRECATED_get_eval_task(
self,
eval_task_id: str,
) -> Optional[Benchmark]:
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.get_benchmark(eval_task_id)
async def DEPRECATED_register_eval_task(
self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_benchmark_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None:
logger.warning("DEPRECATED: Use /eval/benchmarks instead")
return await self.register_benchmark(
benchmark_id=eval_task_id,
dataset_id=dataset_id,
scoring_functions=scoring_functions,
metadata=metadata,
provider_benchmark_id=provider_benchmark_id,
)
await self.register_object(eval_task)
class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
@ -537,3 +569,6 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
for tool in tools:
await self.unregister_object(tool)
await self.unregister_object(tool_group)
async def shutdown(self) -> None:
pass

View file

@ -10,11 +10,8 @@ from typing import Dict, List
from pydantic import BaseModel
from llama_stack.apis.tools import RAGToolRuntime, SpecialToolGroup
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from llama_stack.distribution.resolver import api_protocol_map
from llama_stack.providers.datatypes import Api

View file

@ -9,6 +9,7 @@ import asyncio
import functools
import inspect
import json
import logging
import os
import signal
import sys
@ -20,7 +21,8 @@ from pathlib import Path
from typing import Any, List, Union
import yaml
from fastapi import Body, FastAPI, HTTPException, Path as FastapiPath, Request
from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from pydantic import BaseModel, ValidationError
@ -52,6 +54,9 @@ from .endpoints import get_all_api_endpoints
REPO_ROOT = Path(__file__).parent.parent.parent.parent
logging.basicConfig(level=logging.INFO, format="%(levelname)s %(asctime)s %(name)s:%(lineno)d: %(message)s")
logger = logging.getLogger(__name__)
def warn_with_traceback(message, category, filename, lineno, file=None, line=None):
log = file if hasattr(file, "write") else sys.stderr
@ -112,21 +117,69 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)
def handle_sigint(app, *args, **kwargs):
print("SIGINT or CTRL-C detected. Exiting gracefully...")
def handle_signal(app, signum, _) -> None:
"""
Handle incoming signals and initiate a graceful shutdown of the application.
async def run_shutdown():
for impl in app.__llama_stack_impls__.values():
print(f"Shutting down {impl}")
await impl.shutdown()
This function is intended to be used as a signal handler for various signals
(e.g., SIGINT, SIGTERM). Upon receiving a signal, it will print a message
indicating the received signal and initiate a shutdown process.
asyncio.run(run_shutdown())
Args:
app: The application instance containing implementations to be shut down.
signum (int): The signal number received.
frame: The current stack frame (not used in this function).
loop = asyncio.get_event_loop()
for task in asyncio.all_tasks(loop):
task.cancel()
The shutdown process involves:
- Shutting down all implementations registered in the application.
- Gathering all running asyncio tasks.
- Cancelling all gathered tasks.
- Waiting for all tasks to finish.
- Stopping the event loop.
loop.stop()
Note:
This function schedules the shutdown process as an asyncio task and does
not block the current execution.
"""
signame = signal.Signals(signum).name
print(f"Received signal {signame} ({signum}). Exiting gracefully...")
async def shutdown():
try:
# Gracefully shut down implementations
for impl in app.__llama_stack_impls__.values():
impl_name = impl.__class__.__name__
logger.info("Shutting down %s", impl_name)
try:
if hasattr(impl, "shutdown"):
await asyncio.wait_for(impl.shutdown(), timeout=5)
else:
logger.warning("No shutdown method for %s", impl_name)
except asyncio.TimeoutError:
logger.exception("Shutdown timeout for %s ", impl_name, exc_info=True)
except Exception as e:
logger.exception("Failed to shutdown %s: %s", impl_name, {e})
# Gather all running tasks
loop = asyncio.get_running_loop()
tasks = [task for task in asyncio.all_tasks(loop) if task is not asyncio.current_task()]
# Cancel all tasks
for task in tasks:
task.cancel()
# Wait for all tasks to finish
try:
await asyncio.wait_for(asyncio.gather(*tasks, return_exceptions=True), timeout=10)
except asyncio.TimeoutError:
logger.exception("Timeout while waiting for tasks to finish")
except asyncio.CancelledError:
pass
finally:
loop.stop()
loop = asyncio.get_running_loop()
loop.create_task(shutdown())
@asynccontextmanager
@ -386,7 +439,8 @@ def main():
print("")
app.exception_handler(RequestValidationError)(global_exception_handler)
app.exception_handler(Exception)(global_exception_handler)
signal.signal(signal.SIGINT, functools.partial(handle_sigint, app))
signal.signal(signal.SIGINT, functools.partial(handle_signal, app))
signal.signal(signal.SIGTERM, functools.partial(handle_signal, app))
app.__llama_stack_impls__ = impls

View file

@ -15,10 +15,10 @@ from termcolor import colored
from llama_stack.apis.agents import Agents
from llama_stack.apis.batch_inference import BatchInference
from llama_stack.apis.benchmarks import Benchmarks
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.eval import Eval
from llama_stack.apis.eval_tasks import EvalTasks
from llama_stack.apis.inference import Inference
from llama_stack.apis.inspect import Inspect
from llama_stack.apis.models import Models
@ -53,7 +53,7 @@ class LlamaStack(
PostTraining,
VectorIO,
Eval,
EvalTasks,
Benchmarks,
Scoring,
ScoringFunctions,
DatasetIO,
@ -78,7 +78,7 @@ RESOURCES = [
"register_scoring_function",
"list_scoring_functions",
),
("eval_tasks", Api.eval_tasks, "register_eval_task", "list_eval_tasks"),
("benchmarks", Api.benchmarks, "register_benchmark", "list_benchmarks"),
("tool_groups", Api.tool_groups, "register_tool_group", "list_tool_groups"),
]

View file

@ -0,0 +1,71 @@
#!/bin/bash
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
set -euo pipefail
RED='\033[0;31m'
NC='\033[0m' # No Color
error_handler() {
echo "Error occurred in script at line: ${1}" >&2
exit 1
}
trap 'error_handler ${LINENO}' ERR
if [ $# -lt 3 ]; then
echo "Usage: $0 <venv_path> <yaml_config> <port> <script_args...>"
exit 1
fi
venv_path="$1"
shift
yaml_config="$1"
shift
port="$1"
shift
# Initialize env_vars as an empty array
env_vars=""
other_args=""
# Process environment variables from --env arguments
while [[ $# -gt 0 ]]; do
case "$1" in
--env)
if [[ -n "$2" ]]; then
env_vars="$env_vars --env $2"
shift 2
else
echo -e "${RED}Error: --env requires a KEY=VALUE argument${NC}" >&2
exit 1
fi
;;
*)
other_args="$other_args $1"
shift
;;
esac
done
# Activate virtual environment
if [ ! -d "$venv_path" ]; then
echo -e "${RED}Error: Virtual environment not found at $venv_path${NC}" >&2
exit 1
fi
source "$venv_path/bin/activate"
set -x
python -m llama_stack.distribution.server.server \
--yaml-config "$yaml_config" \
--port "$port" \
$env_vars \
$other_args

View file

@ -8,9 +8,9 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.inference import Model
from llama_stack.apis.vector_dbs import VectorDB
from llama_stack.distribution.store.registry import (
CachedDiskDistributionRegistry,
DiskDistributionRegistry,

View file

@ -26,7 +26,7 @@ $ llama-stack-client datasets register \
```
```bash
$ llama-stack-client eval_tasks register \
$ llama-stack-client benchmarks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
from typing import Optional
from llama_stack_client import LlamaStackClient

View file

@ -8,12 +8,12 @@ import streamlit as st
from modules.api import llama_stack_api
def eval_tasks():
# Eval Tasks Section
st.header("Eval Tasks")
def benchmarks():
# Benchmarks Section
st.header("Benchmarks")
eval_tasks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.eval_tasks.list()}
benchmarks_info = {d.identifier: d.to_dict() for d in llama_stack_api.client.benchmarks.list()}
if len(eval_tasks_info) > 0:
selected_eval_task = st.selectbox("Select an eval task", list(eval_tasks_info.keys()), key="eval_task_inspect")
st.json(eval_tasks_info[selected_eval_task], expanded=True)
if len(benchmarks_info) > 0:
selected_benchmark = st.selectbox("Select an eval task", list(benchmarks_info.keys()), key="benchmark_inspect")
st.json(benchmarks_info[selected_benchmark], expanded=True)

View file

@ -4,13 +4,12 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from page.distribution.benchmarks import benchmarks
from page.distribution.datasets import datasets
from page.distribution.eval_tasks import eval_tasks
from page.distribution.models import models
from page.distribution.scoring_functions import scoring_functions
from page.distribution.shields import shields
from page.distribution.vector_dbs import vector_dbs
from streamlit_option_menu import option_menu
@ -21,7 +20,7 @@ def resources_page():
"Shields",
"Scoring Functions",
"Datasets",
"Eval Tasks",
"Benchmarks",
]
icons = ["magic", "memory", "shield", "file-bar-graph", "database", "list-task"]
selected_resource = option_menu(
@ -35,8 +34,8 @@ def resources_page():
},
},
)
if selected_resource == "Eval Tasks":
eval_tasks()
if selected_resource == "Benchmarks":
benchmarks()
elif selected_resource == "Vector Databases":
vector_dbs()
elif selected_resource == "Datasets":

View file

@ -8,7 +8,6 @@ import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api
from modules.utils import process_dataset

View file

@ -7,34 +7,32 @@
import json
import pandas as pd
import streamlit as st
from modules.api import llama_stack_api
def select_eval_task_1():
# Select Eval Tasks
def select_benchmark_1():
# Select Benchmarks
st.subheader("1. Choose An Eval Task")
eval_tasks = llama_stack_api.client.eval_tasks.list()
eval_tasks = {et.identifier: et for et in eval_tasks}
eval_tasks_names = list(eval_tasks.keys())
selected_eval_task = st.selectbox(
benchmarks = llama_stack_api.client.benchmarks.list()
benchmarks = {et.identifier: et for et in benchmarks}
benchmarks_names = list(benchmarks.keys())
selected_benchmark = st.selectbox(
"Choose an eval task.",
options=eval_tasks_names,
options=benchmarks_names,
help="Choose an eval task. Each eval task is parameterized by a dataset, and list of scoring functions.",
)
with st.expander("View Eval Task"):
st.json(eval_tasks[selected_eval_task], expanded=True)
st.json(benchmarks[selected_benchmark], expanded=True)
st.session_state["selected_eval_task"] = selected_eval_task
st.session_state["eval_tasks"] = eval_tasks
st.session_state["selected_benchmark"] = selected_benchmark
st.session_state["benchmarks"] = benchmarks
if st.button("Confirm", key="confirm_1"):
st.session_state["selected_eval_task_1_next"] = True
st.session_state["selected_benchmark_1_next"] = True
def define_eval_candidate_2():
if not st.session_state.get("selected_eval_task_1_next", None):
if not st.session_state.get("selected_benchmark_1_next", None):
return
st.subheader("2. Define Eval Candidate")
@ -163,11 +161,11 @@ def run_evaluation_3():
Review the configurations that will be used for this evaluation run, make any necessary changes, and then click the "Run Evaluation" button.
"""
)
selected_eval_task = st.session_state["selected_eval_task"]
eval_tasks = st.session_state["eval_tasks"]
selected_benchmark = st.session_state["selected_benchmark"]
benchmarks = st.session_state["benchmarks"]
eval_candidate = st.session_state["eval_candidate"]
dataset_id = eval_tasks[selected_eval_task].dataset_id
dataset_id = benchmarks[selected_benchmark].dataset_id
rows = llama_stack_api.client.datasetio.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
@ -182,16 +180,16 @@ def run_evaluation_3():
help="Number of examples from the dataset to evaluate. ",
)
eval_task_config = {
benchmark_config = {
"type": "benchmark",
"eval_candidate": eval_candidate,
"scoring_params": {},
}
with st.expander("View Evaluation Task", expanded=True):
st.json(eval_tasks[selected_eval_task], expanded=True)
st.json(benchmarks[selected_benchmark], expanded=True)
with st.expander("View Evaluation Task Configuration", expanded=True):
st.json(eval_task_config, expanded=True)
st.json(benchmark_config, expanded=True)
# Add run button and handle evaluation
if st.button("Run Evaluation"):
@ -211,10 +209,10 @@ def run_evaluation_3():
progress_bar.progress(progress, text=progress_text)
# Run evaluation for current row
eval_res = llama_stack_api.client.eval.evaluate_rows(
task_id=selected_eval_task,
benchmark_id=selected_benchmark,
input_rows=[r],
scoring_functions=eval_tasks[selected_eval_task].scoring_functions,
task_config=eval_task_config,
scoring_functions=benchmarks[selected_benchmark].scoring_functions,
task_config=benchmark_config,
)
for k in r.keys():
@ -227,7 +225,7 @@ def run_evaluation_3():
output_res[k] = []
output_res[k].append(eval_res.generations[0][k])
for scoring_fn in eval_tasks[selected_eval_task].scoring_functions:
for scoring_fn in benchmarks[selected_benchmark].scoring_functions:
if scoring_fn not in output_res:
output_res[scoring_fn] = []
output_res[scoring_fn].append(eval_res.scores[scoring_fn].score_rows[0])
@ -247,7 +245,7 @@ def native_evaluation_page():
st.set_page_config(page_title="Evaluations (Generation + Scoring)", page_icon="🦙")
st.title("📊 Evaluations (Generation + Scoring)")
select_eval_task_1()
select_benchmark_1()
define_eval_candidate_2()
run_evaluation_3()

View file

@ -9,7 +9,6 @@ from llama_stack_client.lib.agents.agent import Agent
from llama_stack_client.lib.agents.event_logger import EventLogger
from llama_stack_client.types.agent_create_params import AgentConfig
from llama_stack_client.types.memory_insert_params import Document
from modules.api import llama_stack_api
from modules.utils import data_url_from_file

View file

@ -7,7 +7,6 @@
import os
from pathlib import Path
LLAMA_STACK_CONFIG_DIR = Path(os.getenv("LLAMA_STACK_CONFIG_DIR", os.path.expanduser("~/.llama/")))
DISTRIBS_BASE_DIR = LLAMA_STACK_CONFIG_DIR / "distributions"

View file

@ -8,13 +8,11 @@ import inspect
import json
import logging
from enum import Enum
from typing import Any, get_args, get_origin, List, Literal, Optional, Type, Union
from typing import Any, List, Literal, Optional, Type, Union, get_args, get_origin
from pydantic import BaseModel
from pydantic.fields import FieldInfo
from pydantic_core import PydanticUndefinedType
from typing_extensions import Annotated
log = logging.getLogger(__name__)

View file

@ -0,0 +1,277 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from enum import Enum
from typing import Any, Dict, Literal, Optional, Union
# import all for backwards compatibility
from llama_models.datatypes import * # noqa: F403
from pydantic import BaseModel, ConfigDict, Field, field_validator
from typing_extensions import Annotated
from llama_stack.schema_utils import json_schema_type, register_schema
register_schema(ToolCall)
@json_schema_type
class ToolParamDefinition(BaseModel):
param_type: str
description: Optional[str] = None
required: Optional[bool] = True
default: Optional[Any] = None
@json_schema_type
class ToolDefinition(BaseModel):
tool_name: Union[BuiltinTool, str]
description: Optional[str] = None
parameters: Optional[Dict[str, ToolParamDefinition]] = None
@field_validator("tool_name", mode="before")
@classmethod
def validate_field(cls, v):
if isinstance(v, str):
try:
return BuiltinTool(v)
except ValueError:
return v
return v
@json_schema_type
class GreedySamplingStrategy(BaseModel):
type: Literal["greedy"] = "greedy"
@json_schema_type
class TopPSamplingStrategy(BaseModel):
type: Literal["top_p"] = "top_p"
temperature: Optional[float] = Field(..., gt=0.0)
top_p: Optional[float] = 0.95
@json_schema_type
class TopKSamplingStrategy(BaseModel):
type: Literal["top_k"] = "top_k"
top_k: int = Field(..., ge=1)
SamplingStrategy = register_schema(
Annotated[
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
Field(discriminator="type"),
],
name="SamplingStrategy",
)
@json_schema_type
class SamplingParams(BaseModel):
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
max_tokens: Optional[int] = 0
repetition_penalty: Optional[float] = 1.0
class CheckpointQuantizationFormat(Enum):
# default format
bf16 = "bf16"
# used for enabling fp8_rowwise inference, some weights are bf16
fp8_mixed = "fp8-mixed"
int8 = "int8"
int4 = "int4"
class ModelFamily(Enum):
llama2 = "llama2"
llama3 = "llama3"
llama3_1 = "llama3_1"
llama3_2 = "llama3_2"
llama3_3 = "llama3_3"
safety = "safety"
class CoreModelId(Enum):
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
# Llama 2 family
llama2_7b = "Llama-2-7b"
llama2_13b = "Llama-2-13b"
llama2_70b = "Llama-2-70b"
llama2_7b_chat = "Llama-2-7b-chat"
llama2_13b_chat = "Llama-2-13b-chat"
llama2_70b_chat = "Llama-2-70b-chat"
# Llama 3 family
llama3_8b = "Llama-3-8B"
llama3_70b = "Llama-3-70B"
llama3_8b_instruct = "Llama-3-8B-Instruct"
llama3_70b_instruct = "Llama-3-70B-Instruct"
# Llama 3.1 family
llama3_1_8b = "Llama3.1-8B"
llama3_1_70b = "Llama3.1-70B"
llama3_1_405b = "Llama3.1-405B"
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
# Llama 3.2 family
llama3_2_1b = "Llama3.2-1B"
llama3_2_3b = "Llama3.2-3B"
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
llama3_2_11b_vision = "Llama3.2-11B-Vision"
llama3_2_90b_vision = "Llama3.2-90B-Vision"
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
# Llama 3.3 family
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
# Safety models
llama_guard_3_8b = "Llama-Guard-3-8B"
llama_guard_2_8b = "Llama-Guard-2-8B"
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
llama_guard_3_1b = "Llama-Guard-3-1B"
def is_multimodal(model_id) -> bool:
if model_id in [
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return True
else:
return False
def model_family(model_id) -> ModelFamily:
if model_id in [
CoreModelId.llama2_7b,
CoreModelId.llama2_13b,
CoreModelId.llama2_70b,
CoreModelId.llama2_7b_chat,
CoreModelId.llama2_13b_chat,
CoreModelId.llama2_70b_chat,
]:
return ModelFamily.llama2
elif model_id in [
CoreModelId.llama3_8b,
CoreModelId.llama3_70b,
CoreModelId.llama3_8b_instruct,
CoreModelId.llama3_70b_instruct,
]:
return ModelFamily.llama3
elif model_id in [
CoreModelId.llama3_1_8b,
CoreModelId.llama3_1_70b,
CoreModelId.llama3_1_405b,
CoreModelId.llama3_1_8b_instruct,
CoreModelId.llama3_1_70b_instruct,
CoreModelId.llama3_1_405b_instruct,
]:
return ModelFamily.llama3_1
elif model_id in [
CoreModelId.llama3_2_1b,
CoreModelId.llama3_2_3b,
CoreModelId.llama3_2_1b_instruct,
CoreModelId.llama3_2_3b_instruct,
CoreModelId.llama3_2_11b_vision,
CoreModelId.llama3_2_90b_vision,
CoreModelId.llama3_2_11b_vision_instruct,
CoreModelId.llama3_2_90b_vision_instruct,
]:
return ModelFamily.llama3_2
elif model_id in [
CoreModelId.llama3_3_70b_instruct,
]:
return ModelFamily.llama3_3
elif model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_2_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return ModelFamily.safety
else:
raise ValueError(f"Unknown model family for {model_id}")
class Model(BaseModel):
core_model_id: CoreModelId
description: str
huggingface_repo: Optional[str] = None
recommended_sampling_params: Optional[SamplingParams] = None
arch_args: Dict[str, Any]
variant: str = ""
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
pth_file_count: int
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
# silence pydantic until we remove the `model_` fields
model_config = ConfigDict(protected_namespaces=())
@property
def model_family(self) -> ModelFamily:
return model_family(self.core_model_id)
# The SKU is uniquely identified by (model_id, variant) combo
def descriptor(self, shorten_default_variant: bool = True) -> str:
if not self.variant:
return self.core_model_id.value
return f"{self.core_model_id.value}:{self.variant}"
@property
def is_instruct_model(self) -> bool:
return "instruct" in self.id.name
# Featured models are shown in the non-exhaustive model list
@property
def is_featured(self) -> bool:
return self.model_family in [
ModelFamily.llama3_1,
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.safety,
]
@property
def max_seq_length(self) -> int:
if self.model_family == ModelFamily.llama2:
return 4096
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
return 4096
elif self.model_family == ModelFamily.llama3:
return 8192
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
return 131072
elif self.model_family == ModelFamily.llama3_2:
if self.quantization_format == CheckpointQuantizationFormat.int4:
return 8192
return 131072
elif self.core_model_id in [
CoreModelId.llama_guard_3_8b,
CoreModelId.llama_guard_3_11b_vision,
CoreModelId.llama_guard_3_1b,
]:
return 131072
else:
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")

Binary file not shown.

After

Width:  |  Height:  |  Size: 39 KiB

View file

@ -0,0 +1,257 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
from pathlib import Path
from typing import List, Optional
from llama_models.datatypes import (
BuiltinTool,
RawMessage,
StopReason,
ToolCall,
ToolPromptFormat,
)
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.tokenizer import Tokenizer
from termcolor import colored
from llama_stack.models.llama.datatypes import ToolDefinition
from . import template_data
from .prompt_templates import (
BuiltinToolGenerator,
FunctionTagCustomToolGenerator,
JsonCustomToolGenerator,
SystemDefaultGenerator,
ToolResponseGenerator,
)
THIS_DIR = Path(__file__).parent
class Template:
def __init__(
self,
role,
template_name,
data_provider=None,
notes=None,
):
self.role = role
self.template_name = template_name
self.data_provider = data_provider or ""
self._notes = notes or ""
@property
def notes(self):
default = "↵ represents newline"
notes = default
if self._notes:
notes += "\n"
notes += self._notes
return notes
TEMPLATES = [
Template(
"user",
"user-default",
"user_default",
),
Template(
"user",
"user-images",
"user_images",
),
Template("user", "user-interleaved-images", "user_interleaved_images"),
Template(
"assistant",
"assistant-builtin-tool-call",
"assistant_builtin_tool_call",
"Notice <|python_tag|>",
),
Template(
"assistant",
"assistant-custom-tool-call",
"assistant_custom_tool_call",
"Notice <function=...> format",
),
Template(
"assistant",
"assistant-default",
"assistant_default",
),
Template(
"system",
"system-builtin-and-custom-tools",
"system_message_builtin_and_custom_tools",
),
Template(
"system",
"system-builtin-tools-only",
"system_message_builtin_tools_only",
),
Template(
"system",
"system-custom-tools-only",
"system_message_custom_tools_only",
),
Template(
"system",
"system-default",
"system_default",
),
Template(
"tool",
"tool-success",
"tool_success",
"Note ipython header and [stdout]",
),
Template(
"tool",
"tool-failure",
"tool_failure",
"Note ipython header and [stderr]",
),
]
class LLama31Interface:
def __init__(self, tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json):
self.tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(self.tokenizer)
self.tool_prompt_format = tool_prompt_format
def get_tokens(self, messages: List[RawMessage]) -> List[int]:
model_input = self.formatter.encode_dialog_prompt(
messages,
self.tool_prompt_format,
)
return model_input.tokens
def tool_response_messages(self, *args, **kwargs):
template = ToolResponseGenerator().gen(*args, **kwargs)
return [
RawMessage(
role="tool",
content=template.render(),
)
]
def system_messages(
self,
builtin_tools: List[BuiltinTool],
custom_tools: List[ToolDefinition],
instruction: Optional[str] = None,
) -> List[RawMessage]:
messages = []
default_gen = SystemDefaultGenerator()
default_template = default_gen.gen()
sys_content = ""
tool_template = None
if builtin_tools or custom_tools:
tool_gen = BuiltinToolGenerator()
tool_template = tool_gen.gen(builtin_tools + custom_tools)
sys_content += tool_template.render()
sys_content += "\n"
sys_content += default_template.render()
if instruction:
sys_content += "\n\n"
sys_content += instruction
sys_content += "\n"
messages.append(RawMessage(role="system", content=sys_content))
if custom_tools:
if self.tool_prompt_format == ToolPromptFormat.json:
tool_gen = JsonCustomToolGenerator()
elif self.tool_prompt_format == ToolPromptFormat.function_tag:
tool_gen = FunctionTagCustomToolGenerator()
else:
raise ValueError(f"Non supported ToolPromptFormat {self.tool_prompt_format}")
custom_template = tool_gen.gen(custom_tools)
messages.append(RawMessage(role="user", content=custom_template.render()))
return messages
def assistant_response_messages(
self,
content: str,
stop_reason: StopReason,
tool_call: Optional[ToolCall] = None,
) -> List[RawMessage]:
tool_calls = []
if tool_call:
tool_calls.append(tool_call)
return [
RawMessage(
role="assistant",
content=content,
tool_calls=tool_calls,
stop_reason=stop_reason,
)
]
def user_message(self, content: str) -> List[RawMessage]:
return [RawMessage(role="user", content=content)]
def display_message_as_tokens(self, message: RawMessage) -> None:
"""Util to print tokenized string to shell"""
tokens = self.formatter.encode_message(message, self.tool_prompt_format)
on_colors = [
"on_red",
"on_green",
"on_yellow",
"on_blue",
"on_magenta",
"on_cyan",
]
for i, t in enumerate(tokens):
on_col = on_colors[i % len(on_colors)]
print(colored(self.tokenizer.decode([t]), "white", on_col), end="")
print("\n", end="")
def list_jinja_templates() -> List[Template]:
return TEMPLATES
def render_jinja_template(name: str, tool_prompt_format: ToolPromptFormat):
by_name = {t.template_name: t for t in TEMPLATES}
if name not in by_name:
raise ValueError(f"No template found for `{name}`")
template = by_name[name]
interface = LLama31Interface(tool_prompt_format)
data_func = getattr(template_data, template.data_provider)
if template.role == "system":
messages = interface.system_messages(**data_func())
elif template.role == "tool":
messages = interface.tool_response_messages(**data_func())
elif template.role == "assistant":
messages = interface.assistant_response_messages(**data_func())
elif template.role == "user":
messages = interface.user_message(**data_func())
tokens = interface.get_tokens(messages)
special_tokens = list(interface.tokenizer.special_tokens.values())
tokens = [(interface.tokenizer.decode([t]), t in special_tokens) for t in tokens]
return template, tokens

Some files were not shown because too many files have changed in this diff Show more