mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-29 07:14:20 +00:00
Merge branch 'main' into merge_conflict
This commit is contained in:
commit
14cd065b6c
64 changed files with 1881 additions and 345 deletions
77
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
77
.github/ISSUE_TEMPLATE/bug.yml
vendored
Normal file
|
@ -0,0 +1,77 @@
|
|||
name: 🐛 Bug Report
|
||||
description: Create a report to help us reproduce and fix the bug
|
||||
|
||||
body:
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
#### Before submitting a bug, please make sure the issue hasn't been already addressed by searching through [the
|
||||
existing and past issues](https://github.com/meta-llama/llama-stack/issues).
|
||||
|
||||
- type: textarea
|
||||
id: system-info
|
||||
attributes:
|
||||
label: System Info
|
||||
description: |
|
||||
Please share your system info with us. You can use the following command to capture your environment information
|
||||
python -m "torch.utils.collect_env"
|
||||
|
||||
placeholder: |
|
||||
PyTorch version, CUDA version, GPU type, #num of GPUs...
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: checkboxes
|
||||
id: information-scripts-examples
|
||||
attributes:
|
||||
label: Information
|
||||
description: 'The problem arises when using:'
|
||||
options:
|
||||
- label: "The official example scripts"
|
||||
- label: "My own modified scripts"
|
||||
|
||||
- type: textarea
|
||||
id: bug-description
|
||||
attributes:
|
||||
label: 🐛 Describe the bug
|
||||
description: |
|
||||
Please provide a clear and concise description of what the bug is.
|
||||
|
||||
Please also paste or describe the results you observe instead of the expected results.
|
||||
placeholder: |
|
||||
A clear and concise description of what the bug is.
|
||||
|
||||
```llama stack
|
||||
# Command that you used for running the examples
|
||||
```
|
||||
Description of the results
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
attributes:
|
||||
label: Error logs
|
||||
description: |
|
||||
If you observe an error, please paste the error message including the **full** traceback of the exception. It may be relevant to wrap error messages in ```` ```triple quotes blocks``` ````.
|
||||
|
||||
placeholder: |
|
||||
```
|
||||
The error message you got, with the full traceback.
|
||||
```
|
||||
|
||||
validations:
|
||||
required: true
|
||||
|
||||
|
||||
- type: textarea
|
||||
id: expected-behavior
|
||||
validations:
|
||||
required: true
|
||||
attributes:
|
||||
label: Expected behavior
|
||||
description: "A clear and concise description of what you would expect to happen."
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
31
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
Normal file
31
.github/ISSUE_TEMPLATE/feature-request.yml
vendored
Normal file
|
@ -0,0 +1,31 @@
|
|||
name: 🚀 Feature request
|
||||
description: Submit a proposal/request for a new llama-stack feature
|
||||
|
||||
body:
|
||||
- type: textarea
|
||||
id: feature-pitch
|
||||
attributes:
|
||||
label: 🚀 The feature, motivation and pitch
|
||||
description: >
|
||||
A clear and concise description of the feature proposal. Please outline the motivation for the proposal. Is your feature request related to a specific problem? e.g., *"I'm working on X and would like Y to be possible"*. If this is related to another GitHub issue, please link here too.
|
||||
validations:
|
||||
required: true
|
||||
|
||||
- type: textarea
|
||||
id: alternatives
|
||||
attributes:
|
||||
label: Alternatives
|
||||
description: >
|
||||
A description of any alternative solutions or features you've considered, if any.
|
||||
|
||||
- type: textarea
|
||||
id: additional-context
|
||||
attributes:
|
||||
label: Additional context
|
||||
description: >
|
||||
Add any other context or screenshots about the feature request.
|
||||
|
||||
- type: markdown
|
||||
attributes:
|
||||
value: >
|
||||
Thanks for contributing 🎉!
|
31
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
31
.github/PULL_REQUEST_TEMPLATE.md
vendored
Normal file
|
@ -0,0 +1,31 @@
|
|||
# What does this PR do?
|
||||
|
||||
Closes # (issue)
|
||||
|
||||
## Feature/Issue validation/testing/test plan
|
||||
|
||||
Please describe the tests that you ran to verify your changes and relevant result summary. Provide instructions so it can be reproduced.
|
||||
Please also list any relevant details for your test configuration or test plan.
|
||||
|
||||
- [ ] Test A
|
||||
Logs for Test A
|
||||
|
||||
- [ ] Test B
|
||||
Logs for Test B
|
||||
|
||||
|
||||
## Sources
|
||||
|
||||
Please link relevant resources if necessary.
|
||||
|
||||
|
||||
## Before submitting
|
||||
- [ ] This PR fixes a typo or improves the docs (you can dismiss the other checks if that's the case).
|
||||
- [ ] Did you read the [contributor guideline](https://github.com/meta-llama/llama-stack/blob/main/CONTRIBUTING.md),
|
||||
Pull Request section?
|
||||
- [ ] Was this discussed/approved via a Github issue? Please add a link
|
||||
to it if that's the case.
|
||||
- [ ] Did you make sure to update the documentation with your changes?
|
||||
- [ ] Did you write any new necessary tests?
|
||||
|
||||
Thanks for contributing 🎉!
|
1
.gitignore
vendored
1
.gitignore
vendored
|
@ -13,6 +13,7 @@ xcuserdata/
|
|||
Package.resolved
|
||||
*.pte
|
||||
*.ipynb_checkpoints*
|
||||
.idea
|
||||
.venv/
|
||||
.idea
|
||||
_build
|
||||
|
|
|
@ -1,4 +1,4 @@
|
|||
include requirements.txt
|
||||
include llama_stack/distribution/*.sh
|
||||
include llama_stack/cli/scripts/*.sh
|
||||
include llama_stack/distribution/templates/*.yaml
|
||||
include distributions/*/build.yaml
|
||||
|
|
29
README.md
29
README.md
|
@ -65,23 +65,30 @@ A Distribution is where APIs and Providers are assembled together to provide a c
|
|||
| Dell-TGI | [Local TGI + Chroma](https://hub.docker.com/repository/docker/llamastack/llamastack-local-tgi-chroma/general) | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: | :heavy_check_mark: |
|
||||
|
||||
|
||||
|
||||
## Installation
|
||||
|
||||
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack`
|
||||
You have two ways to install this repository:
|
||||
|
||||
If you want to install from source:
|
||||
1. **Install as a package**:
|
||||
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
|
||||
```bash
|
||||
pip install llama-stack
|
||||
```
|
||||
|
||||
```bash
|
||||
mkdir -p ~/local
|
||||
cd ~/local
|
||||
git clone git@github.com:meta-llama/llama-stack.git
|
||||
2. **Install from source**:
|
||||
If you prefer to install from the source code, follow these steps:
|
||||
```bash
|
||||
mkdir -p ~/local
|
||||
cd ~/local
|
||||
git clone git@github.com:meta-llama/llama-stack.git
|
||||
|
||||
conda create -n stack python=3.10
|
||||
conda activate stack
|
||||
conda create -n stack python=3.10
|
||||
conda activate stack
|
||||
|
||||
cd llama-stack
|
||||
$CONDA_PREFIX/bin/pip install -e .
|
||||
```
|
||||
cd llama-stack
|
||||
$CONDA_PREFIX/bin/pip install -e .
|
||||
```
|
||||
|
||||
## Documentations
|
||||
|
||||
|
|
|
@ -7,6 +7,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
|
|||
| **Distribution** | **Llama Stack Docker** | Start This Distribution | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|
||||
|:----------------: |:------------------------------------------: |:-----------------------: |:------------------: |:------------------: |:------------------: |:------------------: |:------------------: |
|
||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](./meta-reference-gpu/) | meta-reference | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
|
||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](./meta-reference-quantized-gpu/) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
|
||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](./ollama/) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | remote::ollama | meta-reference |
|
||||
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](./tgi/) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
|
||||
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](./together/) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
name: meta-reference-gpu
|
||||
distribution_spec:
|
||||
docker_image: pytorch/pytorch
|
||||
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||
providers:
|
||||
inference: meta-reference
|
||||
|
|
34
distributions/meta-reference-quantized-gpu/README.md
Normal file
34
distributions/meta-reference-quantized-gpu/README.md
Normal file
|
@ -0,0 +1,34 @@
|
|||
# Meta Reference Quantized Distribution
|
||||
|
||||
The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations.
|
||||
|
||||
|
||||
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|
||||
|----------------- |------------------------ |---------------- |-------------------------------------------------- |---------------- |---------------- |
|
||||
| **Provider(s)** | meta-reference-quantized | meta-reference | meta-reference, remote::pgvector, remote::chroma | meta-reference | meta-reference |
|
||||
|
||||
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
||||
|
||||
### Start the Distribution (Single Node GPU)
|
||||
|
||||
> [!NOTE]
|
||||
> This assumes you have access to GPU to start a local server with access to your GPU.
|
||||
|
||||
|
||||
> [!NOTE]
|
||||
> `~/.llama` should be the path containing downloaded weights of Llama models.
|
||||
|
||||
|
||||
To download and start running a pre-built docker container, you may use the following commands:
|
||||
|
||||
```
|
||||
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama \
|
||||
-v ./run.yaml:/root/my-run.yaml \
|
||||
--gpus=all \
|
||||
distribution-meta-reference-quantized-gpu \
|
||||
--yaml_config /root/my-run.yaml
|
||||
```
|
||||
|
||||
### Alternative (Build and start distribution locally via conda)
|
||||
|
||||
- You may checkout the [Getting Started](../../docs/getting_started.md) for more details on building locally via conda and starting up the distribution.
|
14
distributions/meta-reference-quantized-gpu/build.yaml
Normal file
14
distributions/meta-reference-quantized-gpu/build.yaml
Normal file
|
@ -0,0 +1,14 @@
|
|||
name: meta-reference-quantized-gpu
|
||||
distribution_spec:
|
||||
docker_image: pytorch/pytorch:2.5.0-cuda12.4-cudnn9-runtime
|
||||
description: Use code from `llama_stack` itself to serve all llama stack APIs
|
||||
providers:
|
||||
inference: meta-reference-quantized
|
||||
memory:
|
||||
- meta-reference
|
||||
- remote::chromadb
|
||||
- remote::pgvector
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
||||
image_type: docker
|
51
distributions/meta-reference-quantized-gpu/run.yaml
Normal file
51
distributions/meta-reference-quantized-gpu/run.yaml
Normal file
|
@ -0,0 +1,51 @@
|
|||
version: '2'
|
||||
built_at: '2024-10-08T17:40:45.325529'
|
||||
image_name: local
|
||||
docker_image: null
|
||||
conda_env: local
|
||||
apis:
|
||||
- shields
|
||||
- agents
|
||||
- models
|
||||
- memory
|
||||
- memory_banks
|
||||
- inference
|
||||
- safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference-quantized
|
||||
config:
|
||||
model: Llama3.2-3B-Instruct:int4-qlora-eo8
|
||||
quantization:
|
||||
type: int4
|
||||
torch_seed: null
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 1
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
llama_guard_shield:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
disable_input_check: false
|
||||
disable_output_check: false
|
||||
prompt_guard_shield:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
telemetry:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config: {}
|
|
@ -5,159 +5,174 @@ This guide will walk you though the steps to get started on end-to-end flow for
|
|||
## Installation
|
||||
The `llama` CLI tool helps you setup and use the Llama toolchain & agentic systems. It should be available on your path after installing the `llama-stack` package.
|
||||
|
||||
You can install this repository as a [package](https://pypi.org/project/llama-stack/) with `pip install llama-stack`
|
||||
You have two ways to install this repository:
|
||||
|
||||
If you want to install from source:
|
||||
1. **Install as a package**:
|
||||
You can install the repository directly from [PyPI](https://pypi.org/project/llama-stack/) by running the following command:
|
||||
```bash
|
||||
pip install llama-stack
|
||||
```
|
||||
|
||||
```bash
|
||||
mkdir -p ~/local
|
||||
cd ~/local
|
||||
git clone git@github.com:meta-llama/llama-stack.git
|
||||
2. **Install from source**:
|
||||
If you prefer to install from the source code, follow these steps:
|
||||
```bash
|
||||
mkdir -p ~/local
|
||||
cd ~/local
|
||||
git clone git@github.com:meta-llama/llama-stack.git
|
||||
|
||||
conda create -n stack python=3.10
|
||||
conda activate stack
|
||||
conda create -n stack python=3.10
|
||||
conda activate stack
|
||||
|
||||
cd llama-stack
|
||||
$CONDA_PREFIX/bin/pip install -e .
|
||||
```
|
||||
cd llama-stack
|
||||
$CONDA_PREFIX/bin/pip install -e .
|
||||
```
|
||||
|
||||
For what you can do with the Llama CLI, please refer to [CLI Reference](./cli_reference.md).
|
||||
|
||||
## Starting Up Llama Stack Server
|
||||
#### Starting up server via docker
|
||||
|
||||
We provide 2 pre-built Docker image of Llama Stack distribution, which can be found in the following links.
|
||||
- [llamastack-local-gpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general)
|
||||
- This is a packaged version with our local meta-reference implementations, where you will be running inference locally with downloaded Llama model checkpoints.
|
||||
- [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general)
|
||||
- This is a lite version with remote inference where you can hook up to your favourite remote inference framework (e.g. ollama, fireworks, together, tgi) for running inference without GPU.
|
||||
You have two ways to start up Llama stack server:
|
||||
|
||||
> [!NOTE]
|
||||
> For GPU inference, you need to set these environment variables for specifying local directory containing your model checkpoints, and enable GPU inference to start running docker container.
|
||||
```
|
||||
export LLAMA_CHECKPOINT_DIR=~/.llama
|
||||
```
|
||||
1. **Starting up server via docker**:
|
||||
|
||||
> [!NOTE]
|
||||
> `~/.llama` should be the path containing downloaded weights of Llama models.
|
||||
We provide 2 pre-built Docker image of Llama Stack distribution, which can be found in the following links.
|
||||
- [llamastack-local-gpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-gpu/general)
|
||||
- This is a packaged version with our local meta-reference implementations, where you will be running inference locally with downloaded Llama model checkpoints.
|
||||
- [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general)
|
||||
- This is a lite version with remote inference where you can hook up to your favourite remote inference framework (e.g. ollama, fireworks, together, tgi) for running inference without GPU.
|
||||
|
||||
> [!NOTE]
|
||||
> For GPU inference, you need to set these environment variables for specifying local directory containing your model checkpoints, and enable GPU inference to start running docker container.
|
||||
```
|
||||
export LLAMA_CHECKPOINT_DIR=~/.llama
|
||||
```
|
||||
|
||||
> [!NOTE]
|
||||
> `~/.llama` should be the path containing downloaded weights of Llama models.
|
||||
|
||||
To download llama models, use
|
||||
```
|
||||
llama download --model-id Llama3.1-8B-Instruct
|
||||
```
|
||||
|
||||
To download and start running a pre-built docker container, you may use the following commands:
|
||||
|
||||
```
|
||||
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack/llamastack-local-gpu
|
||||
```
|
||||
|
||||
> [!TIP]
|
||||
> Pro Tip: We may use `docker compose up` for starting up a distribution with remote providers (e.g. TGI) using [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general). You can checkout [these scripts](../distributions/) to help you get started.
|
||||
|
||||
|
||||
To download and start running a pre-built docker container, you may use the following commands:
|
||||
2. **Build->Configure->Run Llama Stack server via conda**:
|
||||
|
||||
```
|
||||
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama --gpus=all llamastack/llamastack-local-gpu
|
||||
```
|
||||
You may also build a LlamaStack distribution from scratch, configure it, and start running the distribution. This is useful for developing on LlamaStack.
|
||||
|
||||
> [!TIP]
|
||||
> Pro Tip: We may use `docker compose up` for starting up a distribution with remote providers (e.g. TGI) using [llamastack-local-cpu](https://hub.docker.com/repository/docker/llamastack/llamastack-local-cpu/general). You can checkout [these scripts](../distributions/) to help you get started.
|
||||
**`llama stack build`**
|
||||
- You'll be prompted to enter build information interactively.
|
||||
```
|
||||
llama stack build
|
||||
|
||||
#### Build->Configure->Run Llama Stack server via conda
|
||||
You may also build a LlamaStack distribution from scratch, configure it, and start running the distribution. This is useful for developing on LlamaStack.
|
||||
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack
|
||||
> Enter the image type you want your distribution to be built with (docker or conda): conda
|
||||
|
||||
**`llama stack build`**
|
||||
- You'll be prompted to enter build information interactively.
|
||||
```
|
||||
llama stack build
|
||||
Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.
|
||||
> Enter the API provider for the inference API: (default=meta-reference): meta-reference
|
||||
> Enter the API provider for the safety API: (default=meta-reference): meta-reference
|
||||
> Enter the API provider for the agents API: (default=meta-reference): meta-reference
|
||||
> Enter the API provider for the memory API: (default=meta-reference): meta-reference
|
||||
> Enter the API provider for the telemetry API: (default=meta-reference): meta-reference
|
||||
|
||||
> Enter an unique name for identifying your Llama Stack build distribution (e.g. my-local-stack): my-local-stack
|
||||
> Enter the image type you want your distribution to be built with (docker or conda): conda
|
||||
> (Optional) Enter a short description for your Llama Stack distribution:
|
||||
|
||||
Llama Stack is composed of several APIs working together. Let's configure the providers (implementations) you want to use for these APIs.
|
||||
> Enter the API provider for the inference API: (default=meta-reference): meta-reference
|
||||
> Enter the API provider for the safety API: (default=meta-reference): meta-reference
|
||||
> Enter the API provider for the agents API: (default=meta-reference): meta-reference
|
||||
> Enter the API provider for the memory API: (default=meta-reference): meta-reference
|
||||
> Enter the API provider for the telemetry API: (default=meta-reference): meta-reference
|
||||
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml
|
||||
You can now run `llama stack configure my-local-stack`
|
||||
```
|
||||
|
||||
> (Optional) Enter a short description for your Llama Stack distribution:
|
||||
**`llama stack configure`**
|
||||
- Run `llama stack configure <name>` with the name you have previously defined in `build` step.
|
||||
```
|
||||
llama stack configure <name>
|
||||
```
|
||||
- You will be prompted to enter configurations for your Llama Stack
|
||||
|
||||
Build spec configuration saved at ~/.conda/envs/llamastack-my-local-stack/my-local-stack-build.yaml
|
||||
You can now run `llama stack configure my-local-stack`
|
||||
```
|
||||
```
|
||||
$ llama stack configure my-local-stack
|
||||
|
||||
**`llama stack configure`**
|
||||
- Run `llama stack configure <name>` with the name you have previously defined in `build` step.
|
||||
```
|
||||
llama stack configure <name>
|
||||
```
|
||||
- You will be prompted to enter configurations for your Llama Stack
|
||||
Could not find my-local-stack. Trying conda build name instead...
|
||||
Configuring API `inference`...
|
||||
=== Configuring provider `meta-reference` for API inference...
|
||||
Enter value for model (default: Llama3.1-8B-Instruct) (required):
|
||||
Do you want to configure quantization? (y/n): n
|
||||
Enter value for torch_seed (optional):
|
||||
Enter value for max_seq_len (default: 4096) (required):
|
||||
Enter value for max_batch_size (default: 1) (required):
|
||||
|
||||
```
|
||||
$ llama stack configure my-local-stack
|
||||
Configuring API `safety`...
|
||||
=== Configuring provider `meta-reference` for API safety...
|
||||
Do you want to configure llama_guard_shield? (y/n): n
|
||||
Do you want to configure prompt_guard_shield? (y/n): n
|
||||
|
||||
Could not find my-local-stack. Trying conda build name instead...
|
||||
Configuring API `inference`...
|
||||
=== Configuring provider `meta-reference` for API inference...
|
||||
Enter value for model (default: Llama3.1-8B-Instruct) (required):
|
||||
Do you want to configure quantization? (y/n): n
|
||||
Enter value for torch_seed (optional):
|
||||
Enter value for max_seq_len (default: 4096) (required):
|
||||
Enter value for max_batch_size (default: 1) (required):
|
||||
Configuring API `agents`...
|
||||
=== Configuring provider `meta-reference` for API agents...
|
||||
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite):
|
||||
|
||||
Configuring API `safety`...
|
||||
=== Configuring provider `meta-reference` for API safety...
|
||||
Do you want to configure llama_guard_shield? (y/n): n
|
||||
Do you want to configure prompt_guard_shield? (y/n): n
|
||||
Configuring SqliteKVStoreConfig:
|
||||
Enter value for namespace (optional):
|
||||
Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required):
|
||||
|
||||
Configuring API `agents`...
|
||||
=== Configuring provider `meta-reference` for API agents...
|
||||
Enter `type` for persistence_store (options: redis, sqlite, postgres) (default: sqlite):
|
||||
Configuring API `memory`...
|
||||
=== Configuring provider `meta-reference` for API memory...
|
||||
> Please enter the supported memory bank type your provider has for memory: vector
|
||||
|
||||
Configuring SqliteKVStoreConfig:
|
||||
Enter value for namespace (optional):
|
||||
Enter value for db_path (default: /home/xiyan/.llama/runtime/kvstore.db) (required):
|
||||
Configuring API `telemetry`...
|
||||
=== Configuring provider `meta-reference` for API telemetry...
|
||||
|
||||
Configuring API `memory`...
|
||||
=== Configuring provider `meta-reference` for API memory...
|
||||
> Please enter the supported memory bank type your provider has for memory: vector
|
||||
> YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml.
|
||||
You can now run `llama stack run my-local-stack --port PORT`
|
||||
```
|
||||
|
||||
Configuring API `telemetry`...
|
||||
=== Configuring provider `meta-reference` for API telemetry...
|
||||
**`llama stack run`**
|
||||
- Run `llama stack run <name>` with the name you have previously defined.
|
||||
```
|
||||
llama stack run my-local-stack
|
||||
|
||||
> YAML configuration has been written to ~/.llama/builds/conda/my-local-stack-run.yaml.
|
||||
You can now run `llama stack run my-local-stack --port PORT`
|
||||
```
|
||||
|
||||
**`llama stack run`**
|
||||
- Run `llama stack run <name>` with the name you have previously defined.
|
||||
```
|
||||
llama stack run my-local-stack
|
||||
|
||||
...
|
||||
> initializing model parallel with size 1
|
||||
> initializing ddp with size 1
|
||||
> initializing pipeline with size 1
|
||||
...
|
||||
Finished model load YES READY
|
||||
Serving POST /inference/chat_completion
|
||||
Serving POST /inference/completion
|
||||
Serving POST /inference/embeddings
|
||||
Serving POST /memory_banks/create
|
||||
Serving DELETE /memory_bank/documents/delete
|
||||
Serving DELETE /memory_banks/drop
|
||||
Serving GET /memory_bank/documents/get
|
||||
Serving GET /memory_banks/get
|
||||
Serving POST /memory_bank/insert
|
||||
Serving GET /memory_banks/list
|
||||
Serving POST /memory_bank/query
|
||||
Serving POST /memory_bank/update
|
||||
Serving POST /safety/run_shield
|
||||
Serving POST /agentic_system/create
|
||||
Serving POST /agentic_system/session/create
|
||||
Serving POST /agentic_system/turn/create
|
||||
Serving POST /agentic_system/delete
|
||||
Serving POST /agentic_system/session/delete
|
||||
Serving POST /agentic_system/session/get
|
||||
Serving POST /agentic_system/step/get
|
||||
Serving POST /agentic_system/turn/get
|
||||
Serving GET /telemetry/get_trace
|
||||
Serving POST /telemetry/log_event
|
||||
Listening on :::5000
|
||||
INFO: Started server process [587053]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
|
||||
```
|
||||
...
|
||||
> initializing model parallel with size 1
|
||||
> initializing ddp with size 1
|
||||
> initializing pipeline with size 1
|
||||
...
|
||||
Finished model load YES READY
|
||||
Serving POST /inference/chat_completion
|
||||
Serving POST /inference/completion
|
||||
Serving POST /inference/embeddings
|
||||
Serving POST /memory_banks/create
|
||||
Serving DELETE /memory_bank/documents/delete
|
||||
Serving DELETE /memory_banks/drop
|
||||
Serving GET /memory_bank/documents/get
|
||||
Serving GET /memory_banks/get
|
||||
Serving POST /memory_bank/insert
|
||||
Serving GET /memory_banks/list
|
||||
Serving POST /memory_bank/query
|
||||
Serving POST /memory_bank/update
|
||||
Serving POST /safety/run_shield
|
||||
Serving POST /agentic_system/create
|
||||
Serving POST /agentic_system/session/create
|
||||
Serving POST /agentic_system/turn/create
|
||||
Serving POST /agentic_system/delete
|
||||
Serving POST /agentic_system/session/delete
|
||||
Serving POST /agentic_system/session/get
|
||||
Serving POST /agentic_system/step/get
|
||||
Serving POST /agentic_system/turn/get
|
||||
Serving GET /telemetry/get_trace
|
||||
Serving POST /telemetry/log_event
|
||||
Listening on :::5000
|
||||
INFO: Started server process [587053]
|
||||
INFO: Waiting for application startup.
|
||||
INFO: Application startup complete.
|
||||
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
|
||||
```
|
||||
|
||||
|
||||
## Testing with client
|
||||
|
|
103
llama_stack/apis/datasetio/client.py
Normal file
103
llama_stack/apis/datasetio/client.py
Normal file
|
@ -0,0 +1,103 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasets.client import DatasetsClient
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
|
||||
|
||||
|
||||
class DatasetIOClient(DatasetIO):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
rows_in_page: int,
|
||||
page_token: Optional[str] = None,
|
||||
filter_condition: Optional[str] = None,
|
||||
) -> PaginatedRowsResult:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/datasetio/get_rows_paginated",
|
||||
params={
|
||||
"dataset_id": dataset_id,
|
||||
"rows_in_page": rows_in_page,
|
||||
"page_token": page_token,
|
||||
"filter_condition": filter_condition,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return PaginatedRowsResult(**response.json())
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = DatasetsClient(f"http://{host}:{port}")
|
||||
|
||||
# register dataset
|
||||
test_file = (
|
||||
Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
/ "providers/tests/datasetio/test_dataset.csv"
|
||||
)
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
response = await client.register_dataset(
|
||||
DatasetDefWithProvider(
|
||||
identifier="test-dataset",
|
||||
provider_id="meta0",
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# list datasets
|
||||
list_dataset = await client.list_datasets()
|
||||
cprint(list_dataset, "blue")
|
||||
|
||||
# datsetio client to get the rows
|
||||
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
|
||||
response = await datasetio_client.get_rows_paginated(
|
||||
dataset_id="test-dataset",
|
||||
rows_in_page=4,
|
||||
page_token=None,
|
||||
filter_condition=None,
|
||||
)
|
||||
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
|
@ -29,7 +29,7 @@ class DatasetIO(Protocol):
|
|||
# keeping for aligning with inference/safety, but this is not used
|
||||
dataset_store: DatasetStore
|
||||
|
||||
@webmethod(route="/dataio/get_rows_paginated")
|
||||
@webmethod(route="/datasetio/get_rows_paginated", method="GET")
|
||||
async def get_rows_paginated(
|
||||
self,
|
||||
dataset_id: str,
|
||||
|
|
116
llama_stack/apis/datasets/client.py
Normal file
116
llama_stack/apis/datasets/client.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import json
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from .datasets import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
|
||||
|
||||
|
||||
class DatasetsClient(Datasets):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def register_dataset(
|
||||
self,
|
||||
dataset_def: DatasetDefWithProvider,
|
||||
) -> None:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/datasets/register",
|
||||
json={
|
||||
"dataset_def": json.loads(dataset_def.json()),
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
return
|
||||
|
||||
async def get_dataset(
|
||||
self,
|
||||
dataset_identifier: str,
|
||||
) -> Optional[DatasetDefWithProvider]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/datasets/get",
|
||||
params={
|
||||
"dataset_identifier": dataset_identifier,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return DatasetDefWithProvider(**response.json())
|
||||
|
||||
async def list_datasets(self) -> List[DatasetDefWithProvider]:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.get(
|
||||
f"{self.base_url}/datasets/list",
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return [DatasetDefWithProvider(**x) for x in response.json()]
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = DatasetsClient(f"http://{host}:{port}")
|
||||
|
||||
# register dataset
|
||||
test_file = (
|
||||
Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
/ "providers/tests/datasetio/test_dataset.csv"
|
||||
)
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
response = await client.register_dataset(
|
||||
DatasetDefWithProvider(
|
||||
identifier="test-dataset",
|
||||
provider_id="meta0",
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# list datasets
|
||||
list_dataset = await client.list_datasets()
|
||||
cprint(list_dataset, "blue")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
|
@ -20,7 +20,7 @@ class DatasetDef(BaseModel):
|
|||
identifier: str = Field(
|
||||
description="A unique name for the dataset",
|
||||
)
|
||||
columns_schema: Dict[str, ParamType] = Field(
|
||||
dataset_schema: Dict[str, ParamType] = Field(
|
||||
description="The schema definition for this dataset",
|
||||
)
|
||||
url: URL
|
||||
|
|
|
@ -172,7 +172,7 @@ async def run_mm_main(
|
|||
],
|
||||
)
|
||||
cprint(f"User>{message.content}", "green")
|
||||
iterator = client.chat_completion(
|
||||
iterator = await client.chat_completion(
|
||||
model=model,
|
||||
messages=[message],
|
||||
stream=stream,
|
||||
|
|
|
@ -25,6 +25,7 @@ class LogProbConfig(BaseModel):
|
|||
class QuantizationType(Enum):
|
||||
bf16 = "bf16"
|
||||
fp8 = "fp8"
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -37,8 +38,14 @@ class Bf16QuantizationConfig(BaseModel):
|
|||
type: Literal[QuantizationType.bf16.value] = QuantizationType.bf16.value
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Int4QuantizationConfig(BaseModel):
|
||||
type: Literal[QuantizationType.int4.value] = QuantizationType.int4.value
|
||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
QuantizationConfig = Annotated[
|
||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig],
|
||||
Union[Bf16QuantizationConfig, Fp8QuantizationConfig, Int4QuantizationConfig],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
|
||||
|
@ -228,8 +235,6 @@ class Inference(Protocol):
|
|||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> Union[CompletionResponse, CompletionResponseStreamChunk]: ...
|
||||
|
||||
# This method is not `async def` because it can result in either an
|
||||
# `AsyncGenerator` or a `ChatCompletionResponse` depending on the value of `stream`.
|
||||
@webmethod(route="/inference/chat_completion")
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
|
132
llama_stack/apis/scoring/client.py
Normal file
132
llama_stack/apis/scoring/client.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import asyncio
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import fire
|
||||
import httpx
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio.client import DatasetIOClient
|
||||
from llama_stack.apis.datasets.client import DatasetsClient
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import data_url_from_file
|
||||
|
||||
|
||||
class ScoringClient(Scoring):
|
||||
def __init__(self, base_url: str):
|
||||
self.base_url = base_url
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
self, dataset_id: str, scoring_functions: List[str]
|
||||
) -> ScoreBatchResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/scoring/score_batch",
|
||||
json={
|
||||
"dataset_id": dataset_id,
|
||||
"scoring_functions": scoring_functions,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return ScoreBatchResponse(**response.json())
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
) -> ScoreResponse:
|
||||
async with httpx.AsyncClient() as client:
|
||||
response = await client.post(
|
||||
f"{self.base_url}/scoring/score",
|
||||
json={
|
||||
"input_rows": input_rows,
|
||||
"scoring_functions": scoring_functions,
|
||||
},
|
||||
headers={"Content-Type": "application/json"},
|
||||
timeout=60,
|
||||
)
|
||||
response.raise_for_status()
|
||||
if not response.json():
|
||||
return
|
||||
|
||||
return ScoreResponse(**response.json())
|
||||
|
||||
|
||||
async def run_main(host: str, port: int):
|
||||
client = DatasetsClient(f"http://{host}:{port}")
|
||||
|
||||
# register dataset
|
||||
test_file = (
|
||||
Path(os.path.abspath(__file__)).parent.parent.parent
|
||||
/ "providers/tests/datasetio/test_dataset.csv"
|
||||
)
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
response = await client.register_dataset(
|
||||
DatasetDefWithProvider(
|
||||
identifier="test-dataset",
|
||||
provider_id="meta0",
|
||||
url=URL(
|
||||
uri=test_url,
|
||||
),
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
)
|
||||
)
|
||||
|
||||
# list datasets
|
||||
list_dataset = await client.list_datasets()
|
||||
cprint(list_dataset, "blue")
|
||||
|
||||
# datsetio client to get the rows
|
||||
datasetio_client = DatasetIOClient(f"http://{host}:{port}")
|
||||
response = await datasetio_client.get_rows_paginated(
|
||||
dataset_id="test-dataset",
|
||||
rows_in_page=4,
|
||||
page_token=None,
|
||||
filter_condition=None,
|
||||
)
|
||||
cprint(f"Returned {len(response.rows)} rows \n {response}", "green")
|
||||
|
||||
# scoring client to score the rows
|
||||
scoring_client = ScoringClient(f"http://{host}:{port}")
|
||||
response = await scoring_client.score(
|
||||
input_rows=response.rows,
|
||||
scoring_functions=["equality"],
|
||||
)
|
||||
cprint(f"score response={response}", "blue")
|
||||
|
||||
# test scoring batch using datasetio api
|
||||
scoring_client = ScoringClient(f"http://{host}:{port}")
|
||||
response = await scoring_client.score_batch(
|
||||
dataset_id="test-dataset",
|
||||
scoring_functions=["equality"],
|
||||
)
|
||||
cprint(f"score_batch response={response}", "cyan")
|
||||
|
||||
|
||||
def main(host: str, port: int):
|
||||
asyncio.run(run_main(host, port))
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
fire.Fire(main)
|
|
@ -13,18 +13,27 @@ from llama_models.llama3.api.datatypes import * # noqa: F403
|
|||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
|
||||
|
||||
ScoringResult = Dict[str, Any]
|
||||
# mapping of metric to value
|
||||
ScoringResultRow = Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoringResult(BaseModel):
|
||||
score_rows: List[ScoringResultRow]
|
||||
# aggregated metrics to value
|
||||
aggregated_results: Dict[str, Any]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoreBatchResponse(BaseModel):
|
||||
dataset_id: str
|
||||
dataset_id: Optional[str] = None
|
||||
results: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ScoreResponse(BaseModel):
|
||||
# each key in the dict is a scoring function name
|
||||
results: List[Dict[str, ScoringResult]]
|
||||
results: Dict[str, ScoringResult]
|
||||
|
||||
|
||||
class ScoringFunctionStore(Protocol):
|
||||
|
@ -37,7 +46,10 @@ class Scoring(Protocol):
|
|||
|
||||
@webmethod(route="/scoring/score_batch")
|
||||
async def score_batch(
|
||||
self, dataset_id: str, scoring_functions: List[str]
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse: ...
|
||||
|
||||
@webmethod(route="/scoring/score")
|
||||
|
|
|
@ -4,20 +4,10 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import (
|
||||
Any,
|
||||
Dict,
|
||||
List,
|
||||
Literal,
|
||||
Optional,
|
||||
Protocol,
|
||||
runtime_checkable,
|
||||
Union,
|
||||
)
|
||||
from typing import Any, Dict, List, Optional, Protocol, runtime_checkable
|
||||
|
||||
from llama_models.schema_utils import json_schema_type, webmethod
|
||||
from pydantic import BaseModel, Field
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.apis.common.type_system import ParamType
|
||||
|
||||
|
@ -33,45 +23,37 @@ class Parameter(BaseModel):
|
|||
# with standard metrics so they can be rolled up?
|
||||
|
||||
|
||||
class LLMAsJudgeContext(BaseModel):
|
||||
judge_model: str
|
||||
prompt_template: Optional[str] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CommonDef(BaseModel):
|
||||
name: str
|
||||
class ScoringFunctionDef(BaseModel):
|
||||
identifier: str
|
||||
description: Optional[str] = None
|
||||
metadata: Dict[str, Any] = Field(
|
||||
default_factory=dict,
|
||||
description="Any additional metadata for this definition",
|
||||
)
|
||||
# Hack: same with memory_banks for union defs
|
||||
provider_id: str = ""
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class DeterministicFunctionDef(CommonDef):
|
||||
type: Literal["deterministic"] = "deterministic"
|
||||
parameters: List[Parameter] = Field(
|
||||
description="List of parameters for the deterministic function",
|
||||
default_factory=list,
|
||||
)
|
||||
return_type: ParamType = Field(
|
||||
description="The return type of the deterministic function",
|
||||
)
|
||||
context: Optional[LLMAsJudgeContext] = None
|
||||
# We can optionally add information here to support packaging of code, etc.
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class LLMJudgeFunctionDef(CommonDef):
|
||||
type: Literal["judge"] = "judge"
|
||||
model: str = Field(
|
||||
description="The LLM model to use for the judge function",
|
||||
class ScoringFunctionDefWithProvider(ScoringFunctionDef):
|
||||
provider_id: str = Field(
|
||||
description="ID of the provider which serves this dataset",
|
||||
)
|
||||
|
||||
|
||||
ScoringFunctionDef = Annotated[
|
||||
Union[DeterministicFunctionDef, LLMJudgeFunctionDef], Field(discriminator="type")
|
||||
]
|
||||
|
||||
ScoringFunctionDefWithProvider = ScoringFunctionDef
|
||||
|
||||
|
||||
@runtime_checkable
|
||||
class ScoringFunctions(Protocol):
|
||||
@webmethod(route="/scoring_functions/list", method="GET")
|
||||
|
@ -84,5 +66,5 @@ class ScoringFunctions(Protocol):
|
|||
|
||||
@webmethod(route="/scoring_functions/register", method="POST")
|
||||
async def register_scoring_function(
|
||||
self, function: ScoringFunctionDefWithProvider
|
||||
self, function_def: ScoringFunctionDefWithProvider
|
||||
) -> None: ...
|
||||
|
|
|
@ -97,7 +97,7 @@ if [ -n "$pip_dependencies" ]; then
|
|||
fi
|
||||
|
||||
if [ -n "$special_pip_deps" ]; then
|
||||
IFS='#' read -ra parts <<< "$special_pip_deps"
|
||||
IFS='#' read -ra parts <<<"$special_pip_deps"
|
||||
for part in "${parts[@]}"; do
|
||||
add_to_docker "RUN pip install $part"
|
||||
done
|
||||
|
@ -127,7 +127,7 @@ if [ -n "$LLAMA_MODELS_DIR" ]; then
|
|||
mounts="$mounts -v $(readlink -f $LLAMA_MODELS_DIR):$models_mount"
|
||||
fi
|
||||
|
||||
if command -v selinuxenabled &> /dev/null && selinuxenabled; then
|
||||
if command -v selinuxenabled &>/dev/null && selinuxenabled; then
|
||||
# Disable SELinux labels -- we don't want to relabel the llama-stack source dir
|
||||
DOCKER_OPTS="$DOCKER_OPTS --security-opt label=disable"
|
||||
fi
|
||||
|
@ -139,4 +139,4 @@ $DOCKER_BINARY build $DOCKER_OPTS -t $image_name -f "$TEMP_DIR/Dockerfile" "$REP
|
|||
rm -rf $REPO_CONFIGS_DIR
|
||||
set +x
|
||||
|
||||
echo "Success! You can run it with: $DOCKER_BINARY $DOCKER_OPTS run -p 5000:5000 $image_name"
|
||||
echo "Success!"
|
||||
|
|
|
@ -15,10 +15,12 @@ from llama_stack.apis.models import * # noqa: F403
|
|||
from llama_stack.apis.shields import * # noqa: F403
|
||||
from llama_stack.apis.memory_banks import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.memory import Memory
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
|
||||
LLAMA_STACK_BUILD_CONFIG_VERSION = "2"
|
||||
LLAMA_STACK_RUN_CONFIG_VERSION = "2"
|
||||
|
@ -32,6 +34,7 @@ RoutableObject = Union[
|
|||
ShieldDef,
|
||||
MemoryBankDef,
|
||||
DatasetDef,
|
||||
ScoringFunctionDef,
|
||||
]
|
||||
|
||||
RoutableObjectWithProvider = Union[
|
||||
|
@ -39,6 +42,7 @@ RoutableObjectWithProvider = Union[
|
|||
ShieldDefWithProvider,
|
||||
MemoryBankDefWithProvider,
|
||||
DatasetDefWithProvider,
|
||||
ScoringFunctionDefWithProvider,
|
||||
]
|
||||
|
||||
RoutedProtocol = Union[
|
||||
|
@ -46,6 +50,7 @@ RoutedProtocol = Union[
|
|||
Safety,
|
||||
Memory,
|
||||
DatasetIO,
|
||||
Scoring,
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -39,6 +39,10 @@ def builtin_automatically_routed_apis() -> List[AutoRoutedApiInfo]:
|
|||
routing_table_api=Api.datasets,
|
||||
router_api=Api.datasetio,
|
||||
),
|
||||
AutoRoutedApiInfo(
|
||||
routing_table_api=Api.scoring_functions,
|
||||
router_api=Api.scoring,
|
||||
),
|
||||
]
|
||||
|
||||
|
||||
|
|
|
@ -20,6 +20,8 @@ from llama_stack.apis.memory import Memory
|
|||
from llama_stack.apis.memory_banks import MemoryBanks
|
||||
from llama_stack.apis.models import Models
|
||||
from llama_stack.apis.safety import Safety
|
||||
from llama_stack.apis.scoring import Scoring
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctions
|
||||
from llama_stack.apis.shields import Shields
|
||||
from llama_stack.apis.telemetry import Telemetry
|
||||
from llama_stack.distribution.distribution import (
|
||||
|
@ -42,6 +44,8 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.telemetry: Telemetry,
|
||||
Api.datasets: Datasets,
|
||||
Api.datasetio: DatasetIO,
|
||||
Api.scoring_functions: ScoringFunctions,
|
||||
Api.scoring: Scoring,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -11,6 +11,7 @@ from .routing_tables import (
|
|||
DatasetsRoutingTable,
|
||||
MemoryBanksRoutingTable,
|
||||
ModelsRoutingTable,
|
||||
ScoringFunctionsRoutingTable,
|
||||
ShieldsRoutingTable,
|
||||
)
|
||||
|
||||
|
@ -25,7 +26,9 @@ async def get_routing_table_impl(
|
|||
"models": ModelsRoutingTable,
|
||||
"shields": ShieldsRoutingTable,
|
||||
"datasets": DatasetsRoutingTable,
|
||||
"scoring_functions": ScoringFunctionsRoutingTable,
|
||||
}
|
||||
|
||||
if api.value not in api_to_tables:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
||||
|
@ -35,13 +38,20 @@ async def get_routing_table_impl(
|
|||
|
||||
|
||||
async def get_auto_router_impl(api: Api, routing_table: RoutingTable, _deps) -> Any:
|
||||
from .routers import DatasetIORouter, InferenceRouter, MemoryRouter, SafetyRouter
|
||||
from .routers import (
|
||||
DatasetIORouter,
|
||||
InferenceRouter,
|
||||
MemoryRouter,
|
||||
SafetyRouter,
|
||||
ScoringRouter,
|
||||
)
|
||||
|
||||
api_to_routers = {
|
||||
"memory": MemoryRouter,
|
||||
"inference": InferenceRouter,
|
||||
"safety": SafetyRouter,
|
||||
"datasetio": DatasetIORouter,
|
||||
"scoring": ScoringRouter,
|
||||
}
|
||||
if api.value not in api_to_routers:
|
||||
raise ValueError(f"API {api.value} not found in router map")
|
||||
|
|
|
@ -13,6 +13,7 @@ from llama_stack.apis.memory import * # noqa: F403
|
|||
from llama_stack.apis.inference import * # noqa: F403
|
||||
from llama_stack.apis.safety import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
|
||||
|
||||
class MemoryRouter(Memory):
|
||||
|
@ -198,3 +199,56 @@ class DatasetIORouter(DatasetIO):
|
|||
page_token=page_token,
|
||||
filter_condition=filter_condition,
|
||||
)
|
||||
|
||||
|
||||
class ScoringRouter(Scoring):
|
||||
def __init__(
|
||||
self,
|
||||
routing_table: RoutingTable,
|
||||
) -> None:
|
||||
self.routing_table = routing_table
|
||||
|
||||
async def initialize(self) -> None:
|
||||
pass
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
pass
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
res = {}
|
||||
for fn_identifier in scoring_functions:
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score_batch(
|
||||
dataset_id=dataset_id,
|
||||
scoring_functions=[fn_identifier],
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
if save_results_dataset:
|
||||
raise NotImplementedError("Save results dataset not implemented yet")
|
||||
|
||||
return ScoreBatchResponse(
|
||||
results=res,
|
||||
)
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
# look up and map each scoring function to its provider impl
|
||||
for fn_identifier in scoring_functions:
|
||||
score_response = await self.routing_table.get_provider_impl(
|
||||
fn_identifier
|
||||
).score(
|
||||
input_rows=input_rows,
|
||||
scoring_functions=[fn_identifier],
|
||||
)
|
||||
res.update(score_response.results)
|
||||
|
||||
return ScoreResponse(results=res)
|
||||
|
|
|
@ -30,6 +30,8 @@ async def register_object_with_provider(obj: RoutableObject, p: Any) -> None:
|
|||
await p.register_memory_bank(obj)
|
||||
elif api == Api.datasetio:
|
||||
await p.register_dataset(obj)
|
||||
elif api == Api.scoring:
|
||||
await p.register_scoring_function(obj)
|
||||
else:
|
||||
raise ValueError(f"Unknown API {api} for registering object with provider")
|
||||
|
||||
|
@ -93,7 +95,15 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
for d in datasets:
|
||||
d.provider_id = pid
|
||||
|
||||
add_objects(datasets)
|
||||
elif api == Api.scoring:
|
||||
p.scoring_function_store = self
|
||||
scoring_functions = await p.list_scoring_functions()
|
||||
add_objects(
|
||||
[
|
||||
ScoringFunctionDefWithProvider(**s.dict(), provider_id=pid)
|
||||
for s in scoring_functions
|
||||
]
|
||||
)
|
||||
|
||||
async def shutdown(self) -> None:
|
||||
for p in self.impls_by_provider_id.values():
|
||||
|
@ -109,6 +119,10 @@ class CommonRoutingTableImpl(RoutingTable):
|
|||
return ("Safety", "shield")
|
||||
elif isinstance(self, MemoryBanksRoutingTable):
|
||||
return ("Memory", "memory_bank")
|
||||
elif isinstance(self, DatasetsRoutingTable):
|
||||
return ("DatasetIO", "dataset")
|
||||
elif isinstance(self, ScoringFunctionsRoutingTable):
|
||||
return ("Scoring", "scoring_function")
|
||||
else:
|
||||
raise ValueError("Unknown routing table type")
|
||||
|
||||
|
@ -218,7 +232,25 @@ class DatasetsRoutingTable(CommonRoutingTableImpl, Datasets):
|
|||
async def get_dataset(
|
||||
self, dataset_identifier: str
|
||||
) -> Optional[DatasetDefWithProvider]:
|
||||
return self.get_object_by_identifier(identifier)
|
||||
return self.get_object_by_identifier(dataset_identifier)
|
||||
|
||||
async def register_dataset(self, dataset_def: DatasetDefWithProvider) -> None:
|
||||
await self.register_object(dataset_def)
|
||||
|
||||
|
||||
class ScoringFunctionsRoutingTable(CommonRoutingTableImpl, Scoring):
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDefWithProvider]:
|
||||
objects = []
|
||||
for objs in self.registry.values():
|
||||
objects.extend(objs)
|
||||
return objects
|
||||
|
||||
async def get_scoring_function(
|
||||
self, name: str
|
||||
) -> Optional[ScoringFunctionDefWithProvider]:
|
||||
return self.get_object_by_identifier(name)
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFunctionDefWithProvider
|
||||
) -> None:
|
||||
await self.register_object(function_def)
|
||||
|
|
|
@ -337,7 +337,8 @@ def main(
|
|||
import uvicorn
|
||||
|
||||
# FYI this does not do hot-reloads
|
||||
listen_host = "::" if not disable_ipv6 else "0.0.0.0"
|
||||
|
||||
listen_host = ["::", "0.0.0.0"] if not disable_ipv6 else "0.0.0.0"
|
||||
print(f"Listening on {listen_host}:{port}")
|
||||
uvicorn.run(app, host=listen_host, port=port)
|
||||
|
||||
|
|
|
@ -29,7 +29,7 @@ if [ $# -lt 3 ]; then
|
|||
fi
|
||||
|
||||
build_name="$1"
|
||||
docker_image="llamastack-$build_name"
|
||||
docker_image="distribution-$build_name"
|
||||
shift
|
||||
|
||||
yaml_config="$1"
|
||||
|
|
|
@ -116,7 +116,7 @@ class DatabricksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
"model": self.map_to_provider_model(request.model),
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -116,7 +116,7 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference):
|
|||
if prompt.startswith("<|begin_of_text|>"):
|
||||
prompt = prompt[len("<|begin_of_text|>") :]
|
||||
|
||||
options = get_sampling_options(request)
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
options.setdefault("max_tokens", 512)
|
||||
|
||||
if fmt := request.response_format:
|
||||
|
|
|
@ -110,7 +110,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||
sampling_options = get_sampling_options(request)
|
||||
sampling_options = get_sampling_options(request.sampling_params)
|
||||
# This is needed since the Ollama API expects num_predict to be set
|
||||
# for early truncation instead of max_tokens.
|
||||
if sampling_options["max_tokens"] is not None:
|
||||
|
@ -187,7 +187,7 @@ class OllamaInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
return {
|
||||
"model": OLLAMA_SUPPORTED_MODELS[request.model],
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"options": get_sampling_options(request),
|
||||
"options": get_sampling_options(request.sampling_params),
|
||||
"raw": True,
|
||||
"stream": request.stream,
|
||||
}
|
||||
|
|
|
@ -24,9 +24,12 @@ from llama_stack.providers.utils.inference.openai_compat import (
|
|||
OpenAICompatCompletionResponse,
|
||||
process_chat_completion_response,
|
||||
process_chat_completion_stream_response,
|
||||
process_completion_response,
|
||||
process_completion_stream_response,
|
||||
)
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
chat_completion_request_to_model_input_info,
|
||||
completion_request_to_prompt_model_input_info,
|
||||
)
|
||||
|
||||
from .config import InferenceAPIImplConfig, InferenceEndpointImplConfig, TGIImplConfig
|
||||
|
@ -75,7 +78,98 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> AsyncGenerator:
|
||||
raise NotImplementedError()
|
||||
request = CompletionRequest(
|
||||
model=model,
|
||||
content=content,
|
||||
sampling_params=sampling_params,
|
||||
response_format=response_format,
|
||||
stream=stream,
|
||||
logprobs=logprobs,
|
||||
)
|
||||
if stream:
|
||||
return self._stream_completion(request)
|
||||
else:
|
||||
return await self._nonstream_completion(request)
|
||||
|
||||
def _get_max_new_tokens(self, sampling_params, input_tokens):
|
||||
return min(
|
||||
sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
|
||||
def _build_options(
|
||||
self,
|
||||
sampling_params: Optional[SamplingParams] = None,
|
||||
fmt: ResponseFormat = None,
|
||||
):
|
||||
options = get_sampling_options(sampling_params)
|
||||
# delete key "max_tokens" from options since its not supported by the API
|
||||
options.pop("max_tokens", None)
|
||||
if fmt:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["grammar"] = {
|
||||
"type": "json",
|
||||
"value": fmt.schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise ValueError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format: {fmt.type}")
|
||||
|
||||
return options
|
||||
|
||||
def _get_params_for_completion(self, request: CompletionRequest) -> dict:
|
||||
prompt, input_tokens = completion_request_to_prompt_model_input_info(
|
||||
request, self.formatter
|
||||
)
|
||||
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
details=True,
|
||||
max_new_tokens=self._get_max_new_tokens(
|
||||
request.sampling_params, input_tokens
|
||||
),
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
)
|
||||
|
||||
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = self._get_params_for_completion(request)
|
||||
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
s = await self.client.text_generation(**params)
|
||||
async for chunk in s:
|
||||
token_result = chunk.token
|
||||
finish_reason = None
|
||||
if chunk.details:
|
||||
finish_reason = chunk.details.finish_reason
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
text=token_result.text, finish_reason=finish_reason
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
stream = _generate_and_convert_to_openai_compat()
|
||||
async for chunk in process_completion_stream_response(stream, self.formatter):
|
||||
yield chunk
|
||||
|
||||
async def _nonstream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||
params = self._get_params_for_completion(request)
|
||||
r = await self.client.text_generation(**params)
|
||||
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=r.details.finish_reason,
|
||||
text="".join(t.text for t in r.details.tokens),
|
||||
)
|
||||
|
||||
response = OpenAICompatCompletionResponse(
|
||||
choices=[choice],
|
||||
)
|
||||
|
||||
return process_completion_response(response, self.formatter)
|
||||
|
||||
async def chat_completion(
|
||||
self,
|
||||
|
@ -146,29 +240,15 @@ class _HfAdapter(Inference, ModelsProtocolPrivate):
|
|||
prompt, input_tokens = chat_completion_request_to_model_input_info(
|
||||
request, self.formatter
|
||||
)
|
||||
max_new_tokens = min(
|
||||
request.sampling_params.max_tokens or (self.max_tokens - input_tokens),
|
||||
self.max_tokens - input_tokens - 1,
|
||||
)
|
||||
options = get_sampling_options(request)
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["grammar"] = {
|
||||
"type": "json",
|
||||
"value": fmt.schema,
|
||||
}
|
||||
elif fmt.type == ResponseFormatType.grammar.value:
|
||||
raise ValueError("Grammar response format not supported yet")
|
||||
else:
|
||||
raise ValueError(f"Unexpected response format: {fmt.type}")
|
||||
|
||||
return dict(
|
||||
prompt=prompt,
|
||||
stream=request.stream,
|
||||
details=True,
|
||||
max_new_tokens=max_new_tokens,
|
||||
max_new_tokens=self._get_max_new_tokens(
|
||||
request.sampling_params, input_tokens
|
||||
),
|
||||
stop_sequences=["<|eom_id|>", "<|eot_id|>"],
|
||||
**options,
|
||||
**self._build_options(request.sampling_params, request.response_format),
|
||||
)
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -131,7 +131,7 @@ class TogetherInferenceAdapter(
|
|||
yield chunk
|
||||
|
||||
def _get_params(self, request: ChatCompletionRequest) -> dict:
|
||||
options = get_sampling_options(request)
|
||||
options = get_sampling_options(request.sampling_params)
|
||||
if fmt := request.response_format:
|
||||
if fmt.type == ResponseFormatType.json_schema.value:
|
||||
options["response_format"] = {
|
||||
|
|
|
@ -143,7 +143,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
|||
"model": VLLM_SUPPORTED_MODELS[request.model],
|
||||
"prompt": chat_completion_request_to_prompt(request, self.formatter),
|
||||
"stream": request.stream,
|
||||
**get_sampling_options(request),
|
||||
**get_sampling_options(request.sampling_params),
|
||||
}
|
||||
|
||||
async def embeddings(
|
||||
|
|
|
@ -11,10 +11,9 @@ from llama_models.schema_utils import json_schema_type
|
|||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.apis.datasets import DatasetDef
|
||||
|
||||
from llama_stack.apis.memory_banks import MemoryBankDef
|
||||
|
||||
from llama_stack.apis.models import ModelDef
|
||||
from llama_stack.apis.scoring_functions import ScoringFunctionDef
|
||||
from llama_stack.apis.shields import ShieldDef
|
||||
|
||||
|
||||
|
@ -25,6 +24,7 @@ class Api(Enum):
|
|||
agents = "agents"
|
||||
memory = "memory"
|
||||
datasetio = "datasetio"
|
||||
scoring = "scoring"
|
||||
|
||||
telemetry = "telemetry"
|
||||
|
||||
|
@ -32,6 +32,7 @@ class Api(Enum):
|
|||
shields = "shields"
|
||||
memory_banks = "memory_banks"
|
||||
datasets = "datasets"
|
||||
scoring_functions = "scoring_functions"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
@ -61,6 +62,14 @@ class DatasetsProtocolPrivate(Protocol):
|
|||
async def register_datasets(self, dataset_def: DatasetDef) -> None: ...
|
||||
|
||||
|
||||
class ScoringFunctionsProtocolPrivate(Protocol):
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDef]: ...
|
||||
|
||||
async def register_scoring_function(
|
||||
self, function_def: ScoringFunctionDef
|
||||
) -> None: ...
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ProviderSpec(BaseModel):
|
||||
api: Api
|
||||
|
|
|
@ -56,9 +56,20 @@ We're working on making LocalInference easier to set up. For now, you'll need t
|
|||
|
||||
## Preparing a model
|
||||
|
||||
1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama2/README.md#step-2-prepare-model)
|
||||
1. Prepare a `.pte` file [following the executorch docs](https://github.com/pytorch/executorch/blob/main/examples/models/llama/README.md#step-2-prepare-model)
|
||||
2. Bundle the `.pte` and `tokenizer.model` file into your app
|
||||
|
||||
We now support models quantized using SpinQuant and QAT-LoRA which offer a significant performance boost (demo app on iPhone 13 Pro):
|
||||
|
||||
|
||||
| Llama 3.2 1B | Tokens / Second (total) | | Time-to-First-Token (sec) | |
|
||||
| :---- | :---- | :---- | :---- | :---- |
|
||||
| | Haiku | Paragraph | Haiku | Paragraph |
|
||||
| BF16 | 2.2 | 2.5 | 2.3 | 1.9 |
|
||||
| QAT+LoRA | 7.1 | 3.3 | 0.37 | 0.24 |
|
||||
| SpinQuant | 10.1 | 5.2 | 0.2 | 0.2 |
|
||||
|
||||
|
||||
## Using LocalInference
|
||||
|
||||
1. Instantiate LocalInference with a DispatchQueue. Optionally, pass it into your agents service:
|
||||
|
|
|
@ -169,7 +169,7 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
turn_ids: Optional[List[str]] = None,
|
||||
) -> Session:
|
||||
session = await self.persistence_store.get(f"session:{agent_id}:{session_id}")
|
||||
session = Session(**json.loads(session))
|
||||
session = Session(**json.loads(session), turns=[])
|
||||
turns = []
|
||||
if turn_ids:
|
||||
for turn_id in turn_ids:
|
||||
|
|
|
@ -3,17 +3,20 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import io
|
||||
from typing import List, Optional
|
||||
|
||||
import pandas
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
import base64
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from urllib.parse import unquote
|
||||
|
||||
from llama_stack.providers.datatypes import DatasetsProtocolPrivate
|
||||
from llama_stack.providers.utils.memory.vector_store import parse_data_url
|
||||
|
||||
from .config import MetaReferenceDatasetIOConfig
|
||||
|
||||
|
@ -52,11 +55,20 @@ class PandasDataframeDataset(BaseDataset):
|
|||
return len(self.df)
|
||||
|
||||
def __getitem__(self, idx):
|
||||
assert self.df is not None, "Dataset not loaded. Please call .load() first"
|
||||
if isinstance(idx, slice):
|
||||
return self.df.iloc[idx].to_dict(orient="records")
|
||||
else:
|
||||
return self.df.iloc[idx].to_dict()
|
||||
|
||||
def _validate_dataset_schema(self, df) -> pandas.DataFrame:
|
||||
# note that we will drop any columns in dataset that are not in the schema
|
||||
df = df[self.dataset_def.dataset_schema.keys()]
|
||||
# check all columns in dataset schema are present
|
||||
assert len(df.columns) == len(self.dataset_def.dataset_schema)
|
||||
# TODO: type checking against column types in dataset schema
|
||||
return df
|
||||
|
||||
def load(self) -> None:
|
||||
if self.df is not None:
|
||||
return
|
||||
|
@ -87,7 +99,7 @@ class PandasDataframeDataset(BaseDataset):
|
|||
else:
|
||||
raise ValueError(f"Unsupported file type: {self.dataset_def.url}")
|
||||
|
||||
self.df = df
|
||||
self.df = self._validate_dataset_schema(df)
|
||||
|
||||
|
||||
class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
||||
|
@ -123,7 +135,10 @@ class MetaReferenceDatasetIOImpl(DatasetIO, DatasetsProtocolPrivate):
|
|||
dataset_info = self.dataset_infos.get(dataset_id)
|
||||
dataset_info.dataset_impl.load()
|
||||
|
||||
if page_token is None:
|
||||
if page_token and not page_token.isnumeric():
|
||||
raise ValueError("Invalid page_token")
|
||||
|
||||
if page_token is None or len(page_token) == 0:
|
||||
next_page_token = 0
|
||||
else:
|
||||
next_page_token = int(page_token)
|
||||
|
|
|
@ -30,7 +30,6 @@ from llama_models.llama3.reference_impl.multimodal.model import (
|
|||
CrossAttentionTransformer,
|
||||
)
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from pydantic import BaseModel
|
||||
from termcolor import cprint
|
||||
|
||||
|
@ -43,7 +42,12 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
chat_completion_request_to_messages,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .config import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
MetaReferenceInferenceConfig,
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
)
|
||||
|
||||
|
||||
def model_checkpoint_dir(model) -> str:
|
||||
|
@ -131,18 +135,34 @@ class Llama:
|
|||
), f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
from .quantization.loader import convert_to_fp8_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
from .quantization.loader import convert_to_int4_quantized_model
|
||||
|
||||
assert (
|
||||
config.quantization.scheme is not None
|
||||
), "Please specify a quantization scheme."
|
||||
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_quantized_model(model, config, ckpt_dir)
|
||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
else:
|
||||
raise NotImplementedError(
|
||||
"Currently int4 and fp8 are the only supported quantization methods."
|
||||
)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
|
|
@ -8,19 +8,25 @@
|
|||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
import torch
|
||||
|
||||
from fairscale.nn.model_parallel.layers import ColumnParallelLinear, RowParallelLinear
|
||||
from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_region
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
|
||||
from llama_models.datatypes import CheckpointQuantizationFormat
|
||||
|
||||
from llama_models.llama3.api.args import ModelArgs
|
||||
from llama_models.llama3.reference_impl.model import Transformer, TransformerBlock
|
||||
from llama_models.sku_list import resolve_model
|
||||
from termcolor import cprint
|
||||
from torch import Tensor
|
||||
from torch import nn, Tensor
|
||||
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
from llama_stack.apis.inference.inference import Int4QuantizationConfig
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.inference.config import (
|
||||
MetaReferenceQuantizedInferenceConfig,
|
||||
|
@ -37,7 +43,7 @@ def swiglu_wrapper(
|
|||
return reduce_from_model_parallel_region(out)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
def convert_to_fp8_quantized_model(
|
||||
model: Transformer,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
checkpoint_dir: str,
|
||||
|
@ -99,3 +105,241 @@ def convert_to_quantized_model(
|
|||
if not isinstance(parameter, Fp8ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
return model
|
||||
|
||||
|
||||
class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||
"""
|
||||
Int8DynActInt4WeightLinear with LoRA adaptor.
|
||||
|
||||
Args:
|
||||
in_features: Number of input features.
|
||||
out_features: Number of output features.
|
||||
bias: Whether to use bias.
|
||||
device: Device to use.
|
||||
group_size: Group size for quantization.
|
||||
precision: Precision of quantization.
|
||||
scales_precision: Precision of scales.
|
||||
lora_rank: Rank of LoRA adaptor.
|
||||
lora_scale: Scale of LoRA adaptor.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
in_features: int,
|
||||
out_features: int,
|
||||
bias=False,
|
||||
device=None,
|
||||
# quantization parameters
|
||||
group_size: int = 256,
|
||||
precision: torch.dtype = torch.float32,
|
||||
scales_precision: torch.dtype = torch.float32,
|
||||
# LoRA parameters
|
||||
lora_rank: Optional[int] = None,
|
||||
lora_scale: Optional[float] = None,
|
||||
) -> None:
|
||||
super().__init__(
|
||||
in_features,
|
||||
out_features,
|
||||
bias=bias,
|
||||
device=device,
|
||||
groupsize=group_size,
|
||||
precision=precision,
|
||||
scales_precision=scales_precision,
|
||||
)
|
||||
if lora_rank is not None:
|
||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||
self.adaptor = nn.Sequential()
|
||||
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
||||
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
||||
self.lora_scale = lora_scale
|
||||
else:
|
||||
self.adaptor = None
|
||||
self.lora_scale = None
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized weights from the state dict."""
|
||||
if prefix + "zeros" not in state_dict:
|
||||
# Zero-point may not be saved in the state dict. In this case, we assume it's zero.
|
||||
assert prefix + "scales" in state_dict
|
||||
state_dict[prefix + "zeros"] = torch.zeros_like(
|
||||
state_dict[prefix + "scales"]
|
||||
)
|
||||
|
||||
def forward(self, input_: torch.Tensor) -> torch.Tensor:
|
||||
module_out = super().forward(input_)
|
||||
if self.adaptor is not None:
|
||||
adaptor_out = self.adaptor(input_) * self.lora_scale
|
||||
return module_out + adaptor_out
|
||||
return module_out
|
||||
|
||||
|
||||
class Int8WeightEmbedding(torch.nn.Embedding):
|
||||
"""An embedding layer to load int8 weights.
|
||||
|
||||
Args:
|
||||
num_embeddings: Number of embeddings.
|
||||
embedding_dim: Embedding dimension.
|
||||
padding_idx: Padding index.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_embeddings: int,
|
||||
embedding_dim: int,
|
||||
padding_idx: int,
|
||||
device=None,
|
||||
) -> None:
|
||||
super().__init__(num_embeddings, embedding_dim, padding_idx, device=device)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized embedding weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
scales = state_dict.pop(prefix + "scales")
|
||||
state_dict[prefix + "weight"] = weights * scales
|
||||
|
||||
|
||||
class Int8WeightLinear(torch.nn.Linear):
|
||||
"""A linear layer to load int8 weights.
|
||||
|
||||
Args:
|
||||
in_features: Number of input features.
|
||||
out_features: Number of output features.
|
||||
bias: Whether to use bias.
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self, in_features: int, out_features: int, bias: bool = True, device=None
|
||||
) -> None:
|
||||
super().__init__(in_features, out_features, bias, device=device)
|
||||
|
||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||
|
||||
def load_hook(
|
||||
self,
|
||||
state_dict: Dict[str, Any],
|
||||
prefix: str,
|
||||
local_metadata: Dict[str, Any],
|
||||
strict: bool,
|
||||
missing_keys: List[str],
|
||||
unexpected_keys: List[str],
|
||||
error_msgs: List[str],
|
||||
) -> None:
|
||||
"""A hook to load the quantized linear weight and scales from the state dict."""
|
||||
weights = state_dict.pop(prefix + "weight")
|
||||
scales = state_dict.pop(prefix + "scales")
|
||||
state_dict[prefix + "weight"] = weights * scales
|
||||
|
||||
|
||||
def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model: torch.nn.Module,
|
||||
group_size: int,
|
||||
lora_rank: Optional[int],
|
||||
lora_scale: Optional[float],
|
||||
):
|
||||
"""Prepare the model for int4 weight and int8 dynamic activation quantization.
|
||||
|
||||
Note that the weights of embedding and output layers are quantized to int8.
|
||||
"""
|
||||
device = None
|
||||
for module_name, module in model.named_children():
|
||||
if module_name == "output":
|
||||
quantized_module = Int8WeightLinear(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=module.bias,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif module_name == "tok_embeddings":
|
||||
quantized_module = Int8WeightEmbedding(
|
||||
num_embeddings=module.num_embeddings,
|
||||
embedding_dim=module.embedding_dim,
|
||||
padding_idx=module.padding_idx,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
elif isinstance(module, (ColumnParallelLinear, RowParallelLinear, nn.Linear)):
|
||||
quantized_module = Int8DynActInt4WeightLinearLoRA(
|
||||
in_features=module.in_features,
|
||||
out_features=module.out_features,
|
||||
bias=False,
|
||||
group_size=group_size,
|
||||
lora_rank=lora_rank,
|
||||
lora_scale=lora_scale,
|
||||
device=device,
|
||||
)
|
||||
del module
|
||||
setattr(model, module_name, quantized_module)
|
||||
else:
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
||||
module, group_size, lora_rank, lora_scale
|
||||
)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
def convert_to_int4_quantized_model(
|
||||
model: Transformer,
|
||||
model_args: ModelArgs,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
) -> Transformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
|
||||
quant_config = config.quantization
|
||||
if not isinstance(quant_config, Int4QuantizationConfig):
|
||||
raise ValueError("Only int4 quantization is supported")
|
||||
|
||||
if quant_config.type != QuantizationType.int4.value:
|
||||
raise ValueError("Only int4 quantization is supported")
|
||||
|
||||
if quant_config.scheme != "int4_weight_int8_dynamic_activation":
|
||||
raise NotImplementedError(
|
||||
"Only int4 quantization with 'int4_weight_int8_dynamic_activation' scheme is supported."
|
||||
)
|
||||
|
||||
if model_args.quantization_args is None:
|
||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
||||
|
||||
group_size = model_args.quantization_args.group_size
|
||||
if group_size is None:
|
||||
raise ValueError(
|
||||
"'group_size' cannot be None in 'quantization_args'. Please specify it."
|
||||
)
|
||||
|
||||
if model_args.lora_args is None:
|
||||
# Certain quantized models (e.g., SpinQuant) may not have LoRA.
|
||||
lora_rank = None
|
||||
lora_scale = None
|
||||
else:
|
||||
lora_rank = model_args.lora_args.rank
|
||||
lora_scale = model_args.lora_args.scale
|
||||
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(
|
||||
model, group_size, lora_rank, lora_scale
|
||||
)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
return model.to(device)
|
||||
|
|
|
@ -47,7 +47,9 @@ class FaissIndex(EmbeddingIndex):
|
|||
|
||||
self.index.add(np.array(embeddings).astype(np.float32))
|
||||
|
||||
async def query(self, embedding: NDArray, k: int) -> QueryDocumentsResponse:
|
||||
async def query(
|
||||
self, embedding: NDArray, k: int, score_threshold: float
|
||||
) -> QueryDocumentsResponse:
|
||||
distances, indices = self.index.search(
|
||||
embedding.reshape(1, -1).astype(np.float32), k
|
||||
)
|
||||
|
|
|
@ -0,0 +1,21 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import Dict
|
||||
|
||||
from llama_stack.distribution.datatypes import Api, ProviderSpec
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: MetaReferenceScoringConfig,
|
||||
deps: Dict[Api, ProviderSpec],
|
||||
):
|
||||
from .scoring import MetaReferenceScoringImpl
|
||||
|
||||
impl = MetaReferenceScoringImpl(config, deps[Api.datasetio], deps[Api.datasets])
|
||||
await impl.initialize()
|
||||
return impl
|
|
@ -0,0 +1,9 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
|
||||
|
||||
class MetaReferenceScoringConfig(BaseModel): ...
|
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,37 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
|
||||
|
||||
class BaseScorer(ABC):
|
||||
"""
|
||||
Base interface class for all meta-reference scorers.
|
||||
Each scorer needs to implement the following methods:
|
||||
- score_row(self, row)
|
||||
- aggregate(self, scorer_results)
|
||||
"""
|
||||
|
||||
scoring_function_def: ScoringFunctionDef
|
||||
|
||||
def __init__(self, *args, **kwargs) -> None:
|
||||
super().__init__(*args, **kwargs)
|
||||
|
||||
def __str__(self) -> str:
|
||||
return self.__class__.__name__
|
||||
|
||||
@abstractmethod
|
||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||
raise NotImplementedError()
|
||||
|
||||
@abstractmethod
|
||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
raise NotImplementedError()
|
||||
|
||||
def score(self, input_rows: List[Dict[str, Any]]) -> List[ScoringResultRow]:
|
||||
return [self.score_row(input_row) for input_row in input_rows]
|
|
@ -0,0 +1,49 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scorer.base_scorer import (
|
||||
BaseScorer,
|
||||
)
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F401, F403
|
||||
from llama_stack.apis.scoring import * # noqa: F401, F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
|
||||
|
||||
class EqualityScorer(BaseScorer):
|
||||
"""
|
||||
A scorer that assigns a score of 1.0 if the input string matches the target string, and 0.0 otherwise.
|
||||
"""
|
||||
|
||||
scoring_function_def = ScoringFunctionDef(
|
||||
identifier="equality",
|
||||
description="Returns 1.0 if the input is equal to the target, 0.0 otherwise.",
|
||||
parameters=[],
|
||||
return_type=NumberType(),
|
||||
)
|
||||
|
||||
def score_row(self, input_row: Dict[str, Any]) -> ScoringResultRow:
|
||||
assert "expected_answer" in input_row, "Expected answer not found in input row."
|
||||
assert (
|
||||
"generated_answer" in input_row
|
||||
), "Generated answer not found in input row."
|
||||
|
||||
expected_answer = input_row["expected_answer"]
|
||||
generated_answer = input_row["generated_answer"]
|
||||
score = 1.0 if expected_answer == generated_answer else 0.0
|
||||
return {
|
||||
"score": score,
|
||||
}
|
||||
|
||||
def aggregate(self, scoring_results: List[ScoringResultRow]) -> Dict[str, Any]:
|
||||
assert len(scoring_results) > 0, "Empty scoring results provided."
|
||||
num_correct = sum(result["score"] for result in scoring_results)
|
||||
avg_score = num_correct / len(scoring_results)
|
||||
|
||||
return {
|
||||
"accuracy": avg_score,
|
||||
"num_correct": num_correct,
|
||||
"num_total": len(scoring_results),
|
||||
}
|
109
llama_stack/providers/impls/meta_reference/scoring/scoring.py
Normal file
109
llama_stack/providers/impls/meta_reference/scoring/scoring.py
Normal file
|
@ -0,0 +1,109 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
from typing import List
|
||||
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_stack.apis.scoring import * # noqa: F403
|
||||
from llama_stack.apis.scoring_functions import * # noqa: F403
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.apis.datasets import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.datatypes import ScoringFunctionsProtocolPrivate
|
||||
from llama_stack.providers.impls.meta_reference.scoring.scorer.equality_scorer import (
|
||||
EqualityScorer,
|
||||
)
|
||||
|
||||
from .config import MetaReferenceScoringConfig
|
||||
|
||||
SUPPORTED_SCORERS = [
|
||||
EqualityScorer,
|
||||
]
|
||||
|
||||
SCORER_REGISTRY = {x.scoring_function_def.identifier: x for x in SUPPORTED_SCORERS}
|
||||
|
||||
|
||||
class MetaReferenceScoringImpl(Scoring, ScoringFunctionsProtocolPrivate):
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceScoringConfig,
|
||||
datasetio_api: DatasetIO,
|
||||
datasets_api: Datasets,
|
||||
) -> None:
|
||||
self.config = config
|
||||
self.datasetio_api = datasetio_api
|
||||
self.datasets_api = datasets_api
|
||||
|
||||
async def initialize(self) -> None: ...
|
||||
|
||||
async def shutdown(self) -> None: ...
|
||||
|
||||
async def list_scoring_functions(self) -> List[ScoringFunctionDef]:
|
||||
return [x.scoring_function_def for x in SUPPORTED_SCORERS]
|
||||
|
||||
async def register_scoring_function(self, function_def: ScoringFunctionDef) -> None:
|
||||
raise NotImplementedError(
|
||||
"Dynamically registering scoring functions is not supported"
|
||||
)
|
||||
|
||||
async def validate_scoring_input_dataset_schema(self, dataset_id: str) -> None:
|
||||
dataset_def = await self.datasets_api.get_dataset(dataset_identifier=dataset_id)
|
||||
if not dataset_def.dataset_schema or len(dataset_def.dataset_schema) == 0:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a schema defined. Please define a schema for the dataset."
|
||||
)
|
||||
|
||||
for required_column in ["generated_answer", "expected_answer", "input_query"]:
|
||||
if required_column not in dataset_def.dataset_schema:
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column."
|
||||
)
|
||||
if dataset_def.dataset_schema[required_column].type != "string":
|
||||
raise ValueError(
|
||||
f"Dataset {dataset_id} does not have a '{required_column}' column of type 'string'."
|
||||
)
|
||||
|
||||
async def score_batch(
|
||||
self,
|
||||
dataset_id: str,
|
||||
scoring_functions: List[str],
|
||||
save_results_dataset: bool = False,
|
||||
) -> ScoreBatchResponse:
|
||||
await self.validate_scoring_input_dataset_schema(dataset_id=dataset_id)
|
||||
all_rows = await self.datasetio_api.get_rows_paginated(
|
||||
dataset_id=dataset_id,
|
||||
rows_in_page=-1,
|
||||
)
|
||||
res = await self.score(
|
||||
input_rows=all_rows.rows, scoring_functions=scoring_functions
|
||||
)
|
||||
if save_results_dataset:
|
||||
# TODO: persist and register dataset on to server for reading
|
||||
# self.datasets_api.register_dataset()
|
||||
raise NotImplementedError("Save results dataset not implemented yet")
|
||||
|
||||
return ScoreBatchResponse(
|
||||
results=res.results,
|
||||
)
|
||||
|
||||
async def score(
|
||||
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str]
|
||||
) -> ScoreResponse:
|
||||
res = {}
|
||||
for scoring_fn_id in scoring_functions:
|
||||
if scoring_fn_id not in SCORER_REGISTRY:
|
||||
raise ValueError(f"Scoring function {scoring_fn_id} is not supported.")
|
||||
scorer = SCORER_REGISTRY[scoring_fn_id]()
|
||||
score_results = scorer.score(input_rows)
|
||||
agg_results = scorer.aggregate(score_results)
|
||||
res[scoring_fn_id] = ScoringResult(
|
||||
score_rows=score_results,
|
||||
aggregated_results=agg_results,
|
||||
)
|
||||
|
||||
return ScoreResponse(
|
||||
results=res,
|
||||
)
|
|
@ -15,13 +15,24 @@ class VLLMConfig(BaseModel):
|
|||
"""Configuration for the vLLM inference provider."""
|
||||
|
||||
model: str = Field(
|
||||
default="Llama3.1-8B-Instruct",
|
||||
default="Llama3.2-3B-Instruct",
|
||||
description="Model descriptor from `llama model list`",
|
||||
)
|
||||
tensor_parallel_size: int = Field(
|
||||
default=1,
|
||||
description="Number of tensor parallel replicas (number of GPUs to use).",
|
||||
)
|
||||
max_tokens: int = Field(
|
||||
default=4096,
|
||||
description="Maximum number of tokens to generate.",
|
||||
)
|
||||
enforce_eager: bool = Field(
|
||||
default=False,
|
||||
description="Whether to use eager mode for inference (otherwise cuda graphs are used).",
|
||||
)
|
||||
gpu_memory_utilization: float = Field(
|
||||
default=0.3,
|
||||
)
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
|
|
|
@ -7,11 +7,12 @@
|
|||
import logging
|
||||
import os
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator
|
||||
from typing import AsyncGenerator, Optional
|
||||
|
||||
from llama_models.llama3.api.chat_format import ChatFormat
|
||||
from llama_models.llama3.api.datatypes import * # noqa: F403
|
||||
from llama_models.llama3.api.tokenizer import Tokenizer
|
||||
from llama_models.sku_list import resolve_model
|
||||
|
||||
from vllm.engine.arg_utils import AsyncEngineArgs
|
||||
from vllm.engine.async_llm_engine import AsyncLLMEngine
|
||||
|
@ -19,7 +20,7 @@ from vllm.sampling_params import SamplingParams as VLLMSamplingParams
|
|||
|
||||
from llama_stack.apis.inference import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
from llama_stack.providers.datatypes import ModelDef, ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
OpenAICompatCompletionChoice,
|
||||
OpenAICompatCompletionResponse,
|
||||
|
@ -40,74 +41,15 @@ def _random_uuid() -> str:
|
|||
return str(uuid.uuid4().hex)
|
||||
|
||||
|
||||
def _vllm_sampling_params(sampling_params: Any) -> VLLMSamplingParams:
|
||||
"""Convert sampling params to vLLM sampling params."""
|
||||
if sampling_params is None:
|
||||
return VLLMSamplingParams()
|
||||
|
||||
# TODO convert what I saw in my first test ... but surely there's more to do here
|
||||
kwargs = {
|
||||
"temperature": sampling_params.temperature,
|
||||
}
|
||||
if sampling_params.top_k >= 1:
|
||||
kwargs["top_k"] = sampling_params.top_k
|
||||
if sampling_params.top_p:
|
||||
kwargs["top_p"] = sampling_params.top_p
|
||||
if sampling_params.max_tokens >= 1:
|
||||
kwargs["max_tokens"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
return VLLMSamplingParams(**kwargs)
|
||||
|
||||
|
||||
class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
||||
class VLLMInferenceImpl(Inference, ModelsProtocolPrivate):
|
||||
"""Inference implementation for vLLM."""
|
||||
|
||||
HF_MODEL_MAPPINGS = {
|
||||
# TODO: seems like we should be able to build this table dynamically ...
|
||||
"Llama3.1-8B": "meta-llama/Llama-3.1-8B",
|
||||
"Llama3.1-70B": "meta-llama/Llama-3.1-70B",
|
||||
"Llama3.1-405B:bf16-mp8": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-405B": "meta-llama/Llama-3.1-405B-FP8",
|
||||
"Llama3.1-405B:bf16-mp16": "meta-llama/Llama-3.1-405B",
|
||||
"Llama3.1-8B-Instruct": "meta-llama/Llama-3.1-8B-Instruct",
|
||||
"Llama3.1-70B-Instruct": "meta-llama/Llama-3.1-70B-Instruct",
|
||||
"Llama3.1-405B-Instruct:bf16-mp8": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.1-405B-Instruct": "meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
"Llama3.1-405B-Instruct:bf16-mp16": "meta-llama/Llama-3.1-405B-Instruct",
|
||||
"Llama3.2-1B": "meta-llama/Llama-3.2-1B",
|
||||
"Llama3.2-3B": "meta-llama/Llama-3.2-3B",
|
||||
"Llama3.2-11B-Vision": "meta-llama/Llama-3.2-11B-Vision",
|
||||
"Llama3.2-90B-Vision": "meta-llama/Llama-3.2-90B-Vision",
|
||||
"Llama3.2-1B-Instruct": "meta-llama/Llama-3.2-1B-Instruct",
|
||||
"Llama3.2-3B-Instruct": "meta-llama/Llama-3.2-3B-Instruct",
|
||||
"Llama3.2-11B-Vision-Instruct": "meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
"Llama3.2-90B-Vision-Instruct": "meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
"Llama-Guard-3-11B-Vision": "meta-llama/Llama-Guard-3-11B-Vision",
|
||||
"Llama-Guard-3-1B:int4-mp1": "meta-llama/Llama-Guard-3-1B-INT4",
|
||||
"Llama-Guard-3-1B": "meta-llama/Llama-Guard-3-1B",
|
||||
"Llama-Guard-3-8B": "meta-llama/Llama-Guard-3-8B",
|
||||
"Llama-Guard-3-8B:int8-mp1": "meta-llama/Llama-Guard-3-8B-INT8",
|
||||
"Prompt-Guard-86M": "meta-llama/Prompt-Guard-86M",
|
||||
"Llama-Guard-2-8B": "meta-llama/Llama-Guard-2-8B",
|
||||
}
|
||||
|
||||
def __init__(self, config: VLLMConfig):
|
||||
Inference.__init__(self)
|
||||
ModelRegistryHelper.__init__(
|
||||
self,
|
||||
stack_to_provider_models_map=self.HF_MODEL_MAPPINGS,
|
||||
)
|
||||
self.config = config
|
||||
self.engine = None
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.formatter = ChatFormat(Tokenizer.get_instance())
|
||||
|
||||
async def initialize(self):
|
||||
"""Initialize the vLLM inference adapter."""
|
||||
|
||||
log.info("Initializing vLLM inference adapter")
|
||||
|
||||
# Disable usage stats reporting. This would be a surprising thing for most
|
||||
|
@ -116,15 +58,22 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
if "VLLM_NO_USAGE_STATS" not in os.environ:
|
||||
os.environ["VLLM_NO_USAGE_STATS"] = "1"
|
||||
|
||||
hf_model = self.HF_MODEL_MAPPINGS.get(self.config.model)
|
||||
model = resolve_model(self.config.model)
|
||||
if model is None:
|
||||
raise ValueError(f"Unknown model {self.config.model}")
|
||||
|
||||
if model.huggingface_repo is None:
|
||||
raise ValueError(f"Model {self.config.model} needs a huggingface repo")
|
||||
|
||||
# TODO -- there are a ton of options supported here ...
|
||||
engine_args = AsyncEngineArgs()
|
||||
engine_args.model = hf_model
|
||||
# We will need a new config item for this in the future if model support is more broad
|
||||
# than it is today (llama only)
|
||||
engine_args.tokenizer = hf_model
|
||||
engine_args.tensor_parallel_size = self.config.tensor_parallel_size
|
||||
engine_args = AsyncEngineArgs(
|
||||
model=model.huggingface_repo,
|
||||
tokenizer=model.huggingface_repo,
|
||||
tensor_parallel_size=self.config.tensor_parallel_size,
|
||||
enforce_eager=self.config.enforce_eager,
|
||||
gpu_memory_utilization=self.config.gpu_memory_utilization,
|
||||
guided_decoding_backend="lm-format-enforcer",
|
||||
)
|
||||
|
||||
self.engine = AsyncLLMEngine.from_engine_args(engine_args)
|
||||
|
||||
|
@ -134,13 +83,47 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
if self.engine:
|
||||
self.engine.shutdown_background_loop()
|
||||
|
||||
async def register_model(self, model: ModelDef) -> None:
|
||||
raise ValueError(
|
||||
"You cannot dynamically add a model to a running vllm instance"
|
||||
)
|
||||
|
||||
async def list_models(self) -> List[ModelDef]:
|
||||
return [
|
||||
ModelDef(
|
||||
identifier=self.config.model,
|
||||
llama_model=self.config.model,
|
||||
)
|
||||
]
|
||||
|
||||
def _sampling_params(self, sampling_params: SamplingParams) -> VLLMSamplingParams:
|
||||
if sampling_params is None:
|
||||
return VLLMSamplingParams(max_tokens=self.config.max_tokens)
|
||||
|
||||
# TODO convert what I saw in my first test ... but surely there's more to do here
|
||||
kwargs = {
|
||||
"temperature": sampling_params.temperature,
|
||||
"max_tokens": self.config.max_tokens,
|
||||
}
|
||||
if sampling_params.top_k:
|
||||
kwargs["top_k"] = sampling_params.top_k
|
||||
if sampling_params.top_p:
|
||||
kwargs["top_p"] = sampling_params.top_p
|
||||
if sampling_params.max_tokens:
|
||||
kwargs["max_tokens"] = sampling_params.max_tokens
|
||||
if sampling_params.repetition_penalty > 0:
|
||||
kwargs["repetition_penalty"] = sampling_params.repetition_penalty
|
||||
|
||||
return VLLMSamplingParams(**kwargs)
|
||||
|
||||
async def completion(
|
||||
self,
|
||||
model: str,
|
||||
content: InterleavedTextMedia,
|
||||
sampling_params: Any | None = ...,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> CompletionResponse | CompletionResponseStreamChunk:
|
||||
log.info("vLLM completion")
|
||||
messages = [UserMessage(content=content)]
|
||||
|
@ -155,13 +138,14 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
async def chat_completion(
|
||||
self,
|
||||
model: str,
|
||||
messages: list[Message],
|
||||
sampling_params: Any | None = ...,
|
||||
tools: list[ToolDefinition] | None = ...,
|
||||
tool_choice: ToolChoice | None = ...,
|
||||
tool_prompt_format: ToolPromptFormat | None = ...,
|
||||
stream: bool | None = False,
|
||||
logprobs: LogProbConfig | None = None,
|
||||
messages: List[Message],
|
||||
sampling_params: Optional[SamplingParams] = SamplingParams(),
|
||||
tools: Optional[List[ToolDefinition]] = None,
|
||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||
tool_prompt_format: Optional[ToolPromptFormat] = ToolPromptFormat.json,
|
||||
response_format: Optional[ResponseFormat] = None,
|
||||
stream: Optional[bool] = False,
|
||||
logprobs: Optional[LogProbConfig] = None,
|
||||
) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk:
|
||||
log.info("vLLM chat completion")
|
||||
|
||||
|
@ -182,7 +166,7 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
request_id = _random_uuid()
|
||||
|
||||
prompt = chat_completion_request_to_prompt(request, self.formatter)
|
||||
vllm_sampling_params = _vllm_sampling_params(request.sampling_params)
|
||||
vllm_sampling_params = self._sampling_params(request.sampling_params)
|
||||
results_generator = self.engine.generate(
|
||||
prompt, vllm_sampling_params, request_id
|
||||
)
|
||||
|
@ -213,14 +197,19 @@ class VLLMInferenceImpl(ModelRegistryHelper, Inference):
|
|||
self, request: ChatCompletionRequest, results_generator: AsyncGenerator
|
||||
) -> AsyncGenerator:
|
||||
async def _generate_and_convert_to_openai_compat():
|
||||
cur = []
|
||||
async for chunk in results_generator:
|
||||
if not chunk.outputs:
|
||||
log.warning("Empty chunk received")
|
||||
continue
|
||||
|
||||
text = "".join([output.text for output in chunk.outputs])
|
||||
output = chunk.outputs[-1]
|
||||
|
||||
new_tokens = output.token_ids[len(cur) :]
|
||||
text = self.formatter.tokenizer.decode(new_tokens)
|
||||
cur.extend(new_tokens)
|
||||
choice = OpenAICompatCompletionChoice(
|
||||
finish_reason=chunk.outputs[-1].stop_reason,
|
||||
finish_reason=output.finish_reason,
|
||||
text=text,
|
||||
)
|
||||
yield OpenAICompatCompletionResponse(
|
||||
|
|
|
@ -36,7 +36,8 @@ def available_providers() -> List[ProviderSpec]:
|
|||
pip_packages=(
|
||||
META_REFERENCE_DEPS
|
||||
+ [
|
||||
"fbgemm-gpu==0.8.0",
|
||||
"fbgemm-gpu",
|
||||
"torchao==0.5.0",
|
||||
]
|
||||
),
|
||||
module="llama_stack.providers.impls.meta_reference.inference",
|
||||
|
|
25
llama_stack/providers/registry/scoring.py
Normal file
25
llama_stack/providers/registry/scoring.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
return [
|
||||
InlineProviderSpec(
|
||||
api=Api.scoring,
|
||||
provider_type="meta-reference",
|
||||
pip_packages=[],
|
||||
module="llama_stack.providers.impls.meta_reference.scoring",
|
||||
config_class="llama_stack.providers.impls.meta_reference.scoring.MetaReferenceScoringConfig",
|
||||
api_dependencies=[
|
||||
Api.datasetio,
|
||||
Api.datasets,
|
||||
],
|
||||
),
|
||||
]
|
6
llama_stack/providers/tests/datasetio/test_dataset.csv
Normal file
6
llama_stack/providers/tests/datasetio/test_dataset.csv
Normal file
|
@ -0,0 +1,6 @@
|
|||
input_query,generated_answer,expected_answer
|
||||
What is the capital of France?,London,Paris
|
||||
Who is the CEO of Meta?,Mark Zuckerberg,Mark Zuckerberg
|
||||
What is the largest planet in our solar system?,Jupiter,Jupiter
|
||||
What is the smallest country in the world?,China,Vatican City
|
||||
What is the currency of Japan?,Yen,Yen
|
|
|
@ -8,8 +8,13 @@ import os
|
|||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
import base64
|
||||
import mimetypes
|
||||
from pathlib import Path
|
||||
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
|
@ -41,14 +46,35 @@ async def datasetio_settings():
|
|||
}
|
||||
|
||||
|
||||
def data_url_from_file(file_path: str) -> str:
|
||||
if not os.path.exists(file_path):
|
||||
raise FileNotFoundError(f"File not found: {file_path}")
|
||||
|
||||
with open(file_path, "rb") as file:
|
||||
file_content = file.read()
|
||||
|
||||
base64_content = base64.b64encode(file_content).decode("utf-8")
|
||||
mime_type, _ = mimetypes.guess_type(file_path)
|
||||
|
||||
data_url = f"data:{mime_type};base64,{base64_content}"
|
||||
|
||||
return data_url
|
||||
|
||||
|
||||
async def register_dataset(datasets_impl: Datasets):
|
||||
test_file = Path(os.path.abspath(__file__)).parent / "test_dataset.csv"
|
||||
test_url = data_url_from_file(str(test_file))
|
||||
dataset = DatasetDefWithProvider(
|
||||
identifier="test_dataset",
|
||||
provider_id=os.environ["PROVIDER_ID"],
|
||||
url=URL(
|
||||
uri="https://openaipublic.blob.core.windows.net/simple-evals/mmlu.csv",
|
||||
uri=test_url,
|
||||
),
|
||||
columns_schema={},
|
||||
dataset_schema={
|
||||
"generated_answer": StringType(),
|
||||
"expected_answer": StringType(),
|
||||
"input_query": StringType(),
|
||||
},
|
||||
)
|
||||
await datasets_impl.register_dataset(dataset)
|
||||
|
||||
|
@ -100,10 +126,10 @@ async def test_get_rows_paginated(datasetio_settings):
|
|||
# iterate over all rows
|
||||
response = await datasetio_impl.get_rows_paginated(
|
||||
dataset_id="test_dataset",
|
||||
rows_in_page=10,
|
||||
rows_in_page=2,
|
||||
page_token=response.next_page_token,
|
||||
)
|
||||
|
||||
assert isinstance(response.rows, list)
|
||||
assert len(response.rows) == 10
|
||||
assert response.next_page_token == "13"
|
||||
assert len(response.rows) == 2
|
||||
assert response.next_page_token == "5"
|
||||
|
|
|
@ -137,6 +137,7 @@ async def test_completion(inference_settings):
|
|||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::ollama",
|
||||
"remote::tgi",
|
||||
):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
|
@ -171,25 +172,43 @@ async def test_completion(inference_settings):
|
|||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed(inference_settings):
|
||||
async def test_completions_structured_output(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_spec__.provider_type not in ("remote::ollama",):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
if provider.__provider_spec__.provider_type not in (
|
||||
"meta-reference",
|
||||
"remote::tgi",
|
||||
):
|
||||
pytest.skip(
|
||||
"Other inference providers don't support structured output in completions yet"
|
||||
)
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
contents=["Roses are red"],
|
||||
class Output(BaseModel):
|
||||
name: str
|
||||
year_born: str
|
||||
year_retired: str
|
||||
|
||||
user_input = "Michael Jordan was born in 1963. He played basketball for the Chicago Bulls. He retired in 2003."
|
||||
response = await inference_impl.completion(
|
||||
content=f"input: '{user_input}'. the schema for json: {Output.schema()}, the json is: ",
|
||||
stream=False,
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
response_format=JsonResponseFormat(
|
||||
schema=Output.model_json_schema(),
|
||||
),
|
||||
)
|
||||
assert isinstance(response, CompletionResponse)
|
||||
assert isinstance(response.content, str)
|
||||
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) > 0
|
||||
|
||||
answer = Output.parse_raw(response.content)
|
||||
assert answer.name == "Michael Jordan"
|
||||
assert answer.year_born == "1963"
|
||||
assert answer.year_retired == "2003"
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_chat_completion_non_streaming(inference_settings, sample_messages):
|
||||
|
@ -382,3 +401,23 @@ async def test_chat_completion_with_tool_calling_streaming(
|
|||
assert call.tool_name == "get_weather"
|
||||
assert "location" in call.arguments
|
||||
assert "San Francisco" in call.arguments["location"]
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_embed(inference_settings):
|
||||
inference_impl = inference_settings["impl"]
|
||||
params = inference_settings["common_params"]
|
||||
|
||||
provider = inference_impl.routing_table.get_provider_impl(params["model"])
|
||||
if provider.__provider_spec__.provider_type not in ("remote::ollama",):
|
||||
pytest.skip("Other inference providers don't support completion() yet")
|
||||
|
||||
response = await inference_impl.embeddings(
|
||||
contents=["Roses are red"],
|
||||
model=params["model"],
|
||||
sampling_params=SamplingParams(
|
||||
max_tokens=50,
|
||||
),
|
||||
)
|
||||
|
||||
assert isinstance(response, EmbeddingsResponse)
|
||||
assert len(response.embeddings) > 0
|
||||
|
|
5
llama_stack/providers/tests/scoring/__init__.py
Normal file
5
llama_stack/providers/tests/scoring/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -0,0 +1,9 @@
|
|||
providers:
|
||||
datasetio:
|
||||
- provider_id: test-meta
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
scoring:
|
||||
- provider_id: test-meta
|
||||
provider_type: meta-reference
|
||||
config: {}
|
69
llama_stack/providers/tests/scoring/test_scoring.py
Normal file
69
llama_stack/providers/tests/scoring/test_scoring.py
Normal file
|
@ -0,0 +1,69 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import pytest
|
||||
import pytest_asyncio
|
||||
|
||||
from llama_stack.apis.common.type_system import * # noqa: F403
|
||||
from llama_stack.apis.datasetio import * # noqa: F403
|
||||
from llama_stack.distribution.datatypes import * # noqa: F403
|
||||
|
||||
from llama_stack.providers.tests.datasetio.test_datasetio import register_dataset
|
||||
from llama_stack.providers.tests.resolver import resolve_impls_for_test
|
||||
|
||||
# How to run this test:
|
||||
#
|
||||
# 1. Ensure you have a conda with the right dependencies installed. This is a bit tricky
|
||||
# since it depends on the provider you are testing. On top of that you need
|
||||
# `pytest` and `pytest-asyncio` installed.
|
||||
#
|
||||
# 2. Copy and modify the provider_config_example.yaml depending on the provider you are testing.
|
||||
#
|
||||
# 3. Run:
|
||||
#
|
||||
# ```bash
|
||||
# PROVIDER_ID=<your_provider> \
|
||||
# PROVIDER_CONFIG=provider_config.yaml \
|
||||
# pytest -s llama_stack/providers/tests/scoring/test_scoring.py \
|
||||
# --tb=short --disable-warnings
|
||||
# ```
|
||||
|
||||
|
||||
@pytest_asyncio.fixture(scope="session")
|
||||
async def scoring_settings():
|
||||
impls = await resolve_impls_for_test(Api.scoring, deps=[Api.datasetio])
|
||||
return {
|
||||
"scoring_impl": impls[Api.scoring],
|
||||
"scoring_functions_impl": impls[Api.scoring_functions],
|
||||
"datasets_impl": impls[Api.datasets],
|
||||
}
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_functions_list(scoring_settings):
|
||||
scoring_functions_impl = scoring_settings["scoring_functions_impl"]
|
||||
scoring_functions = await scoring_functions_impl.list_scoring_functions()
|
||||
assert isinstance(scoring_functions, list)
|
||||
assert len(scoring_functions) > 0
|
||||
function_ids = [f.identifier for f in scoring_functions]
|
||||
assert "equality" in function_ids
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_scoring_score(scoring_settings):
|
||||
scoring_impl = scoring_settings["scoring_impl"]
|
||||
datasets_impl = scoring_settings["datasets_impl"]
|
||||
await register_dataset(datasets_impl)
|
||||
|
||||
response = await datasets_impl.list_datasets()
|
||||
assert len(response) == 1
|
||||
|
||||
response = await scoring_impl.score_batch(
|
||||
dataset_id=response[0].identifier,
|
||||
scoring_functions=["equality"],
|
||||
)
|
||||
|
||||
assert len(response.results) == 1
|
||||
assert "equality" in response.results
|
|
@ -29,9 +29,9 @@ class OpenAICompatCompletionResponse(BaseModel):
|
|||
choices: List[OpenAICompatCompletionChoice]
|
||||
|
||||
|
||||
def get_sampling_options(request: ChatCompletionRequest) -> dict:
|
||||
def get_sampling_options(params: SamplingParams) -> dict:
|
||||
options = {}
|
||||
if params := request.sampling_params:
|
||||
if params:
|
||||
for attr in {"temperature", "top_p", "top_k", "max_tokens"}:
|
||||
if getattr(params, attr):
|
||||
options[attr] = getattr(params, attr)
|
||||
|
@ -64,7 +64,18 @@ def process_completion_response(
|
|||
response: OpenAICompatCompletionResponse, formatter: ChatFormat
|
||||
) -> CompletionResponse:
|
||||
choice = response.choices[0]
|
||||
|
||||
# drop suffix <eot_id> if present and return stop reason as end of turn
|
||||
if choice.text.endswith("<|eot_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_turn,
|
||||
content=choice.text[: -len("<|eot_id|>")],
|
||||
)
|
||||
# drop suffix <eom_id> if present and return stop reason as end of message
|
||||
if choice.text.endswith("<|eom_id|>"):
|
||||
return CompletionResponse(
|
||||
stop_reason=StopReason.end_of_message,
|
||||
content=choice.text[: -len("<|eom_id|>")],
|
||||
)
|
||||
return CompletionResponse(
|
||||
stop_reason=get_stop_reason(choice.finish_reason),
|
||||
content=choice.text,
|
||||
|
@ -95,13 +106,6 @@ async def process_completion_stream_response(
|
|||
choice = chunk.choices[0]
|
||||
finish_reason = choice.finish_reason
|
||||
|
||||
if finish_reason:
|
||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
text = text_from_choice(choice)
|
||||
if text == "<|eot_id|>":
|
||||
stop_reason = StopReason.end_of_turn
|
||||
|
@ -115,6 +119,12 @@ async def process_completion_stream_response(
|
|||
delta=text,
|
||||
stop_reason=stop_reason,
|
||||
)
|
||||
if finish_reason:
|
||||
if finish_reason in ["stop", "eos", "eos_token"]:
|
||||
stop_reason = StopReason.end_of_turn
|
||||
elif finish_reason == "length":
|
||||
stop_reason = StopReason.out_of_tokens
|
||||
break
|
||||
|
||||
yield CompletionResponseStreamChunk(
|
||||
delta="",
|
||||
|
|
|
@ -31,6 +31,13 @@ def completion_request_to_prompt(
|
|||
return formatter.tokenizer.decode(model_input.tokens)
|
||||
|
||||
|
||||
def completion_request_to_prompt_model_input_info(
|
||||
request: CompletionRequest, formatter: ChatFormat
|
||||
) -> Tuple[str, int]:
|
||||
model_input = formatter.encode_content(request.content)
|
||||
return (formatter.tokenizer.decode(model_input.tokens), len(model_input.tokens))
|
||||
|
||||
|
||||
def chat_completion_request_to_prompt(
|
||||
request: ChatCompletionRequest, formatter: ChatFormat
|
||||
) -> str:
|
||||
|
|
|
@ -2,7 +2,7 @@ blobfile
|
|||
fire
|
||||
httpx
|
||||
huggingface-hub
|
||||
llama-models>=0.0.43
|
||||
llama-models>=0.0.45
|
||||
prompt-toolkit
|
||||
python-dotenv
|
||||
pydantic>=2
|
||||
|
|
2
setup.py
2
setup.py
|
@ -16,7 +16,7 @@ def read_requirements():
|
|||
|
||||
setup(
|
||||
name="llama_stack",
|
||||
version="0.0.43",
|
||||
version="0.0.45",
|
||||
author="Meta Llama",
|
||||
author_email="llama-oss@meta.com",
|
||||
description="Llama Stack",
|
||||
|
|
|
@ -13,7 +13,12 @@ apis:
|
|||
- inference
|
||||
- datasets
|
||||
- datasetio
|
||||
- scoring
|
||||
providers:
|
||||
scoring:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
config: {}
|
||||
datasetio:
|
||||
- provider_id: meta0
|
||||
provider_type: meta-reference
|
||||
|
|
Loading…
Add table
Add a link
Reference in a new issue