mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Merge branch 'meta-llama:main' into main
This commit is contained in:
commit
54e48d555d
110 changed files with 12606 additions and 747 deletions
12
README.md
12
README.md
|
@ -84,18 +84,18 @@ Additionally, we have designed every element of the Stack such that APIs as well
|
||||||
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
| Fireworks | Hosted | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | | |
|
||||||
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
| AWS Bedrock | Hosted | | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||||
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
| Together | Hosted | :heavy_check_mark: | :heavy_check_mark: | | :heavy_check_mark: | |
|
||||||
| Ollama | Single Node | | :heavy_check_mark: | | |
|
| Ollama | Single Node | | :heavy_check_mark: | | | |
|
||||||
| TGI | Hosted and Single Node | | :heavy_check_mark: | | |
|
| TGI | Hosted and Single Node | | :heavy_check_mark: | | | |
|
||||||
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | |
|
| [NVIDIA NIM](https://build.nvidia.com/nim?filters=nimType%3Anim_type_run_anywhere&q=llama) | Hosted and Single Node | | :heavy_check_mark: | | | |
|
||||||
| Chroma | Single Node | | | :heavy_check_mark: | | |
|
| Chroma | Single Node | | | :heavy_check_mark: | | |
|
||||||
| PG Vector | Single Node | | | :heavy_check_mark: | | |
|
| PG Vector | Single Node | | | :heavy_check_mark: | | |
|
||||||
| PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | |
|
| PyTorch ExecuTorch | On-device iOS | :heavy_check_mark: | :heavy_check_mark: | | | |
|
||||||
| [vLLM](https://github.com/vllm-project/vllm) | | | :heavy_check_mark: | | |
|
| [vLLM](https://github.com/vllm-project/vllm) | Hosted and Single Node | | :heavy_check_mark: | | | |
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
|
||||||
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
||||||
|:----------------------------------------------------:|:---------------------------------------------------------------------------------------------------------------------------------------------------------------:|:--------------------------------------------------------------------------------------------------------------------------------:|
|
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
||||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
||||||
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
|
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/cerebras.html) |
|
||||||
|
|
|
@ -249,6 +249,7 @@
|
||||||
"redis",
|
"redis",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
|
"sentence-transformers",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"torch",
|
"torch",
|
||||||
"torchvision",
|
"torchvision",
|
||||||
|
@ -287,6 +288,7 @@
|
||||||
"redis",
|
"redis",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
"scipy",
|
"scipy",
|
||||||
|
"sentence-transformers",
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"torch",
|
"torch",
|
||||||
"torchao==0.5.0",
|
"torchao==0.5.0",
|
||||||
|
|
File diff suppressed because one or more lines are too long
4485
docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb
Normal file
4485
docs/notebooks/Llama_Stack_Benchmark_Evals.ipynb
Normal file
File diff suppressed because it is too large
Load diff
4658
docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb
Normal file
4658
docs/notebooks/Llama_Stack_Building_AI_Applications.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -10,3 +10,4 @@ sphinx-design
|
||||||
sphinxcontrib-openapi
|
sphinxcontrib-openapi
|
||||||
sphinxcontrib-redoc
|
sphinxcontrib-redoc
|
||||||
sphinxcontrib-mermaid
|
sphinxcontrib-mermaid
|
||||||
|
sphinxcontrib-video
|
||||||
|
|
167
docs/source/benchmark_evaluations/index.md
Normal file
167
docs/source/benchmark_evaluations/index.md
Normal file
|
@ -0,0 +1,167 @@
|
||||||
|
# Benchmark Evaluations
|
||||||
|
|
||||||
|
[](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing)
|
||||||
|
|
||||||
|
Llama Stack provides the building blocks needed to run benchmark and application evaluations. This guide will walk you through how to use these components to run open benchmark evaluations. Visit our [Evaluation Concepts](../concepts/evaluation_concepts.md) guide for more details on how evaluations work in Llama Stack, and our [Evaluation Reference](../references/evals_reference/index.md) guide for a comprehensive reference on the APIs. Check out our [Colab notebook](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing) on working examples on how you can use Llama Stack for running benchmark evaluations.
|
||||||
|
|
||||||
|
### 1. Open Benchmark Model Evaluation
|
||||||
|
|
||||||
|
This first example walks you through how to evaluate a model candidate served by Llama Stack on open benchmarks. We will use the following benchmark:
|
||||||
|
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI): Benchmark designed to evaluate multimodal models.
|
||||||
|
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
|
||||||
|
|
||||||
|
#### 1.1 Running MMMU
|
||||||
|
- We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in in this [Github Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import datasets
|
||||||
|
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
|
||||||
|
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
|
||||||
|
eval_rows = ds.to_pandas().to_dict(orient="records")
|
||||||
|
```
|
||||||
|
|
||||||
|
- Next, we will run evaluation on an model candidate, we will need to:
|
||||||
|
- Define a system prompt
|
||||||
|
- Define an EvalCandidate
|
||||||
|
- Run evaluate on the dataset
|
||||||
|
|
||||||
|
```python
|
||||||
|
SYSTEM_PROMPT_TEMPLATE = """
|
||||||
|
You are an expert in Agriculture whose job is to answer questions from the user using images.
|
||||||
|
First, reason about the correct answer.
|
||||||
|
Then write the answer in the following format where X is exactly one of A,B,C,D:
|
||||||
|
Answer: X
|
||||||
|
Make sure X is one of A,B,C,D.
|
||||||
|
If you are uncertain of the correct answer, guess the most likely one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_message = {
|
||||||
|
"role": "system",
|
||||||
|
"content": SYSTEM_PROMPT_TEMPLATE,
|
||||||
|
}
|
||||||
|
|
||||||
|
client.eval_tasks.register(
|
||||||
|
eval_task_id="meta-reference::mmmu",
|
||||||
|
dataset_id=f"mmmu-{subset}-{split}",
|
||||||
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.eval.evaluate_rows(
|
||||||
|
task_id="meta-reference::mmmu",
|
||||||
|
input_rows=eval_rows,
|
||||||
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
|
task_config={
|
||||||
|
"type": "benchmark",
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "model",
|
||||||
|
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"repeat_penalty": 1.0,
|
||||||
|
},
|
||||||
|
"system_message": system_message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 1.2. Running SimpleQA
|
||||||
|
- We will use a pre-processed SimpleQA dataset from [llamastack/evals](https://huggingface.co/datasets/llamastack/evals/viewer/evals__simpleqa) which is obtained by transforming the input query into correct format accepted by `inference/chat-completion` API.
|
||||||
|
- Since we will be using this same dataset in our next example for Agentic evaluation, we will register it using the `/datasets` API, and interact with it through `/datasetio` API.
|
||||||
|
|
||||||
|
```python
|
||||||
|
simpleqa_dataset_id = "huggingface::simpleqa"
|
||||||
|
|
||||||
|
_ = client.datasets.register(
|
||||||
|
dataset_id=simpleqa_dataset_id,
|
||||||
|
provider_id="huggingface",
|
||||||
|
url={"uri": "https://huggingface.co/datasets/llamastack/evals"},
|
||||||
|
metadata={
|
||||||
|
"path": "llamastack/evals",
|
||||||
|
"name": "evals__simpleqa",
|
||||||
|
"split": "train",
|
||||||
|
},
|
||||||
|
dataset_schema={
|
||||||
|
"input_query": {"type": "string"},
|
||||||
|
"expected_answer": {"type": "string"},
|
||||||
|
"chat_completion_input": {"type": "chat_completion_input"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_rows = client.datasetio.get_rows_paginated(
|
||||||
|
dataset_id=simpleqa_dataset_id,
|
||||||
|
rows_in_page=5,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
client.eval_tasks.register(
|
||||||
|
eval_task_id="meta-reference::simpleqa",
|
||||||
|
dataset_id=simpleqa_dataset_id,
|
||||||
|
scoring_functions=["llm-as-judge::405b-simpleqa"]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.eval.evaluate_rows(
|
||||||
|
task_id="meta-reference::simpleqa",
|
||||||
|
input_rows=eval_rows.rows,
|
||||||
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
|
task_config={
|
||||||
|
"type": "benchmark",
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "model",
|
||||||
|
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"repeat_penalty": 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 2. Agentic Evaluation
|
||||||
|
- In this example, we will demonstrate how to evaluate a agent candidate served by Llama Stack via `/agent` API.
|
||||||
|
- We will continue to use the SimpleQA dataset we used in previous example.
|
||||||
|
- Instead of running evaluation on model, we will run the evaluation on a Search Agent with access to search tool. We will define our agent evaluation candidate through `AgentConfig`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
agent_config = {
|
||||||
|
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
||||||
|
"instructions": "You are a helpful assistant",
|
||||||
|
"sampling_params": {
|
||||||
|
"strategy": "greedy",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_p": 0.95,
|
||||||
|
},
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "brave_search",
|
||||||
|
"engine": "tavily",
|
||||||
|
"api_key": userdata.get("TAVILY_SEARCH_API_KEY")
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_choice": "auto",
|
||||||
|
"tool_prompt_format": "json",
|
||||||
|
"input_shields": [],
|
||||||
|
"output_shields": [],
|
||||||
|
"enable_session_persistence": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.eval.evaluate_rows(
|
||||||
|
task_id="meta-reference::simpleqa",
|
||||||
|
input_rows=eval_rows.rows,
|
||||||
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
|
task_config={
|
||||||
|
"type": "benchmark",
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "agent",
|
||||||
|
"config": agent_config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
|
@ -1,6 +1,8 @@
|
||||||
# Building AI Applications
|
# Building AI Applications
|
||||||
|
|
||||||
Llama Stack provides all the building blocks needed to create sophisticated AI applications. This guide will walk you through how to use these components effectively.
|
[](https://colab.research.google.com/drive/1F2ksmkoGQPa4pzRjMOE6BXWeOxWFIW6n?usp=sharing)
|
||||||
|
|
||||||
|
Llama Stack provides all the building blocks needed to create sophisticated AI applications. This guide will walk you through how to use these components effectively. Check out our Colab notebook on to follow along working examples on how you can build LLM-powered agentic applications using Llama Stack.
|
||||||
|
|
||||||
## Basic Inference
|
## Basic Inference
|
||||||
|
|
||||||
|
|
40
docs/source/concepts/evaluation_concepts.md
Normal file
40
docs/source/concepts/evaluation_concepts.md
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# Evaluation Concepts
|
||||||
|
|
||||||
|
The Llama Stack Evaluation flow allows you to run evaluations on your GenAI application datasets or pre-registered benchmarks.
|
||||||
|
|
||||||
|
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
||||||
|
- `/datasetio` + `/datasets` API
|
||||||
|
- `/scoring` + `/scoring_functions` API
|
||||||
|
- `/eval` + `/eval_tasks` API
|
||||||
|
|
||||||
|
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||||
|
|
||||||
|
|
||||||
|
## Evaluation Concepts
|
||||||
|
|
||||||
|
The Evaluation APIs are associated with a set of Resources as shown in the following diagram. Please visit the Resources section in our [Core Concepts](../concepts/index.md) guide for better high-level understanding.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
- **DatasetIO**: defines interface with datasets and data loaders.
|
||||||
|
- Associated with `Dataset` resource.
|
||||||
|
- **Scoring**: evaluate outputs of the system.
|
||||||
|
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
|
||||||
|
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
|
||||||
|
- Associated with `EvalTask` resource.
|
||||||
|
|
||||||
|
|
||||||
|
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
```{admonition} Note on Benchmark v.s. Application Evaluation
|
||||||
|
:class: tip
|
||||||
|
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
|
||||||
|
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
|
||||||
|
```
|
||||||
|
|
||||||
|
## What's Next?
|
||||||
|
|
||||||
|
- Check out our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||||
|
- Check out our [Evaluation Reference](../references/evals_reference/index.md) for more details on the APIs.
|
|
@ -62,3 +62,13 @@ While there is a lot of flexibility to mix-and-match providers, often users will
|
||||||
|
|
||||||
|
|
||||||
**On-device Distro**: Finally, you may want to run Llama Stack directly on an edge device (mobile phone or a tablet.) We provide Distros for iOS and Android (coming soon.)
|
**On-device Distro**: Finally, you may want to run Llama Stack directly on an edge device (mobile phone or a tablet.) We provide Distros for iOS and Android (coming soon.)
|
||||||
|
|
||||||
|
## More Concepts
|
||||||
|
- [Evaluation Concepts](evaluation_concepts.md)
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 1
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
evaluation_concepts
|
||||||
|
```
|
||||||
|
|
|
@ -29,6 +29,7 @@ extensions = [
|
||||||
"sphinx_design",
|
"sphinx_design",
|
||||||
"sphinxcontrib.redoc",
|
"sphinxcontrib.redoc",
|
||||||
"sphinxcontrib.mermaid",
|
"sphinxcontrib.mermaid",
|
||||||
|
"sphinxcontrib.video",
|
||||||
]
|
]
|
||||||
myst_enable_extensions = ["colon_fence"]
|
myst_enable_extensions = ["colon_fence"]
|
||||||
|
|
||||||
|
|
|
@ -1,123 +0,0 @@
|
||||||
# Evaluations
|
|
||||||
|
|
||||||
The Llama Stack Evaluation flow allows you to run evaluations on your GenAI application datasets or pre-registered benchmarks.
|
|
||||||
|
|
||||||
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
|
||||||
- `/datasetio` + `/datasets` API
|
|
||||||
- `/scoring` + `/scoring_functions` API
|
|
||||||
- `/eval` + `/eval_tasks` API
|
|
||||||
|
|
||||||
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases.
|
|
||||||
|
|
||||||
## Evaluation Concepts
|
|
||||||
|
|
||||||
The Evaluation APIs are associated with a set of Resources as shown in the following diagram. Please visit the Resources section in our [Core Concepts](../concepts/index.md) guide for better high-level understanding.
|
|
||||||
|
|
||||||

|
|
||||||
|
|
||||||
- **DatasetIO**: defines interface with datasets and data loaders.
|
|
||||||
- Associated with `Dataset` resource.
|
|
||||||
- **Scoring**: evaluate outputs of the system.
|
|
||||||
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
|
|
||||||
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
|
|
||||||
- Associated with `EvalTask` resource.
|
|
||||||
|
|
||||||
|
|
||||||
## Running Evaluations
|
|
||||||
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
|
||||||

|
|
||||||
|
|
||||||
|
|
||||||
```{admonition} Note on Benchmark v.s. Application Evaluation
|
|
||||||
:class: tip
|
|
||||||
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
|
|
||||||
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
|
|
||||||
```
|
|
||||||
|
|
||||||
The following examples give the quick steps to start running evaluations using the llama-stack-client CLI.
|
|
||||||
|
|
||||||
#### Benchmark Evaluation CLI
|
|
||||||
Usage: There are 2 inputs necessary for running a benchmark eval
|
|
||||||
- `eval-task-id`: the identifier associated with the eval task. Each `EvalTask` is parametrized by
|
|
||||||
- `dataset_id`: the identifier associated with the dataset.
|
|
||||||
- `List[scoring_function_id]`: list of scoring function identifiers.
|
|
||||||
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
|
|
||||||
|
|
||||||
|
|
||||||
```
|
|
||||||
llama-stack-client eval run_benchmark <eval-task-id> \
|
|
||||||
--eval-task-config ~/eval_task_config.json \
|
|
||||||
--visualize
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
#### Application Evaluation CLI
|
|
||||||
Usage: For running application evals, you will already have available datasets in hand from your application. You will need to specify:
|
|
||||||
- `scoring-fn-id`: List of ScoringFunction identifiers you wish to use to run on your application.
|
|
||||||
- `Dataset` used for evaluation:
|
|
||||||
- (1) `--dataset-path`: path to local file system containing datasets to run evaluation on
|
|
||||||
- (2) `--dataset-id`: pre-registered dataset in Llama Stack
|
|
||||||
- (Optional) `--scoring-params-config`: optionally parameterize scoring functions with custom params (e.g. `judge_prompt`, `judge_model`, `parsing_regexes`).
|
|
||||||
|
|
||||||
|
|
||||||
```
|
|
||||||
llama-stack-client eval run_scoring <scoring_fn_id_1> <scoring_fn_id_2> ... <scoring_fn_id_n>
|
|
||||||
--dataset-path <path-to-local-dataset> \
|
|
||||||
--output-dir ./
|
|
||||||
```
|
|
||||||
|
|
||||||
#### Defining EvalTaskConfig
|
|
||||||
The `EvalTaskConfig` are user specified config to define:
|
|
||||||
1. `EvalCandidate` to run generation on:
|
|
||||||
- `ModelCandidate`: The model will be used for generation through LlamaStack /inference API.
|
|
||||||
- `AgentCandidate`: The agentic system specified by AgentConfig will be used for generation through LlamaStack /agents API.
|
|
||||||
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
|
|
||||||
|
|
||||||
|
|
||||||
**Example Benchmark EvalTaskConfig**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "benchmark",
|
|
||||||
"eval_candidate": {
|
|
||||||
"type": "model",
|
|
||||||
"model": "Llama3.2-3B-Instruct",
|
|
||||||
"sampling_params": {
|
|
||||||
"strategy": "greedy",
|
|
||||||
"temperature": 0,
|
|
||||||
"top_p": 0.95,
|
|
||||||
"top_k": 0,
|
|
||||||
"max_tokens": 0,
|
|
||||||
"repetition_penalty": 1.0
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
||||||
|
|
||||||
**Example Application EvalTaskConfig**
|
|
||||||
```json
|
|
||||||
{
|
|
||||||
"type": "app",
|
|
||||||
"eval_candidate": {
|
|
||||||
"type": "model",
|
|
||||||
"model": "Llama3.1-405B-Instruct",
|
|
||||||
"sampling_params": {
|
|
||||||
"strategy": "greedy",
|
|
||||||
"temperature": 0,
|
|
||||||
"top_p": 0.95,
|
|
||||||
"top_k": 0,
|
|
||||||
"max_tokens": 0,
|
|
||||||
"repetition_penalty": 1.0
|
|
||||||
}
|
|
||||||
},
|
|
||||||
"scoring_params": {
|
|
||||||
"llm-as-judge::llm_as_judge_base": {
|
|
||||||
"type": "llm_as_judge",
|
|
||||||
"judge_model": "meta-llama/Llama-3.1-8B-Instruct",
|
|
||||||
"prompt_template": "Your job is to look at a question, a gold target ........",
|
|
||||||
"judge_score_regexes": [
|
|
||||||
"(A|B|C)"
|
|
||||||
]
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
```
|
|
|
@ -1,9 +0,0 @@
|
||||||
# Cookbooks
|
|
||||||
|
|
||||||
- [Evaluations Flow](evals.md)
|
|
||||||
|
|
||||||
```{toctree}
|
|
||||||
:maxdepth: 2
|
|
||||||
:hidden:
|
|
||||||
evals.md
|
|
||||||
```
|
|
|
@ -8,12 +8,14 @@ Features:
|
||||||
- Remote Inferencing: Perform inferencing tasks remotely with Llama models hosted on a remote connection (or serverless localhost).
|
- Remote Inferencing: Perform inferencing tasks remotely with Llama models hosted on a remote connection (or serverless localhost).
|
||||||
- Simple Integration: With easy-to-use APIs, a developer can quickly integrate Llama Stack in their Android app. The difference with local vs remote inferencing is also minimal.
|
- Simple Integration: With easy-to-use APIs, a developer can quickly integrate Llama Stack in their Android app. The difference with local vs remote inferencing is also minimal.
|
||||||
|
|
||||||
Latest Release Notes: [v0.0.54.1](https://github.com/meta-llama/llama-stack-client-kotlin/releases/tag/v0.0.54.1)
|
Latest Release Notes: [v0.0.58](https://github.com/meta-llama/llama-stack-client-kotlin/releases/tag/v0.0.58)
|
||||||
|
|
||||||
|
*Tagged releases are stable versions of the project. While we strive to maintain a stable main branch, it's not guaranteed to be free of bugs or issues.*
|
||||||
|
|
||||||
## Android Demo App
|
## Android Demo App
|
||||||
Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app)
|
Check out our demo app to see how to integrate Llama Stack into your Android app: [Android Demo App](https://github.com/meta-llama/llama-stack-apps/tree/android-kotlin-app-latest/examples/android_app)
|
||||||
|
|
||||||
The key files in the app are `LlamaStackLocalInference.kt`, `LlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments.
|
The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlamaStackRemoteInference.kts`, and `MainActivity.java`. With encompassed business logic, the app shows how to use Llama Stack for both the environments.
|
||||||
|
|
||||||
## Quick Start
|
## Quick Start
|
||||||
|
|
||||||
|
@ -22,7 +24,7 @@ The key files in the app are `LlamaStackLocalInference.kt`, `LlamaStackRemoteInf
|
||||||
Add the following dependency in your `build.gradle.kts` file:
|
Add the following dependency in your `build.gradle.kts` file:
|
||||||
```
|
```
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.0.54.1")
|
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.0.58")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
|
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
|
||||||
|
@ -34,10 +36,10 @@ If you plan on doing remote inferencing this is sufficient to get started.
|
||||||
For local inferencing, it is required to include the ExecuTorch library into your app.
|
For local inferencing, it is required to include the ExecuTorch library into your app.
|
||||||
|
|
||||||
Include the ExecuTorch library by:
|
Include the ExecuTorch library by:
|
||||||
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/blob/release/0.0.54.1/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
|
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/blob/release/0.0.58/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
|
||||||
2. Move the script to the top level of your Android app where the app directory resides:
|
2. Move the script to the top level of your Android app where the app directory resides:
|
||||||
<p align="center">
|
<p align="center">
|
||||||
<img src="https://raw.githubusercontent.com/meta-llama/llama-stack-client-kotlin/refs/heads/release/0.0.54.1/doc/img/example_android_app_directory.png" style="width:300px">
|
<img src="https://raw.githubusercontent.com/meta-llama/llama-stack-client-kotlin/refs/heads/release/0.0.58/doc/img/example_android_app_directory.png" style="width:300px">
|
||||||
</p>
|
</p>
|
||||||
|
|
||||||
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate with commit: [0a12e33](https://github.com/pytorch/executorch/commit/0a12e33d22a3d44d1aa2af5f0d0673d45b962553).
|
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate with commit: [0a12e33](https://github.com/pytorch/executorch/commit/0a12e33d22a3d44d1aa2af5f0d0673d45b962553).
|
||||||
|
@ -58,12 +60,14 @@ Start a Llama Stack server on localhost. Here is an example of how you can do th
|
||||||
```
|
```
|
||||||
conda create -n stack-fireworks python=3.10
|
conda create -n stack-fireworks python=3.10
|
||||||
conda activate stack-fireworks
|
conda activate stack-fireworks
|
||||||
pip install llama-stack=0.0.54
|
pip install llama-stack=0.0.58
|
||||||
llama stack build --template fireworks --image-type conda
|
llama stack build --template fireworks --image-type conda
|
||||||
export FIREWORKS_API_KEY=<SOME_KEY>
|
export FIREWORKS_API_KEY=<SOME_KEY>
|
||||||
llama stack run /Users/<your_username>/.llama/distributions/llamastack-fireworks/fireworks-run.yaml --port=5050
|
llama stack run /Users/<your_username>/.llama/distributions/llamastack-fireworks/fireworks-run.yaml --port=5050
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Ensure the Llama Stack server version is the same as the Kotlin SDK Library for maximum compatibility.
|
||||||
|
|
||||||
Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations)
|
Other inference providers: [Table](https://llama-stack.readthedocs.io/en/latest/index.html#supported-llama-stack-implementations)
|
||||||
|
|
||||||
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#settings)
|
How to set remote localhost in Demo App: [Settings](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#settings)
|
||||||
|
@ -109,7 +113,6 @@ With the Kotlin Library managing all the major operational logic, there are mini
|
||||||
val result = client!!.inference().chatCompletion(
|
val result = client!!.inference().chatCompletion(
|
||||||
InferenceChatCompletionParams.builder()
|
InferenceChatCompletionParams.builder()
|
||||||
.modelId(modelName)
|
.modelId(modelName)
|
||||||
.putAdditionalQueryParam("seq_len", sequenceLength.toString())
|
|
||||||
.messages(listOfMessages)
|
.messages(listOfMessages)
|
||||||
.build()
|
.build()
|
||||||
)
|
)
|
||||||
|
@ -118,9 +121,23 @@ val result = client!!.inference().chatCompletion(
|
||||||
var response = result.asChatCompletionResponse().completionMessage().content().string();
|
var response = result.asChatCompletionResponse().completionMessage().content().string();
|
||||||
```
|
```
|
||||||
|
|
||||||
### Setup Tool Calling
|
[Remote only] For inference with a streaming response:
|
||||||
|
|
||||||
Android demo app for more details: [Tool Calling](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#tool-calling)
|
```
|
||||||
|
val result = client!!.inference().chatCompletionStreaming(
|
||||||
|
InferenceChatCompletionParams.builder()
|
||||||
|
.modelId(modelName)
|
||||||
|
.messages(listOfMessages)
|
||||||
|
.build()
|
||||||
|
)
|
||||||
|
|
||||||
|
// Response can be received as a asChatCompletionResponseStreamChunk as part of a callback.
|
||||||
|
// See Android demo app for a detailed implementation example.
|
||||||
|
```
|
||||||
|
|
||||||
|
### Setup Custom Tool Calling
|
||||||
|
|
||||||
|
Android demo app for more details: [Custom Tool Calling](https://github.com/meta-llama/llama-stack-apps/tree/main/examples/android_app#tool-calling)
|
||||||
|
|
||||||
## Advanced Users
|
## Advanced Users
|
||||||
|
|
||||||
|
@ -129,7 +146,7 @@ The purpose of this section is to share more details with users that would like
|
||||||
### Prerequisite
|
### Prerequisite
|
||||||
|
|
||||||
You must complete the following steps:
|
You must complete the following steps:
|
||||||
1. Clone the repo (`git clone https://github.com/meta-llama/llama-stack-client-kotlin.git -b release/0.0.54.1`)
|
1. Clone the repo (`git clone https://github.com/meta-llama/llama-stack-client-kotlin.git -b release/0.0.58`)
|
||||||
2. Port the appropriate ExecuTorch libraries over into your Llama Stack Kotlin library environment.
|
2. Port the appropriate ExecuTorch libraries over into your Llama Stack Kotlin library environment.
|
||||||
```
|
```
|
||||||
cd llama-stack-client-kotlin-client-local
|
cd llama-stack-client-kotlin-client-local
|
||||||
|
|
|
@ -102,7 +102,7 @@ Make sure you have done `pip install llama-stack` and have the Llama Stack CLI a
|
||||||
export LLAMA_STACK_PORT=5001
|
export LLAMA_STACK_PORT=5001
|
||||||
|
|
||||||
llama stack build --template ollama --image-type conda
|
llama stack build --template ollama --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./distributions/ollama/run.yaml \
|
||||||
--port $LLAMA_STACK_PORT \
|
--port $LLAMA_STACK_PORT \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
--env OLLAMA_URL=http://localhost:11434
|
--env OLLAMA_URL=http://localhost:11434
|
||||||
|
|
|
@ -59,7 +59,8 @@ getting_started/index
|
||||||
concepts/index
|
concepts/index
|
||||||
distributions/index
|
distributions/index
|
||||||
building_applications/index
|
building_applications/index
|
||||||
|
benchmark_evaluations/index
|
||||||
|
playground/index
|
||||||
contributing/index
|
contributing/index
|
||||||
references/index
|
references/index
|
||||||
cookbooks/index
|
|
||||||
```
|
```
|
||||||
|
|
109
docs/source/playground/index.md
Normal file
109
docs/source/playground/index.md
Normal file
|
@ -0,0 +1,109 @@
|
||||||
|
# Llama Stack Playground
|
||||||
|
|
||||||
|
```{note}
|
||||||
|
The Llama Stack Playground is currently experimental and subject to change. We welcome feedback and contributions to help improve it.
|
||||||
|
```
|
||||||
|
|
||||||
|
The Llama Stack Playground is an simple interface which aims to:
|
||||||
|
- Showcase **capabilities** and **concepts** of Llama Stack in an interactive environment
|
||||||
|
- Demo **end-to-end** application code to help users get started to build their own applications
|
||||||
|
- Provide an **UI** to help users inspect and understand Llama Stack API providers and resources
|
||||||
|
|
||||||
|
## Key Features
|
||||||
|
|
||||||
|
#### Playground
|
||||||
|
Interactive pages for users to play with and explore Llama Stack API capabilities.
|
||||||
|
|
||||||
|
##### Chatbot
|
||||||
|
```{eval-rst}
|
||||||
|
.. video:: https://github.com/user-attachments/assets/6ca617e8-32ca-49b2-9774-185020ff5204
|
||||||
|
:autoplay:
|
||||||
|
:playsinline:
|
||||||
|
:muted:
|
||||||
|
:loop:
|
||||||
|
:width: 100%
|
||||||
|
```
|
||||||
|
- **Chat**: Chat with Llama models.
|
||||||
|
- This page is a simple chatbot that allows you to chat with Llama models. Under the hood, it uses the `/inference/chat-completion` streaming API to send messages to the model and receive responses.
|
||||||
|
- **RAG**: Uploading documents to memory_banks and chat with RAG agent
|
||||||
|
- This page allows you to upload documents as a `memory_bank` and then chat with a RAG agent to query information about the uploaded documents.
|
||||||
|
- Under the hood, it uses Llama Stack's `/agents` API to define and create a RAG agent and chat with it in a session.
|
||||||
|
|
||||||
|
##### Evaluations
|
||||||
|
```{eval-rst}
|
||||||
|
.. video:: https://github.com/user-attachments/assets/6cc1659f-eba4-49ca-a0a5-7c243557b4f5
|
||||||
|
:autoplay:
|
||||||
|
:playsinline:
|
||||||
|
:muted:
|
||||||
|
:loop:
|
||||||
|
:width: 100%
|
||||||
|
```
|
||||||
|
- **Evaluations (Scoring)**: Run evaluations on your AI application datasets.
|
||||||
|
- This page demonstrates the flow evaluation API to run evaluations on your custom AI application datasets. You may upload your own evaluation datasets and run evaluations using available scoring functions.
|
||||||
|
- Under the hood, it uses Llama Stack's `/scoring` API to run evaluations on selected scoring functions.
|
||||||
|
|
||||||
|
```{eval-rst}
|
||||||
|
.. video:: https://github.com/user-attachments/assets/345845c7-2a2b-4095-960a-9ae40f6a93cf
|
||||||
|
:autoplay:
|
||||||
|
:playsinline:
|
||||||
|
:muted:
|
||||||
|
:loop:
|
||||||
|
:width: 100%
|
||||||
|
```
|
||||||
|
- **Evaluations (Generation + Scoring)**: Use pre-registered evaluation tasks to evaluate an model or agent candidate
|
||||||
|
- This page demonstrates the flow for evaluation API to evaluate an model or agent candidate on pre-defined evaluation tasks. An evaluation task is a combination of dataset and scoring functions.
|
||||||
|
- Under the hood, it uses Llama Stack's `/eval` API to run generations and scorings on specified evaluation configs.
|
||||||
|
- In order to run this page, you may need to register evaluation tasks and datasets as resources first through the following commands.
|
||||||
|
```bash
|
||||||
|
$ llama-stack-client datasets register \
|
||||||
|
--dataset-id "mmlu" \
|
||||||
|
--provider-id "huggingface" \
|
||||||
|
--url "https://huggingface.co/datasets/llamastack/evals" \
|
||||||
|
--metadata '{"path": "llamastack/evals", "name": "evals__mmlu__details", "split": "train"}' \
|
||||||
|
--schema '{"input_query": {"type": "string"}, "expected_answer": {"type": "string"}, "chat_completion_input": {"type": "string"}}'
|
||||||
|
```
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ llama-stack-client eval_tasks register \
|
||||||
|
--eval-task-id meta-reference-mmlu \
|
||||||
|
--provider-id meta-reference \
|
||||||
|
--dataset-id mmlu \
|
||||||
|
--scoring-functions basic::regex_parser_multiple_choice_answer
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
##### Inspect
|
||||||
|
```{eval-rst}
|
||||||
|
.. video:: https://github.com/user-attachments/assets/01d52b2d-92af-4e3a-b623-a9b8ba22ba99
|
||||||
|
:autoplay:
|
||||||
|
:playsinline:
|
||||||
|
:muted:
|
||||||
|
:loop:
|
||||||
|
:width: 100%
|
||||||
|
```
|
||||||
|
- **API Providers**: Inspect Llama Stack API providers
|
||||||
|
- This page allows you to inspect Llama Stack API providers and resources.
|
||||||
|
- Under the hood, it uses Llama Stack's `/providers` API to get information about the providers.
|
||||||
|
|
||||||
|
- **API Resources**: Inspect Llama Stack API resources
|
||||||
|
- This page allows you to inspect Llama Stack API resources (`models`, `datasets`, `memory_banks`, `eval_tasks`, `shields`).
|
||||||
|
- Under the hood, it uses Llama Stack's `/<resources>/list` API to get information about each resources.
|
||||||
|
- Please visit [Core Concepts](https://llama-stack.readthedocs.io/en/latest/concepts/index.html) for more details about the resources.
|
||||||
|
|
||||||
|
## Starting the Llama Stack Playground
|
||||||
|
|
||||||
|
To start the Llama Stack Playground, run the following commands:
|
||||||
|
|
||||||
|
1. Start up the Llama Stack API server
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template together --image-type conda
|
||||||
|
llama stack run together
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Start Streamlit UI
|
||||||
|
```bash
|
||||||
|
cd llama_stack/distribution/ui
|
||||||
|
pip install -r requirements.txt
|
||||||
|
streamlit run app.py
|
||||||
|
```
|
359
docs/source/references/evals_reference/index.md
Normal file
359
docs/source/references/evals_reference/index.md
Normal file
|
@ -0,0 +1,359 @@
|
||||||
|
# Evaluations
|
||||||
|
|
||||||
|
The Llama Stack Evaluation flow allows you to run evaluations on your GenAI application datasets or pre-registered benchmarks.
|
||||||
|
|
||||||
|
We introduce a set of APIs in Llama Stack for supporting running evaluations of LLM applications.
|
||||||
|
- `/datasetio` + `/datasets` API
|
||||||
|
- `/scoring` + `/scoring_functions` API
|
||||||
|
- `/eval` + `/eval_tasks` API
|
||||||
|
|
||||||
|
This guide goes over the sets of APIs and developer experience flow of using Llama Stack to run evaluations for different use cases. Checkout our Colab notebook on working examples with evaluations [here](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing).
|
||||||
|
|
||||||
|
|
||||||
|
## Evaluation Concepts
|
||||||
|
|
||||||
|
The Evaluation APIs are associated with a set of Resources as shown in the following diagram. Please visit the Resources section in our [Core Concepts](../concepts/index.md) guide for better high-level understanding.
|
||||||
|
|
||||||
|

|
||||||
|
|
||||||
|
- **DatasetIO**: defines interface with datasets and data loaders.
|
||||||
|
- Associated with `Dataset` resource.
|
||||||
|
- **Scoring**: evaluate outputs of the system.
|
||||||
|
- Associated with `ScoringFunction` resource. We provide a suite of out-of-the box scoring functions and also the ability for you to add custom evaluators. These scoring functions are the core part of defining an evaluation task to output evaluation metrics.
|
||||||
|
- **Eval**: generate outputs (via Inference or Agents) and perform scoring.
|
||||||
|
- Associated with `EvalTask` resource.
|
||||||
|
|
||||||
|
|
||||||
|
Use the following decision tree to decide how to use LlamaStack Evaluation flow.
|
||||||
|

|
||||||
|
|
||||||
|
|
||||||
|
```{admonition} Note on Benchmark v.s. Application Evaluation
|
||||||
|
:class: tip
|
||||||
|
- **Benchmark Evaluation** is a well-defined eval-task consisting of `dataset` and `scoring_function`. The generation (inference or agent) will be done as part of evaluation.
|
||||||
|
- **Application Evaluation** assumes users already have app inputs & generated outputs. Evaluation will purely focus on scoring the generated outputs via scoring functions (e.g. LLM-as-judge).
|
||||||
|
```
|
||||||
|
|
||||||
|
## Evaluation Examples Walkthrough
|
||||||
|
|
||||||
|
[](https://colab.research.google.com/drive/10CHyykee9j2OigaIcRv47BKG9mrNm0tJ?usp=sharing)
|
||||||
|
|
||||||
|
It is best to open this notebook in Colab to follow along with the examples.
|
||||||
|
|
||||||
|
### 1. Open Benchmark Model Evaluation
|
||||||
|
|
||||||
|
This first example walks you through how to evaluate a model candidate served by Llama Stack on open benchmarks. We will use the following benchmark:
|
||||||
|
- [MMMU](https://arxiv.org/abs/2311.16502) (A Massive Multi-discipline Multimodal Understanding and Reasoning Benchmark for Expert AGI)]: Benchmark designed to evaluate multimodal models.
|
||||||
|
- [SimpleQA](https://openai.com/index/introducing-simpleqa/): Benchmark designed to access models to answer short, fact-seeking questions.
|
||||||
|
|
||||||
|
#### 1.1 Running MMMU
|
||||||
|
- We will use a pre-processed MMMU dataset from [llamastack/mmmu](https://huggingface.co/datasets/llamastack/mmmu). The preprocessing code is shown in in this [Github Gist](https://gist.github.com/yanxi0830/118e9c560227d27132a7fd10e2c92840). The dataset is obtained by transforming the original [MMMU/MMMU](https://huggingface.co/datasets/MMMU/MMMU) dataset into correct format by `inference/chat-completion` API.
|
||||||
|
|
||||||
|
```python
|
||||||
|
import datasets
|
||||||
|
ds = datasets.load_dataset(path="llamastack/mmmu", name="Agriculture", split="dev")
|
||||||
|
ds = ds.select_columns(["chat_completion_input", "input_query", "expected_answer"])
|
||||||
|
eval_rows = ds.to_pandas().to_dict(orient="records")
|
||||||
|
```
|
||||||
|
|
||||||
|
- Next, we will run evaluation on an model candidate, we will need to:
|
||||||
|
- Define a system prompt
|
||||||
|
- Define an EvalCandidate
|
||||||
|
- Run evaluate on the dataset
|
||||||
|
|
||||||
|
```python
|
||||||
|
SYSTEM_PROMPT_TEMPLATE = """
|
||||||
|
You are an expert in Agriculture whose job is to answer questions from the user using images.
|
||||||
|
First, reason about the correct answer.
|
||||||
|
Then write the answer in the following format where X is exactly one of A,B,C,D:
|
||||||
|
Answer: X
|
||||||
|
Make sure X is one of A,B,C,D.
|
||||||
|
If you are uncertain of the correct answer, guess the most likely one.
|
||||||
|
"""
|
||||||
|
|
||||||
|
system_message = {
|
||||||
|
"role": "system",
|
||||||
|
"content": SYSTEM_PROMPT_TEMPLATE,
|
||||||
|
}
|
||||||
|
|
||||||
|
client.eval_tasks.register(
|
||||||
|
eval_task_id="meta-reference::mmmu",
|
||||||
|
dataset_id=f"mmmu-{subset}-{split}",
|
||||||
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.eval.evaluate_rows(
|
||||||
|
task_id="meta-reference::mmmu",
|
||||||
|
input_rows=eval_rows,
|
||||||
|
scoring_functions=["basic::regex_parser_multiple_choice_answer"],
|
||||||
|
task_config={
|
||||||
|
"type": "benchmark",
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "model",
|
||||||
|
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"repeat_penalty": 1.0,
|
||||||
|
},
|
||||||
|
"system_message": system_message
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 1.2. Running SimpleQA
|
||||||
|
- We will use a pre-processed SimpleQA dataset from [llamastack/evals](https://huggingface.co/datasets/llamastack/evals/viewer/evals__simpleqa) which is obtained by transforming the input query into correct format accepted by `inference/chat-completion` API.
|
||||||
|
- Since we will be using this same dataset in our next example for Agentic evaluation, we will register it using the `/datasets` API, and interact with it through `/datasetio` API.
|
||||||
|
|
||||||
|
```python
|
||||||
|
simpleqa_dataset_id = "huggingface::simpleqa"
|
||||||
|
|
||||||
|
_ = client.datasets.register(
|
||||||
|
dataset_id=simpleqa_dataset_id,
|
||||||
|
provider_id="huggingface",
|
||||||
|
url={"uri": "https://huggingface.co/datasets/llamastack/evals"},
|
||||||
|
metadata={
|
||||||
|
"path": "llamastack/evals",
|
||||||
|
"name": "evals__simpleqa",
|
||||||
|
"split": "train",
|
||||||
|
},
|
||||||
|
dataset_schema={
|
||||||
|
"input_query": {"type": "string"},
|
||||||
|
"expected_answer": {"type": "string"},
|
||||||
|
"chat_completion_input": {"type": "chat_completion_input"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
eval_rows = client.datasetio.get_rows_paginated(
|
||||||
|
dataset_id=simpleqa_dataset_id,
|
||||||
|
rows_in_page=5,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
```python
|
||||||
|
client.eval_tasks.register(
|
||||||
|
eval_task_id="meta-reference::simpleqa",
|
||||||
|
dataset_id=simpleqa_dataset_id,
|
||||||
|
scoring_functions=["llm-as-judge::405b-simpleqa"]
|
||||||
|
)
|
||||||
|
|
||||||
|
response = client.eval.evaluate_rows(
|
||||||
|
task_id="meta-reference::simpleqa",
|
||||||
|
input_rows=eval_rows.rows,
|
||||||
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
|
task_config={
|
||||||
|
"type": "benchmark",
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "model",
|
||||||
|
"model": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
|
"sampling_params": {
|
||||||
|
"temperature": 0.0,
|
||||||
|
"max_tokens": 4096,
|
||||||
|
"top_p": 0.9,
|
||||||
|
"repeat_penalty": 1.0,
|
||||||
|
},
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
### 2. Agentic Evaluation
|
||||||
|
- In this example, we will demonstrate how to evaluate a agent candidate served by Llama Stack via `/agent` API.
|
||||||
|
- We will continue to use the SimpleQA dataset we used in previous example.
|
||||||
|
- Instead of running evaluation on model, we will run the evaluation on a Search Agent with access to search tool. We will define our agent evaluation candidate through `AgentConfig`.
|
||||||
|
|
||||||
|
```python
|
||||||
|
agent_config = {
|
||||||
|
"model": "meta-llama/Llama-3.1-405B-Instruct",
|
||||||
|
"instructions": "You are a helpful assistant",
|
||||||
|
"sampling_params": {
|
||||||
|
"strategy": "greedy",
|
||||||
|
"temperature": 0.0,
|
||||||
|
"top_p": 0.95,
|
||||||
|
},
|
||||||
|
"tools": [
|
||||||
|
{
|
||||||
|
"type": "brave_search",
|
||||||
|
"engine": "tavily",
|
||||||
|
"api_key": userdata.get("TAVILY_SEARCH_API_KEY")
|
||||||
|
}
|
||||||
|
],
|
||||||
|
"tool_choice": "auto",
|
||||||
|
"tool_prompt_format": "json",
|
||||||
|
"input_shields": [],
|
||||||
|
"output_shields": [],
|
||||||
|
"enable_session_persistence": False
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.eval.evaluate_rows(
|
||||||
|
task_id="meta-reference::simpleqa",
|
||||||
|
input_rows=eval_rows.rows,
|
||||||
|
scoring_functions=["llm-as-judge::405b-simpleqa"],
|
||||||
|
task_config={
|
||||||
|
"type": "benchmark",
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "agent",
|
||||||
|
"config": agent_config,
|
||||||
|
}
|
||||||
|
}
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### 3. Agentic Application Dataset Scoring
|
||||||
|
- Llama Stack offers a library of scoring functions and the `/scoring` API, allowing you to run evaluations on your pre-annotated AI application datasets.
|
||||||
|
|
||||||
|
- In this example, we will work with an example RAG dataset and couple of scoring functions for evaluation.
|
||||||
|
- `llm-as-judge::base`: LLM-As-Judge with custom judge prompt & model.
|
||||||
|
- `braintrust::factuality`: Factuality scorer from [braintrust](https://github.com/braintrustdata/autoevals).
|
||||||
|
- `basic::subset_of`: Basic checking if generated answer is a subset of expected answer.
|
||||||
|
|
||||||
|
- Please checkout our [Llama Stack Playground](https://llama-stack.readthedocs.io/en/latest/playground/index.html) for an interactive interface to upload datasets and run scorings.
|
||||||
|
|
||||||
|
```python
|
||||||
|
judge_model_id = "meta-llama/Llama-3.1-405B-Instruct-FP8"
|
||||||
|
|
||||||
|
JUDGE_PROMPT = """
|
||||||
|
Given a QUESTION and GENERATED_RESPONSE and EXPECTED_RESPONSE.
|
||||||
|
|
||||||
|
Compare the factual content of the GENERATED_RESPONSE with the EXPECTED_RESPONSE. Ignore any differences in style, grammar, or punctuation.
|
||||||
|
The GENERATED_RESPONSE may either be a subset or superset of the EXPECTED_RESPONSE, or it may conflict with it. Determine which case applies. Answer the question by selecting one of the following options:
|
||||||
|
(A) The GENERATED_RESPONSE is a subset of the EXPECTED_RESPONSE and is fully consistent with it.
|
||||||
|
(B) The GENERATED_RESPONSE is a superset of the EXPECTED_RESPONSE and is fully consistent with it.
|
||||||
|
(C) The GENERATED_RESPONSE contains all the same details as the EXPECTED_RESPONSE.
|
||||||
|
(D) There is a disagreement between the GENERATED_RESPONSE and the EXPECTED_RESPONSE.
|
||||||
|
(E) The answers differ, but these differences don't matter from the perspective of factuality.
|
||||||
|
|
||||||
|
Give your answer in the format "Answer: One of ABCDE, Explanation: ".
|
||||||
|
|
||||||
|
Your actual task:
|
||||||
|
|
||||||
|
QUESTION: {input_query}
|
||||||
|
GENERATED_RESPONSE: {generated_answer}
|
||||||
|
EXPECTED_RESPONSE: {expected_answer}
|
||||||
|
"""
|
||||||
|
|
||||||
|
input_query = "What are the top 5 topics that were explained? Only list succinct bullet points."
|
||||||
|
generated_answer = """
|
||||||
|
Here are the top 5 topics that were explained in the documentation for Torchtune:
|
||||||
|
|
||||||
|
* What is LoRA and how does it work?
|
||||||
|
* Fine-tuning with LoRA: memory savings and parameter-efficient finetuning
|
||||||
|
* Running a LoRA finetune with Torchtune: overview and recipe
|
||||||
|
* Experimenting with different LoRA configurations: rank, alpha, and attention modules
|
||||||
|
* LoRA finetuning
|
||||||
|
"""
|
||||||
|
expected_answer = """LoRA"""
|
||||||
|
|
||||||
|
dataset_rows = [
|
||||||
|
{
|
||||||
|
"input_query": input_query,
|
||||||
|
"generated_answer": generated_answer,
|
||||||
|
"expected_answer": expected_answer,
|
||||||
|
},
|
||||||
|
]
|
||||||
|
|
||||||
|
scoring_params = {
|
||||||
|
"llm-as-judge::base": {
|
||||||
|
"judge_model": judge_model_id,
|
||||||
|
"prompt_template": JUDGE_PROMPT,
|
||||||
|
"type": "llm_as_judge",
|
||||||
|
"judge_score_regexes": ["Answer: (A|B|C|D|E)"],
|
||||||
|
},
|
||||||
|
"basic::subset_of": None,
|
||||||
|
"braintrust::factuality": None,
|
||||||
|
}
|
||||||
|
|
||||||
|
response = client.scoring.score(input_rows=dataset_rows, scoring_functions=scoring_params)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Running Evaluations via CLI
|
||||||
|
The following examples give the quick steps to start running evaluations using the llama-stack-client CLI.
|
||||||
|
|
||||||
|
#### Benchmark Evaluation CLI
|
||||||
|
Usage: There are 2 inputs necessary for running a benchmark eval
|
||||||
|
- `eval-task-id`: the identifier associated with the eval task. Each `EvalTask` is parametrized by
|
||||||
|
- `dataset_id`: the identifier associated with the dataset.
|
||||||
|
- `List[scoring_function_id]`: list of scoring function identifiers.
|
||||||
|
- `eval-task-config`: specifies the configuration of the model / agent to evaluate on.
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
llama-stack-client eval run_benchmark <eval-task-id> \
|
||||||
|
--eval-task-config ~/eval_task_config.json \
|
||||||
|
--visualize
|
||||||
|
```
|
||||||
|
|
||||||
|
|
||||||
|
#### Application Evaluation CLI
|
||||||
|
Usage: For running application evals, you will already have available datasets in hand from your application. You will need to specify:
|
||||||
|
- `scoring-fn-id`: List of ScoringFunction identifiers you wish to use to run on your application.
|
||||||
|
- `Dataset` used for evaluation:
|
||||||
|
- (1) `--dataset-path`: path to local file system containing datasets to run evaluation on
|
||||||
|
- (2) `--dataset-id`: pre-registered dataset in Llama Stack
|
||||||
|
- (Optional) `--scoring-params-config`: optionally parameterize scoring functions with custom params (e.g. `judge_prompt`, `judge_model`, `parsing_regexes`).
|
||||||
|
|
||||||
|
|
||||||
|
```
|
||||||
|
llama-stack-client eval run_scoring <scoring_fn_id_1> <scoring_fn_id_2> ... <scoring_fn_id_n>
|
||||||
|
--dataset-path <path-to-local-dataset> \
|
||||||
|
--output-dir ./
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Defining EvalTaskConfig
|
||||||
|
The `EvalTaskConfig` are user specified config to define:
|
||||||
|
1. `EvalCandidate` to run generation on:
|
||||||
|
- `ModelCandidate`: The model will be used for generation through LlamaStack /inference API.
|
||||||
|
- `AgentCandidate`: The agentic system specified by AgentConfig will be used for generation through LlamaStack /agents API.
|
||||||
|
2. Optionally scoring function params to allow customization of scoring function behaviour. This is useful to parameterize generic scoring functions such as LLMAsJudge with custom `judge_model` / `judge_prompt`.
|
||||||
|
|
||||||
|
|
||||||
|
**Example Benchmark EvalTaskConfig**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "benchmark",
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "model",
|
||||||
|
"model": "Llama3.2-3B-Instruct",
|
||||||
|
"sampling_params": {
|
||||||
|
"strategy": "greedy",
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"top_k": 0,
|
||||||
|
"max_tokens": 0,
|
||||||
|
"repetition_penalty": 1.0
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
**Example Application EvalTaskConfig**
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"type": "app",
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "model",
|
||||||
|
"model": "Llama3.1-405B-Instruct",
|
||||||
|
"sampling_params": {
|
||||||
|
"strategy": "greedy",
|
||||||
|
"temperature": 0,
|
||||||
|
"top_p": 0.95,
|
||||||
|
"top_k": 0,
|
||||||
|
"max_tokens": 0,
|
||||||
|
"repetition_penalty": 1.0
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scoring_params": {
|
||||||
|
"llm-as-judge::llm_as_judge_base": {
|
||||||
|
"type": "llm_as_judge",
|
||||||
|
"judge_model": "meta-llama/Llama-3.1-8B-Instruct",
|
||||||
|
"prompt_template": "Your job is to look at a question, a gold target ........",
|
||||||
|
"judge_score_regexes": [
|
||||||
|
"(A|B|C)"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
Before Width: | Height: | Size: 68 KiB After Width: | Height: | Size: 68 KiB |
Before Width: | Height: | Size: 249 KiB After Width: | Height: | Size: 249 KiB |
|
@ -14,4 +14,5 @@ python_sdk_reference/index
|
||||||
llama_cli_reference/index
|
llama_cli_reference/index
|
||||||
llama_stack_client_cli_reference
|
llama_stack_client_cli_reference
|
||||||
llama_cli_reference/download_models
|
llama_cli_reference/download_models
|
||||||
|
evals_reference/index
|
||||||
```
|
```
|
||||||
|
|
|
@ -3,5 +3,8 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
#
|
|
||||||
# from .distribution.library_client import LlamaStackAsLibraryClient, AsyncLlamaStackAsLibraryClient
|
from llama_stack.distribution.library_client import ( # noqa: F401
|
||||||
|
AsyncLlamaStackAsLibraryClient,
|
||||||
|
LlamaStackAsLibraryClient,
|
||||||
|
)
|
||||||
|
|
|
@ -18,3 +18,5 @@ class Job(BaseModel):
|
||||||
class JobStatus(Enum):
|
class JobStatus(Enum):
|
||||||
completed = "completed"
|
completed = "completed"
|
||||||
in_progress = "in_progress"
|
in_progress = "in_progress"
|
||||||
|
failed = "failed"
|
||||||
|
scheduled = "scheduled"
|
||||||
|
|
|
@ -4,13 +4,26 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import URL
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type
|
from llama_models.schema_utils import json_schema_type
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class PostTrainingMetric(BaseModel):
|
||||||
|
epoch: int
|
||||||
|
train_loss: float
|
||||||
|
validation_loss: float
|
||||||
|
perplexity: float
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
|
@json_schema_type(schema={"description": "Checkpoint created during training runs"})
|
||||||
class Checkpoint(BaseModel):
|
class Checkpoint(BaseModel):
|
||||||
iters: int
|
identifier: str
|
||||||
path: URL
|
created_at: datetime
|
||||||
epoch: int
|
epoch: int
|
||||||
|
post_training_job_id: str
|
||||||
|
path: str
|
||||||
|
training_metrics: Optional[PostTrainingMetric] = None
|
||||||
|
|
|
@ -89,6 +89,7 @@ class VectorMemoryBank(MemoryBankResourceMixin):
|
||||||
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
|
||||||
embedding_model: str
|
embedding_model: str
|
||||||
chunk_size_in_tokens: int
|
chunk_size_in_tokens: int
|
||||||
|
embedding_dimension: Optional[int] = 384 # default to minilm-l6-v2
|
||||||
overlap_size_in_tokens: Optional[int] = None
|
overlap_size_in_tokens: Optional[int] = None
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
@ -20,6 +21,12 @@ class CommonModelFields(BaseModel):
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class ModelType(str, Enum):
|
||||||
|
llm = "llm"
|
||||||
|
embedding = "embedding"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Model(CommonModelFields, Resource):
|
class Model(CommonModelFields, Resource):
|
||||||
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
type: Literal[ResourceType.model.value] = ResourceType.model.value
|
||||||
|
@ -34,12 +41,14 @@ class Model(CommonModelFields, Resource):
|
||||||
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
model_type: ModelType = Field(default=ModelType.llm)
|
||||||
|
|
||||||
|
|
||||||
class ModelInput(CommonModelFields):
|
class ModelInput(CommonModelFields):
|
||||||
model_id: str
|
model_id: str
|
||||||
provider_id: Optional[str] = None
|
provider_id: Optional[str] = None
|
||||||
provider_model_id: Optional[str] = None
|
provider_model_id: Optional[str] = None
|
||||||
|
model_type: Optional[ModelType] = ModelType.llm
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
|
||||||
|
@ -59,6 +68,7 @@ class Models(Protocol):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Model: ...
|
) -> Model: ...
|
||||||
|
|
||||||
@webmethod(route="/models/unregister", method="POST")
|
@webmethod(route="/models/unregister", method="POST")
|
||||||
|
|
|
@ -7,68 +7,85 @@
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
|
||||||
from typing import Any, Dict, List, Optional, Protocol
|
from typing import Any, Dict, List, Optional, Protocol, Union
|
||||||
|
|
||||||
from llama_models.schema_utils import json_schema_type, webmethod
|
from llama_models.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
from pydantic import BaseModel, Field
|
from pydantic import BaseModel, Field
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
from llama_stack.apis.common.job_types import JobStatus
|
||||||
from llama_stack.apis.datasets import * # noqa: F403
|
from llama_stack.apis.datasets import * # noqa: F403
|
||||||
from llama_stack.apis.common.training_types import * # noqa: F403
|
from llama_stack.apis.common.training_types import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
class OptimizerType(Enum):
|
class OptimizerType(Enum):
|
||||||
adam = "adam"
|
adam = "adam"
|
||||||
adamw = "adamw"
|
adamw = "adamw"
|
||||||
sgd = "sgd"
|
sgd = "sgd"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class DataConfig(BaseModel):
|
||||||
|
dataset_id: str
|
||||||
|
batch_size: int
|
||||||
|
shuffle: bool
|
||||||
|
validation_dataset_id: Optional[str] = None
|
||||||
|
packed: Optional[bool] = False
|
||||||
|
train_on_input: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class OptimizerConfig(BaseModel):
|
class OptimizerConfig(BaseModel):
|
||||||
optimizer_type: OptimizerType
|
optimizer_type: OptimizerType
|
||||||
lr: float
|
lr: float
|
||||||
lr_min: float
|
|
||||||
weight_decay: float
|
weight_decay: float
|
||||||
|
num_warmup_steps: int
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class EfficiencyConfig(BaseModel):
|
||||||
|
enable_activation_checkpointing: Optional[bool] = False
|
||||||
|
enable_activation_offloading: Optional[bool] = False
|
||||||
|
memory_efficient_fsdp_wrap: Optional[bool] = False
|
||||||
|
fsdp_cpu_offload: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
batch_size: int
|
max_steps_per_epoch: int
|
||||||
shuffle: bool
|
gradient_accumulation_steps: int
|
||||||
n_iters: int
|
data_config: DataConfig
|
||||||
|
optimizer_config: OptimizerConfig
|
||||||
enable_activation_checkpointing: bool
|
efficiency_config: Optional[EfficiencyConfig] = None
|
||||||
memory_efficient_fsdp_wrap: bool
|
dtype: Optional[str] = "bf16"
|
||||||
fsdp_cpu_offload: bool
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class FinetuningAlgorithm(Enum):
|
|
||||||
full = "full"
|
|
||||||
lora = "lora"
|
|
||||||
qlora = "qlora"
|
|
||||||
dora = "dora"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class LoraFinetuningConfig(BaseModel):
|
class LoraFinetuningConfig(BaseModel):
|
||||||
|
type: Literal["LoRA"] = "LoRA"
|
||||||
lora_attn_modules: List[str]
|
lora_attn_modules: List[str]
|
||||||
apply_lora_to_mlp: bool
|
apply_lora_to_mlp: bool
|
||||||
apply_lora_to_output: bool
|
apply_lora_to_output: bool
|
||||||
rank: int
|
rank: int
|
||||||
alpha: int
|
alpha: int
|
||||||
|
use_dora: Optional[bool] = False
|
||||||
|
quantize_base: Optional[bool] = False
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class QLoraFinetuningConfig(LoraFinetuningConfig):
|
class QATFinetuningConfig(BaseModel):
|
||||||
pass
|
type: Literal["QAT"] = "QAT"
|
||||||
|
quantizer_name: str
|
||||||
|
group_size: int
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
AlgorithmConfig = Annotated[
|
||||||
class DoraFinetuningConfig(LoraFinetuningConfig):
|
Union[LoraFinetuningConfig, QATFinetuningConfig], Field(discriminator="type")
|
||||||
pass
|
]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -79,14 +96,6 @@ class PostTrainingJobLogStream(BaseModel):
|
||||||
log_lines: List[str]
|
log_lines: List[str]
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingJobStatus(Enum):
|
|
||||||
running = "running"
|
|
||||||
completed = "completed"
|
|
||||||
failed = "failed"
|
|
||||||
scheduled = "scheduled"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class RLHFAlgorithm(Enum):
|
class RLHFAlgorithm(Enum):
|
||||||
dpo = "dpo"
|
dpo = "dpo"
|
||||||
|
@ -100,29 +109,6 @@ class DPOAlignmentConfig(BaseModel):
|
||||||
gamma: float
|
gamma: float
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class PostTrainingSFTRequest(BaseModel):
|
|
||||||
"""Request to finetune a model."""
|
|
||||||
|
|
||||||
job_uuid: str
|
|
||||||
|
|
||||||
model: str
|
|
||||||
dataset_id: str
|
|
||||||
validation_dataset_id: str
|
|
||||||
|
|
||||||
algorithm: FinetuningAlgorithm
|
|
||||||
algorithm_config: Union[
|
|
||||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
|
||||||
]
|
|
||||||
|
|
||||||
optimizer_config: OptimizerConfig
|
|
||||||
training_config: TrainingConfig
|
|
||||||
|
|
||||||
# TODO: define these
|
|
||||||
hyperparam_search_config: Dict[str, Any]
|
|
||||||
logger_config: Dict[str, Any]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class PostTrainingRLHFRequest(BaseModel):
|
class PostTrainingRLHFRequest(BaseModel):
|
||||||
"""Request to finetune a model."""
|
"""Request to finetune a model."""
|
||||||
|
@ -135,7 +121,7 @@ class PostTrainingRLHFRequest(BaseModel):
|
||||||
validation_dataset_id: str
|
validation_dataset_id: str
|
||||||
|
|
||||||
algorithm: RLHFAlgorithm
|
algorithm: RLHFAlgorithm
|
||||||
algorithm_config: Union[DPOAlignmentConfig]
|
algorithm_config: DPOAlignmentConfig
|
||||||
|
|
||||||
optimizer_config: OptimizerConfig
|
optimizer_config: OptimizerConfig
|
||||||
training_config: TrainingConfig
|
training_config: TrainingConfig
|
||||||
|
@ -154,7 +140,7 @@ class PostTrainingJobStatusResponse(BaseModel):
|
||||||
"""Status of a finetuning job."""
|
"""Status of a finetuning job."""
|
||||||
|
|
||||||
job_uuid: str
|
job_uuid: str
|
||||||
status: PostTrainingJobStatus
|
status: JobStatus
|
||||||
|
|
||||||
scheduled_at: Optional[datetime] = None
|
scheduled_at: Optional[datetime] = None
|
||||||
started_at: Optional[datetime] = None
|
started_at: Optional[datetime] = None
|
||||||
|
@ -176,54 +162,44 @@ class PostTrainingJobArtifactsResponse(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
class PostTraining(Protocol):
|
class PostTraining(Protocol):
|
||||||
@webmethod(route="/post-training/supervised-fine-tune")
|
@webmethod(route="/post-training/supervised-fine-tune", method="POST")
|
||||||
def supervised_fine_tune(
|
async def supervised_fine_tune(
|
||||||
self,
|
self,
|
||||||
job_uuid: str,
|
job_uuid: str,
|
||||||
model: str,
|
training_config: TrainingConfig,
|
||||||
dataset_id: str,
|
hyperparam_search_config: Dict[str, Any],
|
||||||
validation_dataset_id: str,
|
logger_config: Dict[str, Any],
|
||||||
algorithm: FinetuningAlgorithm,
|
model: str = Field(
|
||||||
algorithm_config: Union[
|
default="Llama3.2-3B-Instruct",
|
||||||
LoraFinetuningConfig, QLoraFinetuningConfig, DoraFinetuningConfig
|
description="Model descriptor from `llama model list`",
|
||||||
],
|
),
|
||||||
optimizer_config: OptimizerConfig,
|
checkpoint_dir: Optional[str] = None,
|
||||||
|
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||||
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
|
@webmethod(route="/post-training/preference-optimize", method="POST")
|
||||||
|
async def preference_optimize(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
finetuned_model: str,
|
||||||
|
algorithm_config: DPOAlignmentConfig,
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
) -> PostTrainingJob: ...
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/preference-optimize")
|
@webmethod(route="/post-training/jobs", method="GET")
|
||||||
def preference_optimize(
|
async def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
||||||
self,
|
|
||||||
job_uuid: str,
|
|
||||||
finetuned_model: URL,
|
|
||||||
dataset_id: str,
|
|
||||||
validation_dataset_id: str,
|
|
||||||
algorithm: RLHFAlgorithm,
|
|
||||||
algorithm_config: Union[DPOAlignmentConfig],
|
|
||||||
optimizer_config: OptimizerConfig,
|
|
||||||
training_config: TrainingConfig,
|
|
||||||
hyperparam_search_config: Dict[str, Any],
|
|
||||||
logger_config: Dict[str, Any],
|
|
||||||
) -> PostTrainingJob: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/jobs")
|
@webmethod(route="/post-training/job/status", method="GET")
|
||||||
def get_training_jobs(self) -> List[PostTrainingJob]: ...
|
async def get_training_job_status(
|
||||||
|
|
||||||
# sends SSE stream of logs
|
|
||||||
@webmethod(route="/post-training/job/logs")
|
|
||||||
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
|
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/status")
|
|
||||||
def get_training_job_status(
|
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobStatusResponse: ...
|
) -> Optional[PostTrainingJobStatusResponse]: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/cancel")
|
@webmethod(route="/post-training/job/cancel", method="POST")
|
||||||
def cancel_training_job(self, job_uuid: str) -> None: ...
|
async def cancel_training_job(self, job_uuid: str) -> None: ...
|
||||||
|
|
||||||
@webmethod(route="/post-training/job/artifacts")
|
@webmethod(route="/post-training/job/artifacts", method="GET")
|
||||||
def get_training_job_artifacts(
|
async def get_training_job_artifacts(
|
||||||
self, job_uuid: str
|
self, job_uuid: str
|
||||||
) -> PostTrainingJobArtifactsResponse: ...
|
) -> Optional[PostTrainingJobArtifactsResponse]: ...
|
||||||
|
|
|
@ -24,6 +24,7 @@ from llama_stack.apis.inspect import Inspect
|
||||||
from llama_stack.apis.memory import Memory
|
from llama_stack.apis.memory import Memory
|
||||||
from llama_stack.apis.memory_banks import MemoryBanks
|
from llama_stack.apis.memory_banks import MemoryBanks
|
||||||
from llama_stack.apis.models import Models
|
from llama_stack.apis.models import Models
|
||||||
|
from llama_stack.apis.post_training import PostTraining
|
||||||
from llama_stack.apis.safety import Safety
|
from llama_stack.apis.safety import Safety
|
||||||
from llama_stack.apis.scoring import Scoring
|
from llama_stack.apis.scoring import Scoring
|
||||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||||
|
@ -58,6 +59,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
||||||
Api.scoring_functions: ScoringFunctions,
|
Api.scoring_functions: ScoringFunctions,
|
||||||
Api.eval: Eval,
|
Api.eval: Eval,
|
||||||
Api.eval_tasks: EvalTasks,
|
Api.eval_tasks: EvalTasks,
|
||||||
|
Api.post_training: PostTraining,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -88,9 +88,10 @@ class InferenceRouter(Inference):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
) -> None:
|
) -> None:
|
||||||
await self.routing_table.register_model(
|
await self.routing_table.register_model(
|
||||||
model_id, provider_model_id, provider_id, metadata
|
model_id, provider_model_id, provider_id, metadata, model_type
|
||||||
)
|
)
|
||||||
|
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
|
@ -105,6 +106,13 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
if model.model_type == ModelType.embedding:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
@ -131,6 +139,13 @@ class InferenceRouter(Inference):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> AsyncGenerator:
|
) -> AsyncGenerator:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
if model.model_type == ModelType.embedding:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an embedding model and does not support chat completions"
|
||||||
|
)
|
||||||
provider = self.routing_table.get_provider_impl(model_id)
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
params = dict(
|
params = dict(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
|
@ -150,6 +165,13 @@ class InferenceRouter(Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
model = await self.routing_table.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model '{model_id}' not found")
|
||||||
|
if model.model_type == ModelType.llm:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model_id}' is an LLM model and does not support embeddings"
|
||||||
|
)
|
||||||
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
return await self.routing_table.get_provider_impl(model_id).embeddings(
|
||||||
model_id=model_id,
|
model_id=model_id,
|
||||||
contents=contents,
|
contents=contents,
|
||||||
|
|
|
@ -209,6 +209,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
provider_model_id: Optional[str] = None,
|
provider_model_id: Optional[str] = None,
|
||||||
provider_id: Optional[str] = None,
|
provider_id: Optional[str] = None,
|
||||||
metadata: Optional[Dict[str, Any]] = None,
|
metadata: Optional[Dict[str, Any]] = None,
|
||||||
|
model_type: Optional[ModelType] = None,
|
||||||
) -> Model:
|
) -> Model:
|
||||||
if provider_model_id is None:
|
if provider_model_id is None:
|
||||||
provider_model_id = model_id
|
provider_model_id = model_id
|
||||||
|
@ -222,11 +223,18 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
)
|
)
|
||||||
if metadata is None:
|
if metadata is None:
|
||||||
metadata = {}
|
metadata = {}
|
||||||
|
if model_type is None:
|
||||||
|
model_type = ModelType.llm
|
||||||
|
if "embedding_dimension" not in metadata and model_type == ModelType.embedding:
|
||||||
|
raise ValueError(
|
||||||
|
"Embedding model must have an embedding dimension in its metadata"
|
||||||
|
)
|
||||||
model = Model(
|
model = Model(
|
||||||
identifier=model_id,
|
identifier=model_id,
|
||||||
provider_resource_id=provider_model_id,
|
provider_resource_id=provider_model_id,
|
||||||
provider_id=provider_id,
|
provider_id=provider_id,
|
||||||
metadata=metadata,
|
metadata=metadata,
|
||||||
|
model_type=model_type,
|
||||||
)
|
)
|
||||||
registered_model = await self.register_object(model)
|
registered_model = await self.register_object(model)
|
||||||
return registered_model
|
return registered_model
|
||||||
|
@ -298,16 +306,36 @@ class MemoryBanksRoutingTable(CommonRoutingTableImpl, MemoryBanks):
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
"No provider specified and multiple providers available. Please specify a provider_id."
|
"No provider specified and multiple providers available. Please specify a provider_id."
|
||||||
)
|
)
|
||||||
memory_bank = parse_obj_as(
|
model = await self.get_object_by_identifier("model", params.embedding_model)
|
||||||
MemoryBank,
|
if model is None:
|
||||||
{
|
if params.embedding_model == "all-MiniLM-L6-v2":
|
||||||
|
raise ValueError(
|
||||||
|
"Embeddings are now served via Inference providers. "
|
||||||
|
"Please upgrade your run.yaml to include inline::sentence-transformer as an additional inference provider. "
|
||||||
|
"See https://github.com/meta-llama/llama-stack/blob/main/llama_stack/templates/together/run.yaml for an example."
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Model {params.embedding_model} not found")
|
||||||
|
if model.model_type != ModelType.embedding:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {params.embedding_model} is not an embedding model"
|
||||||
|
)
|
||||||
|
if "embedding_dimension" not in model.metadata:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model {params.embedding_model} does not have an embedding dimension"
|
||||||
|
)
|
||||||
|
memory_bank_data = {
|
||||||
"identifier": memory_bank_id,
|
"identifier": memory_bank_id,
|
||||||
"type": ResourceType.memory_bank.value,
|
"type": ResourceType.memory_bank.value,
|
||||||
"provider_id": provider_id,
|
"provider_id": provider_id,
|
||||||
"provider_resource_id": provider_memory_bank_id,
|
"provider_resource_id": provider_memory_bank_id,
|
||||||
**params.model_dump(),
|
**params.model_dump(),
|
||||||
},
|
}
|
||||||
)
|
if params.memory_bank_type == MemoryBankType.vector.value:
|
||||||
|
memory_bank_data["embedding_dimension"] = model.metadata[
|
||||||
|
"embedding_dimension"
|
||||||
|
]
|
||||||
|
memory_bank = parse_obj_as(MemoryBank, memory_bank_data)
|
||||||
await self.register_object(memory_bank)
|
await self.register_object(memory_bank)
|
||||||
return memory_bank
|
return memory_bank
|
||||||
|
|
||||||
|
|
|
@ -40,7 +40,7 @@ class DistributionRegistry(Protocol):
|
||||||
|
|
||||||
|
|
||||||
REGISTER_PREFIX = "distributions:registry"
|
REGISTER_PREFIX = "distributions:registry"
|
||||||
KEY_VERSION = "v2"
|
KEY_VERSION = "v3"
|
||||||
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
KEY_FORMAT = f"{REGISTER_PREFIX}:{KEY_VERSION}::" + "{type}:{identifier}"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -28,6 +28,7 @@ class Api(Enum):
|
||||||
datasetio = "datasetio"
|
datasetio = "datasetio"
|
||||||
scoring = "scoring"
|
scoring = "scoring"
|
||||||
eval = "eval"
|
eval = "eval"
|
||||||
|
post_training = "post_training"
|
||||||
|
|
||||||
telemetry = "telemetry"
|
telemetry = "telemetry"
|
||||||
|
|
||||||
|
@ -200,10 +201,13 @@ API responses, specify the adapter here.
|
||||||
return self.adapter.provider_data_validator
|
return self.adapter.provider_data_validator
|
||||||
|
|
||||||
|
|
||||||
def remote_provider_spec(api: Api, adapter: AdapterSpec) -> RemoteProviderSpec:
|
def remote_provider_spec(
|
||||||
|
api: Api, adapter: AdapterSpec, api_dependencies: Optional[List[Api]] = None
|
||||||
|
) -> RemoteProviderSpec:
|
||||||
return RemoteProviderSpec(
|
return RemoteProviderSpec(
|
||||||
api=api,
|
api=api,
|
||||||
provider_type=f"remote::{adapter.adapter_type}",
|
provider_type=f"remote::{adapter.adapter_type}",
|
||||||
config_class=adapter.config_class,
|
config_class=adapter.config_class,
|
||||||
adapter=adapter,
|
adapter=adapter,
|
||||||
|
api_dependencies=api_dependencies or [],
|
||||||
)
|
)
|
||||||
|
|
|
@ -16,12 +16,14 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
from llama_stack.providers.utils.inference.model_registry import build_model_alias
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .generation import Llama
|
from .generation import Llama
|
||||||
from .model_parallel import LlamaModelParallelGenerator
|
from .model_parallel import LlamaModelParallelGenerator
|
||||||
|
@ -32,12 +34,17 @@ log = logging.getLogger(__name__)
|
||||||
SEMAPHORE = asyncio.Semaphore(1)
|
SEMAPHORE = asyncio.Semaphore(1)
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolPrivate):
|
class MetaReferenceInferenceImpl(
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
Inference,
|
||||||
|
ModelsProtocolPrivate,
|
||||||
|
):
|
||||||
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
def __init__(self, config: MetaReferenceInferenceConfig) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
model = resolve_model(config.model)
|
model = resolve_model(config.model)
|
||||||
ModelRegistryHelper.__init__(
|
if model is None:
|
||||||
self,
|
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
||||||
|
self.model_registry_helper = ModelRegistryHelper(
|
||||||
[
|
[
|
||||||
build_model_alias(
|
build_model_alias(
|
||||||
model.descriptor(),
|
model.descriptor(),
|
||||||
|
@ -45,8 +52,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
)
|
)
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
if model is None:
|
|
||||||
raise RuntimeError(f"Unknown model: {config.model}, Run `llama model list`")
|
|
||||||
self.model = model
|
self.model = model
|
||||||
# verify that the checkpoint actually is for this model lol
|
# verify that the checkpoint actually is for this model lol
|
||||||
|
|
||||||
|
@ -76,6 +81,12 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
async def unregister_model(self, model_id: str) -> None:
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
model = await self.model_registry_helper.register_model(model)
|
||||||
|
if model.model_type == ModelType.embedding:
|
||||||
|
self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
return model
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -394,13 +405,6 @@ class MetaReferenceInferenceImpl(Inference, ModelRegistryHelper, ModelsProtocolP
|
||||||
for x in impl():
|
for x in impl():
|
||||||
yield x
|
yield x
|
||||||
|
|
||||||
async def embeddings(
|
|
||||||
self,
|
|
||||||
model_id: str,
|
|
||||||
contents: List[InterleavedTextMedia],
|
|
||||||
) -> EmbeddingsResponse:
|
|
||||||
raise NotImplementedError()
|
|
||||||
|
|
||||||
|
|
||||||
async def request_with_localized_media(
|
async def request_with_localized_media(
|
||||||
request: Union[ChatCompletionRequest, CompletionRequest],
|
request: Union[ChatCompletionRequest, CompletionRequest],
|
||||||
|
|
|
@ -0,0 +1,20 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers.config import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: SentenceTransformersInferenceConfig,
|
||||||
|
_deps,
|
||||||
|
):
|
||||||
|
from .sentence_transformers import SentenceTransformersInferenceImpl
|
||||||
|
|
||||||
|
impl = SentenceTransformersInferenceImpl(config)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
|
@ -0,0 +1,16 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformersInferenceConfig(BaseModel):
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls) -> Dict[str, Any]:
|
||||||
|
return {}
|
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
CompletionResponse,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
ToolChoice,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
)
|
||||||
|
from .config import SentenceTransformersInferenceConfig
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformersInferenceImpl(
|
||||||
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
Inference,
|
||||||
|
ModelsProtocolPrivate,
|
||||||
|
):
|
||||||
|
def __init__(self, config: SentenceTransformersInferenceConfig) -> None:
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> None:
|
||||||
|
_ = self._load_sentence_transformer_model(model.provider_resource_id)
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def unregister_model(self, model_id: str) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: str,
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> Union[CompletionResponse, AsyncGenerator]:
|
||||||
|
raise ValueError("Sentence transformers don't support completion")
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
raise ValueError("Sentence transformers don't support chat completion")
|
|
@ -4,16 +4,19 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
from .config import FaissImplConfig
|
from .config import FaissImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(config: FaissImplConfig, _deps):
|
async def get_provider_impl(config: FaissImplConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .faiss import FaissMemoryImpl
|
from .faiss import FaissMemoryImpl
|
||||||
|
|
||||||
assert isinstance(
|
assert isinstance(
|
||||||
config, FaissImplConfig
|
config, FaissImplConfig
|
||||||
), f"Unexpected config type: {type(config)}"
|
), f"Unexpected config type: {type(config)}"
|
||||||
|
|
||||||
impl = FaissMemoryImpl(config)
|
impl = FaissMemoryImpl(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -19,11 +19,10 @@ from numpy.typing import NDArray
|
||||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.kvstore import kvstore_impl
|
from llama_stack.providers.utils.kvstore import kvstore_impl
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
|
@ -32,7 +31,8 @@ from .config import FaissImplConfig
|
||||||
|
|
||||||
logger = logging.getLogger(__name__)
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
MEMORY_BANKS_PREFIX = "memory_banks:v1::"
|
MEMORY_BANKS_PREFIX = "memory_banks:v2::"
|
||||||
|
FAISS_INDEX_PREFIX = "faiss_index:v2::"
|
||||||
|
|
||||||
|
|
||||||
class FaissIndex(EmbeddingIndex):
|
class FaissIndex(EmbeddingIndex):
|
||||||
|
@ -56,7 +56,7 @@ class FaissIndex(EmbeddingIndex):
|
||||||
if not self.kvstore:
|
if not self.kvstore:
|
||||||
return
|
return
|
||||||
|
|
||||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||||
stored_data = await self.kvstore.get(index_key)
|
stored_data = await self.kvstore.get(index_key)
|
||||||
|
|
||||||
if stored_data:
|
if stored_data:
|
||||||
|
@ -85,16 +85,25 @@ class FaissIndex(EmbeddingIndex):
|
||||||
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
"faiss_index": base64.b64encode(buffer.getvalue()).decode("utf-8"),
|
||||||
}
|
}
|
||||||
|
|
||||||
index_key = f"faiss_index:v1::{self.bank_id}"
|
index_key = f"{FAISS_INDEX_PREFIX}{self.bank_id}"
|
||||||
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
await self.kvstore.set(key=index_key, value=json.dumps(data))
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
if not self.kvstore or not self.bank_id:
|
if not self.kvstore or not self.bank_id:
|
||||||
return
|
return
|
||||||
|
|
||||||
await self.kvstore.delete(f"faiss_index:v1::{self.bank_id}")
|
await self.kvstore.delete(f"{FAISS_INDEX_PREFIX}{self.bank_id}")
|
||||||
|
|
||||||
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
async def add_chunks(self, chunks: List[Chunk], embeddings: NDArray):
|
||||||
|
# Add dimension check
|
||||||
|
embedding_dim = (
|
||||||
|
embeddings.shape[1] if len(embeddings.shape) > 1 else embeddings.shape[0]
|
||||||
|
)
|
||||||
|
if embedding_dim != self.index.d:
|
||||||
|
raise ValueError(
|
||||||
|
f"Embedding dimension mismatch. Expected {self.index.d}, got {embedding_dim}"
|
||||||
|
)
|
||||||
|
|
||||||
indexlen = len(self.id_by_index)
|
indexlen = len(self.id_by_index)
|
||||||
for i, chunk in enumerate(chunks):
|
for i, chunk in enumerate(chunks):
|
||||||
self.chunk_by_index[indexlen + i] = chunk
|
self.chunk_by_index[indexlen + i] = chunk
|
||||||
|
@ -124,8 +133,9 @@ class FaissIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: FaissImplConfig) -> None:
|
def __init__(self, config: FaissImplConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
self.kvstore = None
|
self.kvstore = None
|
||||||
|
|
||||||
|
@ -139,10 +149,11 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
for bank_data in stored_banks:
|
for bank_data in stored_banks:
|
||||||
bank = VectorMemoryBank.model_validate_json(bank_data)
|
bank = VectorMemoryBank.model_validate_json(bank_data)
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank,
|
||||||
index=await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, bank.identifier
|
bank.embedding_dimension, self.kvstore, bank.identifier
|
||||||
),
|
),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank.identifier] = index
|
self.cache[bank.identifier] = index
|
||||||
|
|
||||||
|
@ -166,13 +177,13 @@ class FaissMemoryImpl(Memory, MemoryBanksProtocolPrivate):
|
||||||
)
|
)
|
||||||
|
|
||||||
# Store in cache
|
# Store in cache
|
||||||
index = BankWithIndex(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
bank=memory_bank,
|
memory_bank,
|
||||||
index=await FaissIndex.create(
|
await FaissIndex.create(
|
||||||
ALL_MINILM_L6_V2_DIMENSION, self.kvstore, memory_bank.identifier
|
memory_bank.embedding_dimension, self.kvstore, memory_bank.identifier
|
||||||
),
|
),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def list_memory_banks(self) -> List[MemoryBank]:
|
async def list_memory_banks(self) -> List[MemoryBank]:
|
||||||
return [i.bank for i in self.cache.values()]
|
return [i.bank for i in self.cache.values()]
|
||||||
|
|
|
@ -0,0 +1,27 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
|
from .config import TorchtunePostTrainingConfig
|
||||||
|
|
||||||
|
# post_training api and the torchtune provider is still experimental and under heavy development
|
||||||
|
|
||||||
|
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: TorchtunePostTrainingConfig,
|
||||||
|
deps: Dict[Api, ProviderSpec],
|
||||||
|
):
|
||||||
|
from .post_training import TorchtunePostTrainingImpl
|
||||||
|
|
||||||
|
impl = TorchtunePostTrainingImpl(
|
||||||
|
config,
|
||||||
|
deps[Api.datasetio],
|
||||||
|
deps[Api.datasets],
|
||||||
|
)
|
||||||
|
return impl
|
|
@ -0,0 +1,157 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import shutil
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from torchtune import training
|
||||||
|
from torchtune.models import convert_weights
|
||||||
|
from torchtune.training.checkpointing._utils import ModelType, safe_torch_load
|
||||||
|
from torchtune.utils._logging import get_logger
|
||||||
|
|
||||||
|
logger = get_logger("DEBUG")
|
||||||
|
|
||||||
|
|
||||||
|
class TorchtuneCheckpointer:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
training_algorithm: str,
|
||||||
|
checkpoint_dir: str,
|
||||||
|
checkpoint_files: List[str],
|
||||||
|
output_dir: str,
|
||||||
|
model_type: str,
|
||||||
|
) -> None:
|
||||||
|
# Fail fast if ``checkpoint_files`` is invalid
|
||||||
|
# TODO: support loading more than one file
|
||||||
|
if len(checkpoint_files) != 1:
|
||||||
|
raise ValueError(
|
||||||
|
"Currently we only support reading from a single torchtune checkpoint file. "
|
||||||
|
f"Got {len(checkpoint_files)} files instead."
|
||||||
|
)
|
||||||
|
self._checkpoint_file = checkpoint_files[0]
|
||||||
|
self._model_id = model_id
|
||||||
|
self._training_algorithm = training_algorithm
|
||||||
|
self._checkpoint_dir = Path(checkpoint_dir)
|
||||||
|
self._model_type = ModelType[model_type]
|
||||||
|
self._output_dir = output_dir
|
||||||
|
# get ckpt paths
|
||||||
|
self._checkpoint_path = Path.joinpath(
|
||||||
|
self._checkpoint_dir, self._checkpoint_file
|
||||||
|
)
|
||||||
|
|
||||||
|
def load_checkpoint(self) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Load Meta checkpoint from file. Currently only loading from a single file is supported.
|
||||||
|
"""
|
||||||
|
state_dict: Dict[str:Any] = {}
|
||||||
|
model_state_dict = safe_torch_load(self._checkpoint_path)
|
||||||
|
if self._model_type == ModelType.LLAMA3_VISION:
|
||||||
|
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||||
|
llama3_vision_meta_to_tune,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = llama3_vision_meta_to_tune(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
state_dict[training.MODEL_KEY] = convert_weights.meta_to_tune(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
# llama3_2 has tied weights, so we need to remove the output.weight key
|
||||||
|
if self._model_type == ModelType.LLAMA3_2:
|
||||||
|
logger.info(
|
||||||
|
"Identified model_type = Llama3_2. Ignoring output.weight in"
|
||||||
|
" checkpoint in favor of the tok_embedding.weight"
|
||||||
|
" tied weights."
|
||||||
|
)
|
||||||
|
state_dict[training.MODEL_KEY].pop("output.weight")
|
||||||
|
|
||||||
|
return state_dict
|
||||||
|
|
||||||
|
def save_checkpoint(
|
||||||
|
self,
|
||||||
|
state_dict: Dict[str, Any],
|
||||||
|
epoch: int,
|
||||||
|
adapter_only: bool = False,
|
||||||
|
) -> str:
|
||||||
|
model_file_path = (
|
||||||
|
Path(self._output_dir)
|
||||||
|
/ f"{self._model_id}-{self._training_algorithm}-{epoch}"
|
||||||
|
)
|
||||||
|
|
||||||
|
model_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
|
||||||
|
# copy the related files for inference
|
||||||
|
shutil.copy(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "params.json"),
|
||||||
|
Path.joinpath(model_file_path, "params.json"),
|
||||||
|
)
|
||||||
|
shutil.copy(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "tokenizer.model"),
|
||||||
|
Path.joinpath(model_file_path, "tokenizer.model"),
|
||||||
|
)
|
||||||
|
shutil.copy(
|
||||||
|
Path.joinpath(self._checkpoint_dir, "orig_params.json"),
|
||||||
|
Path.joinpath(model_file_path, "orig_params.json"),
|
||||||
|
)
|
||||||
|
|
||||||
|
if not adapter_only:
|
||||||
|
model_state_dict = state_dict[training.MODEL_KEY]
|
||||||
|
if self._model_type == ModelType.LLAMA3_VISION:
|
||||||
|
from torchtune.models.llama3_2_vision._convert_weights import (
|
||||||
|
llama3_vision_tune_to_meta,
|
||||||
|
)
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = llama3_vision_tune_to_meta(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# llama3_2 has tied weights, so we need to add the output.weight key
|
||||||
|
if (
|
||||||
|
self._model_type == ModelType.LLAMA3_2
|
||||||
|
and "output.weight" not in model_state_dict
|
||||||
|
):
|
||||||
|
model_state_dict["output.weight"] = model_state_dict[
|
||||||
|
"tok_embeddings.weight"
|
||||||
|
]
|
||||||
|
|
||||||
|
state_dict[training.MODEL_KEY] = convert_weights.tune_to_meta(
|
||||||
|
model_state_dict
|
||||||
|
)
|
||||||
|
|
||||||
|
model_file_name = Path.joinpath(model_file_path, "consolidated.00.pth")
|
||||||
|
|
||||||
|
torch.save(state_dict[training.MODEL_KEY], model_file_name)
|
||||||
|
logger.info(
|
||||||
|
"Model checkpoint of size "
|
||||||
|
f"{os.path.getsize(model_file_name) / 1000**3:.2f} GB "
|
||||||
|
f"saved to {model_file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
if training.ADAPTER_KEY in state_dict:
|
||||||
|
adapter_file_path = model_file_path / "adapter"
|
||||||
|
adapter_file_path.mkdir(parents=True, exist_ok=True)
|
||||||
|
adapter_file_name = Path.joinpath(adapter_file_path, "adapter.pth")
|
||||||
|
torch.save(state_dict[training.ADAPTER_KEY], adapter_file_name)
|
||||||
|
logger.info(
|
||||||
|
"Adapter checkpoint of size "
|
||||||
|
f"{os.path.getsize(adapter_file_name) / 1000**3:.2f} GB "
|
||||||
|
f"saved to {adapter_file_name}"
|
||||||
|
)
|
||||||
|
|
||||||
|
elif adapter_only:
|
||||||
|
raise ValueError(
|
||||||
|
"Adapter checkpoint not found in state_dict. Please ensure that the state_dict contains adapter weights."
|
||||||
|
)
|
||||||
|
|
||||||
|
print("model_file_path", str(model_file_path))
|
||||||
|
|
||||||
|
return str(model_file_path)
|
|
@ -0,0 +1,139 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, IAny, nc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa
|
||||||
|
from llama_models.datatypes import Model
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
from llama_stack.apis.common.type_system import ParamType
|
||||||
|
|
||||||
|
from torchtune.models.llama3 import llama3_tokenizer, lora_llama3_8b
|
||||||
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
|
from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||||
|
|
||||||
|
|
||||||
|
class ColumnName(Enum):
|
||||||
|
instruction = "instruction"
|
||||||
|
input = "input"
|
||||||
|
output = "output"
|
||||||
|
text = "text"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelConfig(BaseModel):
|
||||||
|
model_definition: Any
|
||||||
|
tokenizer_type: Any
|
||||||
|
checkpoint_type: str
|
||||||
|
|
||||||
|
|
||||||
|
class DatasetSchema(BaseModel):
|
||||||
|
alpaca: List[Dict[str, ParamType]]
|
||||||
|
|
||||||
|
|
||||||
|
MODEL_CONFIGS: Dict[str, ModelConfig] = {
|
||||||
|
"Llama3.2-3B-Instruct": ModelConfig(
|
||||||
|
model_definition=lora_llama3_2_3b,
|
||||||
|
tokenizer_type=llama3_tokenizer,
|
||||||
|
checkpoint_type="LLAMA3_2",
|
||||||
|
),
|
||||||
|
"Llama-3-8B-Instruct": ModelConfig(
|
||||||
|
model_definition=lora_llama3_8b,
|
||||||
|
tokenizer_type=llama3_tokenizer,
|
||||||
|
checkpoint_type="LLAMA3",
|
||||||
|
),
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
EXPECTED_DATASET_SCHEMA = DatasetSchema(
|
||||||
|
alpaca=[
|
||||||
|
{
|
||||||
|
ColumnName.instruction.value: StringType(),
|
||||||
|
ColumnName.input.value: StringType(),
|
||||||
|
ColumnName.output.value: StringType(),
|
||||||
|
ColumnName.text.value: StringType(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ColumnName.instruction.value: StringType(),
|
||||||
|
ColumnName.input.value: StringType(),
|
||||||
|
ColumnName.output.value: StringType(),
|
||||||
|
},
|
||||||
|
{
|
||||||
|
ColumnName.instruction.value: StringType(),
|
||||||
|
ColumnName.output.value: StringType(),
|
||||||
|
},
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||||
|
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||||
|
|
||||||
|
|
||||||
|
def _validate_model_id(model_id: str) -> Model:
|
||||||
|
model = resolve_model(model_id)
|
||||||
|
if model is None or model.core_model_id.value not in MODEL_CONFIGS:
|
||||||
|
raise ValueError(f"Model {model_id} is not supported.")
|
||||||
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
async def get_model_definition(
|
||||||
|
model_id: str,
|
||||||
|
) -> BuildLoraModelCallable:
|
||||||
|
model = _validate_model_id(model_id)
|
||||||
|
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||||
|
if not hasattr(model_config, "model_definition"):
|
||||||
|
raise ValueError(f"Model {model_id} does not have model definition.")
|
||||||
|
return model_config.model_definition
|
||||||
|
|
||||||
|
|
||||||
|
async def get_tokenizer_type(
|
||||||
|
model_id: str,
|
||||||
|
) -> BuildTokenizerCallable:
|
||||||
|
model = _validate_model_id(model_id)
|
||||||
|
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||||
|
if not hasattr(model_config, "tokenizer_type"):
|
||||||
|
raise ValueError(f"Model {model_id} does not have tokenizer_type.")
|
||||||
|
return model_config.tokenizer_type
|
||||||
|
|
||||||
|
|
||||||
|
async def get_checkpointer_model_type(
|
||||||
|
model_id: str,
|
||||||
|
) -> str:
|
||||||
|
"""
|
||||||
|
checkpointer model type is used in checkpointer for some special treatment on some specific model types
|
||||||
|
For example, llama3.2 model tied weights (https://github.com/pytorch/torchtune/blob/main/torchtune/training/checkpointing/_checkpointer.py#L1041)
|
||||||
|
"""
|
||||||
|
model = _validate_model_id(model_id)
|
||||||
|
model_config = MODEL_CONFIGS[model.core_model_id.value]
|
||||||
|
if not hasattr(model_config, "checkpoint_type"):
|
||||||
|
raise ValueError(f"Model {model_id} does not have checkpoint_type.")
|
||||||
|
return model_config.checkpoint_type
|
||||||
|
|
||||||
|
|
||||||
|
async def validate_input_dataset_schema(
|
||||||
|
datasets_api: Datasets,
|
||||||
|
dataset_id: str,
|
||||||
|
dataset_type: str,
|
||||||
|
) -> None:
|
||||||
|
dataset_def = await datasets_api.get_dataset(dataset_id=dataset_id)
|
||||||
|
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||||
|
raise ValueError(f"Dataset {dataset_id} does not have a schema defined.")
|
||||||
|
|
||||||
|
if not hasattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||||
|
raise ValueError(f"Dataset type {dataset_type} is not supported.")
|
||||||
|
|
||||||
|
if dataset_def.dataset_schema not in getattr(EXPECTED_DATASET_SCHEMA, dataset_type):
|
||||||
|
raise ValueError(
|
||||||
|
f"Dataset {dataset_id} does not have a correct input schema in {getattr(EXPECTED_DATASET_SCHEMA, dataset_type)}"
|
||||||
|
)
|
|
@ -0,0 +1,13 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
|
||||||
|
class TorchtunePostTrainingConfig(BaseModel):
|
||||||
|
torch_seed: Optional[int] = None
|
|
@ -0,0 +1,66 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the BSD-style license found in the
|
||||||
|
# LICENSE file in the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, List, Mapping
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
|
||||||
|
from torch.utils.data import Dataset
|
||||||
|
from torchtune.data._common import CROSS_ENTROPY_IGNORE_IDX
|
||||||
|
from torchtune.data._messages import validate_messages
|
||||||
|
from torchtune.modules.transforms import Transform
|
||||||
|
|
||||||
|
|
||||||
|
class SFTDataset(Dataset):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
rows: List[Dict[str, Any]],
|
||||||
|
message_transform: Transform,
|
||||||
|
model_transform: Transform,
|
||||||
|
) -> None:
|
||||||
|
self._rows = rows
|
||||||
|
self._message_transform = message_transform
|
||||||
|
self._model_transform = model_transform
|
||||||
|
|
||||||
|
def __len__(self):
|
||||||
|
return len(self._rows)
|
||||||
|
|
||||||
|
def __getitem__(self, index: int) -> Dict[str, Any]:
|
||||||
|
sample = self._rows[index]
|
||||||
|
return self._prepare_sample(sample)
|
||||||
|
|
||||||
|
def _prepare_sample(self, sample: Mapping[str, Any]) -> Dict[str, Any]:
|
||||||
|
transformed_sample = self._message_transform(sample)
|
||||||
|
if "messages" in transformed_sample:
|
||||||
|
validate_messages(transformed_sample["messages"])
|
||||||
|
|
||||||
|
tokenized_dict = self._model_transform(transformed_sample)
|
||||||
|
|
||||||
|
if not ("tokens" in tokenized_dict and "mask" in tokenized_dict):
|
||||||
|
keys_str = ", ".join(tokenized_dict.keys())
|
||||||
|
error_message = (
|
||||||
|
"model_transform returned the following keys: "
|
||||||
|
f"{keys_str}. Must return 'tokens' and 'mask' as keys."
|
||||||
|
)
|
||||||
|
raise ValueError(error_message)
|
||||||
|
|
||||||
|
# Wherever mask == True, set to CROSS_ENTROPY_IGNORE_IDX. Otherwise keep as tokens
|
||||||
|
tokenized_dict["labels"] = list(
|
||||||
|
np.where(
|
||||||
|
tokenized_dict["mask"],
|
||||||
|
CROSS_ENTROPY_IGNORE_IDX,
|
||||||
|
tokenized_dict["tokens"],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
assert len(tokenized_dict["tokens"]) == len(tokenized_dict["labels"])
|
||||||
|
|
||||||
|
return tokenized_dict
|
|
@ -0,0 +1,126 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
|
TorchtunePostTrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.post_training import * # noqa
|
||||||
|
from llama_stack.providers.inline.post_training.torchtune.recipes.lora_finetuning_single_device import (
|
||||||
|
LoraFinetuningSingleDevice,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class TorchtunePostTrainingImpl:
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TorchtunePostTrainingConfig,
|
||||||
|
datasetio_api: DatasetIO,
|
||||||
|
datasets: Datasets,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets
|
||||||
|
|
||||||
|
# TODO: assume sync job, will need jobs API for async scheduling
|
||||||
|
self.jobs_status = {}
|
||||||
|
self.jobs_list = []
|
||||||
|
self.checkpoints_dict = {}
|
||||||
|
|
||||||
|
async def supervised_fine_tune(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: Dict[str, Any],
|
||||||
|
logger_config: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
checkpoint_dir: Optional[str],
|
||||||
|
algorithm_config: Optional[AlgorithmConfig],
|
||||||
|
) -> PostTrainingJob:
|
||||||
|
for job in self.jobs_list:
|
||||||
|
if job_uuid == job.job_uuid:
|
||||||
|
raise ValueError(f"Job {job_uuid} already exists")
|
||||||
|
|
||||||
|
post_training_job = PostTrainingJob(job_uuid=job_uuid)
|
||||||
|
|
||||||
|
job_status_response = PostTrainingJobStatusResponse(
|
||||||
|
job_uuid=job_uuid,
|
||||||
|
status=JobStatus.scheduled,
|
||||||
|
scheduled_at=datetime.now(),
|
||||||
|
)
|
||||||
|
|
||||||
|
self.jobs_list.append(post_training_job)
|
||||||
|
if isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
|
try:
|
||||||
|
recipe = LoraFinetuningSingleDevice(
|
||||||
|
self.config,
|
||||||
|
job_uuid,
|
||||||
|
training_config,
|
||||||
|
hyperparam_search_config,
|
||||||
|
logger_config,
|
||||||
|
model,
|
||||||
|
checkpoint_dir,
|
||||||
|
algorithm_config,
|
||||||
|
self.datasetio_api,
|
||||||
|
self.datasets_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
job_status_response.status = JobStatus.in_progress
|
||||||
|
job_status_response.started_at = datetime.now()
|
||||||
|
|
||||||
|
await recipe.setup()
|
||||||
|
resources_allocated, checkpoints = await recipe.train()
|
||||||
|
|
||||||
|
self.checkpoints_dict[job_uuid] = checkpoints
|
||||||
|
job_status_response.resources_allocated = resources_allocated
|
||||||
|
job_status_response.checkpoints = checkpoints
|
||||||
|
job_status_response.status = JobStatus.completed
|
||||||
|
job_status_response.completed_at = datetime.now()
|
||||||
|
|
||||||
|
except Exception:
|
||||||
|
job_status_response.status = JobStatus.failed
|
||||||
|
raise
|
||||||
|
else:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
self.jobs_status[job_uuid] = job_status_response
|
||||||
|
|
||||||
|
return post_training_job
|
||||||
|
|
||||||
|
async def preference_optimize(
|
||||||
|
self,
|
||||||
|
job_uuid: str,
|
||||||
|
finetuned_model: str,
|
||||||
|
algorithm_config: DPOAlignmentConfig,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: Dict[str, Any],
|
||||||
|
logger_config: Dict[str, Any],
|
||||||
|
) -> PostTrainingJob: ...
|
||||||
|
|
||||||
|
async def get_training_jobs(self) -> List[PostTrainingJob]:
|
||||||
|
return self.jobs_list
|
||||||
|
|
||||||
|
@webmethod(route="/post-training/job/status")
|
||||||
|
async def get_training_job_status(
|
||||||
|
self, job_uuid: str
|
||||||
|
) -> Optional[PostTrainingJobStatusResponse]:
|
||||||
|
if job_uuid in self.jobs_status:
|
||||||
|
return self.jobs_status[job_uuid]
|
||||||
|
return None
|
||||||
|
|
||||||
|
@webmethod(route="/post-training/job/cancel")
|
||||||
|
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||||
|
raise NotImplementedError("Job cancel is not implemented yet")
|
||||||
|
|
||||||
|
@webmethod(route="/post-training/job/artifacts")
|
||||||
|
async def get_training_job_artifacts(
|
||||||
|
self, job_uuid: str
|
||||||
|
) -> Optional[PostTrainingJobArtifactsResponse]:
|
||||||
|
if job_uuid in self.checkpoints_dict:
|
||||||
|
checkpoints = self.checkpoints_dict.get(job_uuid, [])
|
||||||
|
return PostTrainingJobArtifactsResponse(
|
||||||
|
job_uuid=job_uuid, checkpoints=checkpoints
|
||||||
|
)
|
||||||
|
return None
|
|
@ -0,0 +1,596 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
import os
|
||||||
|
import time
|
||||||
|
from functools import partial
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
import torch
|
||||||
|
from llama_models.sku_list import resolve_model
|
||||||
|
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
|
||||||
|
from llama_stack.distribution.utils.config_dirs import DEFAULT_CHECKPOINT_DIR
|
||||||
|
from llama_stack.providers.inline.post_training.torchtune.common.checkpointer import (
|
||||||
|
TorchtuneCheckpointer,
|
||||||
|
)
|
||||||
|
from torch import nn
|
||||||
|
from torchtune import utils as torchtune_utils
|
||||||
|
from torchtune.training.metric_logging import DiskLogger
|
||||||
|
from tqdm import tqdm
|
||||||
|
from llama_stack.apis.post_training import * # noqa
|
||||||
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
from llama_stack.providers.inline.post_training.torchtune.common import utils
|
||||||
|
from llama_stack.providers.inline.post_training.torchtune.config import (
|
||||||
|
TorchtunePostTrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.inline.post_training.torchtune.datasets.sft import SFTDataset
|
||||||
|
from torch.optim import Optimizer
|
||||||
|
from torch.utils.data import DataLoader, DistributedSampler
|
||||||
|
from torchtune import modules, training
|
||||||
|
from torchtune.data import AlpacaToMessages, padded_collate_sft
|
||||||
|
|
||||||
|
from torchtune.modules.loss import CEWithChunkedOutputLoss
|
||||||
|
from torchtune.modules.peft import (
|
||||||
|
get_adapter_params,
|
||||||
|
get_adapter_state_dict,
|
||||||
|
get_lora_module_names,
|
||||||
|
get_merged_lora_ckpt,
|
||||||
|
load_dora_magnitudes,
|
||||||
|
set_trainable_params,
|
||||||
|
validate_missing_and_unexpected_for_lora,
|
||||||
|
)
|
||||||
|
from torchtune.training.lr_schedulers import get_cosine_schedule_with_warmup
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
from torchtune.models.llama3._tokenizer import Llama3Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class LoraFinetuningSingleDevice:
|
||||||
|
# This recipe only supports GPU training
|
||||||
|
|
||||||
|
# This recipe doesn't include several training efficiency setting within origin torchtune repo, including
|
||||||
|
# - compile
|
||||||
|
# - activation offloading
|
||||||
|
|
||||||
|
# Resume from checkpoint hasn't been supported yet
|
||||||
|
# Validation hasn't been supported yet
|
||||||
|
|
||||||
|
# Currently logging only logs limited training metrics to local disk
|
||||||
|
# will figure out more loggings and how it works with telemetry in future PRs
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: TorchtunePostTrainingConfig,
|
||||||
|
job_uuid: str,
|
||||||
|
training_config: TrainingConfig,
|
||||||
|
hyperparam_search_config: Dict[str, Any],
|
||||||
|
logger_config: Dict[str, Any],
|
||||||
|
model: str,
|
||||||
|
checkpoint_dir: Optional[str],
|
||||||
|
algorithm_config: Optional[AlgorithmConfig],
|
||||||
|
datasetio_api: DatasetIO,
|
||||||
|
datasets_api: Datasets,
|
||||||
|
) -> None:
|
||||||
|
self.job_uuid = job_uuid
|
||||||
|
self.training_config = training_config
|
||||||
|
if not isinstance(algorithm_config, LoraFinetuningConfig):
|
||||||
|
raise ValueError(
|
||||||
|
"You need to speicifc LoraFinetuningConfig for LoRA finetuning"
|
||||||
|
)
|
||||||
|
self.algorithm_config = algorithm_config
|
||||||
|
self._device = torchtune_utils.get_device(device="cuda")
|
||||||
|
self._dtype = training.get_dtype(training_config.dtype, device=self._device)
|
||||||
|
self.model_id = model
|
||||||
|
|
||||||
|
def model_checkpoint_dir(model) -> str:
|
||||||
|
checkpoint_dir = Path(model_local_dir(model.descriptor()))
|
||||||
|
|
||||||
|
paths = [
|
||||||
|
Path(checkpoint_dir / f"consolidated.{ext}")
|
||||||
|
for ext in ["pth", "00.pth"]
|
||||||
|
]
|
||||||
|
if not any(p.exists() for p in paths):
|
||||||
|
checkpoint_dir = checkpoint_dir / "original"
|
||||||
|
|
||||||
|
assert checkpoint_dir.exists(), (
|
||||||
|
f"Could not find checkpoints in: {model_local_dir(model.descriptor())}. "
|
||||||
|
f"Please download model using `llama download --model-id {model.descriptor()}`"
|
||||||
|
)
|
||||||
|
return str(checkpoint_dir)
|
||||||
|
|
||||||
|
if checkpoint_dir and checkpoint_dir != "null":
|
||||||
|
self.checkpoint_dir = config.checkpoint_dir
|
||||||
|
else:
|
||||||
|
model = resolve_model(self.model_id)
|
||||||
|
self.checkpoint_dir = model_checkpoint_dir(model)
|
||||||
|
|
||||||
|
self._output_dir = str(DEFAULT_CHECKPOINT_DIR)
|
||||||
|
|
||||||
|
self.seed = training.set_seed(seed=config.torch_seed)
|
||||||
|
self.epochs_run = 0
|
||||||
|
self.total_epochs = training_config.n_epochs
|
||||||
|
self._shuffle = training_config.data_config.shuffle
|
||||||
|
self._batch_size = training_config.data_config.batch_size
|
||||||
|
|
||||||
|
# this is important for debugging purpose
|
||||||
|
self.max_steps_per_epoch = training_config.max_steps_per_epoch
|
||||||
|
self.global_step = 0
|
||||||
|
|
||||||
|
self._gradient_accumulation_steps = training_config.gradient_accumulation_steps
|
||||||
|
|
||||||
|
self._clip_grad_norm = 1.0
|
||||||
|
self._enable_activation_checkpointing = (
|
||||||
|
(training_config.efficiency_config.enable_activation_checkpointing)
|
||||||
|
if training_config.efficiency_config
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
self._enable_activation_offloading = (
|
||||||
|
(training_config.efficiency_config.enable_activation_offloading)
|
||||||
|
if training_config.efficiency_config
|
||||||
|
else False
|
||||||
|
)
|
||||||
|
|
||||||
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets_api
|
||||||
|
|
||||||
|
async def load_checkpoint(self):
|
||||||
|
def get_checkpoint_files(checkpoint_dir: str) -> List[str]:
|
||||||
|
try:
|
||||||
|
# List all files in the given directory
|
||||||
|
files = os.listdir(checkpoint_dir)
|
||||||
|
# Filter files that end with .pth
|
||||||
|
pth_files = [file for file in files if file.endswith(".pth")]
|
||||||
|
return pth_files
|
||||||
|
except FileNotFoundError:
|
||||||
|
return [f"Error: The directory '{checkpoint_dir}' does not exist."]
|
||||||
|
|
||||||
|
self._checkpointer = TorchtuneCheckpointer(
|
||||||
|
model_id=self.model_id,
|
||||||
|
training_algorithm="sft",
|
||||||
|
checkpoint_dir=self.checkpoint_dir,
|
||||||
|
checkpoint_files=get_checkpoint_files(self.checkpoint_dir),
|
||||||
|
output_dir=self._output_dir,
|
||||||
|
model_type=await utils.get_checkpointer_model_type(self.model_id),
|
||||||
|
)
|
||||||
|
checkpoint_dict = self._checkpointer.load_checkpoint()
|
||||||
|
return checkpoint_dict
|
||||||
|
|
||||||
|
async def setup(self) -> None:
|
||||||
|
checkpoint_dict = await self.load_checkpoint()
|
||||||
|
|
||||||
|
self._model = await self._setup_model(
|
||||||
|
enable_activation_checkpointing=self._enable_activation_checkpointing,
|
||||||
|
enable_activation_offloading=self._enable_activation_offloading,
|
||||||
|
base_model_state_dict=checkpoint_dict[training.MODEL_KEY],
|
||||||
|
lora_weights_state_dict=None,
|
||||||
|
)
|
||||||
|
log.info(f"Model is initialized with precision {self._dtype}.")
|
||||||
|
|
||||||
|
self._tokenizer = await self._setup_tokenizer()
|
||||||
|
log.info("Tokenizer is initialized.")
|
||||||
|
|
||||||
|
self._optimizer = await self._setup_optimizer(
|
||||||
|
optimizer_config=self.training_config.optimizer_config
|
||||||
|
)
|
||||||
|
log.info("Optimizer is initialized.")
|
||||||
|
|
||||||
|
self._loss_fn = CEWithChunkedOutputLoss()
|
||||||
|
self._model.set_num_output_chunks(self._loss_fn.num_output_chunks)
|
||||||
|
log.info("Loss is initialized.")
|
||||||
|
|
||||||
|
self._training_sampler, self._training_dataloader = await self._setup_data(
|
||||||
|
dataset_id=self.training_config.data_config.dataset_id,
|
||||||
|
tokenizer=self._tokenizer,
|
||||||
|
shuffle=self._shuffle,
|
||||||
|
batch_size=self._batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
if self.training_config.data_config.validation_dataset_id:
|
||||||
|
_, self._validation_dataloader = await self._setup_data(
|
||||||
|
dataset_id=self.training_config.data_config.validation_dataset_id,
|
||||||
|
tokenizer=self._tokenizer,
|
||||||
|
shuffle=False,
|
||||||
|
batch_size=self._batch_size,
|
||||||
|
)
|
||||||
|
|
||||||
|
log.info("Dataset and Sampler are initialized.")
|
||||||
|
|
||||||
|
# Number of training steps in each epoch depends on the number of batches produced
|
||||||
|
# by the dataloader and the max_steps_per_epoch param set by the user and is used
|
||||||
|
# for logging and tracking training state. This should be computed after the dataloader
|
||||||
|
# has been setup
|
||||||
|
self._steps_per_epoch = (
|
||||||
|
len(self._training_dataloader) // self._gradient_accumulation_steps
|
||||||
|
)
|
||||||
|
if (
|
||||||
|
self.max_steps_per_epoch is not None
|
||||||
|
and self.max_steps_per_epoch < self._steps_per_epoch
|
||||||
|
):
|
||||||
|
self._steps_per_epoch = self.max_steps_per_epoch
|
||||||
|
self.global_step = self.epochs_run * self._steps_per_epoch
|
||||||
|
|
||||||
|
# Learning rate scheduler can only be set up after number of steps
|
||||||
|
# has been computed
|
||||||
|
self._lr_scheduler = await self._setup_lr_scheduler(
|
||||||
|
num_warmup_steps=self.training_config.optimizer_config.num_warmup_steps,
|
||||||
|
num_training_steps=self.total_epochs * self._steps_per_epoch,
|
||||||
|
last_epoch=self.global_step - 1,
|
||||||
|
)
|
||||||
|
log.info("Learning rate scheduler is initialized.")
|
||||||
|
|
||||||
|
# Used to ignore labels for loss computation
|
||||||
|
self.ignore_labels_cache = torch.full(
|
||||||
|
(self._batch_size, 1), self._loss_fn.ignore_index, device=self._device
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _setup_model(
|
||||||
|
self,
|
||||||
|
enable_activation_checkpointing: bool,
|
||||||
|
enable_activation_offloading: bool,
|
||||||
|
base_model_state_dict: Dict[str, Any],
|
||||||
|
lora_weights_state_dict: Optional[Dict[str, Any]] = None,
|
||||||
|
) -> nn.Module:
|
||||||
|
self._lora_rank = self.algorithm_config.rank
|
||||||
|
self._lora_alpha = self.algorithm_config.alpha
|
||||||
|
self._lora_attn_modules = list(self.algorithm_config.lora_attn_modules)
|
||||||
|
self._apply_lora_to_mlp = self.algorithm_config.apply_lora_to_mlp
|
||||||
|
self._apply_lora_to_output = self.algorithm_config.apply_lora_to_output
|
||||||
|
self._use_dora = self.algorithm_config.use_dora or False
|
||||||
|
|
||||||
|
with training.set_default_dtype(self._dtype), self._device:
|
||||||
|
model_type = await utils.get_model_definition(self.model_id)
|
||||||
|
model = model_type(
|
||||||
|
lora_attn_modules=self._lora_attn_modules,
|
||||||
|
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
||||||
|
apply_lora_to_output=self._apply_lora_to_output,
|
||||||
|
lora_rank=self._lora_rank,
|
||||||
|
lora_alpha=self._lora_alpha,
|
||||||
|
quantize_base=False,
|
||||||
|
use_dora=self._use_dora,
|
||||||
|
)
|
||||||
|
|
||||||
|
self.adapter_params = get_adapter_params(model)
|
||||||
|
self._is_dora = any(["magnitude" in k for k in self.adapter_params.keys()])
|
||||||
|
|
||||||
|
set_trainable_params(model, self.adapter_params)
|
||||||
|
|
||||||
|
if enable_activation_checkpointing:
|
||||||
|
training.set_activation_checkpointing(
|
||||||
|
model, auto_wrap_policy={modules.TransformerSelfAttentionLayer}
|
||||||
|
)
|
||||||
|
|
||||||
|
base_missing, base_unexpected = model.load_state_dict(
|
||||||
|
base_model_state_dict, strict=False
|
||||||
|
)
|
||||||
|
|
||||||
|
# This is for any adapters that need to be initialized after base weights
|
||||||
|
# have been loaded (e.g. DoRA).
|
||||||
|
if self._is_dora:
|
||||||
|
for m in model.modules():
|
||||||
|
if hasattr(m, "initialize_dora_magnitude"):
|
||||||
|
m.initialize_dora_magnitude()
|
||||||
|
load_dora_magnitudes(model)
|
||||||
|
if lora_weights_state_dict:
|
||||||
|
lora_missing, lora_unexpected = model.load_state_dict(
|
||||||
|
lora_weights_state_dict, strict=False
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
lora_missing, lora_unexpected = None, None
|
||||||
|
validate_missing_and_unexpected_for_lora(
|
||||||
|
lora_attn_modules=self._lora_attn_modules,
|
||||||
|
apply_lora_to_mlp=self._apply_lora_to_mlp,
|
||||||
|
apply_lora_to_output=self._apply_lora_to_output,
|
||||||
|
base_missing=base_missing,
|
||||||
|
base_unexpected=base_unexpected,
|
||||||
|
lora_missing=lora_missing,
|
||||||
|
lora_unexpected=lora_unexpected,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Validate model adapter params were loaded in with the expected dtype
|
||||||
|
training.validate_expected_param_dtype(
|
||||||
|
self.adapter_params.items(), dtype=self._dtype
|
||||||
|
)
|
||||||
|
|
||||||
|
# activation offloading
|
||||||
|
self.activations_handling_ctx = training.get_act_offloading_ctx_manager(
|
||||||
|
model, enable_activation_offloading
|
||||||
|
)
|
||||||
|
|
||||||
|
memory_stats = training.get_memory_stats(device=self._device)
|
||||||
|
training.log_memory_stats(memory_stats)
|
||||||
|
|
||||||
|
return model
|
||||||
|
|
||||||
|
async def _setup_tokenizer(
|
||||||
|
self,
|
||||||
|
) -> Llama3Tokenizer:
|
||||||
|
tokenizer_path = self.checkpoint_dir + "/tokenizer.model"
|
||||||
|
tokenizer_type = await utils.get_tokenizer_type(self.model_id)
|
||||||
|
return tokenizer_type(path=tokenizer_path)
|
||||||
|
|
||||||
|
async def _setup_optimizer(self, optimizer_config: OptimizerConfig) -> Optimizer:
|
||||||
|
optimizer = torch.optim.AdamW(
|
||||||
|
params=self._model.parameters(),
|
||||||
|
lr=optimizer_config.lr,
|
||||||
|
betas=(0.9, 0.95),
|
||||||
|
eps=1e-8,
|
||||||
|
weight_decay=0.1,
|
||||||
|
)
|
||||||
|
return optimizer
|
||||||
|
|
||||||
|
async def _setup_data(
|
||||||
|
self,
|
||||||
|
dataset_id: str,
|
||||||
|
tokenizer: Llama3Tokenizer,
|
||||||
|
shuffle: bool,
|
||||||
|
batch_size: int,
|
||||||
|
) -> Tuple[DistributedSampler, DataLoader]:
|
||||||
|
async def fetch_rows(dataset_id: str):
|
||||||
|
return await self.datasetio_api.get_rows_paginated(
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
rows_in_page=-1,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_rows = await fetch_rows(dataset_id)
|
||||||
|
rows = all_rows.rows
|
||||||
|
|
||||||
|
# Curretly only support alpaca instruct dataset
|
||||||
|
# TODO @SLR722 make the message_transform swappable and support more dataset types
|
||||||
|
# TODO @SLR722 make the input dataset schema more flexible by exposing column_map
|
||||||
|
await utils.validate_input_dataset_schema(
|
||||||
|
datasets_api=self.datasets_api,
|
||||||
|
dataset_id=dataset_id,
|
||||||
|
dataset_type="alpaca",
|
||||||
|
)
|
||||||
|
ds = SFTDataset(
|
||||||
|
rows,
|
||||||
|
message_transform=AlpacaToMessages(train_on_input=False),
|
||||||
|
model_transform=tokenizer,
|
||||||
|
)
|
||||||
|
|
||||||
|
sampler = DistributedSampler(
|
||||||
|
ds,
|
||||||
|
num_replicas=1,
|
||||||
|
rank=0,
|
||||||
|
shuffle=shuffle,
|
||||||
|
seed=0,
|
||||||
|
)
|
||||||
|
dataloader = DataLoader(
|
||||||
|
dataset=ds,
|
||||||
|
sampler=sampler,
|
||||||
|
batch_size=batch_size,
|
||||||
|
# dropping last avoids shape issues with compile + flex attention
|
||||||
|
drop_last=True,
|
||||||
|
collate_fn=(
|
||||||
|
partial(
|
||||||
|
padded_collate_sft,
|
||||||
|
padding_idx=self._tokenizer.pad_id,
|
||||||
|
ignore_idx=self._loss_fn.ignore_index,
|
||||||
|
)
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
return sampler, dataloader
|
||||||
|
|
||||||
|
async def _setup_lr_scheduler(
|
||||||
|
self,
|
||||||
|
num_warmup_steps: int,
|
||||||
|
num_training_steps: int,
|
||||||
|
last_epoch: int,
|
||||||
|
) -> Optimizer:
|
||||||
|
lr_scheduler = get_cosine_schedule_with_warmup(
|
||||||
|
self._optimizer,
|
||||||
|
num_warmup_steps=num_warmup_steps,
|
||||||
|
num_training_steps=num_training_steps,
|
||||||
|
last_epoch=last_epoch,
|
||||||
|
)
|
||||||
|
return lr_scheduler
|
||||||
|
|
||||||
|
async def save_checkpoint(self, epoch: int) -> str:
|
||||||
|
ckpt_dict = {}
|
||||||
|
|
||||||
|
adapter_state_dict = get_adapter_state_dict(self._model.state_dict())
|
||||||
|
ckpt_dict.update({training.ADAPTER_KEY: adapter_state_dict})
|
||||||
|
|
||||||
|
# Construct the full state dict with LoRA weights merged into base LLM weights
|
||||||
|
# Move to CPU to avoid a copy on GPU
|
||||||
|
state_dict = {k: v.cpu() for k, v in self._model.state_dict().items()}
|
||||||
|
|
||||||
|
merged_state_dict = get_merged_lora_ckpt(
|
||||||
|
state_dict,
|
||||||
|
rank=self._lora_rank,
|
||||||
|
alpha=self._lora_alpha,
|
||||||
|
)
|
||||||
|
|
||||||
|
ckpt_dict.update({training.MODEL_KEY: merged_state_dict})
|
||||||
|
|
||||||
|
adapter_config = {
|
||||||
|
"r": self._lora_rank,
|
||||||
|
"lora_alpha": self._lora_alpha,
|
||||||
|
"target_modules": get_lora_module_names(
|
||||||
|
self._lora_attn_modules,
|
||||||
|
self._apply_lora_to_mlp,
|
||||||
|
self._apply_lora_to_output,
|
||||||
|
),
|
||||||
|
"peft_type": "LORA",
|
||||||
|
}
|
||||||
|
ckpt_dict.update({training.ADAPTER_CONFIG: adapter_config})
|
||||||
|
|
||||||
|
return self._checkpointer.save_checkpoint(
|
||||||
|
ckpt_dict,
|
||||||
|
epoch=epoch,
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _loss_step(self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
|
||||||
|
# Shape [b, s], needed for the loss not the model
|
||||||
|
labels = batch.pop("labels")
|
||||||
|
# run model
|
||||||
|
with self.activations_handling_ctx:
|
||||||
|
logits = self._model(**batch)
|
||||||
|
|
||||||
|
# Shift labels to compute loss
|
||||||
|
# equivalent to doing labels[..., 1:] and logits[..., :-1, :]
|
||||||
|
# But this way we dont need to slice the logits. We just add an ignore index to labels.
|
||||||
|
labels = torch.hstack(
|
||||||
|
(labels[..., 1:], self.ignore_labels_cache[: labels.shape[0]])
|
||||||
|
)
|
||||||
|
if not isinstance(logits, list):
|
||||||
|
labels = labels.reshape(-1)
|
||||||
|
logits = logits.reshape(-1, logits.size(-1))
|
||||||
|
|
||||||
|
loss = self._loss_fn(logits, labels)
|
||||||
|
|
||||||
|
# free logits otherwise it peaks backward memory
|
||||||
|
del logits
|
||||||
|
|
||||||
|
return loss
|
||||||
|
|
||||||
|
async def train(self) -> Tuple[Dict[str, Any], List[Checkpoint]]:
|
||||||
|
"""
|
||||||
|
The core training loop.
|
||||||
|
"""
|
||||||
|
# Initialize tokens count and running loss (for grad accumulation)
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
running_loss = 0
|
||||||
|
num_tokens = 0
|
||||||
|
|
||||||
|
# training artifacts
|
||||||
|
checkpoints = []
|
||||||
|
memory_stats = {}
|
||||||
|
|
||||||
|
# self.epochs_run should be non-zero when we're resuming from a checkpoint
|
||||||
|
for curr_epoch in range(self.epochs_run, self.total_epochs):
|
||||||
|
# Update the sampler to ensure data is correctly shuffled across epochs
|
||||||
|
# in case shuffle is True
|
||||||
|
metric_logger = DiskLogger(
|
||||||
|
log_dir=self._output_dir + f"/{self.model_id}-sft-{curr_epoch}"
|
||||||
|
)
|
||||||
|
self._training_sampler.set_epoch(curr_epoch)
|
||||||
|
loss_to_log = 0.0
|
||||||
|
|
||||||
|
pbar = tqdm(total=self._steps_per_epoch)
|
||||||
|
for idx, batch in enumerate(self._training_dataloader):
|
||||||
|
if (
|
||||||
|
self.max_steps_per_epoch is not None
|
||||||
|
and (idx // self._gradient_accumulation_steps)
|
||||||
|
== self.max_steps_per_epoch
|
||||||
|
):
|
||||||
|
break
|
||||||
|
|
||||||
|
torchtune_utils.batch_to_device(batch, self._device)
|
||||||
|
|
||||||
|
# Calculate the number of unmasked tokens in the current batch
|
||||||
|
# and increment the total number of tokens seen in the step
|
||||||
|
current_num_tokens = (
|
||||||
|
batch["labels"] != self._loss_fn.ignore_index
|
||||||
|
).sum()
|
||||||
|
num_tokens += current_num_tokens
|
||||||
|
|
||||||
|
# Loss is normalized by default so we multiply by the number of tokens
|
||||||
|
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||||
|
current_loss = await self._loss_step(batch) * current_num_tokens
|
||||||
|
running_loss += current_loss
|
||||||
|
current_loss.backward()
|
||||||
|
|
||||||
|
# Step with optimizer
|
||||||
|
if (idx + 1) % self._gradient_accumulation_steps == 0:
|
||||||
|
training.scale_grads(self._model, 1 / num_tokens)
|
||||||
|
grad_norm = torch.nn.utils.clip_grad_norm_(
|
||||||
|
self._model.parameters(),
|
||||||
|
max_norm=float(self._clip_grad_norm),
|
||||||
|
)
|
||||||
|
self._optimizer.step()
|
||||||
|
self._optimizer.zero_grad(set_to_none=True)
|
||||||
|
self._lr_scheduler.step()
|
||||||
|
# Update the number of steps when the weights are updated
|
||||||
|
self.global_step += 1
|
||||||
|
|
||||||
|
loss_to_log = running_loss.item() / num_tokens
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(
|
||||||
|
f"{curr_epoch + 1}|{self.global_step}|Loss: {loss_to_log}"
|
||||||
|
)
|
||||||
|
|
||||||
|
time_per_step = time.perf_counter() - t0
|
||||||
|
log_dict = {
|
||||||
|
"loss": loss_to_log,
|
||||||
|
"lr": self._optimizer.param_groups[0]["lr"],
|
||||||
|
"tokens_per_second_per_gpu": num_tokens / time_per_step,
|
||||||
|
}
|
||||||
|
|
||||||
|
memory_stats = training.get_memory_stats(device=self._device)
|
||||||
|
log_dict.update(memory_stats)
|
||||||
|
|
||||||
|
if self._clip_grad_norm is not None:
|
||||||
|
log_dict.update({"grad_norm": grad_norm})
|
||||||
|
|
||||||
|
metric_logger.log_dict(
|
||||||
|
log_dict,
|
||||||
|
step=self.global_step,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Reset running stats for the next step
|
||||||
|
running_loss = 0
|
||||||
|
num_tokens = 0
|
||||||
|
t0 = time.perf_counter()
|
||||||
|
|
||||||
|
self.epochs_run += 1
|
||||||
|
log.info("Starting checkpoint save...")
|
||||||
|
checkpoint_path = await self.save_checkpoint(epoch=curr_epoch)
|
||||||
|
checkpoint = Checkpoint(
|
||||||
|
identifier=f"{self.model_id}-sft-{curr_epoch}",
|
||||||
|
created_at=datetime.now(),
|
||||||
|
epoch=curr_epoch,
|
||||||
|
post_training_job_id=self.job_uuid,
|
||||||
|
path=checkpoint_path,
|
||||||
|
)
|
||||||
|
if self.training_config.data_config.validation_dataset_id:
|
||||||
|
validation_loss, perplexity = await self.validation()
|
||||||
|
training_metrics = PostTrainingMetric(
|
||||||
|
epoch=curr_epoch,
|
||||||
|
train_loss=loss_to_log,
|
||||||
|
validation_loss=validation_loss,
|
||||||
|
perplexity=perplexity,
|
||||||
|
)
|
||||||
|
checkpoint.training_metrics = training_metrics
|
||||||
|
checkpoints.append(checkpoint)
|
||||||
|
|
||||||
|
return (memory_stats, checkpoints)
|
||||||
|
|
||||||
|
async def validation(self) -> Tuple[float, float]:
|
||||||
|
total_loss = 0.0
|
||||||
|
total_tokens = 0
|
||||||
|
log.info("Starting validation...")
|
||||||
|
pbar = tqdm(total=len(self._validation_dataloader))
|
||||||
|
for idx, batch in enumerate(self._validation_dataloader):
|
||||||
|
if idx == 10:
|
||||||
|
break
|
||||||
|
torchtune_utils.batch_to_device(batch, self._device)
|
||||||
|
|
||||||
|
# Calculate the number of unmasked tokens in the current batch
|
||||||
|
# and increment the total number of tokens seen in the step
|
||||||
|
num_tokens = (batch["labels"] != self._loss_fn.ignore_index).sum()
|
||||||
|
|
||||||
|
# Loss is normalized by default so we multiply by the number of tokens
|
||||||
|
# This way we can normalize by the total number of tokens if we're accumulating gradients
|
||||||
|
loss = await self._loss_step(batch) * num_tokens
|
||||||
|
|
||||||
|
total_loss += loss
|
||||||
|
total_tokens += num_tokens
|
||||||
|
|
||||||
|
pbar.update(1)
|
||||||
|
pbar.set_description(f"validation step: {idx}")
|
||||||
|
|
||||||
|
mean_loss = total_loss / total_tokens
|
||||||
|
perplexity = torch.exp(torch.tensor(mean_loss))
|
||||||
|
|
||||||
|
return mean_loss, perplexity.item()
|
|
@ -18,6 +18,7 @@ META_REFERENCE_DEPS = [
|
||||||
"transformers",
|
"transformers",
|
||||||
"zmq",
|
"zmq",
|
||||||
"lm-format-enforcer",
|
"lm-format-enforcer",
|
||||||
|
"sentence-transformers",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -52,6 +53,13 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.inference.vllm",
|
module="llama_stack.providers.inline.inference.vllm",
|
||||||
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig",
|
||||||
),
|
),
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.inference,
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
pip_packages=["sentence-transformers"],
|
||||||
|
module="llama_stack.providers.inline.inference.sentence_transformers",
|
||||||
|
config_class="llama_stack.providers.inline.inference.sentence_transformers.config.SentenceTransformersInferenceConfig",
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -39,6 +39,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
deprecation_warning="Please use the `inline::faiss` provider instead.",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -46,6 +47,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
pip_packages=EMBEDDING_DEPS + ["faiss-cpu"],
|
||||||
module="llama_stack.providers.inline.memory.faiss",
|
module="llama_stack.providers.inline.memory.faiss",
|
||||||
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
config_class="llama_stack.providers.inline.memory.faiss.FaissImplConfig",
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -55,6 +57,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.chroma",
|
module="llama_stack.providers.remote.memory.chroma",
|
||||||
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
config_class="llama_stack.providers.remote.memory.chroma.ChromaRemoteImplConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -71,6 +74,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.pgvector",
|
module="llama_stack.providers.remote.memory.pgvector",
|
||||||
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
config_class="llama_stack.providers.remote.memory.pgvector.PGVectorConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -81,6 +85,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
config_class="llama_stack.providers.remote.memory.weaviate.WeaviateConfig",
|
||||||
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
provider_data_validator="llama_stack.providers.remote.memory.weaviate.WeaviateRequestProviderData",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.memory,
|
api=Api.memory,
|
||||||
|
@ -90,6 +95,7 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.sample",
|
module="llama_stack.providers.remote.memory.sample",
|
||||||
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
config_class="llama_stack.providers.remote.memory.sample.SampleConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[],
|
||||||
),
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
Api.memory,
|
Api.memory,
|
||||||
|
@ -99,5 +105,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.remote.memory.qdrant",
|
module="llama_stack.providers.remote.memory.qdrant",
|
||||||
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
config_class="llama_stack.providers.remote.memory.qdrant.QdrantConfig",
|
||||||
),
|
),
|
||||||
|
api_dependencies=[Api.inference],
|
||||||
),
|
),
|
||||||
]
|
]
|
||||||
|
|
25
llama_stack/providers/registry/post_training.py
Normal file
25
llama_stack/providers/registry/post_training.py
Normal file
|
@ -0,0 +1,25 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
return [
|
||||||
|
InlineProviderSpec(
|
||||||
|
api=Api.post_training,
|
||||||
|
provider_type="inline::torchtune",
|
||||||
|
pip_packages=["torch", "torchtune", "torchao", "numpy"],
|
||||||
|
module="llama_stack.providers.inline.post_training.torchtune",
|
||||||
|
config_class="llama_stack.providers.inline.post_training.torchtune.TorchtunePostTrainingConfig",
|
||||||
|
api_dependencies=[
|
||||||
|
Api.datasetio,
|
||||||
|
Api.datasets,
|
||||||
|
],
|
||||||
|
),
|
||||||
|
]
|
|
@ -21,14 +21,19 @@ DATASETS_PREFIX = "datasets:"
|
||||||
|
|
||||||
def load_hf_dataset(dataset_def: Dataset):
|
def load_hf_dataset(dataset_def: Dataset):
|
||||||
if dataset_def.metadata.get("path", None):
|
if dataset_def.metadata.get("path", None):
|
||||||
return hf_datasets.load_dataset(**dataset_def.metadata)
|
dataset = hf_datasets.load_dataset(**dataset_def.metadata)
|
||||||
|
else:
|
||||||
df = get_dataframe_from_url(dataset_def.url)
|
df = get_dataframe_from_url(dataset_def.url)
|
||||||
|
|
||||||
if df is None:
|
if df is None:
|
||||||
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
raise ValueError(f"Failed to load dataset from {dataset_def.url}")
|
||||||
|
|
||||||
dataset = hf_datasets.Dataset.from_pandas(df)
|
dataset = hf_datasets.Dataset.from_pandas(df)
|
||||||
|
|
||||||
|
# drop columns not specified by schema
|
||||||
|
if dataset_def.dataset_schema:
|
||||||
|
dataset = dataset.select_columns(list(dataset_def.dataset_schema.keys()))
|
||||||
|
|
||||||
return dataset
|
return dataset
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import * # noqa: F403
|
from typing import * # noqa: F403
|
||||||
|
import json
|
||||||
|
|
||||||
from botocore.client import BaseClient
|
from botocore.client import BaseClient
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
@ -19,8 +20,10 @@ from llama_stack.providers.utils.inference.model_registry import (
|
||||||
|
|
||||||
from llama_stack.apis.inference import * # noqa: F403
|
from llama_stack.apis.inference import * # noqa: F403
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
from llama_stack.providers.remote.inference.bedrock.config import BedrockConfig
|
||||||
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
from llama_stack.providers.utils.bedrock.client import create_bedrock_client
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import content_has_media
|
||||||
|
|
||||||
|
|
||||||
model_aliases = [
|
model_aliases = [
|
||||||
|
@ -448,4 +451,21 @@ class BedrockInferenceAdapter(ModelRegistryHelper, Inference):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
embeddings = []
|
||||||
|
for content in contents:
|
||||||
|
assert not content_has_media(
|
||||||
|
content
|
||||||
|
), "Bedrock does not support media for embeddings"
|
||||||
|
input_text = interleaved_text_media_as_str(content)
|
||||||
|
input_body = {"inputText": input_text}
|
||||||
|
body = json.dumps(input_body)
|
||||||
|
response = self.client.invoke_model(
|
||||||
|
body=body,
|
||||||
|
modelId=model.provider_resource_id,
|
||||||
|
accept="application/json",
|
||||||
|
contentType="application/json",
|
||||||
|
)
|
||||||
|
response_body = json.loads(response.get("body").read())
|
||||||
|
embeddings.append(response_body.get("embedding"))
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -13,7 +13,7 @@ from pydantic import BaseModel, Field
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class FireworksImplConfig(BaseModel):
|
class FireworksImplConfig(BaseModel):
|
||||||
url: str = Field(
|
url: str = Field(
|
||||||
default="https://api.fireworks.ai/inference",
|
default="https://api.fireworks.ai/inference/v1",
|
||||||
description="The URL for the Fireworks server",
|
description="The URL for the Fireworks server",
|
||||||
)
|
)
|
||||||
api_key: Optional[str] = Field(
|
api_key: Optional[str] = Field(
|
||||||
|
@ -24,6 +24,6 @@ class FireworksImplConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls) -> Dict[str, Any]:
|
def sample_run_config(cls) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "https://api.fireworks.ai/inference",
|
"url": "https://api.fireworks.ai/inference/v1",
|
||||||
"api_key": "${env.FIREWORKS_API_KEY}",
|
"api_key": "${env.FIREWORKS_API_KEY}",
|
||||||
}
|
}
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import AsyncGenerator
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
from fireworks.client import Fireworks
|
from fireworks.client import Fireworks
|
||||||
from llama_models.datatypes import CoreModelId
|
from llama_models.datatypes import CoreModelId
|
||||||
|
@ -28,6 +28,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -89,17 +90,19 @@ class FireworksInferenceAdapter(
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
def _get_client(self) -> Fireworks:
|
def _get_api_key(self) -> str:
|
||||||
fireworks_api_key = None
|
|
||||||
if self.config.api_key is not None:
|
if self.config.api_key is not None:
|
||||||
fireworks_api_key = self.config.api_key
|
return self.config.api_key
|
||||||
else:
|
else:
|
||||||
provider_data = self.get_request_provider_data()
|
provider_data = self.get_request_provider_data()
|
||||||
if provider_data is None or not provider_data.fireworks_api_key:
|
if provider_data is None or not provider_data.fireworks_api_key:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
'Pass Fireworks API Key in the header X-LlamaStack-ProviderData as { "fireworks_api_key": <your api key>}'
|
||||||
)
|
)
|
||||||
fireworks_api_key = provider_data.fireworks_api_key
|
return provider_data.fireworks_api_key
|
||||||
|
|
||||||
|
def _get_client(self) -> Fireworks:
|
||||||
|
fireworks_api_key = self._get_api_key()
|
||||||
return Fireworks(api_key=fireworks_api_key)
|
return Fireworks(api_key=fireworks_api_key)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
|
@ -264,4 +267,19 @@ class FireworksInferenceAdapter(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
if model.metadata.get("embedding_dimensions"):
|
||||||
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Fireworks does not support media for embeddings"
|
||||||
|
response = self._get_client().embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -36,6 +36,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_image_media_to_url,
|
convert_image_media_to_url,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -321,9 +322,30 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Ollama does not support media for embeddings"
|
||||||
|
response = await self.client.embed(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
)
|
||||||
|
embeddings = response["embeddings"]
|
||||||
|
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
# ollama does not have embedding models running. Check if the model is in list of available models.
|
||||||
|
if model.model_type == ModelType.embedding:
|
||||||
|
response = await self.client.list()
|
||||||
|
available_models = [m["model"] for m in response["models"]]
|
||||||
|
if model.provider_resource_id not in available_models:
|
||||||
|
raise ValueError(
|
||||||
|
f"Model '{model.provider_resource_id}' is not available in Ollama. "
|
||||||
|
f"Available models: {', '.join(available_models)}"
|
||||||
|
)
|
||||||
|
return model
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
models = await self.client.ps()
|
models = await self.client.ps()
|
||||||
available_models = [m["model"] for m in models["models"]]
|
available_models = [m["model"] for m in models["models"]]
|
||||||
|
|
|
@ -31,6 +31,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -253,4 +254,13 @@ class TogetherInferenceAdapter(
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "Together does not support media for embeddings"
|
||||||
|
r = self._get_client().embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
)
|
||||||
|
embeddings = [item.embedding for item in r.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -29,6 +29,7 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
chat_completion_request_to_prompt,
|
chat_completion_request_to_prompt,
|
||||||
completion_request_to_prompt,
|
completion_request_to_prompt,
|
||||||
|
content_has_media,
|
||||||
convert_message_to_dict,
|
convert_message_to_dict,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
@ -203,4 +204,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
model_id: str,
|
model_id: str,
|
||||||
contents: List[InterleavedTextMedia],
|
contents: List[InterleavedTextMedia],
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
raise NotImplementedError()
|
model = await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
kwargs = {}
|
||||||
|
assert model.model_type == ModelType.embedding
|
||||||
|
assert model.metadata.get("embedding_dimensions")
|
||||||
|
kwargs["dimensions"] = model.metadata.get("embedding_dimensions")
|
||||||
|
assert all(
|
||||||
|
not content_has_media(content) for content in contents
|
||||||
|
), "VLLM does not support media for embeddings"
|
||||||
|
response = self.client.embeddings.create(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
input=[interleaved_text_media_as_str(content) for content in contents],
|
||||||
|
**kwargs,
|
||||||
|
)
|
||||||
|
|
||||||
|
embeddings = [data.embedding for data in response.data]
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
|
@ -4,12 +4,18 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import ChromaRemoteImplConfig
|
from .config import ChromaRemoteImplConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: ChromaRemoteImplConfig, _deps):
|
async def get_adapter_impl(
|
||||||
|
config: ChromaRemoteImplConfig, deps: Dict[Api, ProviderSpec]
|
||||||
|
):
|
||||||
from .chroma import ChromaMemoryAdapter
|
from .chroma import ChromaMemoryAdapter
|
||||||
|
|
||||||
impl = ChromaMemoryAdapter(config)
|
impl = ChromaMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -13,8 +13,7 @@ import chromadb
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
|
||||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
|
@ -87,10 +86,14 @@ class ChromaIndex(EmbeddingIndex):
|
||||||
|
|
||||||
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(
|
def __init__(
|
||||||
self, config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig]
|
self,
|
||||||
|
config: Union[ChromaRemoteImplConfig, ChromaInlineImplConfig],
|
||||||
|
inference_api: Api.inference,
|
||||||
) -> None:
|
) -> None:
|
||||||
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
log.info(f"Initializing ChromaMemoryAdapter with url: {config}")
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
self.client = None
|
self.client = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
@ -127,10 +130,9 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
metadata={"bank": memory_bank.model_dump_json()},
|
metadata={"bank": memory_bank.model_dump_json()},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
bank_index = BankWithIndex(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
bank=memory_bank, index=ChromaIndex(self.client, collection)
|
memory_bank, ChromaIndex(self.client, collection), self.inference_api
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = bank_index
|
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
await self.cache[memory_bank_id].index.delete()
|
await self.cache[memory_bank_id].index.delete()
|
||||||
|
@ -166,6 +168,8 @@ class ChromaMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
collection = await maybe_await(self.client.get_collection(bank_id))
|
collection = await maybe_await(self.client.get_collection(bank_id))
|
||||||
if not collection:
|
if not collection:
|
||||||
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
raise ValueError(f"Bank {bank_id} not found in Chroma")
|
||||||
index = BankWithIndex(bank=bank, index=ChromaIndex(self.client, collection))
|
index = BankWithIndex(
|
||||||
|
bank, ChromaIndex(self.client, collection), self.inference_api
|
||||||
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import PGVectorConfig
|
from .config import PGVectorConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: PGVectorConfig, _deps):
|
async def get_adapter_impl(config: PGVectorConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .pgvector import PGVectorMemoryAdapter
|
from .pgvector import PGVectorMemoryAdapter
|
||||||
|
|
||||||
impl = PGVectorMemoryAdapter(config)
|
impl = PGVectorMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -16,9 +16,9 @@ from pydantic import BaseModel, parse_obj_as
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
ALL_MINILM_L6_V2_DIMENSION,
|
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
)
|
)
|
||||||
|
@ -120,8 +120,9 @@ class PGVectorIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: PGVectorConfig) -> None:
|
def __init__(self, config: PGVectorConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.cursor = None
|
self.cursor = None
|
||||||
self.conn = None
|
self.conn = None
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
@ -160,27 +161,17 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def register_memory_bank(
|
async def register_memory_bank(self, memory_bank: MemoryBank) -> None:
|
||||||
self,
|
|
||||||
memory_bank: MemoryBank,
|
|
||||||
) -> None:
|
|
||||||
assert (
|
assert (
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
upsert_models(
|
upsert_models(self.cursor, [(memory_bank.identifier, memory_bank)])
|
||||||
self.cursor,
|
index = PGVectorIndex(memory_bank, memory_bank.embedding_dimension, self.cursor)
|
||||||
[
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
(memory_bank.identifier, memory_bank),
|
memory_bank, index, self.inference_api
|
||||||
],
|
|
||||||
)
|
)
|
||||||
|
|
||||||
index = BankWithIndex(
|
|
||||||
bank=memory_bank,
|
|
||||||
index=PGVectorIndex(memory_bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
|
||||||
)
|
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
async def unregister_memory_bank(self, memory_bank_id: str) -> None:
|
||||||
await self.cache[memory_bank_id].index.delete()
|
await self.cache[memory_bank_id].index.delete()
|
||||||
del self.cache[memory_bank_id]
|
del self.cache[memory_bank_id]
|
||||||
|
@ -203,14 +194,13 @@ class PGVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = await self._get_and_cache_bank_index(bank_id)
|
index = await self._get_and_cache_bank_index(bank_id)
|
||||||
return await index.query_documents(query, params)
|
return await index.query_documents(query, params)
|
||||||
|
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> BankWithIndex:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
return self.cache[bank_id]
|
return self.cache[bank_id]
|
||||||
|
|
||||||
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
bank = await self.memory_bank_store.get_memory_bank(bank_id)
|
||||||
index = BankWithIndex(
|
index = PGVectorIndex(bank, bank.embedding_dimension, self.cursor)
|
||||||
bank=bank,
|
self.cache[bank_id] = BankWithIndex(bank, index, self.inference_api)
|
||||||
index=PGVectorIndex(bank, ALL_MINILM_L6_V2_DIMENSION, self.cursor),
|
return self.cache[bank_id]
|
||||||
)
|
|
||||||
self.cache[bank_id] = index
|
|
||||||
return index
|
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import QdrantConfig
|
from .config import QdrantConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: QdrantConfig, _deps):
|
async def get_adapter_impl(config: QdrantConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .qdrant import QdrantVectorMemoryAdapter
|
from .qdrant import QdrantVectorMemoryAdapter
|
||||||
|
|
||||||
impl = QdrantVectorMemoryAdapter(config)
|
impl = QdrantVectorMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -101,10 +101,11 @@ class QdrantIndex(EmbeddingIndex):
|
||||||
|
|
||||||
|
|
||||||
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
def __init__(self, config: QdrantConfig) -> None:
|
def __init__(self, config: QdrantConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
self.client = AsyncQdrantClient(**self.config.model_dump(exclude_none=True))
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
self.inference_api = inference_api
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -123,6 +124,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=memory_bank,
|
bank=memory_bank,
|
||||||
index=QdrantIndex(self.client, memory_bank.identifier),
|
index=QdrantIndex(self.client, memory_bank.identifier),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
|
|
||||||
self.cache[memory_bank.identifier] = index
|
self.cache[memory_bank.identifier] = index
|
||||||
|
@ -138,6 +140,7 @@ class QdrantVectorMemoryAdapter(Memory, MemoryBanksProtocolPrivate):
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
index=QdrantIndex(client=self.client, collection_name=bank_id),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -4,12 +4,16 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Dict
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||||
|
|
||||||
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
from .config import WeaviateConfig, WeaviateRequestProviderData # noqa: F401
|
||||||
|
|
||||||
|
|
||||||
async def get_adapter_impl(config: WeaviateConfig, _deps):
|
async def get_adapter_impl(config: WeaviateConfig, deps: Dict[Api, ProviderSpec]):
|
||||||
from .weaviate import WeaviateMemoryAdapter
|
from .weaviate import WeaviateMemoryAdapter
|
||||||
|
|
||||||
impl = WeaviateMemoryAdapter(config)
|
impl = WeaviateMemoryAdapter(config, deps[Api.inference])
|
||||||
await impl.initialize()
|
await impl.initialize()
|
||||||
return impl
|
return impl
|
||||||
|
|
|
@ -12,10 +12,11 @@ import weaviate
|
||||||
import weaviate.classes as wvc
|
import weaviate.classes as wvc
|
||||||
from numpy.typing import NDArray
|
from numpy.typing import NDArray
|
||||||
from weaviate.classes.init import Auth
|
from weaviate.classes.init import Auth
|
||||||
|
from weaviate.classes.query import Filter
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
from llama_stack.distribution.request_headers import NeedsRequestProviderData
|
||||||
from llama_stack.providers.datatypes import MemoryBanksProtocolPrivate
|
from llama_stack.providers.datatypes import Api, MemoryBanksProtocolPrivate
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
BankWithIndex,
|
BankWithIndex,
|
||||||
EmbeddingIndex,
|
EmbeddingIndex,
|
||||||
|
@ -80,12 +81,21 @@ class WeaviateIndex(EmbeddingIndex):
|
||||||
|
|
||||||
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
return QueryDocumentsResponse(chunks=chunks, scores=scores)
|
||||||
|
|
||||||
|
async def delete(self, chunk_ids: List[str]) -> None:
|
||||||
|
collection = self.client.collections.get(self.collection_name)
|
||||||
|
collection.data.delete_many(
|
||||||
|
where=Filter.by_property("id").contains_any(chunk_ids)
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class WeaviateMemoryAdapter(
|
class WeaviateMemoryAdapter(
|
||||||
Memory, NeedsRequestProviderData, MemoryBanksProtocolPrivate
|
Memory,
|
||||||
|
NeedsRequestProviderData,
|
||||||
|
MemoryBanksProtocolPrivate,
|
||||||
):
|
):
|
||||||
def __init__(self, config: WeaviateConfig) -> None:
|
def __init__(self, config: WeaviateConfig, inference_api: Api.inference) -> None:
|
||||||
self.config = config
|
self.config = config
|
||||||
|
self.inference_api = inference_api
|
||||||
self.client_cache = {}
|
self.client_cache = {}
|
||||||
self.cache = {}
|
self.cache = {}
|
||||||
|
|
||||||
|
@ -117,7 +127,7 @@ class WeaviateMemoryAdapter(
|
||||||
memory_bank: MemoryBank,
|
memory_bank: MemoryBank,
|
||||||
) -> None:
|
) -> None:
|
||||||
assert (
|
assert (
|
||||||
memory_bank.memory_bank_type == MemoryBankType.vector
|
memory_bank.memory_bank_type == MemoryBankType.vector.value
|
||||||
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
), f"Only vector banks are supported {memory_bank.memory_bank_type}"
|
||||||
|
|
||||||
client = self._get_client()
|
client = self._get_client()
|
||||||
|
@ -135,11 +145,11 @@ class WeaviateMemoryAdapter(
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
index = BankWithIndex(
|
self.cache[memory_bank.identifier] = BankWithIndex(
|
||||||
bank=memory_bank,
|
memory_bank,
|
||||||
index=WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
WeaviateIndex(client=client, collection_name=memory_bank.identifier),
|
||||||
|
self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[memory_bank.identifier] = index
|
|
||||||
|
|
||||||
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
async def _get_and_cache_bank_index(self, bank_id: str) -> Optional[BankWithIndex]:
|
||||||
if bank_id in self.cache:
|
if bank_id in self.cache:
|
||||||
|
@ -156,6 +166,7 @@ class WeaviateMemoryAdapter(
|
||||||
index = BankWithIndex(
|
index = BankWithIndex(
|
||||||
bank=bank,
|
bank=bank,
|
||||||
index=WeaviateIndex(client=client, collection_name=bank_id),
|
index=WeaviateIndex(client=client, collection_name=bank_id),
|
||||||
|
inference_api=self.inference_api,
|
||||||
)
|
)
|
||||||
self.cache[bank_id] = index
|
self.cache[bank_id] = index
|
||||||
return index
|
return index
|
||||||
|
|
|
@ -156,4 +156,5 @@ pytest_plugins = [
|
||||||
"llama_stack.providers.tests.datasetio.fixtures",
|
"llama_stack.providers.tests.datasetio.fixtures",
|
||||||
"llama_stack.providers.tests.scoring.fixtures",
|
"llama_stack.providers.tests.scoring.fixtures",
|
||||||
"llama_stack.providers.tests.eval.fixtures",
|
"llama_stack.providers.tests.eval.fixtures",
|
||||||
|
"llama_stack.providers.tests.post_training.fixtures",
|
||||||
]
|
]
|
||||||
|
|
|
@ -10,6 +10,7 @@ import pytest_asyncio
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
from ..conftest import ProviderFixture, remote_stack_fixture
|
from ..conftest import ProviderFixture, remote_stack_fixture
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,12 @@ def pytest_addoption(parser):
|
||||||
default=None,
|
default=None,
|
||||||
help="Specify the inference model to use for testing",
|
help="Specify the inference model to use for testing",
|
||||||
)
|
)
|
||||||
|
parser.addoption(
|
||||||
|
"--embedding-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the embedding model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
|
|
|
@ -9,9 +9,9 @@ import os
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
from llama_stack.apis.models import ModelInput
|
from llama_stack.apis.models import ModelInput, ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
from llama_stack.providers.inline.inference.meta_reference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
MetaReferenceInferenceConfig,
|
MetaReferenceInferenceConfig,
|
||||||
)
|
)
|
||||||
|
@ -47,6 +47,9 @@ def inference_meta_reference(inference_model) -> ProviderFixture:
|
||||||
inference_model = (
|
inference_model = (
|
||||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
)
|
)
|
||||||
|
# If embedding dimension is set, use the 8B model for testing
|
||||||
|
if os.getenv("EMBEDDING_DIMENSION"):
|
||||||
|
inference_model = ["meta-llama/Llama-3.1-8B-Instruct"]
|
||||||
|
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
providers=[
|
providers=[
|
||||||
|
@ -85,7 +88,7 @@ def inference_ollama(inference_model) -> ProviderFixture:
|
||||||
inference_model = (
|
inference_model = (
|
||||||
[inference_model] if isinstance(inference_model, str) else inference_model
|
[inference_model] if isinstance(inference_model, str) else inference_model
|
||||||
)
|
)
|
||||||
if "Llama3.1-8B-Instruct" in inference_model:
|
if inference_model and "Llama3.1-8B-Instruct" in inference_model:
|
||||||
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
pytest.skip("Ollama only supports Llama3.2-3B-Instruct for testing")
|
||||||
|
|
||||||
return ProviderFixture(
|
return ProviderFixture(
|
||||||
|
@ -232,11 +235,23 @@ INFERENCE_FIXTURES = [
|
||||||
async def inference_stack(request, inference_model):
|
async def inference_stack(request, inference_model):
|
||||||
fixture_name = request.param
|
fixture_name = request.param
|
||||||
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
inference_fixture = request.getfixturevalue(f"inference_{fixture_name}")
|
||||||
|
model_type = ModelType.llm
|
||||||
|
metadata = {}
|
||||||
|
if os.getenv("EMBEDDING_DIMENSION"):
|
||||||
|
model_type = ModelType.embedding
|
||||||
|
metadata["embedding_dimension"] = get_env_or_fail("EMBEDDING_DIMENSION")
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.inference],
|
[Api.inference],
|
||||||
{"inference": inference_fixture.providers},
|
{"inference": inference_fixture.providers},
|
||||||
inference_fixture.provider_data,
|
inference_fixture.provider_data,
|
||||||
models=[ModelInput(model_id=inference_model)],
|
models=[
|
||||||
|
ModelInput(
|
||||||
|
model_id=inference_model,
|
||||||
|
model_type=model_type,
|
||||||
|
metadata=metadata,
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
return test_stack.impls[Api.inference], test_stack.impls[Api.models]
|
||||||
|
|
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
62
llama_stack/providers/tests/inference/test_embeddings.py
Normal file
|
@ -0,0 +1,62 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import EmbeddingsResponse, ModelType
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
# pytest -v -s llama_stack/providers/tests/inference/test_embeddings.py
|
||||||
|
|
||||||
|
|
||||||
|
class TestEmbeddings:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_embeddings(self, inference_model, inference_stack):
|
||||||
|
inference_impl, models_impl = inference_stack
|
||||||
|
model = await models_impl.get_model(inference_model)
|
||||||
|
|
||||||
|
if model.model_type != ModelType.embedding:
|
||||||
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
|
||||||
|
response = await inference_impl.embeddings(
|
||||||
|
model_id=inference_model,
|
||||||
|
contents=["Hello, world!"],
|
||||||
|
)
|
||||||
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
assert len(response.embeddings) > 0
|
||||||
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||||
|
assert all(
|
||||||
|
isinstance(value, float)
|
||||||
|
for embedding in response.embeddings
|
||||||
|
for value in embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_batch_embeddings(self, inference_model, inference_stack):
|
||||||
|
inference_impl, models_impl = inference_stack
|
||||||
|
model = await models_impl.get_model(inference_model)
|
||||||
|
|
||||||
|
if model.model_type != ModelType.embedding:
|
||||||
|
pytest.skip("This test is only applicable for embedding models")
|
||||||
|
|
||||||
|
texts = ["Hello, world!", "This is a test", "Testing embeddings"]
|
||||||
|
|
||||||
|
response = await inference_impl.embeddings(
|
||||||
|
model_id=inference_model,
|
||||||
|
contents=texts,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, EmbeddingsResponse)
|
||||||
|
assert len(response.embeddings) == len(texts)
|
||||||
|
assert all(isinstance(embedding, list) for embedding in response.embeddings)
|
||||||
|
assert all(
|
||||||
|
isinstance(value, float)
|
||||||
|
for embedding in response.embeddings
|
||||||
|
for value in embedding
|
||||||
|
)
|
||||||
|
|
||||||
|
embedding_dim = len(response.embeddings[0])
|
||||||
|
assert all(len(embedding) == embedding_dim for embedding in response.embeddings)
|
|
@ -128,6 +128,61 @@ class TestInference:
|
||||||
last = chunks[-1]
|
last = chunks[-1]
|
||||||
assert last.stop_reason == StopReason.out_of_tokens
|
assert last.stop_reason == StopReason.out_of_tokens
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_completion_logprobs(self, inference_model, inference_stack):
|
||||||
|
inference_impl, _ = inference_stack
|
||||||
|
|
||||||
|
provider = inference_impl.routing_table.get_provider_impl(inference_model)
|
||||||
|
if provider.__provider_spec__.provider_type not in (
|
||||||
|
# "remote::nvidia", -- provider doesn't provide all logprobs
|
||||||
|
):
|
||||||
|
pytest.skip("Other inference providers don't support completion() yet")
|
||||||
|
|
||||||
|
response = await inference_impl.completion(
|
||||||
|
content="Micheael Jordan is born in ",
|
||||||
|
stream=False,
|
||||||
|
model_id=inference_model,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=5,
|
||||||
|
),
|
||||||
|
logprobs=LogProbConfig(
|
||||||
|
top_k=3,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert isinstance(response, CompletionResponse)
|
||||||
|
assert 1 <= len(response.logprobs) <= 5
|
||||||
|
assert response.logprobs, "Logprobs should not be empty"
|
||||||
|
assert all(len(logprob.logprobs_by_token) == 3 for logprob in response.logprobs)
|
||||||
|
|
||||||
|
chunks = [
|
||||||
|
r
|
||||||
|
async for r in await inference_impl.completion(
|
||||||
|
content="Roses are red,",
|
||||||
|
stream=True,
|
||||||
|
model_id=inference_model,
|
||||||
|
sampling_params=SamplingParams(
|
||||||
|
max_tokens=5,
|
||||||
|
),
|
||||||
|
logprobs=LogProbConfig(
|
||||||
|
top_k=3,
|
||||||
|
),
|
||||||
|
)
|
||||||
|
]
|
||||||
|
|
||||||
|
assert all(isinstance(chunk, CompletionResponseStreamChunk) for chunk in chunks)
|
||||||
|
assert (
|
||||||
|
1 <= len(chunks) <= 6
|
||||||
|
) # why 6 and not 5? the response may have an extra closing chunk, e.g. for usage or stop_reason
|
||||||
|
for chunk in chunks:
|
||||||
|
if chunk.delta: # if there's a token, we expect logprobs
|
||||||
|
assert chunk.logprobs, "Logprobs should not be empty"
|
||||||
|
assert all(
|
||||||
|
len(logprob.logprobs_by_token) == 3 for logprob in chunk.logprobs
|
||||||
|
)
|
||||||
|
else: # no token, no logprobs
|
||||||
|
assert not chunk.logprobs, "Logprobs should be empty"
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
@pytest.mark.skip("This test is not quite robust")
|
@pytest.mark.skip("This test is not quite robust")
|
||||||
async def test_completion_structured_output(self, inference_model, inference_stack):
|
async def test_completion_structured_output(self, inference_model, inference_stack):
|
||||||
|
|
|
@ -6,9 +6,65 @@
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
|
from ..inference.fixtures import INFERENCE_FIXTURES
|
||||||
from .fixtures import MEMORY_FIXTURES
|
from .fixtures import MEMORY_FIXTURES
|
||||||
|
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "meta_reference",
|
||||||
|
"memory": "faiss",
|
||||||
|
},
|
||||||
|
id="meta_reference",
|
||||||
|
marks=pytest.mark.meta_reference,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "ollama",
|
||||||
|
"memory": "pgvector",
|
||||||
|
},
|
||||||
|
id="ollama",
|
||||||
|
marks=pytest.mark.ollama,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "together",
|
||||||
|
"memory": "chroma",
|
||||||
|
},
|
||||||
|
id="chroma",
|
||||||
|
marks=pytest.mark.chroma,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "bedrock",
|
||||||
|
"memory": "qdrant",
|
||||||
|
},
|
||||||
|
id="qdrant",
|
||||||
|
marks=pytest.mark.qdrant,
|
||||||
|
),
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"inference": "fireworks",
|
||||||
|
"memory": "weaviate",
|
||||||
|
},
|
||||||
|
id="weaviate",
|
||||||
|
marks=pytest.mark.weaviate,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_addoption(parser):
|
||||||
|
parser.addoption(
|
||||||
|
"--inference-model",
|
||||||
|
action="store",
|
||||||
|
default=None,
|
||||||
|
help="Specify the inference model to use for testing",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
for fixture_name in MEMORY_FIXTURES:
|
for fixture_name in MEMORY_FIXTURES:
|
||||||
config.addinivalue_line(
|
config.addinivalue_line(
|
||||||
|
@ -18,12 +74,22 @@ def pytest_configure(config):
|
||||||
|
|
||||||
|
|
||||||
def pytest_generate_tests(metafunc):
|
def pytest_generate_tests(metafunc):
|
||||||
if "memory_stack" in metafunc.fixturenames:
|
if "inference_model" in metafunc.fixturenames:
|
||||||
metafunc.parametrize(
|
model = metafunc.config.getoption("--inference-model")
|
||||||
"memory_stack",
|
if not model:
|
||||||
[
|
raise ValueError(
|
||||||
pytest.param(fixture_name, marks=getattr(pytest.mark, fixture_name))
|
"No inference model specified. Please provide a valid inference model."
|
||||||
for fixture_name in MEMORY_FIXTURES
|
|
||||||
],
|
|
||||||
indirect=True,
|
|
||||||
)
|
)
|
||||||
|
params = [pytest.param(model, id="")]
|
||||||
|
|
||||||
|
metafunc.parametrize("inference_model", params, indirect=True)
|
||||||
|
if "memory_stack" in metafunc.fixturenames:
|
||||||
|
available_fixtures = {
|
||||||
|
"inference": INFERENCE_FIXTURES,
|
||||||
|
"memory": MEMORY_FIXTURES,
|
||||||
|
}
|
||||||
|
combinations = (
|
||||||
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
metafunc.parametrize("memory_stack", combinations, indirect=True)
|
||||||
|
|
|
@ -10,6 +10,8 @@ import tempfile
|
||||||
import pytest
|
import pytest
|
||||||
import pytest_asyncio
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ModelInput, ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import Api, Provider
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
from llama_stack.providers.inline.memory.chroma import ChromaInlineImplConfig
|
||||||
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss import FaissImplConfig
|
||||||
|
@ -105,14 +107,30 @@ MEMORY_FIXTURES = ["faiss", "pgvector", "weaviate", "remote", "chroma"]
|
||||||
|
|
||||||
|
|
||||||
@pytest_asyncio.fixture(scope="session")
|
@pytest_asyncio.fixture(scope="session")
|
||||||
async def memory_stack(request):
|
async def memory_stack(inference_model, request):
|
||||||
fixture_name = request.param
|
fixture_dict = request.param
|
||||||
fixture = request.getfixturevalue(f"memory_{fixture_name}")
|
|
||||||
|
providers = {}
|
||||||
|
provider_data = {}
|
||||||
|
for key in ["inference", "memory"]:
|
||||||
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
|
providers[key] = fixture.providers
|
||||||
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
test_stack = await construct_stack_for_test(
|
test_stack = await construct_stack_for_test(
|
||||||
[Api.memory],
|
[Api.memory, Api.inference],
|
||||||
{"memory": fixture.providers},
|
providers,
|
||||||
fixture.provider_data,
|
provider_data,
|
||||||
|
models=[
|
||||||
|
ModelInput(
|
||||||
|
model_id=inference_model,
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": get_env_or_fail("EMBEDDING_DIMENSION"),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
return test_stack.impls[Api.memory], test_stack.impls[Api.memory_banks]
|
||||||
|
|
|
@ -45,12 +45,14 @@ def sample_documents():
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
async def register_memory_bank(
|
||||||
|
banks_impl: MemoryBanks, inference_model: str
|
||||||
|
) -> MemoryBank:
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
return await banks_impl.register_memory_bank(
|
return await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=inference_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -59,11 +61,11 @@ async def register_memory_bank(banks_impl: MemoryBanks) -> MemoryBank:
|
||||||
|
|
||||||
class TestMemory:
|
class TestMemory:
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_list(self, memory_stack):
|
async def test_banks_list(self, memory_stack, inference_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
# Register a test bank
|
# Register a test bank
|
||||||
registered_bank = await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Verify our bank shows up in list
|
# Verify our bank shows up in list
|
||||||
|
@ -84,7 +86,7 @@ class TestMemory:
|
||||||
)
|
)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_banks_register(self, memory_stack):
|
async def test_banks_register(self, memory_stack, inference_model):
|
||||||
_, banks_impl = memory_stack
|
_, banks_impl = memory_stack
|
||||||
|
|
||||||
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
bank_id = f"test_bank_{uuid.uuid4().hex}"
|
||||||
|
@ -94,7 +96,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=inference_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -109,7 +111,7 @@ class TestMemory:
|
||||||
await banks_impl.register_memory_bank(
|
await banks_impl.register_memory_bank(
|
||||||
memory_bank_id=bank_id,
|
memory_bank_id=bank_id,
|
||||||
params=VectorMemoryBankParams(
|
params=VectorMemoryBankParams(
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_model=inference_model,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=512,
|
||||||
overlap_size_in_tokens=64,
|
overlap_size_in_tokens=64,
|
||||||
),
|
),
|
||||||
|
@ -126,13 +128,15 @@ class TestMemory:
|
||||||
await banks_impl.unregister_memory_bank(bank_id)
|
await banks_impl.unregister_memory_bank(bank_id)
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_query_documents(self, memory_stack, sample_documents):
|
async def test_query_documents(
|
||||||
|
self, memory_stack, inference_model, sample_documents
|
||||||
|
):
|
||||||
memory_impl, banks_impl = memory_stack
|
memory_impl, banks_impl = memory_stack
|
||||||
|
|
||||||
with pytest.raises(ValueError):
|
with pytest.raises(ValueError):
|
||||||
await memory_impl.insert_documents("test_bank", sample_documents)
|
await memory_impl.insert_documents("test_bank", sample_documents)
|
||||||
|
|
||||||
registered_bank = await register_memory_bank(banks_impl)
|
registered_bank = await register_memory_bank(banks_impl, inference_model)
|
||||||
await memory_impl.insert_documents(
|
await memory_impl.insert_documents(
|
||||||
registered_bank.memory_bank_id, sample_documents
|
registered_bank.memory_bank_id, sample_documents
|
||||||
)
|
)
|
||||||
|
@ -165,13 +169,13 @@ class TestMemory:
|
||||||
|
|
||||||
# Test case 5: Query with threshold on similarity score
|
# Test case 5: Query with threshold on similarity score
|
||||||
query5 = "quantum computing" # Not directly related to any document
|
query5 = "quantum computing" # Not directly related to any document
|
||||||
params5 = {"score_threshold": 0.2}
|
params5 = {"score_threshold": 0.01}
|
||||||
response5 = await memory_impl.query_documents(
|
response5 = await memory_impl.query_documents(
|
||||||
registered_bank.memory_bank_id, query5, params5
|
registered_bank.memory_bank_id, query5, params5
|
||||||
)
|
)
|
||||||
assert_valid_response(response5)
|
assert_valid_response(response5)
|
||||||
print("The scores are:", response5.scores)
|
print("The scores are:", response5.scores)
|
||||||
assert all(score >= 0.2 for score in response5.scores)
|
assert all(score >= 0.01 for score in response5.scores)
|
||||||
|
|
||||||
|
|
||||||
def assert_valid_response(response: QueryDocumentsResponse):
|
def assert_valid_response(response: QueryDocumentsResponse):
|
||||||
|
|
5
llama_stack/providers/tests/post_training/__init__.py
Normal file
5
llama_stack/providers/tests/post_training/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
45
llama_stack/providers/tests/post_training/conftest.py
Normal file
45
llama_stack/providers/tests/post_training/conftest.py
Normal file
|
@ -0,0 +1,45 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from ..conftest import get_provider_fixture_overrides
|
||||||
|
|
||||||
|
from ..datasetio.fixtures import DATASETIO_FIXTURES
|
||||||
|
|
||||||
|
from .fixtures import POST_TRAINING_FIXTURES
|
||||||
|
|
||||||
|
DEFAULT_PROVIDER_COMBINATIONS = [
|
||||||
|
pytest.param(
|
||||||
|
{
|
||||||
|
"post_training": "torchtune",
|
||||||
|
"datasetio": "huggingface",
|
||||||
|
},
|
||||||
|
id="torchtune_post_training_huggingface_datasetio",
|
||||||
|
marks=pytest.mark.torchtune_post_training_huggingface_datasetio,
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_configure(config):
|
||||||
|
combined_fixtures = "torchtune_post_training_huggingface_datasetio"
|
||||||
|
config.addinivalue_line(
|
||||||
|
"markers",
|
||||||
|
f"{combined_fixtures}: marks tests as {combined_fixtures} specific",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def pytest_generate_tests(metafunc):
|
||||||
|
if "post_training_stack" in metafunc.fixturenames:
|
||||||
|
available_fixtures = {
|
||||||
|
"eval": POST_TRAINING_FIXTURES,
|
||||||
|
"datasetio": DATASETIO_FIXTURES,
|
||||||
|
}
|
||||||
|
combinations = (
|
||||||
|
get_provider_fixture_overrides(metafunc.config, available_fixtures)
|
||||||
|
or DEFAULT_PROVIDER_COMBINATIONS
|
||||||
|
)
|
||||||
|
metafunc.parametrize("post_training_stack", combinations, indirect=True)
|
74
llama_stack/providers/tests/post_training/fixtures.py
Normal file
74
llama_stack/providers/tests/post_training/fixtures.py
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
import pytest_asyncio
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import URL
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.datasets import DatasetInput
|
||||||
|
from llama_stack.apis.models import ModelInput
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api, Provider
|
||||||
|
|
||||||
|
from llama_stack.providers.tests.resolver import construct_stack_for_test
|
||||||
|
|
||||||
|
from ..conftest import ProviderFixture
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(scope="session")
|
||||||
|
def post_training_torchtune() -> ProviderFixture:
|
||||||
|
return ProviderFixture(
|
||||||
|
providers=[
|
||||||
|
Provider(
|
||||||
|
provider_id="torchtune",
|
||||||
|
provider_type="inline::torchtune",
|
||||||
|
config={},
|
||||||
|
)
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
POST_TRAINING_FIXTURES = ["torchtune"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest_asyncio.fixture(scope="session")
|
||||||
|
async def post_training_stack(request):
|
||||||
|
fixture_dict = request.param
|
||||||
|
|
||||||
|
providers = {}
|
||||||
|
provider_data = {}
|
||||||
|
for key in ["post_training", "datasetio"]:
|
||||||
|
fixture = request.getfixturevalue(f"{key}_{fixture_dict[key]}")
|
||||||
|
providers[key] = fixture.providers
|
||||||
|
if fixture.provider_data:
|
||||||
|
provider_data.update(fixture.provider_data)
|
||||||
|
|
||||||
|
test_stack = await construct_stack_for_test(
|
||||||
|
[Api.post_training, Api.datasetio],
|
||||||
|
providers,
|
||||||
|
provider_data,
|
||||||
|
models=[ModelInput(model_id="meta-llama/Llama-3.2-3B-Instruct")],
|
||||||
|
datasets=[
|
||||||
|
DatasetInput(
|
||||||
|
dataset_id="alpaca",
|
||||||
|
provider_id="huggingface",
|
||||||
|
url=URL(uri="https://huggingface.co/datasets/tatsu-lab/alpaca"),
|
||||||
|
metadata={
|
||||||
|
"path": "tatsu-lab/alpaca",
|
||||||
|
"split": "train",
|
||||||
|
},
|
||||||
|
dataset_schema={
|
||||||
|
"instruction": StringType(),
|
||||||
|
"input": StringType(),
|
||||||
|
"output": StringType(),
|
||||||
|
"text": StringType(),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
return test_stack.impls[Api.post_training]
|
|
@ -0,0 +1,92 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import pytest
|
||||||
|
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||||
|
from llama_stack.apis.post_training import * # noqa: F403
|
||||||
|
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||||
|
|
||||||
|
# How to run this test:
|
||||||
|
#
|
||||||
|
# pytest llama_stack/providers/tests/post_training/test_post_training.py
|
||||||
|
# -m "torchtune_post_training_huggingface_datasetio"
|
||||||
|
# -v -s --tb=short --disable-warnings
|
||||||
|
|
||||||
|
|
||||||
|
class TestPostTraining:
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_supervised_fine_tune(self, post_training_stack):
|
||||||
|
algorithm_config = LoraFinetuningConfig(
|
||||||
|
type="LoRA",
|
||||||
|
lora_attn_modules=["q_proj", "v_proj", "output_proj"],
|
||||||
|
apply_lora_to_mlp=True,
|
||||||
|
apply_lora_to_output=False,
|
||||||
|
rank=8,
|
||||||
|
alpha=16,
|
||||||
|
)
|
||||||
|
|
||||||
|
data_config = DataConfig(
|
||||||
|
dataset_id="alpaca",
|
||||||
|
batch_size=1,
|
||||||
|
shuffle=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type="adamw",
|
||||||
|
lr=3e-4,
|
||||||
|
lr_min=3e-5,
|
||||||
|
weight_decay=0.1,
|
||||||
|
num_warmup_steps=100,
|
||||||
|
)
|
||||||
|
|
||||||
|
training_config = TrainingConfig(
|
||||||
|
n_epochs=1,
|
||||||
|
data_config=data_config,
|
||||||
|
optimizer_config=optimizer_config,
|
||||||
|
max_steps_per_epoch=1,
|
||||||
|
gradient_accumulation_steps=1,
|
||||||
|
)
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
response = await post_training_impl.supervised_fine_tune(
|
||||||
|
job_uuid="1234",
|
||||||
|
model="Llama3.2-3B-Instruct",
|
||||||
|
algorithm_config=algorithm_config,
|
||||||
|
training_config=training_config,
|
||||||
|
hyperparam_search_config={},
|
||||||
|
logger_config={},
|
||||||
|
checkpoint_dir="null",
|
||||||
|
)
|
||||||
|
assert isinstance(response, PostTrainingJob)
|
||||||
|
assert response.job_uuid == "1234"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_training_jobs(self, post_training_stack):
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
jobs_list = await post_training_impl.get_training_jobs()
|
||||||
|
assert isinstance(jobs_list, List)
|
||||||
|
assert jobs_list[0].job_uuid == "1234"
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_training_job_status(self, post_training_stack):
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
job_status = await post_training_impl.get_training_job_status("1234")
|
||||||
|
assert isinstance(job_status, PostTrainingJobStatusResponse)
|
||||||
|
assert job_status.job_uuid == "1234"
|
||||||
|
assert job_status.status == JobStatus.completed
|
||||||
|
assert isinstance(job_status.checkpoints[0], Checkpoint)
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_training_job_artifacts(self, post_training_stack):
|
||||||
|
post_training_impl = post_training_stack
|
||||||
|
job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
|
||||||
|
assert isinstance(job_artifacts, PostTrainingJobArtifactsResponse)
|
||||||
|
assert job_artifacts.job_uuid == "1234"
|
||||||
|
assert isinstance(job_artifacts.checkpoints[0], Checkpoint)
|
||||||
|
assert job_artifacts.checkpoints[0].identifier == "Llama3.2-3B-Instruct-sft-0"
|
||||||
|
assert job_artifacts.checkpoints[0].epoch == 0
|
||||||
|
assert (
|
||||||
|
"/.llama/checkpoints/Llama3.2-3B-Instruct-sft-0"
|
||||||
|
in job_artifacts.checkpoints[0].path
|
||||||
|
)
|
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
47
llama_stack/providers/utils/inference/embedding_mixin.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import logging
|
||||||
|
from typing import List
|
||||||
|
|
||||||
|
from llama_models.llama3.api.datatypes import InterleavedTextMedia
|
||||||
|
|
||||||
|
from llama_stack.apis.inference.inference import EmbeddingsResponse, ModelStore
|
||||||
|
|
||||||
|
EMBEDDING_MODELS = {}
|
||||||
|
|
||||||
|
|
||||||
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SentenceTransformerEmbeddingMixin:
|
||||||
|
model_store: ModelStore
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[InterleavedTextMedia],
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
embedding_model = self._load_sentence_transformer_model(
|
||||||
|
model.provider_resource_id
|
||||||
|
)
|
||||||
|
embeddings = embedding_model.encode(contents)
|
||||||
|
return EmbeddingsResponse(embeddings=embeddings)
|
||||||
|
|
||||||
|
def _load_sentence_transformer_model(self, model: str) -> "SentenceTransformer":
|
||||||
|
global EMBEDDING_MODELS
|
||||||
|
|
||||||
|
loaded_model = EMBEDDING_MODELS.get(model)
|
||||||
|
if loaded_model is not None:
|
||||||
|
return loaded_model
|
||||||
|
|
||||||
|
log.info(f"Loading sentence transformer for {model}...")
|
||||||
|
from sentence_transformers import SentenceTransformer
|
||||||
|
|
||||||
|
loaded_model = SentenceTransformer(model)
|
||||||
|
EMBEDDING_MODELS[model] = loaded_model
|
||||||
|
return loaded_model
|
|
@ -9,6 +9,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import Model, ModelsProtocolPrivate
|
||||||
|
|
||||||
from llama_stack.providers.utils.inference import (
|
from llama_stack.providers.utils.inference import (
|
||||||
|
@ -77,7 +78,13 @@ class ModelRegistryHelper(ModelsProtocolPrivate):
|
||||||
return None
|
return None
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
if model.model_type == ModelType.embedding:
|
||||||
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
|
provider_resource_id = model.provider_resource_id
|
||||||
|
else:
|
||||||
|
provider_resource_id = self.get_provider_model_id(
|
||||||
|
model.provider_resource_id
|
||||||
|
)
|
||||||
if provider_resource_id:
|
if provider_resource_id:
|
||||||
model.provider_resource_id = provider_resource_id
|
model.provider_resource_id = provider_resource_id
|
||||||
else:
|
else:
|
||||||
|
|
|
@ -22,28 +22,10 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||||
|
|
||||||
from llama_stack.apis.memory import * # noqa: F403
|
from llama_stack.apis.memory import * # noqa: F403
|
||||||
|
from llama_stack.providers.datatypes import Api
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
ALL_MINILM_L6_V2_DIMENSION = 384
|
|
||||||
|
|
||||||
EMBEDDING_MODELS = {}
|
|
||||||
|
|
||||||
|
|
||||||
def get_embedding_model(model: str) -> "SentenceTransformer":
|
|
||||||
global EMBEDDING_MODELS
|
|
||||||
|
|
||||||
loaded_model = EMBEDDING_MODELS.get(model)
|
|
||||||
if loaded_model is not None:
|
|
||||||
return loaded_model
|
|
||||||
|
|
||||||
log.info(f"Loading sentence transformer for {model}...")
|
|
||||||
from sentence_transformers import SentenceTransformer
|
|
||||||
|
|
||||||
loaded_model = SentenceTransformer(model)
|
|
||||||
EMBEDDING_MODELS[model] = loaded_model
|
|
||||||
return loaded_model
|
|
||||||
|
|
||||||
|
|
||||||
def parse_pdf(data: bytes) -> str:
|
def parse_pdf(data: bytes) -> str:
|
||||||
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
# For PDF and DOC/DOCX files, we can't reliably convert to string
|
||||||
|
@ -166,12 +148,12 @@ class EmbeddingIndex(ABC):
|
||||||
class BankWithIndex:
|
class BankWithIndex:
|
||||||
bank: VectorMemoryBank
|
bank: VectorMemoryBank
|
||||||
index: EmbeddingIndex
|
index: EmbeddingIndex
|
||||||
|
inference_api: Api.inference
|
||||||
|
|
||||||
async def insert_documents(
|
async def insert_documents(
|
||||||
self,
|
self,
|
||||||
documents: List[MemoryBankDocument],
|
documents: List[MemoryBankDocument],
|
||||||
) -> None:
|
) -> None:
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
|
||||||
for doc in documents:
|
for doc in documents:
|
||||||
content = await content_from_doc(doc)
|
content = await content_from_doc(doc)
|
||||||
chunks = make_overlapped_chunks(
|
chunks = make_overlapped_chunks(
|
||||||
|
@ -183,7 +165,10 @@ class BankWithIndex:
|
||||||
)
|
)
|
||||||
if not chunks:
|
if not chunks:
|
||||||
continue
|
continue
|
||||||
embeddings = model.encode([x.content for x in chunks]).astype(np.float32)
|
embeddings_response = await self.inference_api.embeddings(
|
||||||
|
self.bank.embedding_model, [x.content for x in chunks]
|
||||||
|
)
|
||||||
|
embeddings = np.array(embeddings_response.embeddings)
|
||||||
|
|
||||||
await self.index.add_chunks(chunks, embeddings)
|
await self.index.add_chunks(chunks, embeddings)
|
||||||
|
|
||||||
|
@ -208,6 +193,8 @@ class BankWithIndex:
|
||||||
else:
|
else:
|
||||||
query_str = _process(query)
|
query_str = _process(query)
|
||||||
|
|
||||||
model = get_embedding_model(self.bank.embedding_model)
|
embeddings_response = await self.inference_api.embeddings(
|
||||||
query_vector = model.encode([query_str])[0].astype(np.float32)
|
self.bank.embedding_model, [query_str]
|
||||||
|
)
|
||||||
|
query_vector = np.array(embeddings_response.embeddings[0], dtype=np.float32)
|
||||||
return await self.index.query(query_vector, k, score_threshold)
|
return await self.index.query(query_vector, k, score_threshold)
|
||||||
|
|
|
@ -8,10 +8,14 @@ from pathlib import Path
|
||||||
|
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
from llama_stack.providers.remote.inference.cerebras import CerebrasImplConfig
|
||||||
from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases
|
from llama_stack.providers.remote.inference.cerebras.cerebras import model_aliases
|
||||||
|
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
|
||||||
|
@ -29,6 +33,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::cerebras",
|
provider_type="remote::cerebras",
|
||||||
config=CerebrasImplConfig.sample_run_config(),
|
config=CerebrasImplConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
core_model_to_hf_repo = {
|
core_model_to_hf_repo = {
|
||||||
m.descriptor(): m.huggingface_repo for m in all_registered_models()
|
m.descriptor(): m.huggingface_repo for m in all_registered_models()
|
||||||
|
@ -37,9 +46,18 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=core_model_to_hf_repo[m.llama_model],
|
model_id=core_model_to_hf_repo[m.llama_model],
|
||||||
provider_model_id=m.provider_model_id,
|
provider_model_id=m.provider_model_id,
|
||||||
|
provider_id="cerebras",
|
||||||
)
|
)
|
||||||
for m in model_aliases
|
for m in model_aliases
|
||||||
]
|
]
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="cerebras",
|
name="cerebras",
|
||||||
|
@ -52,9 +70,9 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models,
|
default_models=default_models + [embedding_model],
|
||||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
|
@ -15,6 +15,9 @@ providers:
|
||||||
config:
|
config:
|
||||||
base_url: https://api.cerebras.ai
|
base_url: https://api.cerebras.ai
|
||||||
api_key: ${env.CEREBRAS_API_KEY}
|
api_key: ${env.CEREBRAS_API_KEY}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
safety:
|
safety:
|
||||||
- provider_id: llama-guard
|
- provider_id: llama-guard
|
||||||
provider_type: inline::llama-guard
|
provider_type: inline::llama-guard
|
||||||
|
@ -49,12 +52,20 @@ metadata_store:
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
provider_id: null
|
provider_id: cerebras
|
||||||
provider_model_id: llama3.1-8b
|
provider_model_id: llama3.1-8b
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
provider_id: null
|
provider_id: cerebras
|
||||||
provider_model_id: llama3.1-70b
|
provider_model_id: llama3.1-70b
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- params: null
|
- params: null
|
||||||
shield_id: meta-llama/Llama-Guard-3-8B
|
shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
|
13
llama_stack/templates/experimental-post-training/build.yaml
Normal file
13
llama_stack/templates/experimental-post-training/build.yaml
Normal file
|
@ -0,0 +1,13 @@
|
||||||
|
version: '2'
|
||||||
|
name: experimental-post-training
|
||||||
|
distribution_spec:
|
||||||
|
description: Experimental template for post training
|
||||||
|
docker_image: null
|
||||||
|
providers:
|
||||||
|
post_training:
|
||||||
|
- inline::torchtune
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
image_type: conda
|
53
llama_stack/templates/experimental-post-training/run.yaml
Normal file
53
llama_stack/templates/experimental-post-training/run.yaml
Normal file
|
@ -0,0 +1,53 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: experimental-post-training
|
||||||
|
docker_image: null
|
||||||
|
conda_env: experimental-post-training
|
||||||
|
apis:
|
||||||
|
- telemetry
|
||||||
|
- datasetio
|
||||||
|
- post_training
|
||||||
|
providers:
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface-0
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config: {}
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config: {}
|
||||||
|
post_training:
|
||||||
|
- provider_id: torchtune-post-training
|
||||||
|
provider_type: inline::torchtune
|
||||||
|
config: {}
|
||||||
|
|
||||||
|
metadata_store:
|
||||||
|
namespace: null
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/meta-reference-gpu}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: ${env.POST_TRAINING_MODEL}
|
||||||
|
provider_id: meta-reference-inference
|
||||||
|
provider_model_id: null
|
||||||
|
shields: []
|
||||||
|
memory_banks: []
|
||||||
|
datasets:
|
||||||
|
- dataset_id: alpaca
|
||||||
|
provider_id: huggingface-0
|
||||||
|
url:
|
||||||
|
uri: https://huggingface.co/datasets/tatsu-lab/alpaca
|
||||||
|
metadata:
|
||||||
|
path: tatsu-lab/alpaca
|
||||||
|
name:
|
||||||
|
split: train
|
||||||
|
dataset_schema:
|
||||||
|
instruction:
|
||||||
|
type: string
|
||||||
|
input:
|
||||||
|
type: string
|
||||||
|
output:
|
||||||
|
type: string
|
||||||
|
text:
|
||||||
|
type: string
|
||||||
|
scoring_fns: []
|
||||||
|
eval_tasks: []
|
|
@ -8,11 +8,15 @@ from pathlib import Path
|
||||||
|
|
||||||
from llama_models.sku_list import all_registered_models
|
from llama_models.sku_list import all_registered_models
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
||||||
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
from llama_stack.providers.remote.inference.fireworks import FireworksImplConfig
|
||||||
from llama_stack.providers.remote.inference.fireworks.fireworks import MODEL_ALIASES
|
from llama_stack.providers.remote.inference.fireworks.fireworks import MODEL_ALIASES
|
||||||
|
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
|
||||||
|
@ -35,6 +39,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::fireworks",
|
provider_type="remote::fireworks",
|
||||||
config=FireworksImplConfig.sample_run_config(),
|
config=FireworksImplConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
memory_provider = Provider(
|
memory_provider = Provider(
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
|
@ -48,9 +57,18 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
ModelInput(
|
ModelInput(
|
||||||
model_id=core_model_to_hf_repo[m.llama_model],
|
model_id=core_model_to_hf_repo[m.llama_model],
|
||||||
provider_model_id=m.provider_model_id,
|
provider_model_id=m.provider_model_id,
|
||||||
|
provider_id="fireworks",
|
||||||
)
|
)
|
||||||
for m in MODEL_ALIASES
|
for m in MODEL_ALIASES
|
||||||
]
|
]
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -63,10 +81,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
"memory": [memory_provider],
|
"memory": [memory_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models,
|
default_models=default_models + [embedding_model],
|
||||||
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
default_shields=[ShieldInput(shield_id="meta-llama/Llama-Guard-3-8B")],
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
|
|
|
@ -16,8 +16,11 @@ providers:
|
||||||
- provider_id: fireworks
|
- provider_id: fireworks
|
||||||
provider_type: remote::fireworks
|
provider_type: remote::fireworks
|
||||||
config:
|
config:
|
||||||
url: https://api.fireworks.ai/inference
|
url: https://api.fireworks.ai/inference/v1
|
||||||
api_key: ${env.FIREWORKS_API_KEY}
|
api_key: ${env.FIREWORKS_API_KEY}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
memory:
|
memory:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -74,40 +77,55 @@ metadata_store:
|
||||||
models:
|
models:
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-8B-Instruct
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
provider_id: null
|
provider_id: fireworks
|
||||||
provider_model_id: fireworks/llama-v3p1-8b-instruct
|
provider_model_id: fireworks/llama-v3p1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-70B-Instruct
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
provider_id: null
|
provider_id: fireworks
|
||||||
provider_model_id: fireworks/llama-v3p1-70b-instruct
|
provider_model_id: fireworks/llama-v3p1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
model_id: meta-llama/Llama-3.1-405B-Instruct-FP8
|
||||||
provider_id: null
|
provider_id: fireworks
|
||||||
provider_model_id: fireworks/llama-v3p1-405b-instruct
|
provider_model_id: fireworks/llama-v3p1-405b-instruct
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.2-1B-Instruct
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
provider_id: null
|
provider_id: fireworks
|
||||||
provider_model_id: fireworks/llama-v3p2-1b-instruct
|
provider_model_id: fireworks/llama-v3p2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.2-3B-Instruct
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
provider_id: null
|
provider_id: fireworks
|
||||||
provider_model_id: fireworks/llama-v3p2-3b-instruct
|
provider_model_id: fireworks/llama-v3p2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
provider_id: null
|
provider_id: fireworks
|
||||||
provider_model_id: fireworks/llama-v3p2-11b-vision-instruct
|
provider_model_id: fireworks/llama-v3p2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
provider_id: null
|
provider_id: fireworks
|
||||||
provider_model_id: fireworks/llama-v3p2-90b-vision-instruct
|
provider_model_id: fireworks/llama-v3p2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-Guard-3-8B
|
model_id: meta-llama/Llama-Guard-3-8B
|
||||||
provider_id: null
|
provider_id: fireworks
|
||||||
provider_model_id: fireworks/llama-guard-3-8b
|
provider_model_id: fireworks/llama-guard-3-8b
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||||
provider_id: null
|
provider_id: fireworks
|
||||||
provider_model_id: fireworks/llama-guard-3-11b-vision
|
provider_model_id: fireworks/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- params: null
|
- params: null
|
||||||
shield_id: meta-llama/Llama-Guard-3-8B
|
shield_id: meta-llama/Llama-Guard-3-8B
|
||||||
|
|
|
@ -4,7 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
||||||
from llama_stack.providers.remote.inference.tgi import InferenceEndpointImplConfig
|
from llama_stack.providers.remote.inference.tgi import InferenceEndpointImplConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
@ -27,6 +31,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::hf::endpoint",
|
provider_type="remote::hf::endpoint",
|
||||||
config=InferenceEndpointImplConfig.sample_run_config(),
|
config=InferenceEndpointImplConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
memory_provider = Provider(
|
memory_provider = Provider(
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
|
@ -41,6 +50,14 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
model_id="${env.SAFETY_MODEL}",
|
model_id="${env.SAFETY_MODEL}",
|
||||||
provider_id="hf-endpoint-safety",
|
provider_id="hf-endpoint-safety",
|
||||||
)
|
)
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -53,15 +70,16 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
"memory": [memory_provider],
|
"memory": [memory_provider],
|
||||||
},
|
},
|
||||||
default_models=[inference_model],
|
default_models=[inference_model, embedding_model],
|
||||||
),
|
),
|
||||||
"run-with-safety.yaml": RunConfigSettings(
|
"run-with-safety.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [
|
"inference": [
|
||||||
inference_provider,
|
inference_provider,
|
||||||
|
embedding_provider,
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="hf-endpoint-safety",
|
provider_id="hf-endpoint-safety",
|
||||||
provider_type="remote::hf::endpoint",
|
provider_type="remote::hf::endpoint",
|
||||||
|
@ -75,6 +93,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
default_models=[
|
default_models=[
|
||||||
inference_model,
|
inference_model,
|
||||||
safety_model,
|
safety_model,
|
||||||
|
embedding_model,
|
||||||
],
|
],
|
||||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||||
),
|
),
|
||||||
|
|
|
@ -18,6 +18,9 @@ providers:
|
||||||
config:
|
config:
|
||||||
endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
|
endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
|
||||||
api_token: ${env.HF_API_TOKEN}
|
api_token: ${env.HF_API_TOKEN}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
- provider_id: hf-endpoint-safety
|
- provider_id: hf-endpoint-safety
|
||||||
provider_type: remote::hf::endpoint
|
provider_type: remote::hf::endpoint
|
||||||
config:
|
config:
|
||||||
|
@ -81,10 +84,18 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: hf-endpoint
|
provider_id: hf-endpoint
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.SAFETY_MODEL}
|
model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: hf-endpoint-safety
|
provider_id: hf-endpoint-safety
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- params: null
|
- params: null
|
||||||
shield_id: ${env.SAFETY_MODEL}
|
shield_id: ${env.SAFETY_MODEL}
|
||||||
|
|
|
@ -18,6 +18,9 @@ providers:
|
||||||
config:
|
config:
|
||||||
endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
|
endpoint_name: ${env.INFERENCE_ENDPOINT_NAME}
|
||||||
api_token: ${env.HF_API_TOKEN}
|
api_token: ${env.HF_API_TOKEN}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
memory:
|
memory:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -76,6 +79,13 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: hf-endpoint
|
provider_id: hf-endpoint
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields: []
|
shields: []
|
||||||
memory_banks: []
|
memory_banks: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -4,7 +4,11 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
||||||
from llama_stack.providers.remote.inference.tgi import InferenceAPIImplConfig
|
from llama_stack.providers.remote.inference.tgi import InferenceAPIImplConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
@ -28,6 +32,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::hf::serverless",
|
provider_type="remote::hf::serverless",
|
||||||
config=InferenceAPIImplConfig.sample_run_config(),
|
config=InferenceAPIImplConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
memory_provider = Provider(
|
memory_provider = Provider(
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
|
@ -42,6 +51,14 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
model_id="${env.SAFETY_MODEL}",
|
model_id="${env.SAFETY_MODEL}",
|
||||||
provider_id="hf-serverless-safety",
|
provider_id="hf-serverless-safety",
|
||||||
)
|
)
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -54,15 +71,16 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
"memory": [memory_provider],
|
"memory": [memory_provider],
|
||||||
},
|
},
|
||||||
default_models=[inference_model],
|
default_models=[inference_model, embedding_model],
|
||||||
),
|
),
|
||||||
"run-with-safety.yaml": RunConfigSettings(
|
"run-with-safety.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [
|
"inference": [
|
||||||
inference_provider,
|
inference_provider,
|
||||||
|
embedding_provider,
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="hf-serverless-safety",
|
provider_id="hf-serverless-safety",
|
||||||
provider_type="remote::hf::serverless",
|
provider_type="remote::hf::serverless",
|
||||||
|
@ -76,6 +94,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
default_models=[
|
default_models=[
|
||||||
inference_model,
|
inference_model,
|
||||||
safety_model,
|
safety_model,
|
||||||
|
embedding_model,
|
||||||
],
|
],
|
||||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||||
),
|
),
|
||||||
|
|
|
@ -18,6 +18,9 @@ providers:
|
||||||
config:
|
config:
|
||||||
huggingface_repo: ${env.INFERENCE_MODEL}
|
huggingface_repo: ${env.INFERENCE_MODEL}
|
||||||
api_token: ${env.HF_API_TOKEN}
|
api_token: ${env.HF_API_TOKEN}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
- provider_id: hf-serverless-safety
|
- provider_id: hf-serverless-safety
|
||||||
provider_type: remote::hf::serverless
|
provider_type: remote::hf::serverless
|
||||||
config:
|
config:
|
||||||
|
@ -81,10 +84,18 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: hf-serverless
|
provider_id: hf-serverless
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.SAFETY_MODEL}
|
model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: hf-serverless-safety
|
provider_id: hf-serverless-safety
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- params: null
|
- params: null
|
||||||
shield_id: ${env.SAFETY_MODEL}
|
shield_id: ${env.SAFETY_MODEL}
|
||||||
|
|
|
@ -18,6 +18,9 @@ providers:
|
||||||
config:
|
config:
|
||||||
huggingface_repo: ${env.INFERENCE_MODEL}
|
huggingface_repo: ${env.INFERENCE_MODEL}
|
||||||
api_token: ${env.HF_API_TOKEN}
|
api_token: ${env.HF_API_TOKEN}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
memory:
|
memory:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -76,6 +79,13 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: hf-serverless
|
provider_id: hf-serverless
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields: []
|
shields: []
|
||||||
memory_banks: []
|
memory_banks: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -6,10 +6,15 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
from llama_stack.providers.inline.inference.meta_reference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
MetaReferenceInferenceConfig,
|
MetaReferenceInferenceConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
@ -34,6 +39,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}",
|
checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
memory_provider = Provider(
|
memory_provider = Provider(
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
|
@ -44,6 +54,14 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="meta-reference-inference",
|
provider_id="meta-reference-inference",
|
||||||
)
|
)
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
safety_model = ModelInput(
|
safety_model = ModelInput(
|
||||||
model_id="${env.SAFETY_MODEL}",
|
model_id="${env.SAFETY_MODEL}",
|
||||||
provider_id="meta-reference-safety",
|
provider_id="meta-reference-safety",
|
||||||
|
@ -59,15 +77,16 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
"memory": [memory_provider],
|
"memory": [memory_provider],
|
||||||
},
|
},
|
||||||
default_models=[inference_model],
|
default_models=[inference_model, embedding_model],
|
||||||
),
|
),
|
||||||
"run-with-safety.yaml": RunConfigSettings(
|
"run-with-safety.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [
|
"inference": [
|
||||||
inference_provider,
|
inference_provider,
|
||||||
|
embedding_provider,
|
||||||
Provider(
|
Provider(
|
||||||
provider_id="meta-reference-safety",
|
provider_id="meta-reference-safety",
|
||||||
provider_type="inline::meta-reference",
|
provider_type="inline::meta-reference",
|
||||||
|
@ -82,6 +101,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
default_models=[
|
default_models=[
|
||||||
inference_model,
|
inference_model,
|
||||||
safety_model,
|
safety_model,
|
||||||
|
embedding_model,
|
||||||
],
|
],
|
||||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||||
),
|
),
|
||||||
|
|
|
@ -19,6 +19,9 @@ providers:
|
||||||
model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL}
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
- provider_id: meta-reference-safety
|
- provider_id: meta-reference-safety
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
config:
|
config:
|
||||||
|
@ -83,10 +86,18 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: meta-reference-inference
|
provider_id: meta-reference-inference
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.SAFETY_MODEL}
|
model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: meta-reference-safety
|
provider_id: meta-reference-safety
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- params: null
|
- params: null
|
||||||
shield_id: ${env.SAFETY_MODEL}
|
shield_id: ${env.SAFETY_MODEL}
|
||||||
|
|
|
@ -19,6 +19,9 @@ providers:
|
||||||
model: ${env.INFERENCE_MODEL}
|
model: ${env.INFERENCE_MODEL}
|
||||||
max_seq_len: 4096
|
max_seq_len: 4096
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
memory:
|
memory:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -77,6 +80,13 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: meta-reference-inference
|
provider_id: meta-reference-inference
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields: []
|
shields: []
|
||||||
memory_banks: []
|
memory_banks: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -6,10 +6,15 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider
|
from llama_stack.distribution.datatypes import ModelInput, Provider
|
||||||
from llama_stack.providers.inline.inference.meta_reference import (
|
from llama_stack.providers.inline.inference.meta_reference import (
|
||||||
MetaReferenceQuantizedInferenceConfig,
|
MetaReferenceQuantizedInferenceConfig,
|
||||||
)
|
)
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
|
||||||
|
@ -34,6 +39,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}",
|
checkpoint_dir="${env.INFERENCE_CHECKPOINT_DIR:null}",
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
memory_provider = Provider(
|
memory_provider = Provider(
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
|
@ -44,6 +54,14 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="meta-reference-inference",
|
provider_id="meta-reference-inference",
|
||||||
)
|
)
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
|
@ -54,10 +72,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
"memory": [memory_provider],
|
"memory": [memory_provider],
|
||||||
},
|
},
|
||||||
default_models=[inference_model],
|
default_models=[inference_model, embedding_model],
|
||||||
),
|
),
|
||||||
},
|
},
|
||||||
run_config_env_vars={
|
run_config_env_vars={
|
||||||
|
|
|
@ -21,6 +21,9 @@ providers:
|
||||||
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
checkpoint_dir: ${env.INFERENCE_CHECKPOINT_DIR:null}
|
||||||
quantization:
|
quantization:
|
||||||
type: fp8
|
type: fp8
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
memory:
|
memory:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -79,6 +82,13 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: meta-reference-inference
|
provider_id: meta-reference-inference
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields: []
|
shields: []
|
||||||
memory_banks: []
|
memory_banks: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -6,7 +6,12 @@
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.apis.models.models import ModelType
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput
|
||||||
|
from llama_stack.providers.inline.inference.sentence_transformers import (
|
||||||
|
SentenceTransformersInferenceConfig,
|
||||||
|
)
|
||||||
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
from llama_stack.providers.inline.memory.faiss.config import FaissImplConfig
|
||||||
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
from llama_stack.providers.remote.inference.ollama import OllamaImplConfig
|
||||||
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings
|
||||||
|
@ -29,6 +34,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::ollama",
|
provider_type="remote::ollama",
|
||||||
config=OllamaImplConfig.sample_run_config(),
|
config=OllamaImplConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
embedding_provider = Provider(
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
provider_type="inline::sentence-transformers",
|
||||||
|
config=SentenceTransformersInferenceConfig.sample_run_config(),
|
||||||
|
)
|
||||||
memory_provider = Provider(
|
memory_provider = Provider(
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
provider_type="inline::faiss",
|
provider_type="inline::faiss",
|
||||||
|
@ -43,6 +53,14 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
model_id="${env.SAFETY_MODEL}",
|
model_id="${env.SAFETY_MODEL}",
|
||||||
provider_id="ollama",
|
provider_id="ollama",
|
||||||
)
|
)
|
||||||
|
embedding_model = ModelInput(
|
||||||
|
model_id="all-MiniLM-L6-v2",
|
||||||
|
provider_id="sentence-transformers",
|
||||||
|
model_type=ModelType.embedding,
|
||||||
|
metadata={
|
||||||
|
"embedding_dimension": 384,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name=name,
|
name=name,
|
||||||
|
@ -55,21 +73,23 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
run_configs={
|
run_configs={
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider, embedding_provider],
|
||||||
"memory": [memory_provider],
|
"memory": [memory_provider],
|
||||||
},
|
},
|
||||||
default_models=[inference_model],
|
default_models=[inference_model, embedding_model],
|
||||||
),
|
),
|
||||||
"run-with-safety.yaml": RunConfigSettings(
|
"run-with-safety.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [
|
"inference": [
|
||||||
inference_provider,
|
inference_provider,
|
||||||
|
embedding_provider,
|
||||||
],
|
],
|
||||||
"memory": [memory_provider],
|
"memory": [memory_provider],
|
||||||
},
|
},
|
||||||
default_models=[
|
default_models=[
|
||||||
inference_model,
|
inference_model,
|
||||||
safety_model,
|
safety_model,
|
||||||
|
embedding_model,
|
||||||
],
|
],
|
||||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}")],
|
||||||
),
|
),
|
||||||
|
|
|
@ -17,6 +17,9 @@ providers:
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
memory:
|
memory:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -75,10 +78,18 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: ollama
|
provider_id: ollama
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.SAFETY_MODEL}
|
model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: ollama
|
provider_id: ollama
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- params: null
|
- params: null
|
||||||
shield_id: ${env.SAFETY_MODEL}
|
shield_id: ${env.SAFETY_MODEL}
|
||||||
|
|
|
@ -17,6 +17,9 @@ providers:
|
||||||
provider_type: remote::ollama
|
provider_type: remote::ollama
|
||||||
config:
|
config:
|
||||||
url: ${env.OLLAMA_URL:http://localhost:11434}
|
url: ${env.OLLAMA_URL:http://localhost:11434}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
memory:
|
memory:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -75,6 +78,13 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: ollama
|
provider_id: ollama
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields: []
|
shields: []
|
||||||
memory_banks: []
|
memory_banks: []
|
||||||
datasets: []
|
datasets: []
|
||||||
|
|
|
@ -22,6 +22,9 @@ providers:
|
||||||
url: ${env.SAFETY_VLLM_URL}
|
url: ${env.SAFETY_VLLM_URL}
|
||||||
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
|
max_tokens: ${env.VLLM_MAX_TOKENS:4096}
|
||||||
api_token: ${env.VLLM_API_TOKEN:fake}
|
api_token: ${env.VLLM_API_TOKEN:fake}
|
||||||
|
- provider_id: sentence-transformers
|
||||||
|
provider_type: inline::sentence-transformers
|
||||||
|
config: {}
|
||||||
memory:
|
memory:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -58,10 +61,18 @@ models:
|
||||||
model_id: ${env.INFERENCE_MODEL}
|
model_id: ${env.INFERENCE_MODEL}
|
||||||
provider_id: vllm-inference
|
provider_id: vllm-inference
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
- metadata: {}
|
- metadata: {}
|
||||||
model_id: ${env.SAFETY_MODEL}
|
model_id: ${env.SAFETY_MODEL}
|
||||||
provider_id: vllm-safety
|
provider_id: vllm-safety
|
||||||
provider_model_id: null
|
provider_model_id: null
|
||||||
|
model_type: llm
|
||||||
|
- metadata:
|
||||||
|
embedding_dimension: 384
|
||||||
|
model_id: all-MiniLM-L6-v2
|
||||||
|
provider_id: sentence-transformers
|
||||||
|
provider_model_id: null
|
||||||
|
model_type: embedding
|
||||||
shields:
|
shields:
|
||||||
- params: null
|
- params: null
|
||||||
shield_id: ${env.SAFETY_MODEL}
|
shield_id: ${env.SAFETY_MODEL}
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue