Merge branch 'meta-llama:main' into main

This commit is contained in:
Shrinit Goyal 2024-12-16 18:14:20 +05:30 committed by GitHub
commit 54e48d555d
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
110 changed files with 12606 additions and 747 deletions

View file

@ -38,7 +38,7 @@ Alongside these APIs, we also related APIs for operating with associated resourc
- Models - Models
- Shields - Shields
- Memory Banks - Memory Banks
- EvalTasks - Eval Tasks
- Datasets - Datasets
- Scoring Functions - Scoring Functions
@ -84,26 +84,26 @@ Additionally, we have designed every element of the Stack such that APIs as well
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | | | Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | | | AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | | | Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
| Ollama | Single Node | | :heavy_check_mark: | | | | Ollama | Single Node | | :heavy_check_mark: | | | |
| TGI | Hosted and Single Node | | :heavy_check_mark: | | | | TGI | Hosted and Single Node | | :heavy_check_mark: | | | |
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | | [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | |
| Chroma | Single Node | | | :heavy_check_mark: | | | | Chroma | Single Node | | | :heavy_check_mark: | | |
| PG Vector | Single Node | | | :heavy_check_mark: | | | | PG Vector | Single Node | | | :heavy_check_mark: | | |
| PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | | | PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | | |
| [vLLM](https://github.com/vllm-project/vllm) | | | :heavy_check_mark: | | | | [vLLM](https://github.com/vllm-project/vllm) | Hosted and Single Node | | :heavy_check_mark: | | | |
### Distributions ### Distributions
| **Distribution** | **Llama Stack Docker** | Start This Distribution | | **Distribution** | **Llama Stack Docker** | Start This Distribution |
|:----------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------:| |:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | | Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | | Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) | | Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) | | TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/tgi.html) |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) | | Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/together.html) |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) | | Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/fireworks.html) |
| [vLLM](https://github.com/vllm-project/vllm) | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) | | [vLLM](https://github.com/vllm-project/vllm) | [llamastack/distribution-remote-vllm](https://hub.docker.com/repository/docker/llamastack/distribution-remote-vllm/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/remote-vllm.html) |
## Installation ## Installation

View file

@ -249,6 +249,7 @@
"redis", "redis",
"scikit-learn", "scikit-learn",
"scipy", "scipy",
"sentence-transformers",
"sentencepiece", "sentencepiece",
"torch", "torch",
"torchvision", "torchvision",
@ -287,6 +288,7 @@
"redis", "redis",
"scikit-learn", "scikit-learn",
"scipy", "scipy",
"sentence-transformers",
"sentencepiece", "sentencepiece",
"torch", "torch",
"torchao==0.5.0", "torchao==0.5.0",

File diff suppressed because one or more lines are too long

File diff suppressed because it is too large Load diff

File diff suppressed because one or more lines are too long

View file

@ -10,3 +10,4 @@ sphinx-design
sphinxcontrib-openapi sphinxcontrib-openapi
sphinxcontrib-redoc sphinxcontrib-redoc
sphinxcontrib-mermaid sphinxcontrib-mermaid
sphinxcontrib-video

View file

@ -0,0 +1,167 @@
# Benchmark Evaluations
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing)
Llama Stack provides the building blocks needed to run benchmark and application evaluations. This guide will walk you through how to use these components to run open benchmark evaluations. Visit our [Evaluation Concepts](../concepts/evaluation_concepts.md) guide for more details on how evaluations work in Llama Stack, and our [Evaluation Reference](../references/evals_reference/index.md) guide for a comprehensive reference on the APIs. Check out our [Colab notebook](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing) on working examples on how you can use Llama Stack for running benchmark evaluations.
### 1. Open Benchmark Model Evaluation
This first example walks you through how to evaluate a model candidate served by Llama Stack on open benchmarks. We will use the following benchmark:
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI): Benchmark designed to evaluate multimodal models.
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
#### 1.1 Running MMMU
- We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in in this [Github Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API.
```python
import datasets
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
eval_rows = ds.to_pandas().to_dict(orient="records")
```
- Next, we will run evaluation on an model candidate, we will need to:
- Define a system prompt
- Define an EvalCandidate
- Run evaluate on the dataset
```python
SYSTEM_PROMPT_TEMPLATE = """
You are an expert in Agriculture whose job is to answer questions from the user using images.
First, reason about the correct answer.
Then write the answer in the following format where X is exactly one of A,B,C,D:
Answer: X
Make sure X is one of A,B,C,D.
If you are uncertain of the correct answer, guess the most likely one.
"""
system_message = {
"role": "system",
"content": SYSTEM_PROMPT_TEMPLATE,
}
client.eval_tasks.register(
eval_task_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}",
scoring_functions=["basic::regex_parser_multiple_choice_answer"]
)
response = client.eval.evaluate_rows(
task_id="meta-reference::mmmu",
input_rows=eval_rows,
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": {
"temperature": 0.0,
"max_tokens": 4096,
"top_p": 0.9,
"repeat_penalty": 1.0,
},
"system_message": system_message
}
}
)
```
#### 1.2. Running SimpleQA
- We will use a pre-processed SimpleQA dataset from [llamastack/evals](https://huggingface.co/datasets/llamastack/evals/viewer/evals__simpleqa) which is obtained by transforming the input query into correct format accepted by `inference/chat-completion` API.
- Since we will be using this same dataset in our next example for Agentic evaluation, we will register it using the `/datasets` API, and interact with it through `/datasetio` API.
```python
simpleqa_dataset_id = "huggingface::simpleqa"
_ = client.datasets.register(
dataset_id=simpleqa_dataset_id,
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/evals"},
metadata={
"path": "llamastack/evals",
"name": "evals__simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
}
)
eval_rows = client.datasetio.get_rows_paginated(
dataset_id=simpleqa_dataset_id,
rows_in_page=5,
)
```
```python
client.eval_tasks.register(
eval_task_id="meta-reference::simpleqa",
dataset_id=simpleqa_dataset_id,
scoring_functions=["llm-as-judge::405b-simpleqa"]
)
response = client.eval.evaluate_rows(
task_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": {
"temperature": 0.0,
"max_tokens": 4096,
"top_p": 0.9,
"repeat_penalty": 1.0,
},
}
}
)
```
### 2. Agentic Evaluation
- In this example, we will demonstrate how to evaluate a agent candidate served by Llama Stack via `/agent` API.
- We will continue to use the SimpleQA dataset we used in previous example.
- Instead of running evaluation on model, we will run the evaluation on a Search Agent with access to search tool. We will define our agent evaluation candidate through `AgentConfig`.
```python
agent_config = {
"model": "meta-llama/Llama-3.1-405B-Instruct",
"instructions": "You are a helpful assistant",
"sampling_params": {
"strategy": "greedy",
"temperature": 0.0,
"top_p": 0.95,
},
"tools": [
{
"type": "brave_search",
"engine": "tavily",
"api_key": userdata.get("TAVILY_SEARCH_API_KEY")
}
],
"tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [],
"output_shields": [],
"enable_session_persistence": False
}
response = client.eval.evaluate_rows(
task_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "agent",
"config": agent_config,
}
}
)
```

View file

@ -1,6 +1,8 @@
# Building AI Applications # Building AI Applications
Llama Stack provides all the building blocks needed to create sophisticated AI applications. This guide will walk you through how to use these components effectively. [![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/1F2ksmkoGQPa4pzRjMOE6BXWeOxWFIW6n?usp=sharing)
Llama Stack provides all the building blocks needed to create sophisticated AI applications. This guide will walk you through how to use these components effectively. Check out our Colab notebook on to follow along working examples on how you can build LLM-powered agentic applications using Llama Stack.
## Basic Inference ## Basic Inference

View file

@ -0,0 +1,40 @@
# Evaluation Concepts
The Llama Stack Evaluation flow allows you to run evaluations on your GenAI application datasets or pre-registered benchmarks.
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
- `/datasetio` + `/datasets` API
- `/scoring` + `/scoring_functions` API
- `/eval` + `/eval_tasks` API
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
## Evaluation Concepts
The Evaluation APIs are associated with a set of Resources as shown in the following diagram. Please visit the Resources section in our [Core Concepts](../concepts/index.md) guide for better high-level understanding.
![Eval Concepts](../references/evals_reference/resources/eval-concept.png)
- **DatasetIO**: defines interface with datasets and data loaders.
- Associated with `Dataset` resource.
- **Scoring**: evaluate outputs of the system.
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
- Associated with `EvalTask` resource.
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
![Eval Flow](../references/evals_reference/resources/eval-flow.png)
```{admonition} Note on Benchmark v.s. Application Evaluation
:class: tip
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
```
## What's Next?
- Check out our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
- Check out our [Evaluation Reference](../references/evals_reference/index.md) for more details on the APIs.

View file

@ -62,3 +62,13 @@ While there is a lot of flexibility to mix-and-match providers, often users will
**On-device Distro**: Finally, you may want to run Llama Stack directly on an edge device (mobile phone or a tablet.) We provide Distros for iOS and Android (coming soon.) **On-device Distro**: Finally, you may want to run Llama Stack directly on an edge device (mobile phone or a tablet.) We provide Distros for iOS and Android (coming soon.)
## More Concepts
- [Evaluation Concepts](evaluation_concepts.md)
```{toctree}
:maxdepth: 1
:hidden:
evaluation_concepts
```

View file

@ -29,6 +29,7 @@ extensions = [
"sphinx_design", "sphinx_design",
"sphinxcontrib.redoc", "sphinxcontrib.redoc",
"sphinxcontrib.mermaid", "sphinxcontrib.mermaid",
"sphinxcontrib.video",
] ]
myst_enable_extensions = ["colon_fence"] myst_enable_extensions = ["colon_fence"]

View file

@ -1,123 +0,0 @@
# Evaluations
The Llama Stack Evaluation flow allows you to run evaluations on your GenAI application datasets or pre-registered benchmarks.
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
- `/datasetio` + `/datasets` API
- `/scoring` + `/scoring_functions` API
- `/eval` + `/eval_tasks` API
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases.
## Evaluation Concepts
The Evaluation APIs are associated with a set of Resources as shown in the following diagram. Please visit the Resources section in our [Core Concepts](../concepts/index.md) guide for better high-level understanding.
![Eval Concepts](./resources/eval-concept.png)
- **DatasetIO**: defines interface with datasets and data loaders.
- Associated with `Dataset` resource.
- **Scoring**: evaluate outputs of the system.
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
- Associated with `EvalTask` resource.
## Running Evaluations
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
![Eval Flow](./resources/eval-flow.png)
```{admonition} Note on Benchmark v.s. Application Evaluation
:class: tip
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
```
The following examples give the quick steps to start running evaluations using the llama-stack-client CLI.
#### Benchmark Evaluation CLI
Usage: There are 2 inputs necessary for running a benchmark eval
- `eval-task-id`: the identifier associated with the eval task. Each `EvalTask` is parametrized by
- `dataset_id`: the identifier associated with the dataset.
- `List[scoring_function_id]`: list of scoring function identifiers.
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
```
llama-stack-client eval run_benchmark <eval-task-id> \
--eval-task-config ~/eval_task_config.json \
--visualize
```
#### Application Evaluation CLI
Usage: For running application evals, you will already have available datasets in hand from your application. You will need to specify:
- `scoring-fn-id`: List of ScoringFunction identifiers you wish to use to run on your application.
- `Dataset` used for evaluation:
- (1) `--dataset-path`: path to local file system containing datasets to run evaluation on
- (2) `--dataset-id`: pre-registered dataset in Llama Stack
- (Optional) `--scoring-params-config`: optionally parameterize scoring functions with custom params (e.g. `judge_prompt`, `judge_model`, `parsing_regexes`).
```
llama-stack-client eval run_scoring <scoring_fn_id_1> <scoring_fn_id_2> ... <scoring_fn_id_n>
--dataset-path <path-to-local-dataset> \
--output-dir ./
```
#### Defining EvalTaskConfig
The `EvalTaskConfig` are user specified config to define:
1. `EvalCandidate` to run generation on:
- `ModelCandidate`: The model will be used for generation through LlamaStack /inference API.
- `AgentCandidate`: The agentic system specified by AgentConfig will be used for generation through LlamaStack /agents API.
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
**Example Benchmark EvalTaskConfig**
```json
{
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "Llama3.2-3B-Instruct",
"sampling_params": {
"strategy": "greedy",
"temperature": 0,
"top_p": 0.95,
"top_k": 0,
"max_tokens": 0,
"repetition_penalty": 1.0
}
}
}
```
**Example Application EvalTaskConfig**
```json
{
"type": "app",
"eval_candidate": {
"type": "model",
"model": "Llama3.1-405B-Instruct",
"sampling_params": {
"strategy": "greedy",
"temperature": 0,
"top_p": 0.95,
"top_k": 0,
"max_tokens": 0,
"repetition_penalty": 1.0
}
},
"scoring_params": {
"llm-as-judge::llm_as_judge_base": {
"type": "llm_as_judge",
"judge_model": "meta-llama/Llama-3.1-8B-Instruct",
"prompt_template": "Your job is to look at a question, a gold target ........",
"judge_score_regexes": [
"(A|B|C)"
]
}
}
}
```

View file

@ -1,9 +0,0 @@
# Cookbooks
- [Evaluations Flow](evals.md)
```{toctree}
:maxdepth: 2
:hidden:
evals.md
```

View file

@ -8,12 +8,14 @@ Features:
- Remote Inferencing: Perform inferencing tasks remotely with Llama models hosted on a remote connection (or serverless localhost). - Remote Inferencing: Perform inferencing tasks remotely with Llama models hosted on a remote connection (or serverless localhost).
- Simple Integration: With easy-to-use APIs, a developer can quickly integrate Llama Stack in their Android app. The difference with local vs remote inferencing is also minimal. - Simple Integration: With easy-to-use APIs, a developer can quickly integrate Llama Stack in their Android app. The difference with local vs remote inferencing is also minimal.
Latest Release Notes: [v0.0.54.1](https://github.com/meta-llama/llama-stack-client-kotlin/releases/tag/v0.0.54.1) Latest Release Notes: [v0.0.58](https://github.com/meta-llama/llama-stack-client-kotlin/releases/tag/v0.0.58)
*Tagged releases are stable versions of the project. While we strive to maintain a stable main branch, it's not guaranteed to be free of bugs or issues.*
## Android Demo App ## Android Demo App
Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app) Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-apps/tree/android-kotlin-app-latest/examples/android_app)
The key files in the app are `LlamaStackLocalInference.kt`, `LlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments. The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments.
## Quick Start ## Quick Start
@ -22,7 +24,7 @@ The key files in the app are `LlamaStackLocalInference.kt`, `LlamaStackRemoteInf
Add the following dependency in your `build.gradle.kts` file: Add the following dependency in your `build.gradle.kts` file:
``` ```
dependencies { dependencies {
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.0.54.1") implementation("com.llama.llamastack:llama-stack-client-kotlin:0.0.58")
} }
``` ```
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/` This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
@ -34,10 +36,10 @@ If you plan on doing remote inferencing this is sufficient to get started.
For local inferencing, it is required to include the ExecuTorch library into your app. For local inferencing, it is required to include the ExecuTorch library into your app.
Include the ExecuTorch library by: Include the ExecuTorch library by:
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/blob/release/0.0.54.1/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine. 1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/blob/release/0.0.58/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
2. Move the script to the top level of your Android app where the app directory resides: 2. Move the script to the top level of your Android app where the app directory resides:
<p align="center"> <p align="center">
<img src="https://raw.githubusercontent.com/meta-llama/llama-stack-client-kotlin/refs/heads/release/0.0.54.1/doc/img/example_android_app_directory.png" style="width:300px"> <img src="https://raw.githubusercontent.com/meta-llama/llama-stack-client-kotlin/refs/heads/release/0.0.58/doc/img/example_android_app_directory.png" style="width:300px">
</p> </p>
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate with commit: [0a12e33](https://github.com/pytorch/executorch/commit/0a12e33d22a3d44d1aa2af5f0d0673d45b962553). 3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate with commit: [0a12e33](https://github.com/pytorch/executorch/commit/0a12e33d22a3d44d1aa2af5f0d0673d45b962553).
@ -58,12 +60,14 @@ Start a Llama Stack server on localhost. Here is an example of how you can do th
``` ```
conda create -n stack-fireworks python=3.10 conda create -n stack-fireworks python=3.10
conda activate stack-fireworks conda activate stack-fireworks
pip install llama-stack=0.0.54 pip install llama-stack=0.0.58
llama stack build --template fireworks --image-type conda llama stack build --template fireworks --image-type conda
export FIREWORKS_API_KEY=<SOME_KEY> export FIREWORKS_API_KEY=<SOME_KEY>
llama stack run /Users/<your_username>/.llama/distributions/llamastack-fireworks/fireworks-run.yaml --port=5050 llama stack run /Users/<your_username>/.llama/distributions/llamastack-fireworks/fireworks-run.yaml --port=5050
``` ```
Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility.
Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations) Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations)
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#settings) How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#settings)
@ -109,7 +113,6 @@ With the Kotlin Library managing all the major operational logic, there are mini
val result = client!!.inference().chatCompletion( val result = client!!.inference().chatCompletion(
InferenceChatCompletionParams.builder() InferenceChatCompletionParams.builder()
.modelId(modelName) .modelId(modelName)
.putAdditionalQueryParam("seq_len", sequenceLength.toString())
.messages(listOfMessages) .messages(listOfMessages)
.build() .build()
) )
@ -118,9 +121,23 @@ val result = client!!.inference().chatCompletion(
var response = result.asChatCompletionResponse().completionMessage().content().string(); var response = result.asChatCompletionResponse().completionMessage().content().string();
``` ```
### Setup Tool Calling [Remote only] For inference with a streaming response:
Android demo app for more details: [Tool Calling](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#tool-calling) ```
val result = client!!.inference().chatCompletionStreaming(
InferenceChatCompletionParams.builder()
.modelId(modelName)
.messages(listOfMessages)
.build()
)
// Response can be received as a asChatCompletionResponseStreamChunk as part of a callback.
// See Android demo app for a detailed implementation example.
```
### Setup Custom Tool Calling
Android demo app for more details: [Custom Tool Calling](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#tool-calling)
## Advanced Users ## Advanced Users
@ -129,7 +146,7 @@ The purpose of this section is to share more details with users that would like
### Prerequisite ### Prerequisite
You must complete the following steps: You must complete the following steps:
1. Clone the repo (`git clone https://github.com/meta-llama/llama-stack-client-kotlin.git -b release/0.0.54.1`) 1. Clone the repo (`git clone https://github.com/meta-llama/llama-stack-client-kotlin.git -b release/0.0.58`)
2. Port the appropriate ExecuTorch libraries over into your Llama Stack Kotlin library environment. 2. Port the appropriate ExecuTorch libraries over into your Llama Stack Kotlin library environment.
``` ```
cd llama-stack-client-kotlin-client-local cd llama-stack-client-kotlin-client-local

View file

@ -102,7 +102,7 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
export LLAMA_STACK_PORT=5001 export LLAMA_STACK_PORT=5001
llama stack build --template ollama --image-type conda llama stack build --template ollama --image-type conda
llama stack run ./run.yaml \ llama stack run ./distributions/ollama/run.yaml \
--port $LLAMA_STACK_PORT \ --port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \ --env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://localhost:11434 --env OLLAMA_URL=http://localhost:11434

View file

@ -59,7 +59,8 @@ getting_started/index
concepts/index concepts/index
distributions/index distributions/index
building_applications/index building_applications/index
benchmark_evaluations/index
playground/index
contributing/index contributing/index
references/index references/index
cookbooks/index
``` ```

View file

@ -0,0 +1,109 @@
# Llama Stack Playground
```{note}
The Llama Stack Playground is currently experimental and subject to change. We welcome feedback and contributions to help improve it.
```
The Llama Stack Playground is an simple interface which aims to:
- Showcase **capabilities** and **concepts** of Llama Stack in an interactive environment
- Demo **end-to-end** application code to help users get started to build their own applications
- Provide an **UI** to help users inspect and understand Llama Stack API providers and resources
## Key Features
#### Playground
Interactive pages for users to play with and explore Llama Stack API capabilities.
##### Chatbot
```{eval-rst}
.. video:: https://github.com/user-attachments/assets/6ca617e8-32ca-49b2-9774-185020ff5204
:autoplay:
:playsinline:
:muted:
:loop:
:width: 100%
```
- **Chat**: Chat with Llama models.
- This page is a simple chatbot that allows you to chat with Llama models. Under the hood, it uses the `/inference/chat-completion` streaming API to send messages to the model and receive responses.
- **RAG**: Uploading documents to memory_banks and chat with RAG agent
- This page allows you to upload documents as a `memory_bank` and then chat with a RAG agent to query information about the uploaded documents.
- Under the hood, it uses Llama Stack's `/agents` API to define and create a RAG agent and chat with it in a session.
##### Evaluations
```{eval-rst}
.. video:: https://github.com/user-attachments/assets/6cc1659f-eba4-49ca-a0a5-7c243557b4f5
:autoplay:
:playsinline:
:muted:
:loop:
:width: 100%
```
- **Evaluations (Scoring)**: Run evaluations on your AI application datasets.
- This page demonstrates the flow evaluation API to run evaluations on your custom AI application datasets. You may upload your own evaluation datasets and run evaluations using available scoring functions.
- Under the hood, it uses Llama Stack's `/scoring` API to run evaluations on selected scoring functions.
```{eval-rst}
.. video:: https://github.com/user-attachments/assets/345845c7-2a2b-4095-960a-9ae40f6a93cf
:autoplay:
:playsinline:
:muted:
:loop:
:width: 100%
```
- **Evaluations (Generation + Scoring)**: Use pre-registered evaluation tasks to evaluate an model or agent candidate
- This page demonstrates the flow for evaluation API to evaluate an model or agent candidate on pre-defined evaluation tasks. An evaluation task is a combination of dataset and scoring functions.
- Under the hood, it uses Llama Stack's `/eval` API to run generations and scorings on specified evaluation configs.
- In order to run this page, you may need to register evaluation tasks and datasets as resources first through the following commands.
```bash
$ llama-stack-client datasets register \
--dataset-id "mmlu" \
--provider-id "huggingface" \
--url "https://huggingface.co/datasets/llamastack/evals" \
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string"}, "chat_completion_input": {"type": "string"}}'
```
```bash
$ llama-stack-client eval_tasks register \
--eval-task-id meta-reference-mmlu \
--provider-id meta-reference \
--dataset-id mmlu \
--scoring-functions basic::regex_parser_multiple_choice_answer
```
##### Inspect
```{eval-rst}
.. video:: https://github.com/user-attachments/assets/01d52b2d-92af-4e3a-b623-a9b8ba22ba99
:autoplay:
:playsinline:
:muted:
:loop:
:width: 100%
```
- **API Providers**: Inspect Llama Stack API providers
- This page allows you to inspect Llama Stack API providers and resources.
- Under the hood, it uses Llama Stack's `/providers` API to get information about the providers.
- **API Resources**: Inspect Llama Stack API resources
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `eval_tasks`, `shields`).
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.
## Starting the Llama Stack Playground
To start the Llama Stack Playground, run the following commands:
1. Start up the Llama Stack API server
```bash
llama stack build --template together --image-type conda
llama stack run together
```
2. Start Streamlit UI
```bash
cd llama_stack/distribution/ui
pip install -r requirements.txt
streamlit run app.py
```

View file

@ -0,0 +1,359 @@
# Evaluations
The Llama Stack Evaluation flow allows you to run evaluations on your GenAI application datasets or pre-registered benchmarks.
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
- `/datasetio` + `/datasets` API
- `/scoring` + `/scoring_functions` API
- `/eval` + `/eval_tasks` API
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
## Evaluation Concepts
The Evaluation APIs are associated with a set of Resources as shown in the following diagram. Please visit the Resources section in our [Core Concepts](../concepts/index.md) guide for better high-level understanding.
![Eval Concepts](./resources/eval-concept.png)
- **DatasetIO**: defines interface with datasets and data loaders.
- Associated with `Dataset` resource.
- **Scoring**: evaluate outputs of the system.
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
- Associated with `EvalTask` resource.
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
![Eval Flow](./resources/eval-flow.png)
```{admonition} Note on Benchmark v.s. Application Evaluation
:class: tip
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
```
## Evaluation Examples Walkthrough
[![Open In Colab](https://colab.research.google.com/assets/colab-badge.svg)](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing)
It is best to open this notebook in Colab to follow along with the examples.
### 1. Open Benchmark Model Evaluation
This first example walks you through how to evaluate a model candidate served by Llama Stack on open benchmarks. We will use the following benchmark:
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
#### 1.1 Running MMMU
- We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in in this [Github Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API.
```python
import datasets
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
eval_rows = ds.to_pandas().to_dict(orient="records")
```
- Next, we will run evaluation on an model candidate, we will need to:
- Define a system prompt
- Define an EvalCandidate
- Run evaluate on the dataset
```python
SYSTEM_PROMPT_TEMPLATE = """
You are an expert in Agriculture whose job is to answer questions from the user using images.
First, reason about the correct answer.
Then write the answer in the following format where X is exactly one of A,B,C,D:
Answer: X
Make sure X is one of A,B,C,D.
If you are uncertain of the correct answer, guess the most likely one.
"""
system_message = {
"role": "system",
"content": SYSTEM_PROMPT_TEMPLATE,
}
client.eval_tasks.register(
eval_task_id="meta-reference::mmmu",
dataset_id=f"mmmu-{subset}-{split}",
scoring_functions=["basic::regex_parser_multiple_choice_answer"]
)
response = client.eval.evaluate_rows(
task_id="meta-reference::mmmu",
input_rows=eval_rows,
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": {
"temperature": 0.0,
"max_tokens": 4096,
"top_p": 0.9,
"repeat_penalty": 1.0,
},
"system_message": system_message
}
}
)
```
#### 1.2. Running SimpleQA
- We will use a pre-processed SimpleQA dataset from [llamastack/evals](https://huggingface.co/datasets/llamastack/evals/viewer/evals__simpleqa) which is obtained by transforming the input query into correct format accepted by `inference/chat-completion` API.
- Since we will be using this same dataset in our next example for Agentic evaluation, we will register it using the `/datasets` API, and interact with it through `/datasetio` API.
```python
simpleqa_dataset_id = "huggingface::simpleqa"
_ = client.datasets.register(
dataset_id=simpleqa_dataset_id,
provider_id="huggingface",
url={"uri": "https://huggingface.co/datasets/llamastack/evals"},
metadata={
"path": "llamastack/evals",
"name": "evals__simpleqa",
"split": "train",
},
dataset_schema={
"input_query": {"type": "string"},
"expected_answer": {"type": "string"},
"chat_completion_input": {"type": "chat_completion_input"},
}
)
eval_rows = client.datasetio.get_rows_paginated(
dataset_id=simpleqa_dataset_id,
rows_in_page=5,
)
```
```python
client.eval_tasks.register(
eval_task_id="meta-reference::simpleqa",
dataset_id=simpleqa_dataset_id,
scoring_functions=["llm-as-judge::405b-simpleqa"]
)
response = client.eval.evaluate_rows(
task_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"sampling_params": {
"temperature": 0.0,
"max_tokens": 4096,
"top_p": 0.9,
"repeat_penalty": 1.0,
},
}
}
)
```
### 2. Agentic Evaluation
- In this example, we will demonstrate how to evaluate a agent candidate served by Llama Stack via `/agent` API.
- We will continue to use the SimpleQA dataset we used in previous example.
- Instead of running evaluation on model, we will run the evaluation on a Search Agent with access to search tool. We will define our agent evaluation candidate through `AgentConfig`.
```python
agent_config = {
"model": "meta-llama/Llama-3.1-405B-Instruct",
"instructions": "You are a helpful assistant",
"sampling_params": {
"strategy": "greedy",
"temperature": 0.0,
"top_p": 0.95,
},
"tools": [
{
"type": "brave_search",
"engine": "tavily",
"api_key": userdata.get("TAVILY_SEARCH_API_KEY")
}
],
"tool_choice": "auto",
"tool_prompt_format": "json",
"input_shields": [],
"output_shields": [],
"enable_session_persistence": False
}
response = client.eval.evaluate_rows(
task_id="meta-reference::simpleqa",
input_rows=eval_rows.rows,
scoring_functions=["llm-as-judge::405b-simpleqa"],
task_config={
"type": "benchmark",
"eval_candidate": {
"type": "agent",
"config": agent_config,
}
}
)
```
### 3. Agentic Application Dataset Scoring
- Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
- In this example, we will work with an example RAG dataset and couple of scoring functions for evaluation.
- `llm-as-judge::base`: LLM-As-Judge with custom judge prompt & model.
- `braintrust::factuality`: Factuality scorer from [braintrust](https://github.com/braintrustdata/autoevals).
- `basic::subset_of`: Basic checking if generated answer is a subset of expected answer.
- Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings.
```python
judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8"
JUDGE_PROMPT = """
Given a QUESTION and GENERATED_RESPONSE and EXPECTED_RESPONSE.
Compare the factual content of the GENERATED_RESPONSE with the EXPECTED_RESPONSE. Ignore any differences in style, grammar, or punctuation.
The GENERATED_RESPONSE may either be a subset or superset of the EXPECTED_RESPONSE, or it may conflict with it. Determine which case applies. Answer the question by selecting one of the following options:
(A) The GENERATED_RESPONSE is a subset of the EXPECTED_RESPONSE and is fully consistent with it.
(B) The GENERATED_RESPONSE is a superset of the EXPECTED_RESPONSE and is fully consistent with it.
(C) The GENERATED_RESPONSE contains all the same details as the EXPECTED_RESPONSE.
(D) There is a disagreement between the GENERATED_RESPONSE and the EXPECTED_RESPONSE.
(E) The answers differ, but these differences don't matter from the perspective of factuality.
Give your answer in the format "Answer: One of ABCDE, Explanation: ".
Your actual task:
QUESTION: {input_query}
GENERATED_RESPONSE: {generated_answer}
EXPECTED_RESPONSE: {expected_answer}
"""
input_query = "What are the top 5 topics that were explained? Only list succinct bullet points."
generated_answer = """
Here are the top 5 topics that were explained in the documentation for Torchtune:
* What is LoRA and how does it work?
* Fine-tuning with LoRA: memory savings and parameter-efficient finetuning
* Running a LoRA finetune with Torchtune: overview and recipe
* Experimenting with different LoRA configurations: rank, alpha, and attention modules
* LoRA finetuning
"""
expected_answer = """LoRA"""
dataset_rows = [
{
"input_query": input_query,
"generated_answer": generated_answer,
"expected_answer": expected_answer,
},
]
scoring_params = {
"llm-as-judge::base": {
"judge_model": judge_model_id,
"prompt_template": JUDGE_PROMPT,
"type": "llm_as_judge",
"judge_score_regexes": ["Answer: (A|B|C|D|E)"],
},
"basic::subset_of": None,
"braintrust::factuality": None,
}
response = client.scoring.score(input_rows=dataset_rows, scoring_functions=scoring_params)
```
## Running Evaluations via CLI
The following examples give the quick steps to start running evaluations using the llama-stack-client CLI.
#### Benchmark Evaluation CLI
Usage: There are 2 inputs necessary for running a benchmark eval
- `eval-task-id`: the identifier associated with the eval task. Each `EvalTask` is parametrized by
- `dataset_id`: the identifier associated with the dataset.
- `List[scoring_function_id]`: list of scoring function identifiers.
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
```
llama-stack-client eval run_benchmark <eval-task-id> \
--eval-task-config ~/eval_task_config.json \
--visualize
```
#### Application Evaluation CLI
Usage: For running application evals, you will already have available datasets in hand from your application. You will need to specify:
- `scoring-fn-id`: List of ScoringFunction identifiers you wish to use to run on your application.
- `Dataset` used for evaluation:
- (1) `--dataset-path`: path to local file system containing datasets to run evaluation on
- (2) `--dataset-id`: pre-registered dataset in Llama Stack
- (Optional) `--scoring-params-config`: optionally parameterize scoring functions with custom params (e.g. `judge_prompt`, `judge_model`, `parsing_regexes`).
```
llama-stack-client eval run_scoring <scoring_fn_id_1> <scoring_fn_id_2> ... <scoring_fn_id_n>
--dataset-path <path-to-local-dataset> \
--output-dir ./
```
#### Defining EvalTaskConfig
The `EvalTaskConfig` are user specified config to define:
1. `EvalCandidate` to run generation on:
- `ModelCandidate`: The model will be used for generation through LlamaStack /inference API.
- `AgentCandidate`: The agentic system specified by AgentConfig will be used for generation through LlamaStack /agents API.
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
**Example Benchmark EvalTaskConfig**
```json
{
"type": "benchmark",
"eval_candidate": {
"type": "model",
"model": "Llama3.2-3B-Instruct",
"sampling_params": {
"strategy": "greedy",
"temperature": 0,
"top_p": 0.95,
"top_k": 0,
"max_tokens": 0,
"repetition_penalty": 1.0
}
}
}
```
**Example Application EvalTaskConfig**
```json
{
"type": "app",
"eval_candidate": {
"type": "model",
"model": "Llama3.1-405B-Instruct",
"sampling_params": {
"strategy": "greedy",
"temperature": 0,
"top_p": 0.95,
"top_k": 0,
"max_tokens": 0,
"repetition_penalty": 1.0
}
},
"scoring_params": {
"llm-as-judge::llm_as_judge_base": {
"type": "llm_as_judge",
"judge_model": "meta-llama/Llama-3.1-8B-Instruct",
"prompt_template": "Your job is to look at a question, a gold target ........",
"judge_score_regexes": [
"(A|B|C)"
]
}
}
}
```

View file

Before

Width:  |  Height:  |  Size: 68 KiB

After

Width:  |  Height:  |  Size: 68 KiB

Before After
Before After

View file

Before

Width:  |  Height:  |  Size: 249 KiB

After

Width:  |  Height:  |  Size: 249 KiB

Before After
Before After

View file

@ -14,4 +14,5 @@ python_sdk_reference/index
llama_cli_reference/index llama_cli_reference/index
llama_stack_client_cli_reference llama_stack_client_cli_reference
llama_cli_reference/download_models llama_cli_reference/download_models
evals_reference/index
``` ```

View file

@ -3,5 +3,8 @@
# #
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
#
# from .distribution.library_client import LlamaStackAsLibraryClient, AsyncLlamaStackAsLibraryClient from llama_stack.distribution.library_client import ( # noqa: F401
AsyncLlamaStackAsLibraryClient,
LlamaStackAsLibraryClient,
)

View file

@ -18,3 +18,5 @@ class Job(BaseModel):
class JobStatus(Enum): class JobStatus(Enum):
completed = "completed" completed = "completed"
in_progress = "in_progress" in_progress = "in_progress"
failed = "failed"
scheduled = "scheduled"

View file

@ -4,13 +4,26 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_models.llama3.api.datatypes import URL from datetime import datetime
from typing import Optional
from llama_models.schema_utils import json_schema_type from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel from pydantic import BaseModel
@json_schema_type
class PostTrainingMetric(BaseModel):
epoch: int
train_loss: float
validation_loss: float
perplexity: float
@json_schema_type(schema={"description": "Checkpoint created during training runs"}) @json_schema_type(schema={"description": "Checkpoint created during training runs"})
class Checkpoint(BaseModel): class Checkpoint(BaseModel):
iters: int identifier: str
path: URL created_at: datetime
epoch: int epoch: int
post_training_job_id: str
path: str
training_metrics: Optional[PostTrainingMetric] = None

View file

@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str embedding_model: str
chunk_size_in_tokens: int chunk_size_in_tokens: int
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
overlap_size_in_tokens: Optional[int] = None overlap_size_in_tokens: Optional[int] = None

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
@ -20,6 +21,12 @@ class CommonModelFields(BaseModel):
) )
@json_schema_type
class ModelType(str, Enum):
llm = "llm"
embedding = "embedding"
@json_schema_type @json_schema_type
class Model(CommonModelFields, Resource): class Model(CommonModelFields, Resource):
type: Literal[ResourceType.model.value] = ResourceType.model.value type: Literal[ResourceType.model.value] = ResourceType.model.value
@ -34,12 +41,14 @@ class Model(CommonModelFields, Resource):
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
model_type: ModelType = Field(default=ModelType.llm)
class ModelInput(CommonModelFields): class ModelInput(CommonModelFields):
model_id: str model_id: str
provider_id: Optional[str] = None provider_id: Optional[str] = None
provider_model_id: Optional[str] = None provider_model_id: Optional[str] = None
model_type: Optional[ModelType] = ModelType.llm
model_config = ConfigDict(protected_namespaces=()) model_config = ConfigDict(protected_namespaces=())
@ -59,6 +68,7 @@ class Models(Protocol):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ... ) -> Model: ...
@webmethod(route="/models/unregister", method="POST") @webmethod(route="/models/unregister", method="POST")

View file

@ -7,68 +7,85 @@
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from typing import Any, Dict, List, Optional, Protocol from typing import Any, Dict, List, Optional, Protocol, Union
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.common.job_types import JobStatus
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.training_types import * # noqa: F403 from llama_stack.apis.common.training_types import * # noqa: F403
@json_schema_type
class OptimizerType(Enum): class OptimizerType(Enum):
adam = "adam" adam = "adam"
adamw = "adamw" adamw = "adamw"
sgd = "sgd" sgd = "sgd"
@json_schema_type
class DataConfig(BaseModel):
dataset_id: str
batch_size: int
shuffle: bool
validation_dataset_id: Optional[str] = None
packed: Optional[bool] = False
train_on_input: Optional[bool] = False
@json_schema_type @json_schema_type
class OptimizerConfig(BaseModel): class OptimizerConfig(BaseModel):
optimizer_type: OptimizerType optimizer_type: OptimizerType
lr: float lr: float
lr_min: float
weight_decay: float weight_decay: float
num_warmup_steps: int
@json_schema_type
class EfficiencyConfig(BaseModel):
enable_activation_checkpointing: Optional[bool] = False
enable_activation_offloading: Optional[bool] = False
memory_efficient_fsdp_wrap: Optional[bool] = False
fsdp_cpu_offload: Optional[bool] = False
@json_schema_type @json_schema_type
class TrainingConfig(BaseModel): class TrainingConfig(BaseModel):
n_epochs: int n_epochs: int
batch_size: int max_steps_per_epoch: int
shuffle: bool gradient_accumulation_steps: int
n_iters: int data_config: DataConfig
optimizer_config: OptimizerConfig
enable_activation_checkpointing: bool efficiency_config: Optional[EfficiencyConfig] = None
memory_efficient_fsdp_wrap: bool dtype: Optional[str] = "bf16"
fsdp_cpu_offload: bool
@json_schema_type
class FinetuningAlgorithm(Enum):
full = "full"
lora = "lora"
qlora = "qlora"
dora = "dora"
@json_schema_type @json_schema_type
class LoraFinetuningConfig(BaseModel): class LoraFinetuningConfig(BaseModel):
type: Literal["LoRA"] = "LoRA"
lora_attn_modules: List[str] lora_attn_modules: List[str]
apply_lora_to_mlp: bool apply_lora_to_mlp: bool
apply_lora_to_output: bool apply_lora_to_output: bool
rank: int rank: int
alpha: int alpha: int
use_dora: Optional[bool] = False
quantize_base: Optional[bool] = False
@json_schema_type @json_schema_type
class QLoraFinetuningConfig(LoraFinetuningConfig): class QATFinetuningConfig(BaseModel):
pass type: Literal["QAT"] = "QAT"
quantizer_name: str
group_size: int
@json_schema_type AlgorithmConfig = Annotated[
class DoraFinetuningConfig(LoraFinetuningConfig): Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
pass ]
@json_schema_type @json_schema_type
@ -79,14 +96,6 @@ class PostTrainingJobLogStream(BaseModel):
log_lines: List[str] log_lines: List[str]
@json_schema_type
class PostTrainingJobStatus(Enum):
running = "running"
completed = "completed"
failed = "failed"
scheduled = "scheduled"
@json_schema_type @json_schema_type
class RLHFAlgorithm(Enum): class RLHFAlgorithm(Enum):
dpo = "dpo" dpo = "dpo"
@ -100,29 +109,6 @@ class DPOAlignmentConfig(BaseModel):
gamma: float gamma: float
@json_schema_type
class PostTrainingSFTRequest(BaseModel):
"""Request to finetune a model."""
job_uuid: str
model: str
dataset_id: str
validation_dataset_id: str
algorithm: FinetuningAlgorithm
algorithm_config: Union[
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
]
optimizer_config: OptimizerConfig
training_config: TrainingConfig
# TODO: define these
hyperparam_search_config: Dict[str, Any]
logger_config: Dict[str, Any]
@json_schema_type @json_schema_type
class PostTrainingRLHFRequest(BaseModel): class PostTrainingRLHFRequest(BaseModel):
"""Request to finetune a model.""" """Request to finetune a model."""
@ -135,7 +121,7 @@ class PostTrainingRLHFRequest(BaseModel):
validation_dataset_id: str validation_dataset_id: str
algorithm: RLHFAlgorithm algorithm: RLHFAlgorithm
algorithm_config: Union[DPOAlignmentConfig] algorithm_config: DPOAlignmentConfig
optimizer_config: OptimizerConfig optimizer_config: OptimizerConfig
training_config: TrainingConfig training_config: TrainingConfig
@ -154,7 +140,7 @@ class PostTrainingJobStatusResponse(BaseModel):
"""Status of a finetuning job.""" """Status of a finetuning job."""
job_uuid: str job_uuid: str
status: PostTrainingJobStatus status: JobStatus
scheduled_at: Optional[datetime] = None scheduled_at: Optional[datetime] = None
started_at: Optional[datetime] = None started_at: Optional[datetime] = None
@ -176,54 +162,44 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol): class PostTraining(Protocol):
@webmethod(route="/post-training/supervised-fine-tune") @webmethod(route="/post-training/supervised-fine-tune", method="POST")
def supervised_fine_tune( async def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
model: str, training_config: TrainingConfig,
dataset_id: str, hyperparam_search_config: Dict[str, Any],
validation_dataset_id: str, logger_config: Dict[str, Any],
algorithm: FinetuningAlgorithm, model: str = Field(
algorithm_config: Union[ default="Llama3.2-3B-Instruct",
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig description="Model descriptor from `llama model list`",
], ),
optimizer_config: OptimizerConfig, checkpoint_dir: Optional[str] = None,
algorithm_config: Optional[AlgorithmConfig] = None,
) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize", method="POST")
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig, training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any], hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post-training/preference-optimize") @webmethod(route="/post-training/jobs", method="GET")
def preference_optimize( async def get_training_jobs(self) -> List[PostTrainingJob]: ...
self,
job_uuid: str,
finetuned_model: URL,
dataset_id: str,
validation_dataset_id: str,
algorithm: RLHFAlgorithm,
algorithm_config: Union[DPOAlignmentConfig],
optimizer_config: OptimizerConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
@webmethod(route="/post-training/jobs") @webmethod(route="/post-training/job/status", method="GET")
def get_training_jobs(self) -> List[PostTrainingJob]: ... async def get_training_job_status(
# sends SSE stream of logs
@webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
@webmethod(route="/post-training/job/status")
def get_training_job_status(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobStatusResponse: ... ) -> Optional[PostTrainingJobStatusResponse]: ...
@webmethod(route="/post-training/job/cancel") @webmethod(route="/post-training/job/cancel", method="POST")
def cancel_training_job(self, job_uuid: str) -> None: ... async def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post-training/job/artifacts") @webmethod(route="/post-training/job/artifacts", method="GET")
def get_training_job_artifacts( async def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ... ) -> Optional[PostTrainingJobArtifactsResponse]: ...

View file

@ -24,6 +24,7 @@ from llama_stack.apis.inspect import Inspect
from llama_stack.apis.memory import Memory from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models from llama_stack.apis.models import Models
from llama_stack.apis.post_training import PostTraining
from llama_stack.apis.safety import Safety from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions from llama_stack.apis.scoring_functions import ScoringFunctions
@ -58,6 +59,7 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.scoring_functions: ScoringFunctions, Api.scoring_functions: ScoringFunctions,
Api.eval: Eval, Api.eval: Eval,
Api.eval_tasks: EvalTasks, Api.eval_tasks: EvalTasks,
Api.post_training: PostTraining,
} }

View file

@ -88,9 +88,10 @@ class InferenceRouter(Inference):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> None: ) -> None:
await self.routing_table.register_model( await self.routing_table.register_model(
model_id, provider_model_id, provider_id, metadata model_id, provider_model_id, provider_id, metadata, model_type
) )
async def chat_completion( async def chat_completion(
@ -105,6 +106,13 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
messages=messages, messages=messages,
@ -131,6 +139,13 @@ class InferenceRouter(Inference):
stream: Optional[bool] = False, stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator: ) -> AsyncGenerator:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.embedding:
raise ValueError(
f"Model '{model_id}' is an embedding model and does not support chat completions"
)
provider = self.routing_table.get_provider_impl(model_id) provider = self.routing_table.get_provider_impl(model_id)
params = dict( params = dict(
model_id=model_id, model_id=model_id,
@ -150,6 +165,13 @@ class InferenceRouter(Inference):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
model = await self.routing_table.get_model(model_id)
if model is None:
raise ValueError(f"Model '{model_id}' not found")
if model.model_type == ModelType.llm:
raise ValueError(
f"Model '{model_id}' is an LLM model and does not support embeddings"
)
return await self.routing_table.get_provider_impl(model_id).embeddings( return await self.routing_table.get_provider_impl(model_id).embeddings(
model_id=model_id, model_id=model_id,
contents=contents, contents=contents,

View file

@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
provider_model_id: Optional[str] = None, provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None, provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None, metadata: Optional[Dict[str, Any]] = None,
model_type: Optional[ModelType] = None,
) -> Model: ) -> Model:
if provider_model_id is None: if provider_model_id is None:
provider_model_id = model_id provider_model_id = model_id
@ -222,11 +223,18 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
) )
if metadata is None: if metadata is None:
metadata = {} metadata = {}
if model_type is None:
model_type = ModelType.llm
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
raise ValueError(
"Embedding model must have an embedding dimension in its metadata"
)
model = Model( model = Model(
identifier=model_id, identifier=model_id,
provider_resource_id=provider_model_id, provider_resource_id=provider_model_id,
provider_id=provider_id, provider_id=provider_id,
metadata=metadata, metadata=metadata,
model_type=model_type,
) )
registered_model = await self.register_object(model) registered_model = await self.register_object(model)
return registered_model return registered_model
@ -298,16 +306,36 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
raise ValueError( raise ValueError(
"No provider specified and multiple providers available. Please specify a provider_id." "No provider specified and multiple providers available. Please specify a provider_id."
) )
memory_bank = parse_obj_as( model = await self.get_object_by_identifier("model", params.embedding_model)
MemoryBank, if model is None:
{ if params.embedding_model == "all-MiniLM-L6-v2":
"identifier": memory_bank_id, raise ValueError(
"type": ResourceType.memory_bank.value, "Embeddings are now served via Inference providers. "
"provider_id": provider_id, "Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
"provider_resource_id": provider_memory_bank_id, "See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
**params.model_dump(), )
}, else:
) raise ValueError(f"Model {params.embedding_model} not found")
if model.model_type != ModelType.embedding:
raise ValueError(
f"Model {params.embedding_model} is not an embedding model"
)
if "embedding_dimension" not in model.metadata:
raise ValueError(
f"Model {params.embedding_model} does not have an embedding dimension"
)
memory_bank_data = {
"identifier": memory_bank_id,
"type": ResourceType.memory_bank.value,
"provider_id": provider_id,
"provider_resource_id": provider_memory_bank_id,
**params.model_dump(),
}
if params.memory_bank_type == MemoryBankType.vector.value:
memory_bank_data["embedding_dimension"] = model.metadata[
"embedding_dimension"
]
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
await self.register_object(memory_bank) await self.register_object(memory_bank)
return memory_bank return memory_bank

View file

@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
REGISTER_PREFIX = "distributions:registry" REGISTER_PREFIX = "distributions:registry"
KEY_VERSION = "v2" KEY_VERSION = "v3"
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}" KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"

View file

@ -28,6 +28,7 @@ class Api(Enum):
datasetio = "datasetio" datasetio = "datasetio"
scoring = "scoring" scoring = "scoring"
eval = "eval" eval = "eval"
post_training = "post_training"
telemetry = "telemetry" telemetry = "telemetry"
@ -200,10 +201,13 @@ API responses, specify the adapter here.
return self.adapter.provider_data_validator return self.adapter.provider_data_validator
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec: def remote_provider_spec(
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
) -> RemoteProviderSpec:
return RemoteProviderSpec( return RemoteProviderSpec(
api=api, api=api,
provider_type=f"remote::{adapter.adapter_type}", provider_type=f"remote::{adapter.adapter_type}",
config_class=adapter.config_class, config_class=adapter.config_class,
adapter=adapter, adapter=adapter,
api_dependencies=api_dependencies or [],
) )

View file

@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import build_model_alias from llama_stack.providers.utils.inference.model_registry import build_model_alias
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.datatypes import ModelsProtocolPrivate from llama_stack.providers.datatypes import ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
convert_image_media_to_url, convert_image_media_to_url,
request_has_media, request_has_media,
) )
from .config import MetaReferenceInferenceConfig from .config import MetaReferenceInferenceConfig
from .generation import Llama from .generation import Llama
from .model_parallel import LlamaModelParallelGenerator from .model_parallel import LlamaModelParallelGenerator
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
SEMAPHORE = asyncio.Semaphore(1) SEMAPHORE = asyncio.Semaphore(1)
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate): class MetaReferenceInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
):
def __init__(self, config: MetaReferenceInferenceConfig) -> None: def __init__(self, config: MetaReferenceInferenceConfig) -> None:
self.config = config self.config = config
model = resolve_model(config.model) model = resolve_model(config.model)
ModelRegistryHelper.__init__( if model is None:
self, raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model_registry_helper = ModelRegistryHelper(
[ [
build_model_alias( build_model_alias(
model.descriptor(), model.descriptor(),
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
) )
], ],
) )
if model is None:
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
self.model = model self.model = model
# verify that the checkpoint actually is for this model lol # verify that the checkpoint actually is for this model lol
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
async def unregister_model(self, model_id: str) -> None: async def unregister_model(self, model_id: str) -> None:
pass pass
async def register_model(self, model: Model) -> Model:
model = await self.model_registry_helper.register_model(model)
if model.model_type == ModelType.embedding:
self._load_sentence_transformer_model(model.provider_resource_id)
return model
async def completion( async def completion(
self, self,
model_id: str, model_id: str,
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
for x in impl(): for x in impl():
yield x yield x
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
raise NotImplementedError()
async def request_with_localized_media( async def request_with_localized_media(
request: Union[ChatCompletionRequest, CompletionRequest], request: Union[ChatCompletionRequest, CompletionRequest],

View file

@ -0,0 +1,20 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.inline.inference.sentence_transformers.config import (
SentenceTransformersInferenceConfig,
)
async def get_provider_impl(
config: SentenceTransformersInferenceConfig,
_deps,
):
from .sentence_transformers import SentenceTransformersInferenceImpl
impl = SentenceTransformersInferenceImpl(config)
await impl.initialize()
return impl

View file

@ -0,0 +1,16 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from pydantic import BaseModel
class SentenceTransformersInferenceConfig(BaseModel):
@classmethod
def sample_run_config(cls) -> Dict[str, Any]:
return {}

View file

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import AsyncGenerator, List, Optional, Union
from llama_stack.apis.inference import (
CompletionResponse,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
ToolChoice,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.embedding_mixin import (
SentenceTransformerEmbeddingMixin,
)
from .config import SentenceTransformersInferenceConfig
log = logging.getLogger(__name__)
class SentenceTransformersInferenceImpl(
SentenceTransformerEmbeddingMixin,
Inference,
ModelsProtocolPrivate,
):
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
self.config = config
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_model(self, model: Model) -> None:
_ = self._load_sentence_transformer_model(model.provider_resource_id)
return model
async def unregister_model(self, model_id: str) -> None:
pass
async def completion(
self,
model_id: str,
content: str,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncGenerator]:
raise ValueError("Sentence transformers don't support completion")
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise ValueError("Sentence transformers don't support chat completion")

View file

@ -4,16 +4,19 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import FaissImplConfig from .config import FaissImplConfig
async def get_provider_impl(config: FaissImplConfig, _deps): async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
from .faiss import FaissMemoryImpl from .faiss import FaissMemoryImpl
assert isinstance( assert isinstance(
config, FaissImplConfig config, FaissImplConfig
), f"Unexpected config type: {type(config)}" ), f"Unexpected config type: {type(config)}"
impl = FaissMemoryImpl(config) impl = FaissMemoryImpl(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -19,11 +19,10 @@ from numpy.typing import NDArray
from llama_models.llama3.api.datatypes import * # noqa: F403 from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.kvstore import kvstore_impl from llama_stack.providers.utils.kvstore import kvstore_impl
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
) )
@ -32,7 +31,8 @@ from .config import FaissImplConfig
logger = logging.getLogger(__name__) logger = logging.getLogger(__name__)
MEMORY_BANKS_PREFIX = "memory_banks:v1::" MEMORY_BANKS_PREFIX = "memory_banks:v2::"
FAISS_INDEX_PREFIX = "faiss_index:v2::"
class FaissIndex(EmbeddingIndex): class FaissIndex(EmbeddingIndex):
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
if not self.kvstore: if not self.kvstore:
return return
index_key = f"faiss_index:v1::{self.bank_id}" index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
stored_data = await self.kvstore.get(index_key) stored_data = await self.kvstore.get(index_key)
if stored_data: if stored_data:
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"), "faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
} }
index_key = f"faiss_index:v1::{self.bank_id}" index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
await self.kvstore.set(key=index_key, value=json.dumps(data)) await self.kvstore.set(key=index_key, value=json.dumps(data))
async def delete(self): async def delete(self):
if not self.kvstore or not self.bank_id: if not self.kvstore or not self.bank_id:
return return
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}") await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray): async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
# Add dimension check
embedding_dim = (
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
)
if embedding_dim != self.index.d:
raise ValueError(
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
)
indexlen = len(self.id_by_index) indexlen = len(self.id_by_index)
for i, chunk in enumerate(chunks): for i, chunk in enumerate(chunks):
self.chunk_by_index[indexlen + i] = chunk self.chunk_by_index[indexlen + i] = chunk
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate): class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: FaissImplConfig) -> None: def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.cache = {} self.cache = {}
self.kvstore = None self.kvstore = None
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
for bank_data in stored_banks: for bank_data in stored_banks:
bank = VectorMemoryBank.model_validate_json(bank_data) bank = VectorMemoryBank.model_validate_json(bank_data)
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank,
index=await FaissIndex.create( await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier bank.embedding_dimension, self.kvstore, bank.identifier
), ),
self.inference_api,
) )
self.cache[bank.identifier] = index self.cache[bank.identifier] = index
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
) )
# Store in cache # Store in cache
index = BankWithIndex( self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank, memory_bank,
index=await FaissIndex.create( await FaissIndex.create(
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
), ),
self.inference_api,
) )
self.cache[memory_bank.identifier] = index
async def list_memory_banks(self) -> List[MemoryBank]: async def list_memory_banks(self) -> List[MemoryBank]:
return [i.bank for i in self.cache.values()] return [i.bank for i in self.cache.values()]

View file

@ -0,0 +1,27 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import TorchtunePostTrainingConfig
# post_training api and the torchtune provider is still experimental and under heavy development
async def get_provider_impl(
config: TorchtunePostTrainingConfig,
deps: Dict[Api, ProviderSpec],
):
from .post_training import TorchtunePostTrainingImpl
impl = TorchtunePostTrainingImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
)
return impl

View file

@ -0,0 +1,157 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import shutil
from pathlib import Path
from typing import Any, Dict, List
import torch
from torchtune import training
from torchtune.models import convert_weights
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
from torchtune.utils._logging import get_logger
logger = get_logger("DEBUG")
class TorchtuneCheckpointer:
def __init__(
self,
model_id: str,
training_algorithm: str,
checkpoint_dir: str,
checkpoint_files: List[str],
output_dir: str,
model_type: str,
) -> None:
# Fail fast if ``checkpoint_files`` is invalid
# TODO: support loading more than one file
if len(checkpoint_files) != 1:
raise ValueError(
"Currently we only support reading from a single torchtune checkpoint file. "
f"Got {len(checkpoint_files)} files instead."
)
self._checkpoint_file = checkpoint_files[0]
self._model_id = model_id
self._training_algorithm = training_algorithm
self._checkpoint_dir = Path(checkpoint_dir)
self._model_type = ModelType[model_type]
self._output_dir = output_dir
# get ckpt paths
self._checkpoint_path = Path.joinpath(
self._checkpoint_dir, self._checkpoint_file
)
def load_checkpoint(self) -> Dict[str, Any]:
"""
Load Meta checkpoint from file. Currently only loading from a single file is supported.
"""
state_dict: Dict[str:Any] = {}
model_state_dict = safe_torch_load(self._checkpoint_path)
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_meta_to_tune,
)
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
model_state_dict
)
else:
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
model_state_dict
)
# llama3_2 has tied weights, so we need to remove the output.weight key
if self._model_type == ModelType.LLAMA3_2:
logger.info(
"Identified model_type = Llama3_2. Ignoring output.weight in"
" checkpoint in favor of the tok_embedding.weight"
" tied weights."
)
state_dict[training.MODEL_KEY].pop("output.weight")
return state_dict
def save_checkpoint(
self,
state_dict: Dict[str, Any],
epoch: int,
adapter_only: bool = False,
) -> str:
model_file_path = (
Path(self._output_dir)
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
)
model_file_path.mkdir(parents=True, exist_ok=True)
# copy the related files for inference
shutil.copy(
Path.joinpath(self._checkpoint_dir, "params.json"),
Path.joinpath(model_file_path, "params.json"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
Path.joinpath(model_file_path, "tokenizer.model"),
)
shutil.copy(
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
Path.joinpath(model_file_path, "orig_params.json"),
)
if not adapter_only:
model_state_dict = state_dict[training.MODEL_KEY]
if self._model_type == ModelType.LLAMA3_VISION:
from torchtune.models.llama3_2_vision._convert_weights import (
llama3_vision_tune_to_meta,
)
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
model_state_dict
)
else:
# llama3_2 has tied weights, so we need to add the output.weight key
if (
self._model_type == ModelType.LLAMA3_2
and "output.weight" not in model_state_dict
):
model_state_dict["output.weight"] = model_state_dict[
"tok_embeddings.weight"
]
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
model_state_dict
)
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
torch.save(state_dict[training.MODEL_KEY], model_file_name)
logger.info(
"Model checkpoint of size "
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
f"saved to {model_file_name}"
)
if training.ADAPTER_KEY in state_dict:
adapter_file_path = model_file_path / "adapter"
adapter_file_path.mkdir(parents=True, exist_ok=True)
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
logger.info(
"Adapter checkpoint of size "
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
f"saved to {adapter_file_name}"
)
elif adapter_only:
raise ValueError(
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
)
print("model_file_path", str(model_file_path))
return str(model_file_path)

View file

@ -0,0 +1,139 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from typing import Any, Callable, Dict, List
import torch
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.common.type_system import * # noqa
from llama_models.datatypes import Model
from llama_models.sku_list import resolve_model
from llama_stack.apis.common.type_system import ParamType
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
from torchtune.models.llama3_2 import lora_llama3_2_3b
class ColumnName(Enum):
instruction = "instruction"
input = "input"
output = "output"
text = "text"
class ModelConfig(BaseModel):
model_definition: Any
tokenizer_type: Any
checkpoint_type: str
class DatasetSchema(BaseModel):
alpaca: List[Dict[str, ParamType]]
MODEL_CONFIGS: Dict[str, ModelConfig] = {
"Llama3.2-3B-Instruct": ModelConfig(
model_definition=lora_llama3_2_3b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3_2",
),
"Llama-3-8B-Instruct": ModelConfig(
model_definition=lora_llama3_8b,
tokenizer_type=llama3_tokenizer,
checkpoint_type="LLAMA3",
),
}
EXPECTED_DATASET_SCHEMA = DatasetSchema(
alpaca=[
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
ColumnName.text.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.input.value: StringType(),
ColumnName.output.value: StringType(),
},
{
ColumnName.instruction.value: StringType(),
ColumnName.output.value: StringType(),
},
]
)
BuildLoraModelCallable = Callable[..., torch.nn.Module]
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
def _validate_model_id(model_id: str) -> Model:
model = resolve_model(model_id)
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
raise ValueError(f"Model {model_id} is not supported.")
return model
async def get_model_definition(
model_id: str,
) -> BuildLoraModelCallable:
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "model_definition"):
raise ValueError(f"Model {model_id} does not have model definition.")
return model_config.model_definition
async def get_tokenizer_type(
model_id: str,
) -> BuildTokenizerCallable:
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "tokenizer_type"):
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
return model_config.tokenizer_type
async def get_checkpointer_model_type(
model_id: str,
) -> str:
"""
checkpointer model type is used in checkpointer for some special treatment on some specific model types
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
"""
model = _validate_model_id(model_id)
model_config = MODEL_CONFIGS[model.core_model_id.value]
if not hasattr(model_config, "checkpoint_type"):
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
return model_config.checkpoint_type
async def validate_input_dataset_schema(
datasets_api: Datasets,
dataset_id: str,
dataset_type: str,
) -> None:
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(f"Dataset type {dataset_type} is not supported.")
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
raise ValueError(
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
)

View file

@ -0,0 +1,13 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Optional
from pydantic import BaseModel
class TorchtunePostTrainingConfig(BaseModel):
torch_seed: Optional[int] = None

View file

@ -0,0 +1,66 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the BSD-style license found in the
# LICENSE file in the root directory of this source tree.
from typing import Any, Dict, List, Mapping
import numpy as np
from torch.utils.data import Dataset
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
from torchtune.data._messages import validate_messages
from torchtune.modules.transforms import Transform
class SFTDataset(Dataset):
def __init__(
self,
rows: List[Dict[str, Any]],
message_transform: Transform,
model_transform: Transform,
) -> None:
self._rows = rows
self._message_transform = message_transform
self._model_transform = model_transform
def __len__(self):
return len(self._rows)
def __getitem__(self, index: int) -> Dict[str, Any]:
sample = self._rows[index]
return self._prepare_sample(sample)
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
transformed_sample = self._message_transform(sample)
if "messages" in transformed_sample:
validate_messages(transformed_sample["messages"])
tokenized_dict = self._model_transform(transformed_sample)
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
keys_str = ", ".join(tokenized_dict.keys())
error_message = (
"model_transform returned the following keys: "
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
)
raise ValueError(error_message)
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
tokenized_dict["labels"] = list(
np.where(
tokenized_dict["mask"],
CROSS_ENTROPY_IGNORE_IDX,
tokenized_dict["tokens"],
)
)
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])
return tokenized_dict

View file

@ -0,0 +1,126 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.apis.post_training import * # noqa
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
LoraFinetuningSingleDevice,
)
class TorchtunePostTrainingImpl:
def __init__(
self,
config: TorchtunePostTrainingConfig,
datasetio_api: DatasetIO,
datasets: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets
# TODO: assume sync job, will need jobs API for async scheduling
self.jobs_status = {}
self.jobs_list = []
self.checkpoints_dict = {}
async def supervised_fine_tune(
self,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
) -> PostTrainingJob:
for job in self.jobs_list:
if job_uuid == job.job_uuid:
raise ValueError(f"Job {job_uuid} already exists")
post_training_job = PostTrainingJob(job_uuid=job_uuid)
job_status_response = PostTrainingJobStatusResponse(
job_uuid=job_uuid,
status=JobStatus.scheduled,
scheduled_at=datetime.now(),
)
self.jobs_list.append(post_training_job)
if isinstance(algorithm_config, LoraFinetuningConfig):
try:
recipe = LoraFinetuningSingleDevice(
self.config,
job_uuid,
training_config,
hyperparam_search_config,
logger_config,
model,
checkpoint_dir,
algorithm_config,
self.datasetio_api,
self.datasets_api,
)
job_status_response.status = JobStatus.in_progress
job_status_response.started_at = datetime.now()
await recipe.setup()
resources_allocated, checkpoints = await recipe.train()
self.checkpoints_dict[job_uuid] = checkpoints
job_status_response.resources_allocated = resources_allocated
job_status_response.checkpoints = checkpoints
job_status_response.status = JobStatus.completed
job_status_response.completed_at = datetime.now()
except Exception:
job_status_response.status = JobStatus.failed
raise
else:
raise NotImplementedError()
self.jobs_status[job_uuid] = job_status_response
return post_training_job
async def preference_optimize(
self,
job_uuid: str,
finetuned_model: str,
algorithm_config: DPOAlignmentConfig,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
) -> PostTrainingJob: ...
async def get_training_jobs(self) -> List[PostTrainingJob]:
return self.jobs_list
@webmethod(route="/post-training/job/status")
async def get_training_job_status(
self, job_uuid: str
) -> Optional[PostTrainingJobStatusResponse]:
if job_uuid in self.jobs_status:
return self.jobs_status[job_uuid]
return None
@webmethod(route="/post-training/job/cancel")
async def cancel_training_job(self, job_uuid: str) -> None:
raise NotImplementedError("Job cancel is not implemented yet")
@webmethod(route="/post-training/job/artifacts")
async def get_training_job_artifacts(
self, job_uuid: str
) -> Optional[PostTrainingJobArtifactsResponse]:
if job_uuid in self.checkpoints_dict:
checkpoints = self.checkpoints_dict.get(job_uuid, [])
return PostTrainingJobArtifactsResponse(
job_uuid=job_uuid, checkpoints=checkpoints
)
return None

View file

@ -0,0 +1,596 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
import os
import time
from functools import partial
from pathlib import Path
from typing import Any, Dict, List, Optional, Tuple
import torch
from llama_models.sku_list import resolve_model
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
TorchtuneCheckpointer,
)
from torch import nn
from torchtune import utils as torchtune_utils
from torchtune.training.metric_logging import DiskLogger
from tqdm import tqdm
from llama_stack.apis.post_training import * # noqa
from llama_stack.distribution.utils.model_utils import model_local_dir
from llama_stack.providers.inline.post_training.torchtune.common import utils
from llama_stack.providers.inline.post_training.torchtune.config import (
TorchtunePostTrainingConfig,
)
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
from torch.optim import Optimizer
from torch.utils.data import DataLoader, DistributedSampler
from torchtune import modules, training
from torchtune.data import AlpacaToMessages, padded_collate_sft
from torchtune.modules.loss import CEWithChunkedOutputLoss
from torchtune.modules.peft import (
get_adapter_params,
get_adapter_state_dict,
get_lora_module_names,
get_merged_lora_ckpt,
load_dora_magnitudes,
set_trainable_params,
validate_missing_and_unexpected_for_lora,
)
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
log = logging.getLogger(__name__)
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
class LoraFinetuningSingleDevice:
# This recipe only supports GPU training
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
# - compile
# - activation offloading
# Resume from checkpoint hasn't been supported yet
# Validation hasn't been supported yet
# Currently logging only logs limited training metrics to local disk
# will figure out more loggings and how it works with telemetry in future PRs
def __init__(
self,
config: TorchtunePostTrainingConfig,
job_uuid: str,
training_config: TrainingConfig,
hyperparam_search_config: Dict[str, Any],
logger_config: Dict[str, Any],
model: str,
checkpoint_dir: Optional[str],
algorithm_config: Optional[AlgorithmConfig],
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.job_uuid = job_uuid
self.training_config = training_config
if not isinstance(algorithm_config, LoraFinetuningConfig):
raise ValueError(
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
)
self.algorithm_config = algorithm_config
self._device = torchtune_utils.get_device(device="cuda")
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
self.model_id = model
def model_checkpoint_dir(model) -> str:
checkpoint_dir = Path(model_local_dir(model.descriptor()))
paths = [
Path(checkpoint_dir / f"consolidated.{ext}")
for ext in ["pth", "00.pth"]
]
if not any(p.exists() for p in paths):
checkpoint_dir = checkpoint_dir / "original"
assert checkpoint_dir.exists(), (
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
f"Please download model using `llama download --model-id {model.descriptor()}`"
)
return str(checkpoint_dir)
if checkpoint_dir and checkpoint_dir != "null":
self.checkpoint_dir = config.checkpoint_dir
else:
model = resolve_model(self.model_id)
self.checkpoint_dir = model_checkpoint_dir(model)
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
self.seed = training.set_seed(seed=config.torch_seed)
self.epochs_run = 0
self.total_epochs = training_config.n_epochs
self._shuffle = training_config.data_config.shuffle
self._batch_size = training_config.data_config.batch_size
# this is important for debugging purpose
self.max_steps_per_epoch = training_config.max_steps_per_epoch
self.global_step = 0
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
self._clip_grad_norm = 1.0
self._enable_activation_checkpointing = (
(training_config.efficiency_config.enable_activation_checkpointing)
if training_config.efficiency_config
else False
)
self._enable_activation_offloading = (
(training_config.efficiency_config.enable_activation_offloading)
if training_config.efficiency_config
else False
)
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
async def load_checkpoint(self):
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
try:
# List all files in the given directory
files = os.listdir(checkpoint_dir)
# Filter files that end with .pth
pth_files = [file for file in files if file.endswith(".pth")]
return pth_files
except FileNotFoundError:
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
self._checkpointer = TorchtuneCheckpointer(
model_id=self.model_id,
training_algorithm="sft",
checkpoint_dir=self.checkpoint_dir,
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
output_dir=self._output_dir,
model_type=await utils.get_checkpointer_model_type(self.model_id),
)
checkpoint_dict = self._checkpointer.load_checkpoint()
return checkpoint_dict
async def setup(self) -> None:
checkpoint_dict = await self.load_checkpoint()
self._model = await self._setup_model(
enable_activation_checkpointing=self._enable_activation_checkpointing,
enable_activation_offloading=self._enable_activation_offloading,
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
lora_weights_state_dict=None,
)
log.info(f"Model is initialized with precision {self._dtype}.")
self._tokenizer = await self._setup_tokenizer()
log.info("Tokenizer is initialized.")
self._optimizer = await self._setup_optimizer(
optimizer_config=self.training_config.optimizer_config
)
log.info("Optimizer is initialized.")
self._loss_fn = CEWithChunkedOutputLoss()
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
log.info("Loss is initialized.")
self._training_sampler, self._training_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.dataset_id,
tokenizer=self._tokenizer,
shuffle=self._shuffle,
batch_size=self._batch_size,
)
if self.training_config.data_config.validation_dataset_id:
_, self._validation_dataloader = await self._setup_data(
dataset_id=self.training_config.data_config.validation_dataset_id,
tokenizer=self._tokenizer,
shuffle=False,
batch_size=self._batch_size,
)
log.info("Dataset and Sampler are initialized.")
# Number of training steps in each epoch depends on the number of batches produced
# by the dataloader and the max_steps_per_epoch param set by the user and is used
# for logging and tracking training state. This should be computed after the dataloader
# has been setup
self._steps_per_epoch = (
len(self._training_dataloader) // self._gradient_accumulation_steps
)
if (
self.max_steps_per_epoch is not None
and self.max_steps_per_epoch < self._steps_per_epoch
):
self._steps_per_epoch = self.max_steps_per_epoch
self.global_step = self.epochs_run * self._steps_per_epoch
# Learning rate scheduler can only be set up after number of steps
# has been computed
self._lr_scheduler = await self._setup_lr_scheduler(
num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
num_training_steps=self.total_epochs * self._steps_per_epoch,
last_epoch=self.global_step - 1,
)
log.info("Learning rate scheduler is initialized.")
# Used to ignore labels for loss computation
self.ignore_labels_cache = torch.full(
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
)
async def _setup_model(
self,
enable_activation_checkpointing: bool,
enable_activation_offloading: bool,
base_model_state_dict: Dict[str, Any],
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
) -> nn.Module:
self._lora_rank = self.algorithm_config.rank
self._lora_alpha = self.algorithm_config.alpha
self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
self._use_dora = self.algorithm_config.use_dora or False
with training.set_default_dtype(self._dtype), self._device:
model_type = await utils.get_model_definition(self.model_id)
model = model_type(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
lora_rank=self._lora_rank,
lora_alpha=self._lora_alpha,
quantize_base=False,
use_dora=self._use_dora,
)
self.adapter_params = get_adapter_params(model)
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
set_trainable_params(model, self.adapter_params)
if enable_activation_checkpointing:
training.set_activation_checkpointing(
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
)
base_missing, base_unexpected = model.load_state_dict(
base_model_state_dict, strict=False
)
# This is for any adapters that need to be initialized after base weights
# have been loaded (e.g. DoRA).
if self._is_dora:
for m in model.modules():
if hasattr(m, "initialize_dora_magnitude"):
m.initialize_dora_magnitude()
load_dora_magnitudes(model)
if lora_weights_state_dict:
lora_missing, lora_unexpected = model.load_state_dict(
lora_weights_state_dict, strict=False
)
else:
lora_missing, lora_unexpected = None, None
validate_missing_and_unexpected_for_lora(
lora_attn_modules=self._lora_attn_modules,
apply_lora_to_mlp=self._apply_lora_to_mlp,
apply_lora_to_output=self._apply_lora_to_output,
base_missing=base_missing,
base_unexpected=base_unexpected,
lora_missing=lora_missing,
lora_unexpected=lora_unexpected,
)
# Validate model adapter params were loaded in with the expected dtype
training.validate_expected_param_dtype(
self.adapter_params.items(), dtype=self._dtype
)
# activation offloading
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
model, enable_activation_offloading
)
memory_stats = training.get_memory_stats(device=self._device)
training.log_memory_stats(memory_stats)
return model
async def _setup_tokenizer(
self,
) -> Llama3Tokenizer:
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
tokenizer_type = await utils.get_tokenizer_type(self.model_id)
return tokenizer_type(path=tokenizer_path)
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
optimizer = torch.optim.AdamW(
params=self._model.parameters(),
lr=optimizer_config.lr,
betas=(0.9, 0.95),
eps=1e-8,
weight_decay=0.1,
)
return optimizer
async def _setup_data(
self,
dataset_id: str,
tokenizer: Llama3Tokenizer,
shuffle: bool,
batch_size: int,
) -> Tuple[DistributedSampler, DataLoader]:
async def fetch_rows(dataset_id: str):
return await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
all_rows = await fetch_rows(dataset_id)
rows = all_rows.rows
# Curretly only support alpaca instruct dataset
# TODO @SLR722 make the message_transform swappable and support more dataset types
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
await utils.validate_input_dataset_schema(
datasets_api=self.datasets_api,
dataset_id=dataset_id,
dataset_type="alpaca",
)
ds = SFTDataset(
rows,
message_transform=AlpacaToMessages(train_on_input=False),
model_transform=tokenizer,
)
sampler = DistributedSampler(
ds,
num_replicas=1,
rank=0,
shuffle=shuffle,
seed=0,
)
dataloader = DataLoader(
dataset=ds,
sampler=sampler,
batch_size=batch_size,
# dropping last avoids shape issues with compile + flex attention
drop_last=True,
collate_fn=(
partial(
padded_collate_sft,
padding_idx=self._tokenizer.pad_id,
ignore_idx=self._loss_fn.ignore_index,
)
),
)
return sampler, dataloader
async def _setup_lr_scheduler(
self,
num_warmup_steps: int,
num_training_steps: int,
last_epoch: int,
) -> Optimizer:
lr_scheduler = get_cosine_schedule_with_warmup(
self._optimizer,
num_warmup_steps=num_warmup_steps,
num_training_steps=num_training_steps,
last_epoch=last_epoch,
)
return lr_scheduler
async def save_checkpoint(self, epoch: int) -> str:
ckpt_dict = {}
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
# Construct the full state dict with LoRA weights merged into base LLM weights
# Move to CPU to avoid a copy on GPU
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
merged_state_dict = get_merged_lora_ckpt(
state_dict,
rank=self._lora_rank,
alpha=self._lora_alpha,
)
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
adapter_config = {
"r": self._lora_rank,
"lora_alpha": self._lora_alpha,
"target_modules": get_lora_module_names(
self._lora_attn_modules,
self._apply_lora_to_mlp,
self._apply_lora_to_output,
),
"peft_type": "LORA",
}
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
return self._checkpointer.save_checkpoint(
ckpt_dict,
epoch=epoch,
)
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
# Shape [b, s], needed for the loss not the model
labels = batch.pop("labels")
# run model
with self.activations_handling_ctx:
logits = self._model(**batch)
# Shift labels to compute loss
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
# But this way we dont need to slice the logits. We just add an ignore index to labels.
labels = torch.hstack(
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
)
if not isinstance(logits, list):
labels = labels.reshape(-1)
logits = logits.reshape(-1, logits.size(-1))
loss = self._loss_fn(logits, labels)
# free logits otherwise it peaks backward memory
del logits
return loss
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
"""
The core training loop.
"""
# Initialize tokens count and running loss (for grad accumulation)
t0 = time.perf_counter()
running_loss = 0
num_tokens = 0
# training artifacts
checkpoints = []
memory_stats = {}
# self.epochs_run should be non-zero when we're resuming from a checkpoint
for curr_epoch in range(self.epochs_run, self.total_epochs):
# Update the sampler to ensure data is correctly shuffled across epochs
# in case shuffle is True
metric_logger = DiskLogger(
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
)
self._training_sampler.set_epoch(curr_epoch)
loss_to_log = 0.0
pbar = tqdm(total=self._steps_per_epoch)
for idx, batch in enumerate(self._training_dataloader):
if (
self.max_steps_per_epoch is not None
and (idx // self._gradient_accumulation_steps)
== self.max_steps_per_epoch
):
break
torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
current_num_tokens = (
batch["labels"] != self._loss_fn.ignore_index
).sum()
num_tokens += current_num_tokens
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
current_loss = await self._loss_step(batch) * current_num_tokens
running_loss += current_loss
current_loss.backward()
# Step with optimizer
if (idx + 1) % self._gradient_accumulation_steps == 0:
training.scale_grads(self._model, 1 / num_tokens)
grad_norm = torch.nn.utils.clip_grad_norm_(
self._model.parameters(),
max_norm=float(self._clip_grad_norm),
)
self._optimizer.step()
self._optimizer.zero_grad(set_to_none=True)
self._lr_scheduler.step()
# Update the number of steps when the weights are updated
self.global_step += 1
loss_to_log = running_loss.item() / num_tokens
pbar.update(1)
pbar.set_description(
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
)
time_per_step = time.perf_counter() - t0
log_dict = {
"loss": loss_to_log,
"lr": self._optimizer.param_groups[0]["lr"],
"tokens_per_second_per_gpu": num_tokens / time_per_step,
}
memory_stats = training.get_memory_stats(device=self._device)
log_dict.update(memory_stats)
if self._clip_grad_norm is not None:
log_dict.update({"grad_norm": grad_norm})
metric_logger.log_dict(
log_dict,
step=self.global_step,
)
# Reset running stats for the next step
running_loss = 0
num_tokens = 0
t0 = time.perf_counter()
self.epochs_run += 1
log.info("Starting checkpoint save...")
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
checkpoint = Checkpoint(
identifier=f"{self.model_id}-sft-{curr_epoch}",
created_at=datetime.now(),
epoch=curr_epoch,
post_training_job_id=self.job_uuid,
path=checkpoint_path,
)
if self.training_config.data_config.validation_dataset_id:
validation_loss, perplexity = await self.validation()
training_metrics = PostTrainingMetric(
epoch=curr_epoch,
train_loss=loss_to_log,
validation_loss=validation_loss,
perplexity=perplexity,
)
checkpoint.training_metrics = training_metrics
checkpoints.append(checkpoint)
return (memory_stats, checkpoints)
async def validation(self) -> Tuple[float, float]:
total_loss = 0.0
total_tokens = 0
log.info("Starting validation...")
pbar = tqdm(total=len(self._validation_dataloader))
for idx, batch in enumerate(self._validation_dataloader):
if idx == 10:
break
torchtune_utils.batch_to_device(batch, self._device)
# Calculate the number of unmasked tokens in the current batch
# and increment the total number of tokens seen in the step
num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
# Loss is normalized by default so we multiply by the number of tokens
# This way we can normalize by the total number of tokens if we're accumulating gradients
loss = await self._loss_step(batch) * num_tokens
total_loss += loss
total_tokens += num_tokens
pbar.update(1)
pbar.set_description(f"validation step: {idx}")
mean_loss = total_loss / total_tokens
perplexity = torch.exp(torch.tensor(mean_loss))
return mean_loss, perplexity.item()

View file

@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [
"transformers", "transformers",
"zmq", "zmq",
"lm-format-enforcer", "lm-format-enforcer",
"sentence-transformers",
] ]
@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.inference.vllm", module="llama_stack.providers.inline.inference.vllm",
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig", config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
), ),
InlineProviderSpec(
api=Api.inference,
provider_type="inline::sentence-transformers",
pip_packages=["sentence-transformers"],
module="llama_stack.providers.inline.inference.sentence_transformers",
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
),
remote_provider_spec( remote_provider_spec(
api=Api.inference, api=Api.inference,
adapter=AdapterSpec( adapter=AdapterSpec(

View file

@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
deprecation_warning="Please use the `inline::faiss` provider instead.", deprecation_warning="Please use the `inline::faiss` provider instead.",
api_dependencies=[Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.memory, api=Api.memory,
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"], pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
module="llama_stack.providers.inline.memory.faiss", module="llama_stack.providers.inline.memory.faiss",
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig", config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.chroma", module="llama_stack.providers.remote.memory.chroma",
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig", config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
), ),
api_dependencies=[Api.inference],
), ),
InlineProviderSpec( InlineProviderSpec(
api=Api.memory, api=Api.memory,
@ -71,6 +74,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.pgvector", module="llama_stack.providers.remote.memory.pgvector",
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig", config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -81,6 +85,7 @@ def available_providers() -> List[ProviderSpec]:
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig", config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData", provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
), ),
api_dependencies=[Api.inference],
), ),
remote_provider_spec( remote_provider_spec(
api=Api.memory, api=Api.memory,
@ -90,6 +95,7 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.sample", module="llama_stack.providers.remote.memory.sample",
config_class="llama_stack.providers.remote.memory.sample.SampleConfig", config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
), ),
api_dependencies=[],
), ),
remote_provider_spec( remote_provider_spec(
Api.memory, Api.memory,
@ -99,5 +105,6 @@ def available_providers() -> List[ProviderSpec]:
module="llama_stack.providers.remote.memory.qdrant", module="llama_stack.providers.remote.memory.qdrant",
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig", config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
), ),
api_dependencies=[Api.inference],
), ),
] ]

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.post_training,
provider_type="inline::torchtune",
pip_packages=["torch", "torchtune", "torchao", "numpy"],
module="llama_stack.providers.inline.post_training.torchtune",
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
]

View file

@ -21,14 +21,19 @@ DATASETS_PREFIX = "datasets:"
def load_hf_dataset(dataset_def: Dataset): def load_hf_dataset(dataset_def: Dataset):
if dataset_def.metadata.get("path", None): if dataset_def.metadata.get("path", None):
return hf_datasets.load_dataset(**dataset_def.metadata) dataset = hf_datasets.load_dataset(**dataset_def.metadata)
else:
df = get_dataframe_from_url(dataset_def.url)
df = get_dataframe_from_url(dataset_def.url) if df is None:
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
if df is None: dataset = hf_datasets.Dataset.from_pandas(df)
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
# drop columns not specified by schema
if dataset_def.dataset_schema:
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
dataset = hf_datasets.Dataset.from_pandas(df)
return dataset return dataset

View file

@ -5,6 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from typing import * # noqa: F403 from typing import * # noqa: F403
import json
from botocore.client import BaseClient from botocore.client import BaseClient
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
@ -19,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
from llama_stack.apis.inference import * # noqa: F403 from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
from llama_stack.providers.utils.bedrock.client import create_bedrock_client from llama_stack.providers.utils.bedrock.client import create_bedrock_client
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
model_aliases = [ model_aliases = [
@ -448,4 +451,21 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
embeddings = []
for content in contents:
assert not content_has_media(
content
), "Bedrock does not support media for embeddings"
input_text = interleaved_text_media_as_str(content)
input_body = {"inputText": input_text}
body = json.dumps(input_body)
response = self.client.invoke_model(
body=body,
modelId=model.provider_resource_id,
accept="application/json",
contentType="application/json",
)
response_body = json.loads(response.get("body").read())
embeddings.append(response_body.get("embedding"))
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
@json_schema_type @json_schema_type
class FireworksImplConfig(BaseModel): class FireworksImplConfig(BaseModel):
url: str = Field( url: str = Field(
default="https://api.fireworks.ai/inference", default="https://api.fireworks.ai/inference/v1",
description="The URL for the Fireworks server", description="The URL for the Fireworks server",
) )
api_key: Optional[str] = Field( api_key: Optional[str] = Field(
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
@classmethod @classmethod
def sample_run_config(cls) -> Dict[str, Any]: def sample_run_config(cls) -> Dict[str, Any]:
return { return {
"url": "https://api.fireworks.ai/inference", "url": "https://api.fireworks.ai/inference/v1",
"api_key": "${env.FIREWORKS_API_KEY}", "api_key": "${env.FIREWORKS_API_KEY}",
} }

View file

@ -4,7 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import AsyncGenerator from typing import AsyncGenerator, List, Optional, Union
from fireworks.client import Fireworks from fireworks.client import Fireworks
from llama_models.datatypes import CoreModelId from llama_models.datatypes import CoreModelId
@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
convert_message_to_dict, convert_message_to_dict,
request_has_media, request_has_media,
) )
@ -89,17 +90,19 @@ class FireworksInferenceAdapter(
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
def _get_client(self) -> Fireworks: def _get_api_key(self) -> str:
fireworks_api_key = None
if self.config.api_key is not None: if self.config.api_key is not None:
fireworks_api_key = self.config.api_key return self.config.api_key
else: else:
provider_data = self.get_request_provider_data() provider_data = self.get_request_provider_data()
if provider_data is None or not provider_data.fireworks_api_key: if provider_data is None or not provider_data.fireworks_api_key:
raise ValueError( raise ValueError(
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}' 'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
) )
fireworks_api_key = provider_data.fireworks_api_key return provider_data.fireworks_api_key
def _get_client(self) -> Fireworks:
fireworks_api_key = self._get_api_key()
return Fireworks(api_key=fireworks_api_key) return Fireworks(api_key=fireworks_api_key)
async def completion( async def completion(
@ -264,4 +267,19 @@ class FireworksInferenceAdapter(
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
kwargs = {}
if model.metadata.get("embedding_dimensions"):
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
not content_has_media(content) for content in contents
), "Fireworks does not support media for embeddings"
response = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
convert_image_media_to_url, convert_image_media_to_url,
request_has_media, request_has_media,
) )
@ -321,9 +322,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Ollama does not support media for embeddings"
response = await self.client.embed(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
)
embeddings = response["embeddings"]
return EmbeddingsResponse(embeddings=embeddings)
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
# ollama does not have embedding models running. Check if the model is in list of available models.
if model.model_type == ModelType.embedding:
response = await self.client.list()
available_models = [m["model"] for m in response["models"]]
if model.provider_resource_id not in available_models:
raise ValueError(
f"Model '{model.provider_resource_id}' is not available in Ollama. "
f"Available models: {', '.join(available_models)}"
)
return model
model = await self.register_helper.register_model(model) model = await self.register_helper.register_model(model)
models = await self.client.ps() models = await self.client.ps()
available_models = [m["model"] for m in models["models"]] available_models = [m["model"] for m in models["models"]]

View file

@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
convert_message_to_dict, convert_message_to_dict,
request_has_media, request_has_media,
) )
@ -253,4 +254,13 @@ class TogetherInferenceAdapter(
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
assert all(
not content_has_media(content) for content in contents
), "Together does not support media for embeddings"
r = self._get_client().embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
)
embeddings = [item.embedding for item in r.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
from llama_stack.providers.utils.inference.prompt_adapter import ( from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt, chat_completion_request_to_prompt,
completion_request_to_prompt, completion_request_to_prompt,
content_has_media,
convert_message_to_dict, convert_message_to_dict,
request_has_media, request_has_media,
) )
@ -203,4 +204,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
model_id: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ) -> EmbeddingsResponse:
raise NotImplementedError() model = await self.model_store.get_model(model_id)
kwargs = {}
assert model.model_type == ModelType.embedding
assert model.metadata.get("embedding_dimensions")
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
assert all(
not content_has_media(content) for content in contents
), "VLLM does not support media for embeddings"
response = self.client.embeddings.create(
model=model.provider_resource_id,
input=[interleaved_text_media_as_str(content) for content in contents],
**kwargs,
)
embeddings = [data.embedding for data in response.data]
return EmbeddingsResponse(embeddings=embeddings)

View file

@ -4,12 +4,18 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import ChromaRemoteImplConfig from .config import ChromaRemoteImplConfig
async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps): async def get_adapter_impl(
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
):
from .chroma import ChromaMemoryAdapter from .chroma import ChromaMemoryAdapter
impl = ChromaMemoryAdapter(config) impl = ChromaMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -13,8 +13,7 @@ import chromadb
from numpy.typing import NDArray from numpy.typing import NDArray
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
@ -87,10 +86,14 @@ class ChromaIndex(EmbeddingIndex):
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__( def __init__(
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig] self,
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
inference_api: Api.inference,
) -> None: ) -> None:
log.info(f"Initializing ChromaMemoryAdapter with url: {config}") log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
self.config = config self.config = config
self.inference_api = inference_api
self.client = None self.client = None
self.cache = {} self.cache = {}
@ -127,10 +130,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
metadata={"bank": memory_bank.model_dump_json()}, metadata={"bank": memory_bank.model_dump_json()},
) )
) )
bank_index = BankWithIndex( self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank, index=ChromaIndex(self.client, collection) memory_bank, ChromaIndex(self.client, collection), self.inference_api
) )
self.cache[memory_bank.identifier] = bank_index
async def unregister_memory_bank(self, memory_bank_id: str) -> None: async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete() await self.cache[memory_bank_id].index.delete()
@ -166,6 +168,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
collection = await maybe_await(self.client.get_collection(bank_id)) collection = await maybe_await(self.client.get_collection(bank_id))
if not collection: if not collection:
raise ValueError(f"Bank {bank_id} not found in Chroma") raise ValueError(f"Bank {bank_id} not found in Chroma")
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection)) index = BankWithIndex(
bank, ChromaIndex(self.client, collection), self.inference_api
)
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import PGVectorConfig from .config import PGVectorConfig
async def get_adapter_impl(config: PGVectorConfig, _deps): async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
from .pgvector import PGVectorMemoryAdapter from .pgvector import PGVectorMemoryAdapter
impl = PGVectorMemoryAdapter(config) impl = PGVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
ALL_MINILM_L6_V2_DIMENSION,
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
) )
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: PGVectorConfig) -> None: def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.cursor = None self.cursor = None
self.conn = None self.conn = None
self.cache = {} self.cache = {}
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def register_memory_bank( async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
self,
memory_bank: MemoryBank,
) -> None:
assert ( assert (
memory_bank.memory_bank_type == MemoryBankType.vector.value memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
upsert_models( upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
self.cursor, index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
[ self.cache[memory_bank.identifier] = BankWithIndex(
(memory_bank.identifier, memory_bank), memory_bank, index, self.inference_api
],
) )
index = BankWithIndex(
bank=memory_bank,
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
)
self.cache[memory_bank.identifier] = index
async def unregister_memory_bank(self, memory_bank_id: str) -> None: async def unregister_memory_bank(self, memory_bank_id: str) -> None:
await self.cache[memory_bank_id].index.delete() await self.cache[memory_bank_id].index.delete()
del self.cache[memory_bank_id] del self.cache[memory_bank_id]
@ -203,14 +194,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = await self._get_and_cache_bank_index(bank_id) index = await self._get_and_cache_bank_index(bank_id)
return await index.query_documents(query, params) return await index.query_documents(query, params)
self.inference_api = inference_api
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex: async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
if bank_id in self.cache: if bank_id in self.cache:
return self.cache[bank_id] return self.cache[bank_id]
bank = await self.memory_bank_store.get_memory_bank(bank_id) bank = await self.memory_bank_store.get_memory_bank(bank_id)
index = BankWithIndex( index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
bank=bank, self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor), return self.cache[bank_id]
)
self.cache[bank_id] = index
return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import QdrantConfig from .config import QdrantConfig
async def get_adapter_impl(config: QdrantConfig, _deps): async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
from .qdrant import QdrantVectorMemoryAdapter from .qdrant import QdrantVectorMemoryAdapter
impl = QdrantVectorMemoryAdapter(config) impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate): class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
def __init__(self, config: QdrantConfig) -> None: def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True)) self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
self.cache = {} self.cache = {}
self.inference_api = inference_api
async def initialize(self) -> None: async def initialize(self) -> None:
pass pass
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex( index = BankWithIndex(
bank=memory_bank, bank=memory_bank,
index=QdrantIndex(self.client, memory_bank.identifier), index=QdrantIndex(self.client, memory_bank.identifier),
inference_api=self.inference_api,
) )
self.cache[memory_bank.identifier] = index self.cache[memory_bank.identifier] = index
@ -138,6 +140,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=QdrantIndex(client=self.client, collection_name=bank_id), index=QdrantIndex(client=self.client, collection_name=bank_id),
inference_api=self.inference_api,
) )
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -4,12 +4,16 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Dict
from llama_stack.providers.datatypes import Api, ProviderSpec
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401 from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
async def get_adapter_impl(config: WeaviateConfig, _deps): async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
from .weaviate import WeaviateMemoryAdapter from .weaviate import WeaviateMemoryAdapter
impl = WeaviateMemoryAdapter(config) impl = WeaviateMemoryAdapter(config, deps[Api.inference])
await impl.initialize() await impl.initialize()
return impl return impl

View file

@ -12,10 +12,11 @@ import weaviate
import weaviate.classes as wvc import weaviate.classes as wvc
from numpy.typing import NDArray from numpy.typing import NDArray
from weaviate.classes.init import Auth from weaviate.classes.init import Auth
from weaviate.classes.query import Filter
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.distribution.request_headers import NeedsRequestProviderData
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import ( from llama_stack.providers.utils.memory.vector_store import (
BankWithIndex, BankWithIndex,
EmbeddingIndex, EmbeddingIndex,
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
return QueryDocumentsResponse(chunks=chunks, scores=scores) return QueryDocumentsResponse(chunks=chunks, scores=scores)
async def delete(self, chunk_ids: List[str]) -> None:
collection = self.client.collections.get(self.collection_name)
collection.data.delete_many(
where=Filter.by_property("id").contains_any(chunk_ids)
)
class WeaviateMemoryAdapter( class WeaviateMemoryAdapter(
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate Memory,
NeedsRequestProviderData,
MemoryBanksProtocolPrivate,
): ):
def __init__(self, config: WeaviateConfig) -> None: def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
self.config = config self.config = config
self.inference_api = inference_api
self.client_cache = {} self.client_cache = {}
self.cache = {} self.cache = {}
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
memory_bank: MemoryBank, memory_bank: MemoryBank,
) -> None: ) -> None:
assert ( assert (
memory_bank.memory_bank_type == MemoryBankType.vector memory_bank.memory_bank_type == MemoryBankType.vector.value
), f"Only vector banks are supported {memory_bank.memory_bank_type}" ), f"Only vector banks are supported {memory_bank.memory_bank_type}"
client = self._get_client() client = self._get_client()
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
], ],
) )
index = BankWithIndex( self.cache[memory_bank.identifier] = BankWithIndex(
bank=memory_bank, memory_bank,
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier), WeaviateIndex(client=client, collection_name=memory_bank.identifier),
self.inference_api,
) )
self.cache[memory_bank.identifier] = index
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]: async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
if bank_id in self.cache: if bank_id in self.cache:
@ -156,6 +166,7 @@ class WeaviateMemoryAdapter(
index = BankWithIndex( index = BankWithIndex(
bank=bank, bank=bank,
index=WeaviateIndex(client=client, collection_name=bank_id), index=WeaviateIndex(client=client, collection_name=bank_id),
inference_api=self.inference_api,
) )
self.cache[bank_id] = index self.cache[bank_id] = index
return index return index

View file

@ -156,4 +156,5 @@ pytest_plugins = [
"llama_stack.providers.tests.datasetio.fixtures", "llama_stack.providers.tests.datasetio.fixtures",
"llama_stack.providers.tests.scoring.fixtures", "llama_stack.providers.tests.scoring.fixtures",
"llama_stack.providers.tests.eval.fixtures", "llama_stack.providers.tests.eval.fixtures",
"llama_stack.providers.tests.post_training.fixtures",
] ]

View file

@ -10,6 +10,7 @@ import pytest_asyncio
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture, remote_stack_fixture from ..conftest import ProviderFixture, remote_stack_fixture

View file

@ -18,6 +18,12 @@ def pytest_addoption(parser):
default=None, default=None,
help="Specify the inference model to use for testing", help="Specify the inference model to use for testing",
) )
parser.addoption(
"--embedding-model",
action="store",
default=None,
help="Specify the embedding model to use for testing",
)
def pytest_configure(config): def pytest_configure(config):

View file

@ -9,9 +9,9 @@ import os
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.models import ModelInput from llama_stack.apis.models import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig, MetaReferenceInferenceConfig,
) )
@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
inference_model = ( inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model [inference_model] if isinstance(inference_model, str) else inference_model
) )
# If embedding dimension is set, use the 8B model for testing
if os.getenv("EMBEDDING_DIMENSION"):
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
return ProviderFixture( return ProviderFixture(
providers=[ providers=[
@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
inference_model = ( inference_model = (
[inference_model] if isinstance(inference_model, str) else inference_model [inference_model] if isinstance(inference_model, str) else inference_model
) )
if "Llama3.1-8B-Instruct" in inference_model: if inference_model and "Llama3.1-8B-Instruct" in inference_model:
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing") pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
return ProviderFixture( return ProviderFixture(
@ -232,11 +235,23 @@ INFERENCE_FIXTURES = [
async def inference_stack(request, inference_model): async def inference_stack(request, inference_model):
fixture_name = request.param fixture_name = request.param
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}") inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
model_type = ModelType.llm
metadata = {}
if os.getenv("EMBEDDING_DIMENSION"):
model_type = ModelType.embedding
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.inference], [Api.inference],
{"inference": inference_fixture.providers}, {"inference": inference_fixture.providers},
inference_fixture.provider_data, inference_fixture.provider_data,
models=[ModelInput(model_id=inference_model)], models=[
ModelInput(
model_id=inference_model,
model_type=model_type,
metadata=metadata,
)
],
) )
return test_stack.impls[Api.inference], test_stack.impls[Api.models] return test_stack.impls[Api.inference], test_stack.impls[Api.models]

View file

@ -0,0 +1,62 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
# How to run this test:
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
class TestEmbeddings:
@pytest.mark.asyncio
async def test_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding:
pytest.skip("This test is only applicable for embedding models")
response = await inference_impl.embeddings(
model_id=inference_model,
contents=["Hello, world!"],
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) > 0
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
@pytest.mark.asyncio
async def test_batch_embeddings(self, inference_model, inference_stack):
inference_impl, models_impl = inference_stack
model = await models_impl.get_model(inference_model)
if model.model_type != ModelType.embedding:
pytest.skip("This test is only applicable for embedding models")
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
response = await inference_impl.embeddings(
model_id=inference_model,
contents=texts,
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) == len(texts)
assert all(isinstance(embedding, list) for embedding in response.embeddings)
assert all(
isinstance(value, float)
for embedding in response.embeddings
for value in embedding
)
embedding_dim = len(response.embeddings[0])
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)

View file

@ -128,6 +128,61 @@ class TestInference:
last = chunks[-1] last = chunks[-1]
assert last.stop_reason == StopReason.out_of_tokens assert last.stop_reason == StopReason.out_of_tokens
@pytest.mark.asyncio
async def test_completion_logprobs(self, inference_model, inference_stack):
inference_impl, _ = inference_stack
provider = inference_impl.routing_table.get_provider_impl(inference_model)
if provider.__provider_spec__.provider_type not in (
# "remote::nvidia", -- provider doesn't provide all logprobs
):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.completion(
content="Micheael Jordan is born in ",
stream=False,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
assert isinstance(response, CompletionResponse)
assert 1 <= len(response.logprobs) <= 5
assert response.logprobs, "Logprobs should not be empty"
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
chunks = [
r
async for r in await inference_impl.completion(
content="Roses are red,",
stream=True,
model_id=inference_model,
sampling_params=SamplingParams(
max_tokens=5,
),
logprobs=LogProbConfig(
top_k=3,
),
)
]
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
assert (
1 <= len(chunks) <= 6
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
for chunk in chunks:
if chunk.delta: # if there's a token, we expect logprobs
assert chunk.logprobs, "Logprobs should not be empty"
assert all(
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
)
else: # no token, no logprobs
assert not chunk.logprobs, "Logprobs should be empty"
@pytest.mark.asyncio @pytest.mark.asyncio
@pytest.mark.skip("This test is not quite robust") @pytest.mark.skip("This test is not quite robust")
async def test_completion_structured_output(self, inference_model, inference_stack): async def test_completion_structured_output(self, inference_model, inference_stack):

View file

@ -6,9 +6,65 @@
import pytest import pytest
from ..conftest import get_provider_fixture_overrides
from ..inference.fixtures import INFERENCE_FIXTURES
from .fixtures import MEMORY_FIXTURES from .fixtures import MEMORY_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"inference": "meta_reference",
"memory": "faiss",
},
id="meta_reference",
marks=pytest.mark.meta_reference,
),
pytest.param(
{
"inference": "ollama",
"memory": "pgvector",
},
id="ollama",
marks=pytest.mark.ollama,
),
pytest.param(
{
"inference": "together",
"memory": "chroma",
},
id="chroma",
marks=pytest.mark.chroma,
),
pytest.param(
{
"inference": "bedrock",
"memory": "qdrant",
},
id="qdrant",
marks=pytest.mark.qdrant,
),
pytest.param(
{
"inference": "fireworks",
"memory": "weaviate",
},
id="weaviate",
marks=pytest.mark.weaviate,
),
]
def pytest_addoption(parser):
parser.addoption(
"--inference-model",
action="store",
default=None,
help="Specify the inference model to use for testing",
)
def pytest_configure(config): def pytest_configure(config):
for fixture_name in MEMORY_FIXTURES: for fixture_name in MEMORY_FIXTURES:
config.addinivalue_line( config.addinivalue_line(
@ -18,12 +74,22 @@ def pytest_configure(config):
def pytest_generate_tests(metafunc): def pytest_generate_tests(metafunc):
if "inference_model" in metafunc.fixturenames:
model = metafunc.config.getoption("--inference-model")
if not model:
raise ValueError(
"No inference model specified. Please provide a valid inference model."
)
params = [pytest.param(model, id="")]
metafunc.parametrize("inference_model", params, indirect=True)
if "memory_stack" in metafunc.fixturenames: if "memory_stack" in metafunc.fixturenames:
metafunc.parametrize( available_fixtures = {
"memory_stack", "inference": INFERENCE_FIXTURES,
[ "memory": MEMORY_FIXTURES,
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name)) }
for fixture_name in MEMORY_FIXTURES combinations = (
], get_provider_fixture_overrides(metafunc.config, available_fixtures)
indirect=True, or DEFAULT_PROVIDER_COMBINATIONS
) )
metafunc.parametrize("memory_stack", combinations, indirect=True)

View file

@ -10,6 +10,8 @@ import tempfile
import pytest import pytest
import pytest_asyncio import pytest_asyncio
from llama_stack.apis.inference import ModelInput, ModelType
from llama_stack.distribution.datatypes import Api, Provider from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
from llama_stack.providers.inline.memory.faiss import FaissImplConfig from llama_stack.providers.inline.memory.faiss import FaissImplConfig
@ -105,14 +107,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
@pytest_asyncio.fixture(scope="session") @pytest_asyncio.fixture(scope="session")
async def memory_stack(request): async def memory_stack(inference_model, request):
fixture_name = request.param fixture_dict = request.param
fixture = request.getfixturevalue(f"memory_{fixture_name}")
providers = {}
provider_data = {}
for key in ["inference", "memory"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test( test_stack = await construct_stack_for_test(
[Api.memory], [Api.memory, Api.inference],
{"memory": fixture.providers}, providers,
fixture.provider_data, provider_data,
models=[
ModelInput(
model_id=inference_model,
model_type=ModelType.embedding,
metadata={
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
},
)
],
) )
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks] return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]

View file

@ -45,12 +45,14 @@ def sample_documents():
] ]
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank: async def register_memory_bank(
banks_impl: MemoryBanks, inference_model: str
) -> MemoryBank:
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
return await banks_impl.register_memory_bank( return await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model=inference_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
class TestMemory: class TestMemory:
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_list(self, memory_stack): async def test_banks_list(self, memory_stack, inference_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
# Register a test bank # Register a test bank
registered_bank = await register_memory_bank(banks_impl) registered_bank = await register_memory_bank(banks_impl, inference_model)
try: try:
# Verify our bank shows up in list # Verify our bank shows up in list
@ -84,7 +86,7 @@ class TestMemory:
) )
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_banks_register(self, memory_stack): async def test_banks_register(self, memory_stack, inference_model):
_, banks_impl = memory_stack _, banks_impl = memory_stack
bank_id = f"test_bank_{uuid.uuid4().hex}" bank_id = f"test_bank_{uuid.uuid4().hex}"
@ -94,7 +96,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model=inference_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -109,7 +111,7 @@ class TestMemory:
await banks_impl.register_memory_bank( await banks_impl.register_memory_bank(
memory_bank_id=bank_id, memory_bank_id=bank_id,
params=VectorMemoryBankParams( params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model=inference_model,
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
), ),
@ -126,13 +128,15 @@ class TestMemory:
await banks_impl.unregister_memory_bank(bank_id) await banks_impl.unregister_memory_bank(bank_id)
@pytest.mark.asyncio @pytest.mark.asyncio
async def test_query_documents(self, memory_stack, sample_documents): async def test_query_documents(
self, memory_stack, inference_model, sample_documents
):
memory_impl, banks_impl = memory_stack memory_impl, banks_impl = memory_stack
with pytest.raises(ValueError): with pytest.raises(ValueError):
await memory_impl.insert_documents("test_bank", sample_documents) await memory_impl.insert_documents("test_bank", sample_documents)
registered_bank = await register_memory_bank(banks_impl) registered_bank = await register_memory_bank(banks_impl, inference_model)
await memory_impl.insert_documents( await memory_impl.insert_documents(
registered_bank.memory_bank_id, sample_documents registered_bank.memory_bank_id, sample_documents
) )
@ -165,13 +169,13 @@ class TestMemory:
# Test case 5: Query with threshold on similarity score # Test case 5: Query with threshold on similarity score
query5 = "quantum computing" # Not directly related to any document query5 = "quantum computing" # Not directly related to any document
params5 = {"score_threshold": 0.2} params5 = {"score_threshold": 0.01}
response5 = await memory_impl.query_documents( response5 = await memory_impl.query_documents(
registered_bank.memory_bank_id, query5, params5 registered_bank.memory_bank_id, query5, params5
) )
assert_valid_response(response5) assert_valid_response(response5)
print("The scores are:", response5.scores) print("The scores are:", response5.scores)
assert all(score >= 0.2 for score in response5.scores) assert all(score >= 0.01 for score in response5.scores)
def assert_valid_response(response: QueryDocumentsResponse): def assert_valid_response(response: QueryDocumentsResponse):

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,45 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from ..conftest import get_provider_fixture_overrides
from ..datasetio.fixtures import DATASETIO_FIXTURES
from .fixtures import POST_TRAINING_FIXTURES
DEFAULT_PROVIDER_COMBINATIONS = [
pytest.param(
{
"post_training": "torchtune",
"datasetio": "huggingface",
},
id="torchtune_post_training_huggingface_datasetio",
marks=pytest.mark.torchtune_post_training_huggingface_datasetio,
),
]
def pytest_configure(config):
combined_fixtures = "torchtune_post_training_huggingface_datasetio"
config.addinivalue_line(
"markers",
f"{combined_fixtures}: marks tests as {combined_fixtures} specific",
)
def pytest_generate_tests(metafunc):
if "post_training_stack" in metafunc.fixturenames:
available_fixtures = {
"eval": POST_TRAINING_FIXTURES,
"datasetio": DATASETIO_FIXTURES,
}
combinations = (
get_provider_fixture_overrides(metafunc.config, available_fixtures)
or DEFAULT_PROVIDER_COMBINATIONS
)
metafunc.parametrize("post_training_stack", combinations, indirect=True)

View file

@ -0,0 +1,74 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_models.llama3.api.datatypes import URL
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasets import DatasetInput
from llama_stack.apis.models import ModelInput
from llama_stack.distribution.datatypes import Api, Provider
from llama_stack.providers.tests.resolver import construct_stack_for_test
from ..conftest import ProviderFixture
@pytest.fixture(scope="session")
def post_training_torchtune() -> ProviderFixture:
return ProviderFixture(
providers=[
Provider(
provider_id="torchtune",
provider_type="inline::torchtune",
config={},
)
],
)
POST_TRAINING_FIXTURES = ["torchtune"]
@pytest_asyncio.fixture(scope="session")
async def post_training_stack(request):
fixture_dict = request.param
providers = {}
provider_data = {}
for key in ["post_training", "datasetio"]:
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
providers[key] = fixture.providers
if fixture.provider_data:
provider_data.update(fixture.provider_data)
test_stack = await construct_stack_for_test(
[Api.post_training, Api.datasetio],
providers,
provider_data,
models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")],
datasets=[
DatasetInput(
dataset_id="alpaca",
provider_id="huggingface",
url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"),
metadata={
"path": "tatsu-lab/alpaca",
"split": "train",
},
dataset_schema={
"instruction": StringType(),
"input": StringType(),
"output": StringType(),
"text": StringType(),
},
),
],
)
return test_stack.impls[Api.post_training]

View file

@ -0,0 +1,92 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
# How to run this test:
#
# pytest llama_stack/providers/tests/post_training/test_post_training.py
# -m "torchtune_post_training_huggingface_datasetio"
# -v -s --tb=short --disable-warnings
class TestPostTraining:
@pytest.mark.asyncio
async def test_supervised_fine_tune(self, post_training_stack):
algorithm_config = LoraFinetuningConfig(
type="LoRA",
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
apply_lora_to_mlp=True,
apply_lora_to_output=False,
rank=8,
alpha=16,
)
data_config = DataConfig(
dataset_id="alpaca",
batch_size=1,
shuffle=False,
)
optimizer_config = OptimizerConfig(
optimizer_type="adamw",
lr=3e-4,
lr_min=3e-5,
weight_decay=0.1,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
max_steps_per_epoch=1,
gradient_accumulation_steps=1,
)
post_training_impl = post_training_stack
response = await post_training_impl.supervised_fine_tune(
job_uuid="1234",
model="Llama3.2-3B-Instruct",
algorithm_config=algorithm_config,
training_config=training_config,
hyperparam_search_config={},
logger_config={},
checkpoint_dir="null",
)
assert isinstance(response, PostTrainingJob)
assert response.job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_jobs(self, post_training_stack):
post_training_impl = post_training_stack
jobs_list = await post_training_impl.get_training_jobs()
assert isinstance(jobs_list, List)
assert jobs_list[0].job_uuid == "1234"
@pytest.mark.asyncio
async def test_get_training_job_status(self, post_training_stack):
post_training_impl = post_training_stack
job_status = await post_training_impl.get_training_job_status("1234")
assert isinstance(job_status, PostTrainingJobStatusResponse)
assert job_status.job_uuid == "1234"
assert job_status.status == JobStatus.completed
assert isinstance(job_status.checkpoints[0], Checkpoint)
@pytest.mark.asyncio
async def test_get_training_job_artifacts(self, post_training_stack):
post_training_impl = post_training_stack
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
assert job_artifacts.job_uuid == "1234"
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
assert job_artifacts.checkpoints[0].epoch == 0
assert (
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
in job_artifacts.checkpoints[0].path
)

View file

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import logging
from typing import List
from llama_models.llama3.api.datatypes import InterleavedTextMedia
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
EMBEDDING_MODELS = {}
log = logging.getLogger(__name__)
class SentenceTransformerEmbeddingMixin:
model_store: ModelStore
async def embeddings(
self,
model_id: str,
contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse:
model = await self.model_store.get_model(model_id)
embedding_model = self._load_sentence_transformer_model(
model.provider_resource_id
)
embeddings = embedding_model.encode(contents)
return EmbeddingsResponse(embeddings=embeddings)
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model

View file

@ -9,6 +9,7 @@ from typing import List, Optional
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
from llama_stack.providers.utils.inference import ( from llama_stack.providers.utils.inference import (
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
return None return None
async def register_model(self, model: Model) -> Model: async def register_model(self, model: Model) -> Model:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id) if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(
model.provider_resource_id
)
if provider_resource_id: if provider_resource_id:
model.provider_resource_id = provider_resource_id model.provider_resource_id = provider_resource_id
else: else:

View file

@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer from llama_models.llama3.api.tokenizer import Tokenizer
from llama_stack.apis.memory import * # noqa: F403 from llama_stack.apis.memory import * # noqa: F403
from llama_stack.providers.datatypes import Api
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
ALL_MINILM_L6_V2_DIMENSION = 384
EMBEDDING_MODELS = {}
def get_embedding_model(model: str) -> "SentenceTransformer":
global EMBEDDING_MODELS
loaded_model = EMBEDDING_MODELS.get(model)
if loaded_model is not None:
return loaded_model
log.info(f"Loading sentence transformer for {model}...")
from sentence_transformers import SentenceTransformer
loaded_model = SentenceTransformer(model)
EMBEDDING_MODELS[model] = loaded_model
return loaded_model
def parse_pdf(data: bytes) -> str: def parse_pdf(data: bytes) -> str:
# For PDF and DOC/DOCX files, we can't reliably convert to string # For PDF and DOC/DOCX files, we can't reliably convert to string
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
class BankWithIndex: class BankWithIndex:
bank: VectorMemoryBank bank: VectorMemoryBank
index: EmbeddingIndex index: EmbeddingIndex
inference_api: Api.inference
async def insert_documents( async def insert_documents(
self, self,
documents: List[MemoryBankDocument], documents: List[MemoryBankDocument],
) -> None: ) -> None:
model = get_embedding_model(self.bank.embedding_model)
for doc in documents: for doc in documents:
content = await content_from_doc(doc) content = await content_from_doc(doc)
chunks = make_overlapped_chunks( chunks = make_overlapped_chunks(
@ -183,7 +165,10 @@ class BankWithIndex:
) )
if not chunks: if not chunks:
continue continue
embeddings = model.encode([x.content for x in chunks]).astype(np.float32) embeddings_response = await self.inference_api.embeddings(
self.bank.embedding_model, [x.content for x in chunks]
)
embeddings = np.array(embeddings_response.embeddings)
await self.index.add_chunks(chunks, embeddings) await self.index.add_chunks(chunks, embeddings)
@ -208,6 +193,8 @@ class BankWithIndex:
else: else:
query_str = _process(query) query_str = _process(query)
model = get_embedding_model(self.bank.embedding_model) embeddings_response = await self.inference_api.embeddings(
query_vector = model.encode([query_str])[0].astype(np.float32) self.bank.embedding_model, [query_str]
)
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
return await self.index.query(query_vector, k, score_threshold) return await self.index.query(query_vector, k, score_threshold)

View file

@ -8,10 +8,14 @@ from pathlib import Path
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -29,6 +33,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::cerebras", provider_type="remote::cerebras",
config=CerebrasImplConfig.sample_run_config(), config=CerebrasImplConfig.sample_run_config(),
) )
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
core_model_to_hf_repo = { core_model_to_hf_repo = {
m.descriptor(): m.huggingface_repo for m in all_registered_models() m.descriptor(): m.huggingface_repo for m in all_registered_models()
@ -37,9 +46,18 @@ def get_distribution_template() -> DistributionTemplate:
ModelInput( ModelInput(
model_id=core_model_to_hf_repo[m.llama_model], model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id, provider_model_id=m.provider_model_id,
provider_id="cerebras",
) )
for m in model_aliases for m in model_aliases
] ]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
return DistributionTemplate( return DistributionTemplate(
name="cerebras", name="cerebras",
@ -52,9 +70,9 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider, embedding_provider],
}, },
default_models=default_models, default_models=default_models + [embedding_model],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
), ),
}, },

View file

@ -15,6 +15,9 @@ providers:
config: config:
base_url: https://api.cerebras.ai base_url: https://api.cerebras.ai
api_key: ${env.CEREBRAS_API_KEY} api_key: ${env.CEREBRAS_API_KEY}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
safety: safety:
- provider_id: llama-guard - provider_id: llama-guard
provider_type: inline::llama-guard provider_type: inline::llama-guard
@ -49,12 +52,20 @@ metadata_store:
models: models:
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null provider_id: cerebras
provider_model_id: llama3.1-8b provider_model_id: llama3.1-8b
model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null provider_id: cerebras
provider_model_id: llama3.1-70b provider_model_id: llama3.1-70b
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: shields:
- params: null - params: null
shield_id: meta-llama/Llama-Guard-3-8B shield_id: meta-llama/Llama-Guard-3-8B

View file

@ -0,0 +1,13 @@
version: '2'
name: experimental-post-training
distribution_spec:
description: Experimental template for post training
docker_image: null
providers:
post_training:
- inline::torchtune
datasetio:
- remote::huggingface
telemetry:
- inline::meta-reference
image_type: conda

View file

@ -0,0 +1,53 @@
version: '2'
image_name: experimental-post-training
docker_image: null
conda_env: experimental-post-training
apis:
- telemetry
- datasetio
- post_training
providers:
datasetio:
- provider_id: huggingface-0
provider_type: remote::huggingface
config: {}
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config: {}
post_training:
- provider_id: torchtune-post-training
provider_type: inline::torchtune
config: {}
metadata_store:
namespace: null
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
models:
- metadata: {}
model_id: ${env.POST_TRAINING_MODEL}
provider_id: meta-reference-inference
provider_model_id: null
shields: []
memory_banks: []
datasets:
- dataset_id: alpaca
provider_id: huggingface-0
url:
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
metadata:
path: tatsu-lab/alpaca
name:
split: train
dataset_schema:
instruction:
type: string
input:
type: string
output:
type: string
text:
type: string
scoring_fns: []
eval_tasks: []

View file

@ -8,11 +8,15 @@ from pathlib import Path
from llama_models.sku_list import all_registered_models from llama_models.sku_list import all_registered_models
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
from llama_stack.providers.remote.inference.fireworks.fireworks import MODEL_ALIASES from llama_stack.providers.remote.inference.fireworks.fireworks import MODEL_ALIASES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -35,6 +39,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::fireworks", provider_type="remote::fireworks",
config=FireworksImplConfig.sample_run_config(), config=FireworksImplConfig.sample_run_config(),
) )
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
memory_provider = Provider( memory_provider = Provider(
provider_id="faiss", provider_id="faiss",
provider_type="inline::faiss", provider_type="inline::faiss",
@ -48,9 +57,18 @@ def get_distribution_template() -> DistributionTemplate:
ModelInput( ModelInput(
model_id=core_model_to_hf_repo[m.llama_model], model_id=core_model_to_hf_repo[m.llama_model],
provider_model_id=m.provider_model_id, provider_model_id=m.provider_model_id,
provider_id="fireworks",
) )
for m in MODEL_ALIASES for m in MODEL_ALIASES
] ]
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
@ -63,10 +81,10 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider, embedding_provider],
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=default_models, default_models=default_models + [embedding_model],
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")], default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
), ),
}, },

View file

@ -16,8 +16,11 @@ providers:
- provider_id: fireworks - provider_id: fireworks
provider_type: remote::fireworks provider_type: remote::fireworks
config: config:
url: https://api.fireworks.ai/inference url: https://api.fireworks.ai/inference/v1
api_key: ${env.FIREWORKS_API_KEY} api_key: ${env.FIREWORKS_API_KEY}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory: memory:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -74,40 +77,55 @@ metadata_store:
models: models:
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: null provider_id: fireworks
provider_model_id: fireworks/llama-v3p1-8b-instruct provider_model_id: fireworks/llama-v3p1-8b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: null provider_id: fireworks
provider_model_id: fireworks/llama-v3p1-70b-instruct provider_model_id: fireworks/llama-v3p1-70b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8 model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
provider_id: null provider_id: fireworks
provider_model_id: fireworks/llama-v3p1-405b-instruct provider_model_id: fireworks/llama-v3p1-405b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: null provider_id: fireworks
provider_model_id: fireworks/llama-v3p2-1b-instruct provider_model_id: fireworks/llama-v3p2-1b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: null provider_id: fireworks
provider_model_id: fireworks/llama-v3p2-3b-instruct provider_model_id: fireworks/llama-v3p2-3b-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: null provider_id: fireworks
provider_model_id: fireworks/llama-v3p2-11b-vision-instruct provider_model_id: fireworks/llama-v3p2-11b-vision-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: null provider_id: fireworks
provider_model_id: fireworks/llama-v3p2-90b-vision-instruct provider_model_id: fireworks/llama-v3p2-90b-vision-instruct
model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-Guard-3-8B model_id: meta-llama/Llama-Guard-3-8B
provider_id: null provider_id: fireworks
provider_model_id: fireworks/llama-guard-3-8b provider_model_id: fireworks/llama-guard-3-8b
model_type: llm
- metadata: {} - metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: null provider_id: fireworks
provider_model_id: fireworks/llama-guard-3-11b-vision provider_model_id: fireworks/llama-guard-3-11b-vision
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: shields:
- params: null - params: null
shield_id: meta-llama/Llama-Guard-3-8B shield_id: meta-llama/Llama-Guard-3-8B

View file

@ -4,7 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.tgi import InferenceEndpointImplConfig from llama_stack.providers.remote.inference.tgi import InferenceEndpointImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -27,6 +31,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::hf::endpoint", provider_type="remote::hf::endpoint",
config=InferenceEndpointImplConfig.sample_run_config(), config=InferenceEndpointImplConfig.sample_run_config(),
) )
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
memory_provider = Provider( memory_provider = Provider(
provider_id="faiss", provider_id="faiss",
provider_type="inline::faiss", provider_type="inline::faiss",
@ -41,6 +50,14 @@ def get_distribution_template() -> DistributionTemplate:
model_id="${env.SAFETY_MODEL}", model_id="${env.SAFETY_MODEL}",
provider_id="hf-endpoint-safety", provider_id="hf-endpoint-safety",
) )
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
@ -53,15 +70,16 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider, embedding_provider],
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model], default_models=[inference_model, embedding_model],
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [ "inference": [
inference_provider, inference_provider,
embedding_provider,
Provider( Provider(
provider_id="hf-endpoint-safety", provider_id="hf-endpoint-safety",
provider_type="remote::hf::endpoint", provider_type="remote::hf::endpoint",
@ -75,6 +93,7 @@ def get_distribution_template() -> DistributionTemplate:
default_models=[ default_models=[
inference_model, inference_model,
safety_model, safety_model,
embedding_model,
], ],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
), ),

View file

@ -18,6 +18,9 @@ providers:
config: config:
endpoint_name: ${env.INFERENCE_ENDPOINT_NAME} endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
api_token: ${env.HF_API_TOKEN} api_token: ${env.HF_API_TOKEN}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
- provider_id: hf-endpoint-safety - provider_id: hf-endpoint-safety
provider_type: remote::hf::endpoint provider_type: remote::hf::endpoint
config: config:
@ -81,10 +84,18 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: hf-endpoint provider_id: hf-endpoint
provider_model_id: null provider_model_id: null
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.SAFETY_MODEL} model_id: ${env.SAFETY_MODEL}
provider_id: hf-endpoint-safety provider_id: hf-endpoint-safety
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: shields:
- params: null - params: null
shield_id: ${env.SAFETY_MODEL} shield_id: ${env.SAFETY_MODEL}

View file

@ -18,6 +18,9 @@ providers:
config: config:
endpoint_name: ${env.INFERENCE_ENDPOINT_NAME} endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
api_token: ${env.HF_API_TOKEN} api_token: ${env.HF_API_TOKEN}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory: memory:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -76,6 +79,13 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: hf-endpoint provider_id: hf-endpoint
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: [] datasets: []

View file

@ -4,7 +4,11 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.tgi import InferenceAPIImplConfig from llama_stack.providers.remote.inference.tgi import InferenceAPIImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -28,6 +32,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::hf::serverless", provider_type="remote::hf::serverless",
config=InferenceAPIImplConfig.sample_run_config(), config=InferenceAPIImplConfig.sample_run_config(),
) )
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
memory_provider = Provider( memory_provider = Provider(
provider_id="faiss", provider_id="faiss",
provider_type="inline::faiss", provider_type="inline::faiss",
@ -42,6 +51,14 @@ def get_distribution_template() -> DistributionTemplate:
model_id="${env.SAFETY_MODEL}", model_id="${env.SAFETY_MODEL}",
provider_id="hf-serverless-safety", provider_id="hf-serverless-safety",
) )
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
@ -54,15 +71,16 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider, embedding_provider],
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model], default_models=[inference_model, embedding_model],
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [ "inference": [
inference_provider, inference_provider,
embedding_provider,
Provider( Provider(
provider_id="hf-serverless-safety", provider_id="hf-serverless-safety",
provider_type="remote::hf::serverless", provider_type="remote::hf::serverless",
@ -76,6 +94,7 @@ def get_distribution_template() -> DistributionTemplate:
default_models=[ default_models=[
inference_model, inference_model,
safety_model, safety_model,
embedding_model,
], ],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
), ),

View file

@ -18,6 +18,9 @@ providers:
config: config:
huggingface_repo: ${env.INFERENCE_MODEL} huggingface_repo: ${env.INFERENCE_MODEL}
api_token: ${env.HF_API_TOKEN} api_token: ${env.HF_API_TOKEN}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
- provider_id: hf-serverless-safety - provider_id: hf-serverless-safety
provider_type: remote::hf::serverless provider_type: remote::hf::serverless
config: config:
@ -81,10 +84,18 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: hf-serverless provider_id: hf-serverless
provider_model_id: null provider_model_id: null
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.SAFETY_MODEL} model_id: ${env.SAFETY_MODEL}
provider_id: hf-serverless-safety provider_id: hf-serverless-safety
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: shields:
- params: null - params: null
shield_id: ${env.SAFETY_MODEL} shield_id: ${env.SAFETY_MODEL}

View file

@ -18,6 +18,9 @@ providers:
config: config:
huggingface_repo: ${env.INFERENCE_MODEL} huggingface_repo: ${env.INFERENCE_MODEL}
api_token: ${env.HF_API_TOKEN} api_token: ${env.HF_API_TOKEN}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory: memory:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -76,6 +79,13 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: hf-serverless provider_id: hf-serverless
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: [] datasets: []

View file

@ -6,10 +6,15 @@
from pathlib import Path from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceInferenceConfig, MetaReferenceInferenceConfig,
) )
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -34,6 +39,11 @@ def get_distribution_template() -> DistributionTemplate:
checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}", checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}",
), ),
) )
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
memory_provider = Provider( memory_provider = Provider(
provider_id="faiss", provider_id="faiss",
provider_type="inline::faiss", provider_type="inline::faiss",
@ -44,6 +54,14 @@ def get_distribution_template() -> DistributionTemplate:
model_id="${env.INFERENCE_MODEL}", model_id="${env.INFERENCE_MODEL}",
provider_id="meta-reference-inference", provider_id="meta-reference-inference",
) )
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
safety_model = ModelInput( safety_model = ModelInput(
model_id="${env.SAFETY_MODEL}", model_id="${env.SAFETY_MODEL}",
provider_id="meta-reference-safety", provider_id="meta-reference-safety",
@ -59,15 +77,16 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider, embedding_provider],
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model], default_models=[inference_model, embedding_model],
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [ "inference": [
inference_provider, inference_provider,
embedding_provider,
Provider( Provider(
provider_id="meta-reference-safety", provider_id="meta-reference-safety",
provider_type="inline::meta-reference", provider_type="inline::meta-reference",
@ -82,6 +101,7 @@ def get_distribution_template() -> DistributionTemplate:
default_models=[ default_models=[
inference_model, inference_model,
safety_model, safety_model,
embedding_model,
], ],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
), ),

View file

@ -19,6 +19,9 @@ providers:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL}
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
- provider_id: meta-reference-safety - provider_id: meta-reference-safety
provider_type: inline::meta-reference provider_type: inline::meta-reference
config: config:
@ -83,10 +86,18 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference provider_id: meta-reference-inference
provider_model_id: null provider_model_id: null
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.SAFETY_MODEL} model_id: ${env.SAFETY_MODEL}
provider_id: meta-reference-safety provider_id: meta-reference-safety
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: shields:
- params: null - params: null
shield_id: ${env.SAFETY_MODEL} shield_id: ${env.SAFETY_MODEL}

View file

@ -19,6 +19,9 @@ providers:
model: ${env.INFERENCE_MODEL} model: ${env.INFERENCE_MODEL}
max_seq_len: 4096 max_seq_len: 4096
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory: memory:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -77,6 +80,13 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference provider_id: meta-reference-inference
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: [] datasets: []

View file

@ -6,10 +6,15 @@
from pathlib import Path from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider from llama_stack.distribution.datatypes import ModelInput, Provider
from llama_stack.providers.inline.inference.meta_reference import ( from llama_stack.providers.inline.inference.meta_reference import (
MetaReferenceQuantizedInferenceConfig, MetaReferenceQuantizedInferenceConfig,
) )
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -34,6 +39,11 @@ def get_distribution_template() -> DistributionTemplate:
checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}", checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}",
), ),
) )
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
memory_provider = Provider( memory_provider = Provider(
provider_id="faiss", provider_id="faiss",
provider_type="inline::faiss", provider_type="inline::faiss",
@ -44,6 +54,14 @@ def get_distribution_template() -> DistributionTemplate:
model_id="${env.INFERENCE_MODEL}", model_id="${env.INFERENCE_MODEL}",
provider_id="meta-reference-inference", provider_id="meta-reference-inference",
) )
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
distro_type="self_hosted", distro_type="self_hosted",
@ -54,10 +72,10 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider, embedding_provider],
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model], default_models=[inference_model, embedding_model],
), ),
}, },
run_config_env_vars={ run_config_env_vars={

View file

@ -21,6 +21,9 @@ providers:
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null} checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
quantization: quantization:
type: fp8 type: fp8
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory: memory:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -79,6 +82,13 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: meta-reference-inference provider_id: meta-reference-inference
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: [] datasets: []

View file

@ -6,7 +6,12 @@
from pathlib import Path from pathlib import Path
from llama_stack.apis.models.models import ModelType
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
from llama_stack.providers.inline.inference.sentence_transformers import (
SentenceTransformersInferenceConfig,
)
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
@ -29,6 +34,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::ollama", provider_type="remote::ollama",
config=OllamaImplConfig.sample_run_config(), config=OllamaImplConfig.sample_run_config(),
) )
embedding_provider = Provider(
provider_id="sentence-transformers",
provider_type="inline::sentence-transformers",
config=SentenceTransformersInferenceConfig.sample_run_config(),
)
memory_provider = Provider( memory_provider = Provider(
provider_id="faiss", provider_id="faiss",
provider_type="inline::faiss", provider_type="inline::faiss",
@ -43,6 +53,14 @@ def get_distribution_template() -> DistributionTemplate:
model_id="${env.SAFETY_MODEL}", model_id="${env.SAFETY_MODEL}",
provider_id="ollama", provider_id="ollama",
) )
embedding_model = ModelInput(
model_id="all-MiniLM-L6-v2",
provider_id="sentence-transformers",
model_type=ModelType.embedding,
metadata={
"embedding_dimension": 384,
},
)
return DistributionTemplate( return DistributionTemplate(
name=name, name=name,
@ -55,21 +73,23 @@ def get_distribution_template() -> DistributionTemplate:
run_configs={ run_configs={
"run.yaml": RunConfigSettings( "run.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [inference_provider], "inference": [inference_provider, embedding_provider],
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[inference_model], default_models=[inference_model, embedding_model],
), ),
"run-with-safety.yaml": RunConfigSettings( "run-with-safety.yaml": RunConfigSettings(
provider_overrides={ provider_overrides={
"inference": [ "inference": [
inference_provider, inference_provider,
embedding_provider,
], ],
"memory": [memory_provider], "memory": [memory_provider],
}, },
default_models=[ default_models=[
inference_model, inference_model,
safety_model, safety_model,
embedding_model,
], ],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
), ),

View file

@ -17,6 +17,9 @@ providers:
provider_type: remote::ollama provider_type: remote::ollama
config: config:
url: ${env.OLLAMA_URL:http://localhost:11434} url: ${env.OLLAMA_URL:http://localhost:11434}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory: memory:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -75,10 +78,18 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: ollama provider_id: ollama
provider_model_id: null provider_model_id: null
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.SAFETY_MODEL} model_id: ${env.SAFETY_MODEL}
provider_id: ollama provider_id: ollama
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: shields:
- params: null - params: null
shield_id: ${env.SAFETY_MODEL} shield_id: ${env.SAFETY_MODEL}

View file

@ -17,6 +17,9 @@ providers:
provider_type: remote::ollama provider_type: remote::ollama
config: config:
url: ${env.OLLAMA_URL:http://localhost:11434} url: ${env.OLLAMA_URL:http://localhost:11434}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory: memory:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -75,6 +78,13 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: ollama provider_id: ollama
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: [] shields: []
memory_banks: [] memory_banks: []
datasets: [] datasets: []

View file

@ -22,6 +22,9 @@ providers:
url: ${env.SAFETY_VLLM_URL} url: ${env.SAFETY_VLLM_URL}
max_tokens: ${env.VLLM_MAX_TOKENS:4096} max_tokens: ${env.VLLM_MAX_TOKENS:4096}
api_token: ${env.VLLM_API_TOKEN:fake} api_token: ${env.VLLM_API_TOKEN:fake}
- provider_id: sentence-transformers
provider_type: inline::sentence-transformers
config: {}
memory: memory:
- provider_id: faiss - provider_id: faiss
provider_type: inline::faiss provider_type: inline::faiss
@ -58,10 +61,18 @@ models:
model_id: ${env.INFERENCE_MODEL} model_id: ${env.INFERENCE_MODEL}
provider_id: vllm-inference provider_id: vllm-inference
provider_model_id: null provider_model_id: null
model_type: llm
- metadata: {} - metadata: {}
model_id: ${env.SAFETY_MODEL} model_id: ${env.SAFETY_MODEL}
provider_id: vllm-safety provider_id: vllm-safety
provider_model_id: null provider_model_id: null
model_type: llm
- metadata:
embedding_dimension: 384
model_id: all-MiniLM-L6-v2
provider_id: sentence-transformers
provider_model_id: null
model_type: embedding
shields: shields:
- params: null - params: null
shield_id: ${env.SAFETY_MODEL} shield_id: ${env.SAFETY_MODEL}

Some files were not shown because too many files have changed in this diff Show more