Merge branch 'main' into merge_conflict

This commit is contained in:
karthikgutha 2024-10-25 07:45:36 -07:00 committed by GitHub
commit 14cd065b6c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
64 changed files with 1881 additions and 345 deletions

77
.github/ISSUE_TEMPLATE/bug.yml vendored Normal file
View file

@ -0,0 +1,77 @@
name: 🐛 Bug Report
description: Create a report to help us reproduce and fix the bug
body:
- type: markdown
attributes:
value: >
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the
existing and past issues](https://github.com/meta-llama/llama-stack/issues).
- type: textarea
id: system-info
attributes:
label: System Info
description: |
Please share your system info with us. You can use the following command to capture your environment information
python -m "torch.utils.collect_env"
placeholder: |
PyTorch version, CUDA version, GPU type, #num of GPUs...
validations:
required: true
- type: checkboxes
id: information-scripts-examples
attributes:
label: Information
description: 'The problem arises when using:'
options:
- label: "The official example scripts"
- label: "My own modified scripts"
- type: textarea
id: bug-description
attributes:
label: 🐛 Describe the bug
description: |
Please provide a clear and concise description of what the bug is.
Please also paste or describe the results you observe instead of the expected results.
placeholder: |
A clear and concise description of what the bug is.
```llama stack
# Command that you used for running the examples
```
Description of the results
validations:
required: true
- type: textarea
attributes:
label: Error logs
description: |
If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
placeholder: |
```
The error message you got, with the full traceback.
```
validations:
required: true
- type: textarea
id: expected-behavior
validations:
required: true
attributes:
label: Expected behavior
description: "A clear and concise description of what you would expect to happen."
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!

View file

@ -0,0 +1,31 @@
name: 🚀 Feature request
description: Submit a proposal/request for a new llama-stack feature
body:
- type: textarea
id: feature-pitch
attributes:
label: 🚀 The feature, motivation and pitch
description: >
A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
validations:
required: true
- type: textarea
id: alternatives
attributes:
label: Alternatives
description: >
A description of any alternative solutions or features you've considered, if any.
- type: textarea
id: additional-context
attributes:
label: Additional context
description: >
Add any other context or screenshots about the feature request.
- type: markdown
attributes:
value: >
Thanks for contributing 🎉!

31
.github/PULL_REQUEST_TEMPLATE.md vendored Normal file
View file

@ -0,0 +1,31 @@
# What does this PR do?
Closes # (issue)
## Feature/Issue validation/testing/test plan
Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
Please also list any relevant details for your test configuration or test plan.
- [ ] Test A
Logs for Test A
- [ ] Test B
Logs for Test B
## Sources
Please link relevant resources if necessary.
## Before submitting
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
- [ ] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
Pull Request section?
- [ ] Was this discussed/approved via a Github issue? Please add a link
to it if that's the case.
- [ ] Did you make sure to update the documentation with your changes?
- [ ] Did you write any new necessary tests?
Thanks for contributing 🎉!

1
.gitignore vendored
View file

@ -13,6 +13,7 @@ xcuserdata/
Package.resolved
*.pte
*.ipynb_checkpoints*
.idea
.venv/
.idea
_build

View file

@ -1,4 +1,4 @@
include requirements.txt
include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh
include llama_stack/distribution/templates/*.yaml
include distributions/*/build.yaml

View file

@ -65,23 +65,30 @@ A Distribution is where APIs and Providers are assembled together to provide a c
| Dell-TGI | [Local TGI + Chroma](https://hub.docker.com/repository/docker/llamastack/llamastack-local-tgi-chroma/general) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
## Installation
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack`
You have two ways to install this repository:
If you want to install from source:
1. **Install as a package**:
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
```bash
pip install llama-stack
```
```bash
mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-stack.git
2. **Install from source**:
If you prefer to install from the source code, follow these steps:
```bash
mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-stack.git
conda create -n stack python=3.10
conda activate stack
conda create -n stack python=3.10
conda activate stack
cd llama-stack
$CONDA_PREFIX/bin/pip install -e .
```
cd llama-stack
$CONDA_PREFIX/bin/pip install -e .
```
## Documentations

View file

@ -7,6 +7,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
| **Distribution** | **Llama Stack Docker** | Start This Distribution | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|:----------------: |:------------------------------------------: |:-----------------------: |:------------------: |:------------------: |:------------------: |:------------------: |:------------------: |
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](./meta-reference-gpu/) | meta-reference | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](./meta-reference-quantized-gpu/) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](./ollama/) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | remote::ollama | meta-reference |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](./tgi/) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](./together/) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |

View file

@ -1,5 +1,6 @@
name: meta-reference-gpu
distribution_spec:
docker_image: pytorch/pytorch
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference

View file

@ -0,0 +1,34 @@
# Meta Reference Quantized Distribution
The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations.
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|----------------- |------------------------ |---------------- |-------------------------------------------------- |---------------- |---------------- |
| **Provider(s)** | meta-reference-quantized | meta-reference | meta-reference, remote::pgvector, remote::chroma | meta-reference | meta-reference |
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
### Start the Distribution (Single Node GPU)
> [!NOTE]
> This assumes you have access to GPU to start a local server with access to your GPU.
> [!NOTE]
> `~/.llama` should be the path containing downloaded weights of Llama models.
To download and start running a pre-built docker container, you may use the following commands:
```
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama \
-v ./run.yaml:/root/my-run.yaml \
--gpus=all \
distribution-meta-reference-quantized-gpu \
--yaml_config /root/my-run.yaml
```
### Alternative (Build and start distribution locally via conda)
- You may checkout the [Getting Started](../../docs/getting_started.md) for more details on building locally via conda and starting up the distribution.

View file

@ -0,0 +1,14 @@
name: meta-reference-quantized-gpu
distribution_spec:
docker_image: pytorch/pytorch:2.5.0-cuda12.4-cudnn9-runtime
description: Use code from `llama_stack` itself to serve all llama stack APIs
providers:
inference: meta-reference-quantized
memory:
- meta-reference
- remote::chromadb
- remote::pgvector
safety: meta-reference
agents: meta-reference
telemetry: meta-reference
image_type: docker

View file

@ -0,0 +1,51 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: meta0
provider_type: meta-reference-quantized
config:
model: Llama3.2-3B-Instruct:int4-qlora-eo8
quantization:
type: int4
torch_seed: null
max_seq_len: 2048
max_batch_size: 1
safety:
- provider_id: meta0
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: meta-reference
config: {}
agents:
- provider_id: meta0
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta0
provider_type: meta-reference
config: {}

View file

@ -5,159 +5,174 @@ This guide will walk you though the steps to get started on end-to-end flow for
## Installation
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack`
You have two ways to install this repository:
If you want to install from source:
1. **Install as a package**:
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
```bash
pip install llama-stack
```
```bash
mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-stack.git
2. **Install from source**:
If you prefer to install from the source code, follow these steps:
```bash
mkdir -p ~/local
cd ~/local
git clone git@github.com:meta-llama/llama-stack.git
conda create -n stack python=3.10
conda activate stack
conda create -n stack python=3.10
conda activate stack
cd llama-stack
$CONDA_PREFIX/bin/pip install -e .
```
cd llama-stack
$CONDA_PREFIX/bin/pip install -e .
```
For what you can do with the Llama CLI, please refer to [CLI Reference](./cli_reference.md).
## Starting Up Llama Stack Server
#### Starting up server via docker
We provide 2 pre-built Docker image of Llama Stack distribution, which can be found in the following links.
- [llamastack-local-gpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general)
- This is a packaged version with our local meta-reference implementations, where you will be running inference locally with downloaded Llama model checkpoints.
- [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general)
- This is a lite version with remote inference where you can hook up to your favourite remote inference framework (e.g. ollama, fireworks, together, tgi) for running inference without GPU.
You have two ways to start up Llama stack server:
> [!NOTE]
> For GPU inference, you need to set these environment variables for specifying local directory containing your model checkpoints, and enable GPU inference to start running docker container.
```
export LLAMA_CHECKPOINT_DIR=~/.llama
```
1. **Starting up server via docker**:
> [!NOTE]
> `~/.llama` should be the path containing downloaded weights of Llama models.
We provide 2 pre-built Docker image of Llama Stack distribution, which can be found in the following links.
- [llamastack-local-gpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general)
- This is a packaged version with our local meta-reference implementations, where you will be running inference locally with downloaded Llama model checkpoints.
- [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general)
- This is a lite version with remote inference where you can hook up to your favourite remote inference framework (e.g. ollama, fireworks, together, tgi) for running inference without GPU.
> [!NOTE]
> For GPU inference, you need to set these environment variables for specifying local directory containing your model checkpoints, and enable GPU inference to start running docker container.
```
export LLAMA_CHECKPOINT_DIR=~/.llama
```
> [!NOTE]
> `~/.llama` should be the path containing downloaded weights of Llama models.
To download llama models, use
```
llama download --model-id Llama3.1-8B-Instruct
```
To download and start running a pre-built docker container, you may use the following commands:
```
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack/llamastack-local-gpu
```
> [!TIP]
> Pro Tip: We may use `docker compose up` for starting up a distribution with remote providers (e.g. TGI) using [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general). You can checkout [these scripts](../distributions/) to help you get started.
To download and start running a pre-built docker container, you may use the following commands:
2. **Build->Configure->Run Llama Stack server via conda**:
```
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack/llamastack-local-gpu
```
You may also build a LlamaStack distribution from scratch, configure it, and start running the distribution. This is useful for developing on LlamaStack.
> [!TIP]
> Pro Tip: We may use `docker compose up` for starting up a distribution with remote providers (e.g. TGI) using [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general). You can checkout [these scripts](../distributions/) to help you get started.
**`llama stack build`**
- You'll be prompted to enter build information interactively.
```
llama stack build
#### Build->Configure->Run Llama Stack server via conda
You may also build a LlamaStack distribution from scratch, configure it, and start running the distribution. This is useful for developing on LlamaStack.
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack
> Enter the image type you want your distribution to be built with (docker or conda): conda
**`llama stack build`**
- You'll be prompted to enter build information interactively.
```
llama stack build
Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.
> Enter the API provider for the inference API: (default=meta-reference): meta-reference
> Enter the API provider for the safety API: (default=meta-reference): meta-reference
> Enter the API provider for the agents API: (default=meta-reference): meta-reference
> Enter the API provider for the memory API: (default=meta-reference): meta-reference
> Enter the API provider for the telemetry API: (default=meta-reference): meta-reference
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack
> Enter the image type you want your distribution to be built with (docker or conda): conda
> (Optional) Enter a short description for your Llama Stack distribution:
Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.
> Enter the API provider for the inference API: (default=meta-reference): meta-reference
> Enter the API provider for the safety API: (default=meta-reference): meta-reference
> Enter the API provider for the agents API: (default=meta-reference): meta-reference
> Enter the API provider for the memory API: (default=meta-reference): meta-reference
> Enter the API provider for the telemetry API: (default=meta-reference): meta-reference
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml
You can now run `llama stack configure my-local-stack`
```
> (Optional) Enter a short description for your Llama Stack distribution:
**`llama stack configure`**
- Run `llama stack configure <name>` with the name you have previously defined in `build` step.
```
llama stack configure <name>
```
- You will be prompted to enter configurations for your Llama Stack
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml
You can now run `llama stack configure my-local-stack`
```
```
$ llama stack configure my-local-stack
**`llama stack configure`**
- Run `llama stack configure <name>` with the name you have previously defined in `build` step.
```
llama stack configure <name>
```
- You will be prompted to enter configurations for your Llama Stack
Could not find my-local-stack. Trying conda build name instead...
Configuring API `inference`...
=== Configuring provider `meta-reference` for API inference...
Enter value for model (default: Llama3.1-8B-Instruct) (required):
Do you want to configure quantization? (y/n): n
Enter value for torch_seed (optional):
Enter value for max_seq_len (default: 4096) (required):
Enter value for max_batch_size (default: 1) (required):
```
$ llama stack configure my-local-stack
Configuring API `safety`...
=== Configuring provider `meta-reference` for API safety...
Do you want to configure llama_guard_shield? (y/n): n
Do you want to configure prompt_guard_shield? (y/n): n
Could not find my-local-stack. Trying conda build name instead...
Configuring API `inference`...
=== Configuring provider `meta-reference` for API inference...
Enter value for model (default: Llama3.1-8B-Instruct) (required):
Do you want to configure quantization? (y/n): n
Enter value for torch_seed (optional):
Enter value for max_seq_len (default: 4096) (required):
Enter value for max_batch_size (default: 1) (required):
Configuring API `agents`...
=== Configuring provider `meta-reference` for API agents...
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite):
Configuring API `safety`...
=== Configuring provider `meta-reference` for API safety...
Do you want to configure llama_guard_shield? (y/n): n
Do you want to configure prompt_guard_shield? (y/n): n
Configuring SqliteKVStoreConfig:
Enter value for namespace (optional):
Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required):
Configuring API `agents`...
=== Configuring provider `meta-reference` for API agents...
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite):
Configuring API `memory`...
=== Configuring provider `meta-reference` for API memory...
> Please enter the supported memory bank type your provider has for memory: vector
Configuring SqliteKVStoreConfig:
Enter value for namespace (optional):
Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required):
Configuring API `telemetry`...
=== Configuring provider `meta-reference` for API telemetry...
Configuring API `memory`...
=== Configuring provider `meta-reference` for API memory...
> Please enter the supported memory bank type your provider has for memory: vector
> YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml.
You can now run `llama stack run my-local-stack --port PORT`
```
Configuring API `telemetry`...
=== Configuring provider `meta-reference` for API telemetry...
**`llama stack run`**
- Run `llama stack run <name>` with the name you have previously defined.
```
llama stack run my-local-stack
> YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml.
You can now run `llama stack run my-local-stack --port PORT`
```
**`llama stack run`**
- Run `llama stack run <name>` with the name you have previously defined.
```
llama stack run my-local-stack
...
> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
...
Finished model load YES READY
Serving POST /inference/chat_completion
Serving POST /inference/completion
Serving POST /inference/embeddings
Serving POST /memory_banks/create
Serving DELETE /memory_bank/documents/delete
Serving DELETE /memory_banks/drop
Serving GET /memory_bank/documents/get
Serving GET /memory_banks/get
Serving POST /memory_bank/insert
Serving GET /memory_banks/list
Serving POST /memory_bank/query
Serving POST /memory_bank/update
Serving POST /safety/run_shield
Serving POST /agentic_system/create
Serving POST /agentic_system/session/create
Serving POST /agentic_system/turn/create
Serving POST /agentic_system/delete
Serving POST /agentic_system/session/delete
Serving POST /agentic_system/session/get
Serving POST /agentic_system/step/get
Serving POST /agentic_system/turn/get
Serving GET /telemetry/get_trace
Serving POST /telemetry/log_event
Listening on :::5000
INFO: Started server process [587053]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
```
...
> initializing model parallel with size 1
> initializing ddp with size 1
> initializing pipeline with size 1
...
Finished model load YES READY
Serving POST /inference/chat_completion
Serving POST /inference/completion
Serving POST /inference/embeddings
Serving POST /memory_banks/create
Serving DELETE /memory_bank/documents/delete
Serving DELETE /memory_banks/drop
Serving GET /memory_bank/documents/get
Serving GET /memory_banks/get
Serving POST /memory_bank/insert
Serving GET /memory_banks/list
Serving POST /memory_bank/query
Serving POST /memory_bank/update
Serving POST /safety/run_shield
Serving POST /agentic_system/create
Serving POST /agentic_system/session/create
Serving POST /agentic_system/turn/create
Serving POST /agentic_system/delete
Serving POST /agentic_system/session/delete
Serving POST /agentic_system/session/get
Serving POST /agentic_system/step/get
Serving POST /agentic_system/turn/get
Serving GET /telemetry/get_trace
Serving POST /telemetry/log_event
Listening on :::5000
INFO: Started server process [587053]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
```
## Testing with client

View file

@ -0,0 +1,103 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from pathlib import Path
from typing import Optional
import fire
import httpx
from termcolor import cprint
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasets.client import DatasetsClient
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class DatasetIOClient(DatasetIO):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def get_rows_paginated(
self,
dataset_id: str,
rows_in_page: int,
page_token: Optional[str] = None,
filter_condition: Optional[str] = None,
) -> PaginatedRowsResult:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasetio/get_rows_paginated",
params={
"dataset_id": dataset_id,
"rows_in_page": rows_in_page,
"page_token": page_token,
"filter_condition": filter_condition,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return PaginatedRowsResult(**response.json())
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
# datsetio client to get the rows
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
response = await datasetio_client.get_rows_paginated(
dataset_id="test-dataset",
rows_in_page=4,
page_token=None,
filter_condition=None,
)
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -29,7 +29,7 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore
@webmethod(route="/dataio/get_rows_paginated")
@webmethod(route="/datasetio/get_rows_paginated", method="GET")
async def get_rows_paginated(
self,
dataset_id: str,

View file

@ -0,0 +1,116 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import json
import os
from pathlib import Path
from typing import Optional
import fire
import httpx
from termcolor import cprint
from .datasets import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class DatasetsClient(Datasets):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def register_dataset(
self,
dataset_def: DatasetDefWithProvider,
) -> None:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/datasets/register",
json={
"dataset_def": json.loads(dataset_def.json()),
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
return
async def get_dataset(
self,
dataset_identifier: str,
) -> Optional[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/get",
params={
"dataset_identifier": dataset_identifier,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return DatasetDefWithProvider(**response.json())
async def list_datasets(self) -> List[DatasetDefWithProvider]:
async with httpx.AsyncClient() as client:
response = await client.get(
f"{self.base_url}/datasets/list",
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return [DatasetDefWithProvider(**x) for x in response.json()]
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -20,7 +20,7 @@ class DatasetDef(BaseModel):
identifier: str = Field(
description="A unique name for the dataset",
)
columns_schema: Dict[str, ParamType] = Field(
dataset_schema: Dict[str, ParamType] = Field(
description="The schema definition for this dataset",
)
url: URL

View file

@ -172,7 +172,7 @@ async def run_mm_main(
],
)
cprint(f"User>{message.content}", "green")
iterator = client.chat_completion(
iterator = await client.chat_completion(
model=model,
messages=[message],
stream=stream,

View file

@ -25,6 +25,7 @@ class LogProbConfig(BaseModel):
class QuantizationType(Enum):
bf16 = "bf16"
fp8 = "fp8"
int4 = "int4"
@json_schema_type
@ -37,8 +38,14 @@ class Bf16QuantizationConfig(BaseModel):
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
@json_schema_type
class Int4QuantizationConfig(BaseModel):
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
QuantizationConfig = Annotated[
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
Field(discriminator="type"),
]
@ -228,8 +235,6 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
# This method is not `async def` because it can result in either an
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
@webmethod(route="/inference/chat_completion")
async def chat_completion(
self,

View file

@ -0,0 +1,132 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import os
from pathlib import Path
import fire
import httpx
from termcolor import cprint
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio.client import DatasetIOClient
from llama_stack.apis.datasets.client import DatasetsClient
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
class ScoringClient(Scoring):
def __init__(self, base_url: str):
self.base_url = base_url
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
) -> ScoreBatchResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/scoring/score_batch",
json={
"dataset_id": dataset_id,
"scoring_functions": scoring_functions,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return ScoreBatchResponse(**response.json())
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
async with httpx.AsyncClient() as client:
response = await client.post(
f"{self.base_url}/scoring/score",
json={
"input_rows": input_rows,
"scoring_functions": scoring_functions,
},
headers={"Content-Type": "application/json"},
timeout=60,
)
response.raise_for_status()
if not response.json():
return
return ScoreResponse(**response.json())
async def run_main(host: str, port: int):
client = DatasetsClient(f"http://{host}:{port}")
# register dataset
test_file = (
Path(os.path.abspath(__file__)).parent.parent.parent
/ "providers/tests/datasetio/test_dataset.csv"
)
test_url = data_url_from_file(str(test_file))
response = await client.register_dataset(
DatasetDefWithProvider(
identifier="test-dataset",
provider_id="meta0",
url=URL(
uri=test_url,
),
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
)
# list datasets
list_dataset = await client.list_datasets()
cprint(list_dataset, "blue")
# datsetio client to get the rows
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
response = await datasetio_client.get_rows_paginated(
dataset_id="test-dataset",
rows_in_page=4,
page_token=None,
filter_condition=None,
)
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
# scoring client to score the rows
scoring_client = ScoringClient(f"http://{host}:{port}")
response = await scoring_client.score(
input_rows=response.rows,
scoring_functions=["equality"],
)
cprint(f"score response={response}", "blue")
# test scoring batch using datasetio api
scoring_client = ScoringClient(f"http://{host}:{port}")
response = await scoring_client.score_batch(
dataset_id="test-dataset",
scoring_functions=["equality"],
)
cprint(f"score_batch response={response}", "cyan")
def main(host: str, port: int):
asyncio.run(run_main(host, port))
if __name__ == "__main__":
fire.Fire(main)

View file

@ -13,18 +13,27 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
ScoringResult = Dict[str, Any]
# mapping of metric to value
ScoringResultRow = Dict[str, Any]
@json_schema_type
class ScoringResult(BaseModel):
score_rows: List[ScoringResultRow]
# aggregated metrics to value
aggregated_results: Dict[str, Any]
@json_schema_type
class ScoreBatchResponse(BaseModel):
dataset_id: str
dataset_id: Optional[str] = None
results: Dict[str, ScoringResult]
@json_schema_type
class ScoreResponse(BaseModel):
# each key in the dict is a scoring function name
results: List[Dict[str, ScoringResult]]
results: Dict[str, ScoringResult]
class ScoringFunctionStore(Protocol):
@ -37,7 +46,10 @@ class Scoring(Protocol):
@webmethod(route="/scoring/score_batch")
async def score_batch(
self, dataset_id: str, scoring_functions: List[str]
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score")

View file

@ -4,20 +4,10 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType
@ -33,45 +23,37 @@ class Parameter(BaseModel):
# with standard metrics so they can be rolled up?
class LLMAsJudgeContext(BaseModel):
judge_model: str
prompt_template: Optional[str] = None
@json_schema_type
class CommonDef(BaseModel):
name: str
class ScoringFunctionDef(BaseModel):
identifier: str
description: Optional[str] = None
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional metadata for this definition",
)
# Hack: same with memory_banks for union defs
provider_id: str = ""
@json_schema_type
class DeterministicFunctionDef(CommonDef):
type: Literal["deterministic"] = "deterministic"
parameters: List[Parameter] = Field(
description="List of parameters for the deterministic function",
default_factory=list,
)
return_type: ParamType = Field(
description="The return type of the deterministic function",
)
context: Optional[LLMAsJudgeContext] = None
# We can optionally add information here to support packaging of code, etc.
@json_schema_type
class LLMJudgeFunctionDef(CommonDef):
type: Literal["judge"] = "judge"
model: str = Field(
description="The LLM model to use for the judge function",
class ScoringFunctionDefWithProvider(ScoringFunctionDef):
provider_id: str = Field(
description="ID of the provider which serves this dataset",
)
ScoringFunctionDef = Annotated[
Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type")
]
ScoringFunctionDefWithProvider = ScoringFunctionDef
@runtime_checkable
class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET")
@ -84,5 +66,5 @@ class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/register", method="POST")
async def register_scoring_function(
self, function: ScoringFunctionDefWithProvider
self, function_def: ScoringFunctionDefWithProvider
) -> None: ...

View file

@ -97,7 +97,7 @@ if [ -n "$pip_dependencies" ]; then
fi
if [ -n "$special_pip_deps" ]; then
IFS='#' read -ra parts <<< "$special_pip_deps"
IFS='#' read -ra parts <<<"$special_pip_deps"
for part in "${parts[@]}"; do
add_to_docker "RUN pip install $part"
done
@ -127,7 +127,7 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
fi
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
DOCKER_OPTS="$DOCKER_OPTS --security-opt label=disable"
fi
@ -139,4 +139,4 @@ $DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REP
rm -rf $REPO_CONFIGS_DIR
set +x
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"
echo "Success!"

View file

@ -15,10 +15,12 @@ from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.inference import Inference
from llama_stack.apis.memory import Memory
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
@ -32,6 +34,7 @@ RoutableObject = Union[
ShieldDef,
MemoryBankDef,
DatasetDef,
ScoringFunctionDef,
]
RoutableObjectWithProvider = Union[
@ -39,6 +42,7 @@ RoutableObjectWithProvider = Union[
ShieldDefWithProvider,
MemoryBankDefWithProvider,
DatasetDefWithProvider,
ScoringFunctionDefWithProvider,
]
RoutedProtocol = Union[
@ -46,6 +50,7 @@ RoutedProtocol = Union[
Safety,
Memory,
DatasetIO,
Scoring,
]

View file

@ -39,6 +39,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
routing_table_api=Api.datasets,
router_api=Api.datasetio,
),
AutoRoutedApiInfo(
routing_table_api=Api.scoring_functions,
router_api=Api.scoring,
),
]

View file

@ -20,6 +20,8 @@ from llama_stack.apis.memory import Memory
from llama_stack.apis.memory_banks import MemoryBanks
from llama_stack.apis.models import Models
from llama_stack.apis.safety import Safety
from llama_stack.apis.scoring import Scoring
from llama_stack.apis.scoring_functions import ScoringFunctions
from llama_stack.apis.shields import Shields
from llama_stack.apis.telemetry import Telemetry
from llama_stack.distribution.distribution import (
@ -42,6 +44,8 @@ def api_protocol_map() -> Dict[Api, Any]:
Api.telemetry: Telemetry,
Api.datasets: Datasets,
Api.datasetio: DatasetIO,
Api.scoring_functions: ScoringFunctions,
Api.scoring: Scoring,
}

View file

@ -11,6 +11,7 @@ from .routing_tables import (
DatasetsRoutingTable,
MemoryBanksRoutingTable,
ModelsRoutingTable,
ScoringFunctionsRoutingTable,
ShieldsRoutingTable,
)
@ -25,7 +26,9 @@ async def get_routing_table_impl(
"models": ModelsRoutingTable,
"shields": ShieldsRoutingTable,
"datasets": DatasetsRoutingTable,
"scoring_functions": ScoringFunctionsRoutingTable,
}
if api.value not in api_to_tables:
raise ValueError(f"API {api.value} not found in router map")
@ -35,13 +38,20 @@ async def get_routing_table_impl(
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter
from .routers import (
DatasetIORouter,
InferenceRouter,
MemoryRouter,
SafetyRouter,
ScoringRouter,
)
api_to_routers = {
"memory": MemoryRouter,
"inference": InferenceRouter,
"safety": SafetyRouter,
"datasetio": DatasetIORouter,
"scoring": ScoringRouter,
}
if api.value not in api_to_routers:
raise ValueError(f"API {api.value} not found in router map")

View file

@ -13,6 +13,7 @@ from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
class MemoryRouter(Memory):
@ -198,3 +199,56 @@ class DatasetIORouter(DatasetIO):
page_token=page_token,
filter_condition=filter_condition,
)
class ScoringRouter(Scoring):
def __init__(
self,
routing_table: RoutingTable,
) -> None:
self.routing_table = routing_table
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
res = {}
for fn_identifier in scoring_functions:
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score_batch(
dataset_id=dataset_id,
scoring_functions=[fn_identifier],
)
res.update(score_response.results)
if save_results_dataset:
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res,
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
res = {}
# look up and map each scoring function to its provider impl
for fn_identifier in scoring_functions:
score_response = await self.routing_table.get_provider_impl(
fn_identifier
).score(
input_rows=input_rows,
scoring_functions=[fn_identifier],
)
res.update(score_response.results)
return ScoreResponse(results=res)

View file

@ -30,6 +30,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
await p.register_memory_bank(obj)
elif api == Api.datasetio:
await p.register_dataset(obj)
elif api == Api.scoring:
await p.register_scoring_function(obj)
else:
raise ValueError(f"Unknown API {api} for registering object with provider")
@ -93,7 +95,15 @@ class CommonRoutingTableImpl(RoutingTable):
for d in datasets:
d.provider_id = pid
add_objects(datasets)
elif api == Api.scoring:
p.scoring_function_store = self
scoring_functions = await p.list_scoring_functions()
add_objects(
[
ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid)
for s in scoring_functions
]
)
async def shutdown(self) -> None:
for p in self.impls_by_provider_id.values():
@ -109,6 +119,10 @@ class CommonRoutingTableImpl(RoutingTable):
return ("Safety", "shield")
elif isinstance(self, MemoryBanksRoutingTable):
return ("Memory", "memory_bank")
elif isinstance(self, DatasetsRoutingTable):
return ("DatasetIO", "dataset")
elif isinstance(self, ScoringFunctionsRoutingTable):
return ("Scoring", "scoring_function")
else:
raise ValueError("Unknown routing table type")
@ -218,7 +232,25 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
async def get_dataset(
self, dataset_identifier: str
) -> Optional[DatasetDefWithProvider]:
return self.get_object_by_identifier(identifier)
return self.get_object_by_identifier(dataset_identifier)
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
await self.register_object(dataset_def)
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]:
objects = []
for objs in self.registry.values():
objects.extend(objs)
return objects
async def get_scoring_function(
self, name: str
) -> Optional[ScoringFunctionDefWithProvider]:
return self.get_object_by_identifier(name)
async def register_scoring_function(
self, function_def: ScoringFunctionDefWithProvider
) -> None:
await self.register_object(function_def)

View file

@ -337,7 +337,8 @@ def main(
import uvicorn
# FYI this does not do hot-reloads
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
listen_host = ["::", "0.0.0.0"] if not disable_ipv6 else "0.0.0.0"
print(f"Listening on {listen_host}:{port}")
uvicorn.run(app, host=listen_host, port=port)

View file

@ -29,7 +29,7 @@ if [ $# -lt 3 ]; then
fi
build_name="$1"
docker_image="llamastack-$build_name"
docker_image="distribution-$build_name"
shift
yaml_config="$1"

View file

@ -116,7 +116,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
"model": self.map_to_provider_model(request.model),
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request),
**get_sampling_options(request.sampling_params),
}
async def embeddings(

View file

@ -116,7 +116,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
if prompt.startswith("<|begin_of_text|>"):
prompt = prompt[len("<|begin_of_text|>") :]
options = get_sampling_options(request)
options = get_sampling_options(request.sampling_params)
options.setdefault("max_tokens", 512)
if fmt := request.response_format:

View file

@ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return await self._nonstream_completion(request)
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
sampling_options = get_sampling_options(request)
sampling_options = get_sampling_options(request.sampling_params)
# This is needed since the Ollama API expects num_predict to be set
# for early truncation instead of max_tokens.
if sampling_options["max_tokens"] is not None:
@ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
return {
"model": OLLAMA_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"options": get_sampling_options(request),
"options": get_sampling_options(request.sampling_params),
"raw": True,
"stream": request.stream,
}

View file

@ -24,9 +24,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionResponse,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_model_input_info,
completion_request_to_prompt_model_input_info,
)
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
@ -75,7 +78,98 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
raise NotImplementedError()
request = CompletionRequest(
model=model,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_max_new_tokens(self, sampling_params, input_tokens):
return min(
sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
def _build_options(
self,
sampling_params: Optional[SamplingParams] = None,
fmt: ResponseFormat = None,
):
options = get_sampling_options(sampling_params)
# delete key "max_tokens" from options since its not supported by the API
options.pop("max_tokens", None)
if fmt:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return options
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
prompt, input_tokens = completion_request_to_prompt_model_input_info(
request, self.formatter
)
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**self._build_options(request.sampling_params, request.response_format),
)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
async def _generate_and_convert_to_openai_compat():
s = await self.client.text_generation(**params)
async for chunk in s:
token_result = chunk.token
finish_reason = None
if chunk.details:
finish_reason = chunk.details.finish_reason
choice = OpenAICompatCompletionChoice(
text=token_result.text, finish_reason=finish_reason
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream, self.formatter):
yield chunk
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = self._get_params_for_completion(request)
r = await self.client.text_generation(**params)
choice = OpenAICompatCompletionChoice(
finish_reason=r.details.finish_reason,
text="".join(t.text for t in r.details.tokens),
)
response = OpenAICompatCompletionResponse(
choices=[choice],
)
return process_completion_response(response, self.formatter)
async def chat_completion(
self,
@ -146,29 +240,15 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
prompt, input_tokens = chat_completion_request_to_model_input_info(
request, self.formatter
)
max_new_tokens = min(
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
self.max_tokens - input_tokens - 1,
)
options = get_sampling_options(request)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["grammar"] = {
"type": "json",
"value": fmt.schema,
}
elif fmt.type == ResponseFormatType.grammar.value:
raise ValueError("Grammar response format not supported yet")
else:
raise ValueError(f"Unexpected response format: {fmt.type}")
return dict(
prompt=prompt,
stream=request.stream,
details=True,
max_new_tokens=max_new_tokens,
max_new_tokens=self._get_max_new_tokens(
request.sampling_params, input_tokens
),
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
**options,
**self._build_options(request.sampling_params, request.response_format),
)
async def embeddings(

View file

@ -131,7 +131,7 @@ class TogetherInferenceAdapter(
yield chunk
def _get_params(self, request: ChatCompletionRequest) -> dict:
options = get_sampling_options(request)
options = get_sampling_options(request.sampling_params)
if fmt := request.response_format:
if fmt.type == ResponseFormatType.json_schema.value:
options["response_format"] = {

View file

@ -143,7 +143,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
"model": VLLM_SUPPORTED_MODELS[request.model],
"prompt": chat_completion_request_to_prompt(request, self.formatter),
"stream": request.stream,
**get_sampling_options(request),
**get_sampling_options(request.sampling_params),
}
async def embeddings(

View file

@ -11,10 +11,9 @@ from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
from llama_stack.apis.datasets import DatasetDef
from llama_stack.apis.memory_banks import MemoryBankDef
from llama_stack.apis.models import ModelDef
from llama_stack.apis.scoring_functions import ScoringFunctionDef
from llama_stack.apis.shields import ShieldDef
@ -25,6 +24,7 @@ class Api(Enum):
agents = "agents"
memory = "memory"
datasetio = "datasetio"
scoring = "scoring"
telemetry = "telemetry"
@ -32,6 +32,7 @@ class Api(Enum):
shields = "shields"
memory_banks = "memory_banks"
datasets = "datasets"
scoring_functions = "scoring_functions"
# built-in API
inspect = "inspect"
@ -61,6 +62,14 @@ class DatasetsProtocolPrivate(Protocol):
async def register_datasets(self, dataset_def: DatasetDef) -> None: ...
class ScoringFunctionsProtocolPrivate(Protocol):
async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ...
async def register_scoring_function(
self, function_def: ScoringFunctionDef
) -> None: ...
@json_schema_type
class ProviderSpec(BaseModel):
api: Api

View file

@ -56,9 +56,20 @@ We're working on making LocalInference easier to set up. For now, you'll need t
## Preparing a model
1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md#step-2-prepare-model)
1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md#step-2-prepare-model)
2. Bundle the `.pte` and `tokenizer.model` file into your app
We now support models quantized using SpinQuant and QAT-LoRA which offer a significant performance boost (demo app on iPhone 13 Pro):
| Llama 3.2 1B | Tokens / Second (total) | | Time-to-First-Token (sec) | |
| :---- | :---- | :---- | :---- | :---- |
| | Haiku | Paragraph | Haiku | Paragraph |
| BF16 | 2.2 | 2.5 | 2.3 | 1.9 |
| QAT+LoRA | 7.1 | 3.3 | 0.37 | 0.24 |
| SpinQuant | 10.1 | 5.2 | 0.2 | 0.2 |
## Using LocalInference
1. Instantiate LocalInference with a DispatchQueue. Optionally, pass it into your agents service:

View file

@ -169,7 +169,7 @@ class MetaReferenceAgentsImpl(Agents):
turn_ids: Optional[List[str]] = None,
) -> Session:
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
session = Session(**json.loads(session))
session = Session(**json.loads(session), turns=[])
turns = []
if turn_ids:
for turn_id in turn_ids:

View file

@ -3,17 +3,20 @@
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import io
from typing import List, Optional
import pandas
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
import base64
from abc import ABC, abstractmethod
from dataclasses import dataclass
from urllib.parse import unquote
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
from llama_stack.providers.utils.memory.vector_store import parse_data_url
from .config import MetaReferenceDatasetIOConfig
@ -52,11 +55,20 @@ class PandasDataframeDataset(BaseDataset):
return len(self.df)
def __getitem__(self, idx):
assert self.df is not None, "Dataset not loaded. Please call .load() first"
if isinstance(idx, slice):
return self.df.iloc[idx].to_dict(orient="records")
else:
return self.df.iloc[idx].to_dict()
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
# note that we will drop any columns in dataset that are not in the schema
df = df[self.dataset_def.dataset_schema.keys()]
# check all columns in dataset schema are present
assert len(df.columns) == len(self.dataset_def.dataset_schema)
# TODO: type checking against column types in dataset schema
return df
def load(self) -> None:
if self.df is not None:
return
@ -87,7 +99,7 @@ class PandasDataframeDataset(BaseDataset):
else:
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
self.df = df
self.df = self._validate_dataset_schema(df)
class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
@ -123,7 +135,10 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
dataset_info = self.dataset_infos.get(dataset_id)
dataset_info.dataset_impl.load()
if page_token is None:
if page_token and not page_token.isnumeric():
raise ValueError("Invalid page_token")
if page_token is None or len(page_token) == 0:
next_page_token = 0
else:
next_page_token = int(page_token)

View file

@ -30,7 +30,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
CrossAttentionTransformer,
)
from llama_models.sku_list import resolve_model
from pydantic import BaseModel
from termcolor import cprint
@ -43,7 +42,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_messages,
)
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
from .config import (
Fp8QuantizationConfig,
Int4QuantizationConfig,
MetaReferenceInferenceConfig,
MetaReferenceQuantizedInferenceConfig,
)
def model_checkpoint_dir(model) -> str:
@ -131,18 +135,34 @@ class Llama:
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
from .quantization.loader import convert_to_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
if isinstance(config.quantization, Fp8QuantizationConfig):
from .quantization.loader import convert_to_fp8_quantized_model
# load on CPU in bf16 so that fp8 conversion does not find an
# unexpected (fp32, e.g.) datatype
torch.set_default_tensor_type(torch.BFloat16Tensor)
if model_args.vision_chunk_size > 0:
model = CrossAttentionTransformer(model_args)
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
else:
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
elif isinstance(config.quantization, Int4QuantizationConfig):
from .quantization.loader import convert_to_int4_quantized_model
assert (
config.quantization.scheme is not None
), "Please specify a quantization scheme."
model = Transformer(model_args)
model.load_state_dict(state_dict, strict=False)
model = convert_to_quantized_model(model, config, ckpt_dir)
model = convert_to_int4_quantized_model(model, model_args, config)
model.load_state_dict(state_dict, strict=True)
else:
raise NotImplementedError(
"Currently int4 and fp8 are the only supported quantization methods."
)
else:
if torch.cuda.is_bf16_supported():
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)

View file

@ -8,19 +8,25 @@
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
import os
from typing import Optional
from typing import Any, Dict, List, Optional
import torch
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.datatypes import CheckpointQuantizationFormat
from llama_models.llama3.api.args import ModelArgs
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
from llama_models.sku_list import resolve_model
from termcolor import cprint
from torch import Tensor
from torch import nn, Tensor
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
from llama_stack.apis.inference import QuantizationType
from llama_stack.apis.inference.inference import Int4QuantizationConfig
from llama_stack.providers.impls.meta_reference.inference.config import (
MetaReferenceQuantizedInferenceConfig,
@ -37,7 +43,7 @@ def swiglu_wrapper(
return reduce_from_model_parallel_region(out)
def convert_to_quantized_model(
def convert_to_fp8_quantized_model(
model: Transformer,
config: MetaReferenceQuantizedInferenceConfig,
checkpoint_dir: str,
@ -99,3 +105,241 @@ def convert_to_quantized_model(
if not isinstance(parameter, Fp8ScaledWeights):
parameter.data = parameter.to(device="cuda")
return model
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
"""
Int8DynActInt4WeightLinear with LoRA adaptor.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
device: Device to use.
group_size: Group size for quantization.
precision: Precision of quantization.
scales_precision: Precision of scales.
lora_rank: Rank of LoRA adaptor.
lora_scale: Scale of LoRA adaptor.
"""
def __init__(
self,
in_features: int,
out_features: int,
bias=False,
device=None,
# quantization parameters
group_size: int = 256,
precision: torch.dtype = torch.float32,
scales_precision: torch.dtype = torch.float32,
# LoRA parameters
lora_rank: Optional[int] = None,
lora_scale: Optional[float] = None,
) -> None:
super().__init__(
in_features,
out_features,
bias=bias,
device=device,
groupsize=group_size,
precision=precision,
scales_precision=scales_precision,
)
if lora_rank is not None:
assert lora_scale is not None, "Please specify lora scale for LoRA."
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
self.adaptor = nn.Sequential()
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
self.lora_scale = lora_scale
else:
self.adaptor = None
self.lora_scale = None
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized weights from the state dict."""
if prefix + "zeros" not in state_dict:
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
assert prefix + "scales" in state_dict
state_dict[prefix + "zeros"] = torch.zeros_like(
state_dict[prefix + "scales"]
)
def forward(self, input_: torch.Tensor) -> torch.Tensor:
module_out = super().forward(input_)
if self.adaptor is not None:
adaptor_out = self.adaptor(input_) * self.lora_scale
return module_out + adaptor_out
return module_out
class Int8WeightEmbedding(torch.nn.Embedding):
"""An embedding layer to load int8 weights.
Args:
num_embeddings: Number of embeddings.
embedding_dim: Embedding dimension.
padding_idx: Padding index.
"""
def __init__(
self,
num_embeddings: int,
embedding_dim: int,
padding_idx: int,
device=None,
) -> None:
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized embedding weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
class Int8WeightLinear(torch.nn.Linear):
"""A linear layer to load int8 weights.
Args:
in_features: Number of input features.
out_features: Number of output features.
bias: Whether to use bias.
"""
def __init__(
self, in_features: int, out_features: int, bias: bool = True, device=None
) -> None:
super().__init__(in_features, out_features, bias, device=device)
self._register_load_state_dict_pre_hook(self.load_hook)
def load_hook(
self,
state_dict: Dict[str, Any],
prefix: str,
local_metadata: Dict[str, Any],
strict: bool,
missing_keys: List[str],
unexpected_keys: List[str],
error_msgs: List[str],
) -> None:
"""A hook to load the quantized linear weight and scales from the state dict."""
weights = state_dict.pop(prefix + "weight")
scales = state_dict.pop(prefix + "scales")
state_dict[prefix + "weight"] = weights * scales
def _prepare_model_int4_weight_int8_dynamic_activation(
model: torch.nn.Module,
group_size: int,
lora_rank: Optional[int],
lora_scale: Optional[float],
):
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
Note that the weights of embedding and output layers are quantized to int8.
"""
device = None
for module_name, module in model.named_children():
if module_name == "output":
quantized_module = Int8WeightLinear(
in_features=module.in_features,
out_features=module.out_features,
bias=module.bias,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif module_name == "tok_embeddings":
quantized_module = Int8WeightEmbedding(
num_embeddings=module.num_embeddings,
embedding_dim=module.embedding_dim,
padding_idx=module.padding_idx,
device=device,
)
del module
setattr(model, module_name, quantized_module)
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
quantized_module = Int8DynActInt4WeightLinearLoRA(
in_features=module.in_features,
out_features=module.out_features,
bias=False,
group_size=group_size,
lora_rank=lora_rank,
lora_scale=lora_scale,
device=device,
)
del module
setattr(model, module_name, quantized_module)
else:
_prepare_model_int4_weight_int8_dynamic_activation(
module, group_size, lora_rank, lora_scale
)
return model
def convert_to_int4_quantized_model(
model: Transformer,
model_args: ModelArgs,
config: MetaReferenceQuantizedInferenceConfig,
) -> Transformer:
"""Convert the model to int4 quantized model."""
quant_config = config.quantization
if not isinstance(quant_config, Int4QuantizationConfig):
raise ValueError("Only int4 quantization is supported")
if quant_config.type != QuantizationType.int4.value:
raise ValueError("Only int4 quantization is supported")
if quant_config.scheme != "int4_weight_int8_dynamic_activation":
raise NotImplementedError(
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
)
if model_args.quantization_args is None:
raise ValueError("'quantization_args' cannot be None. Please specify it.")
group_size = model_args.quantization_args.group_size
if group_size is None:
raise ValueError(
"'group_size' cannot be None in 'quantization_args'. Please specify it."
)
if model_args.lora_args is None:
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
lora_rank = None
lora_scale = None
else:
lora_rank = model_args.lora_args.rank
lora_scale = model_args.lora_args.scale
_prepare_model_int4_weight_int8_dynamic_activation(
model, group_size, lora_rank, lora_scale
)
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
return model.to(device)

View file

@ -47,7 +47,9 @@ class FaissIndex(EmbeddingIndex):
self.index.add(np.array(embeddings).astype(np.float32))
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
async def query(
self, embedding: NDArray, k: int, score_threshold: float
) -> QueryDocumentsResponse:
distances, indices = self.index.search(
embedding.reshape(1, -1).astype(np.float32), k
)

View file

@ -0,0 +1,21 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Dict
from llama_stack.distribution.datatypes import Api, ProviderSpec
from .config import MetaReferenceScoringConfig
async def get_provider_impl(
config: MetaReferenceScoringConfig,
deps: Dict[Api, ProviderSpec],
):
from .scoring import MetaReferenceScoringImpl
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
await impl.initialize()
return impl

View file

@ -0,0 +1,9 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.scoring import * # noqa: F401, F403
class MetaReferenceScoringConfig(BaseModel): ...

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,37 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
class BaseScorer(ABC):
"""
Base interface class for all meta-reference scorers.
Each scorer needs to implement the following methods:
- score_row(self, row)
- aggregate(self, scorer_results)
"""
scoring_function_def: ScoringFunctionDef
def __init__(self, *args, **kwargs) -> None:
super().__init__(*args, **kwargs)
def __str__(self) -> str:
return self.__class__.__name__
@abstractmethod
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
raise NotImplementedError()
@abstractmethod
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
raise NotImplementedError()
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]:
return [self.score_row(input_row) for input_row in input_rows]

View file

@ -0,0 +1,49 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import (
BaseScorer,
)
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
from llama_stack.apis.scoring import * # noqa: F401, F403
from llama_stack.apis.common.type_system import * # noqa: F403
class EqualityScorer(BaseScorer):
"""
A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
"""
scoring_function_def = ScoringFunctionDef(
identifier="equality",
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
parameters=[],
return_type=NumberType(),
)
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
assert "expected_answer" in input_row, "Expected answer not found in input row."
assert (
"generated_answer" in input_row
), "Generated answer not found in input row."
expected_answer = input_row["expected_answer"]
generated_answer = input_row["generated_answer"]
score = 1.0 if expected_answer == generated_answer else 0.0
return {
"score": score,
}
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
assert len(scoring_results) > 0, "Empty scoring results provided."
num_correct = sum(result["score"] for result in scoring_results)
avg_score = num_correct / len(scoring_results)
return {
"accuracy": avg_score,
"num_correct": num_correct,
"num_total": len(scoring_results),
}

View file

@ -0,0 +1,109 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.apis.datasets import * # noqa: F403
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
EqualityScorer,
)
from .config import MetaReferenceScoringConfig
SUPPORTED_SCORERS = [
EqualityScorer,
]
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
def __init__(
self,
config: MetaReferenceScoringConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def list_scoring_functions(self) -> List[ScoringFunctionDef]:
return [x.scoring_function_def for x in SUPPORTED_SCORERS]
async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None:
raise NotImplementedError(
"Dynamically registering scoring functions is not supported"
)
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
raise ValueError(
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
)
for required_column in ["generated_answer", "expected_answer", "input_query"]:
if required_column not in dataset_def.dataset_schema:
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column."
)
if dataset_def.dataset_schema[required_column].type != "string":
raise ValueError(
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
)
async def score_batch(
self,
dataset_id: str,
scoring_functions: List[str],
save_results_dataset: bool = False,
) -> ScoreBatchResponse:
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
all_rows = await self.datasetio_api.get_rows_paginated(
dataset_id=dataset_id,
rows_in_page=-1,
)
res = await self.score(
input_rows=all_rows.rows, scoring_functions=scoring_functions
)
if save_results_dataset:
# TODO: persist and register dataset on to server for reading
# self.datasets_api.register_dataset()
raise NotImplementedError("Save results dataset not implemented yet")
return ScoreBatchResponse(
results=res.results,
)
async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
) -> ScoreResponse:
res = {}
for scoring_fn_id in scoring_functions:
if scoring_fn_id not in SCORER_REGISTRY:
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
scorer = SCORER_REGISTRY[scoring_fn_id]()
score_results = scorer.score(input_rows)
agg_results = scorer.aggregate(score_results)
res[scoring_fn_id] = ScoringResult(
score_rows=score_results,
aggregated_results=agg_results,
)
return ScoreResponse(
results=res,
)

View file

@ -15,13 +15,24 @@ class VLLMConfig(BaseModel):
"""Configuration for the vLLM inference provider."""
model: str = Field(
default="Llama3.1-8B-Instruct",
default="Llama3.2-3B-Instruct",
description="Model descriptor from `llama model list`",
)
tensor_parallel_size: int = Field(
default=1,
description="Number of tensor parallel replicas (number of GPUs to use).",
)
max_tokens: int = Field(
default=4096,
description="Maximum number of tokens to generate.",
)
enforce_eager: bool = Field(
default=False,
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
)
gpu_memory_utilization: float = Field(
default=0.3,
)
@field_validator("model")
@classmethod

View file

@ -7,11 +7,12 @@
import logging
import os
import uuid
from typing import Any, AsyncGenerator
from typing import AsyncGenerator, Optional
from llama_models.llama3.api.chat_format import ChatFormat
from llama_models.llama3.api.datatypes import * # noqa: F403
from llama_models.llama3.api.tokenizer import Tokenizer
from llama_models.sku_list import resolve_model
from vllm.engine.arg_utils import AsyncEngineArgs
from vllm.engine.async_llm_engine import AsyncLLMEngine
@ -19,7 +20,7 @@ from vllm.sampling_params import SamplingParams as VLLMSamplingParams
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
@ -40,74 +41,15 @@ def _random_uuid() -> str:
return str(uuid.uuid4().hex)
def _vllm_sampling_params(sampling_params: Any) -> VLLMSamplingParams:
"""Convert sampling params to vLLM sampling params."""
if sampling_params is None:
return VLLMSamplingParams()
# TODO convert what I saw in my first test ... but surely there's more to do here
kwargs = {
"temperature": sampling_params.temperature,
}
if sampling_params.top_k >= 1:
kwargs["top_k"] = sampling_params.top_k
if sampling_params.top_p:
kwargs["top_p"] = sampling_params.top_p
if sampling_params.max_tokens >= 1:
kwargs["max_tokens"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
return VLLMSamplingParams(**kwargs)
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
"""Inference implementation for vLLM."""
HF_MODEL_MAPPINGS = {
# TODO: seems like we should be able to build this table dynamically ...
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
}
def __init__(self, config: VLLMConfig):
Inference.__init__(self)
ModelRegistryHelper.__init__(
self,
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
)
self.config = config
self.engine = None
tokenizer = Tokenizer.get_instance()
self.formatter = ChatFormat(tokenizer)
self.formatter = ChatFormat(Tokenizer.get_instance())
async def initialize(self):
"""Initialize the vLLM inference adapter."""
log.info("Initializing vLLM inference adapter")
# Disable usage stats reporting. This would be a surprising thing for most
@ -116,15 +58,22 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
if "VLLM_NO_USAGE_STATS" not in os.environ:
os.environ["VLLM_NO_USAGE_STATS"] = "1"
hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model)
model = resolve_model(self.config.model)
if model is None:
raise ValueError(f"Unknown model {self.config.model}")
if model.huggingface_repo is None:
raise ValueError(f"Model {self.config.model} needs a huggingface repo")
# TODO -- there are a ton of options supported here ...
engine_args = AsyncEngineArgs()
engine_args.model = hf_model
# We will need a new config item for this in the future if model support is more broad
# than it is today (llama only)
engine_args.tokenizer = hf_model
engine_args.tensor_parallel_size = self.config.tensor_parallel_size
engine_args = AsyncEngineArgs(
model=model.huggingface_repo,
tokenizer=model.huggingface_repo,
tensor_parallel_size=self.config.tensor_parallel_size,
enforce_eager=self.config.enforce_eager,
gpu_memory_utilization=self.config.gpu_memory_utilization,
guided_decoding_backend="lm-format-enforcer",
)
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
@ -134,13 +83,47 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
if self.engine:
self.engine.shutdown_background_loop()
async def register_model(self, model: ModelDef) -> None:
raise ValueError(
"You cannot dynamically add a model to a running vllm instance"
)
async def list_models(self) -> List[ModelDef]:
return [
ModelDef(
identifier=self.config.model,
llama_model=self.config.model,
)
]
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
if sampling_params is None:
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
# TODO convert what I saw in my first test ... but surely there's more to do here
kwargs = {
"temperature": sampling_params.temperature,
"max_tokens": self.config.max_tokens,
}
if sampling_params.top_k:
kwargs["top_k"] = sampling_params.top_k
if sampling_params.top_p:
kwargs["top_p"] = sampling_params.top_p
if sampling_params.max_tokens:
kwargs["max_tokens"] = sampling_params.max_tokens
if sampling_params.repetition_penalty > 0:
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
return VLLMSamplingParams(**kwargs)
async def completion(
self,
model: str,
content: InterleavedTextMedia,
sampling_params: Any | None = ...,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | CompletionResponseStreamChunk:
log.info("vLLM completion")
messages = [UserMessage(content=content)]
@ -155,13 +138,14 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
async def chat_completion(
self,
model: str,
messages: list[Message],
sampling_params: Any | None = ...,
tools: list[ToolDefinition] | None = ...,
tool_choice: ToolChoice | None = ...,
tool_prompt_format: ToolPromptFormat | None = ...,
stream: bool | None = False,
logprobs: LogProbConfig | None = None,
messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(),
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
log.info("vLLM chat completion")
@ -182,7 +166,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
request_id = _random_uuid()
prompt = chat_completion_request_to_prompt(request, self.formatter)
vllm_sampling_params = _vllm_sampling_params(request.sampling_params)
vllm_sampling_params = self._sampling_params(request.sampling_params)
results_generator = self.engine.generate(
prompt, vllm_sampling_params, request_id
)
@ -213,14 +197,19 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
) -> AsyncGenerator:
async def _generate_and_convert_to_openai_compat():
cur = []
async for chunk in results_generator:
if not chunk.outputs:
log.warning("Empty chunk received")
continue
text = "".join([output.text for output in chunk.outputs])
output = chunk.outputs[-1]
new_tokens = output.token_ids[len(cur) :]
text = self.formatter.tokenizer.decode(new_tokens)
cur.extend(new_tokens)
choice = OpenAICompatCompletionChoice(
finish_reason=chunk.outputs[-1].stop_reason,
finish_reason=output.finish_reason,
text=text,
)
yield OpenAICompatCompletionResponse(

View file

@ -36,7 +36,8 @@ def available_providers() -> List[ProviderSpec]:
pip_packages=(
META_REFERENCE_DEPS
+ [
"fbgemm-gpu==0.8.0",
"fbgemm-gpu",
"torchao==0.5.0",
]
),
module="llama_stack.providers.impls.meta_reference.inference",

View file

@ -0,0 +1,25 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import List
from llama_stack.distribution.datatypes import * # noqa: F403
def available_providers() -> List[ProviderSpec]:
return [
InlineProviderSpec(
api=Api.scoring,
provider_type="meta-reference",
pip_packages=[],
module="llama_stack.providers.impls.meta_reference.scoring",
config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig",
api_dependencies=[
Api.datasetio,
Api.datasets,
],
),
]

View file

@ -0,0 +1,6 @@
input_query,generated_answer,expected_answer
What is the capital of France?,London,Paris
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
What is the largest planet in our solar system?,Jupiter,Jupiter
What is the smallest country in the world?,China,Vatican City
What is the currency of Japan?,Yen,Yen
1 input_query generated_answer expected_answer
2 What is the capital of France? London Paris
3 Who is the CEO of Meta? Mark Zuckerberg Mark Zuckerberg
4 What is the largest planet in our solar system? Jupiter Jupiter
5 What is the smallest country in the world? China Vatican City
6 What is the currency of Japan? Yen Yen

View file

@ -8,8 +8,13 @@ import os
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
import base64
import mimetypes
from pathlib import Path
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
@ -41,14 +46,35 @@ async def datasetio_settings():
}
def data_url_from_file(file_path: str) -> str:
if not os.path.exists(file_path):
raise FileNotFoundError(f"File not found: {file_path}")
with open(file_path, "rb") as file:
file_content = file.read()
base64_content = base64.b64encode(file_content).decode("utf-8")
mime_type, _ = mimetypes.guess_type(file_path)
data_url = f"data:{mime_type};base64,{base64_content}"
return data_url
async def register_dataset(datasets_impl: Datasets):
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
test_url = data_url_from_file(str(test_file))
dataset = DatasetDefWithProvider(
identifier="test_dataset",
provider_id=os.environ["PROVIDER_ID"],
url=URL(
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
uri=test_url,
),
columns_schema={},
dataset_schema={
"generated_answer": StringType(),
"expected_answer": StringType(),
"input_query": StringType(),
},
)
await datasets_impl.register_dataset(dataset)
@ -100,10 +126,10 @@ async def test_get_rows_paginated(datasetio_settings):
# iterate over all rows
response = await datasetio_impl.get_rows_paginated(
dataset_id="test_dataset",
rows_in_page=10,
rows_in_page=2,
page_token=response.next_page_token,
)
assert isinstance(response.rows, list)
assert len(response.rows) == 10
assert response.next_page_token == "13"
assert len(response.rows) == 2
assert response.next_page_token == "5"

View file

@ -137,6 +137,7 @@ async def test_completion(inference_settings):
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::ollama",
"remote::tgi",
):
pytest.skip("Other inference providers don't support completion() yet")
@ -171,25 +172,43 @@ async def test_completion(inference_settings):
@pytest.mark.asyncio
async def test_embed(inference_settings):
async def test_completions_structured_output(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in ("remote::ollama",):
pytest.skip("Other inference providers don't support completion() yet")
if provider.__provider_spec__.provider_type not in (
"meta-reference",
"remote::tgi",
):
pytest.skip(
"Other inference providers don't support structured output in completions yet"
)
response = await inference_impl.embeddings(
contents=["Roses are red"],
class Output(BaseModel):
name: str
year_born: str
year_retired: str
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
response = await inference_impl.completion(
content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ",
stream=False,
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
response_format=JsonResponseFormat(
schema=Output.model_json_schema(),
),
)
assert isinstance(response, CompletionResponse)
assert isinstance(response.content, str)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) > 0
answer = Output.parse_raw(response.content)
assert answer.name == "Michael Jordan"
assert answer.year_born == "1963"
assert answer.year_retired == "2003"
@pytest.mark.asyncio
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
@ -382,3 +401,23 @@ async def test_chat_completion_with_tool_calling_streaming(
assert call.tool_name == "get_weather"
assert "location" in call.arguments
assert "San Francisco" in call.arguments["location"]
@pytest.mark.asyncio
async def test_embed(inference_settings):
inference_impl = inference_settings["impl"]
params = inference_settings["common_params"]
provider = inference_impl.routing_table.get_provider_impl(params["model"])
if provider.__provider_spec__.provider_type not in ("remote::ollama",):
pytest.skip("Other inference providers don't support completion() yet")
response = await inference_impl.embeddings(
contents=["Roses are red"],
model=params["model"],
sampling_params=SamplingParams(
max_tokens=50,
),
)
assert isinstance(response, EmbeddingsResponse)
assert len(response.embeddings) > 0

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,9 @@
providers:
datasetio:
- provider_id: test-meta
provider_type: meta-reference
config: {}
scoring:
- provider_id: test-meta
provider_type: meta-reference
config: {}

View file

@ -0,0 +1,69 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
import pytest_asyncio
from llama_stack.apis.common.type_system import * # noqa: F403
from llama_stack.apis.datasetio import * # noqa: F403
from llama_stack.distribution.datatypes import * # noqa: F403
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
from llama_stack.providers.tests.resolver import resolve_impls_for_test
# How to run this test:
#
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
# since it depends on the provider you are testing. On top of that you need
# `pytest` and `pytest-asyncio` installed.
#
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
#
# 3. Run:
#
# ```bash
# PROVIDER_ID=<your_provider> \
# PROVIDER_CONFIG=provider_config.yaml \
# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \
# --tb=short --disable-warnings
# ```
@pytest_asyncio.fixture(scope="session")
async def scoring_settings():
impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio])
return {
"scoring_impl": impls[Api.scoring],
"scoring_functions_impl": impls[Api.scoring_functions],
"datasets_impl": impls[Api.datasets],
}
@pytest.mark.asyncio
async def test_scoring_functions_list(scoring_settings):
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
scoring_functions = await scoring_functions_impl.list_scoring_functions()
assert isinstance(scoring_functions, list)
assert len(scoring_functions) > 0
function_ids = [f.identifier for f in scoring_functions]
assert "equality" in function_ids
@pytest.mark.asyncio
async def test_scoring_score(scoring_settings):
scoring_impl = scoring_settings["scoring_impl"]
datasets_impl = scoring_settings["datasets_impl"]
await register_dataset(datasets_impl)
response = await datasets_impl.list_datasets()
assert len(response) == 1
response = await scoring_impl.score_batch(
dataset_id=response[0].identifier,
scoring_functions=["equality"],
)
assert len(response.results) == 1
assert "equality" in response.results

View file

@ -29,9 +29,9 @@ class OpenAICompatCompletionResponse(BaseModel):
choices: List[OpenAICompatCompletionChoice]
def get_sampling_options(request: ChatCompletionRequest) -> dict:
def get_sampling_options(params: SamplingParams) -> dict:
options = {}
if params := request.sampling_params:
if params:
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
if getattr(params, attr):
options[attr] = getattr(params, attr)
@ -64,7 +64,18 @@ def process_completion_response(
response: OpenAICompatCompletionResponse, formatter: ChatFormat
) -> CompletionResponse:
choice = response.choices[0]
# drop suffix <eot_id> if present and return stop reason as end of turn
if choice.text.endswith("<|eot_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_turn,
content=choice.text[: -len("<|eot_id|>")],
)
# drop suffix <eom_id> if present and return stop reason as end of message
if choice.text.endswith("<|eom_id|>"):
return CompletionResponse(
stop_reason=StopReason.end_of_message,
content=choice.text[: -len("<|eom_id|>")],
)
return CompletionResponse(
stop_reason=get_stop_reason(choice.finish_reason),
content=choice.text,
@ -95,13 +106,6 @@ async def process_completion_stream_response(
choice = chunk.choices[0]
finish_reason = choice.finish_reason
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
text = text_from_choice(choice)
if text == "<|eot_id|>":
stop_reason = StopReason.end_of_turn
@ -115,6 +119,12 @@ async def process_completion_stream_response(
delta=text,
stop_reason=stop_reason,
)
if finish_reason:
if finish_reason in ["stop", "eos", "eos_token"]:
stop_reason = StopReason.end_of_turn
elif finish_reason == "length":
stop_reason = StopReason.out_of_tokens
break
yield CompletionResponseStreamChunk(
delta="",

View file

@ -31,6 +31,13 @@ def completion_request_to_prompt(
return formatter.tokenizer.decode(model_input.tokens)
def completion_request_to_prompt_model_input_info(
request: CompletionRequest, formatter: ChatFormat
) -> Tuple[str, int]:
model_input = formatter.encode_content(request.content)
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
def chat_completion_request_to_prompt(
request: ChatCompletionRequest, formatter: ChatFormat
) -> str:

View file

@ -2,7 +2,7 @@ blobfile
fire
httpx
huggingface-hub
llama-models>=0.0.43
llama-models>=0.0.45
prompt-toolkit
python-dotenv
pydantic>=2

View file

@ -16,7 +16,7 @@ def read_requirements():
setup(
name="llama_stack",
version="0.0.43",
version="0.0.45",
author="Meta Llama",
author_email="llama-oss@meta.com",
description="Llama Stack",

View file

@ -13,7 +13,12 @@ apis:
- inference
- datasets
- datasetio
- scoring
providers:
scoring:
- provider_id: meta0
provider_type: meta-reference
config: {}
datasetio:
- provider_id: meta0
provider_type: meta-reference