mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 17:29:01 +00:00
Merge branch 'main' into docs-4
This commit is contained in:
commit
3366937765
168 changed files with 14921 additions and 1625 deletions
|
@ -320,7 +320,7 @@ jobs:
|
||||||
- name: "PR - Update comment"
|
- name: "PR - Update comment"
|
||||||
id: pr_update_comment
|
id: pr_update_comment
|
||||||
if: github.event_name == 'pull_request_target'
|
if: github.event_name == 'pull_request_target'
|
||||||
uses: thollander/actions-comment-pull-request@65f9e5c9a1f2cd378bd74b2e057c9736982a8e74 # v3.0.1
|
uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b # v3.0.1
|
||||||
with:
|
with:
|
||||||
filePath: test-summary.md
|
filePath: test-summary.md
|
||||||
|
|
||||||
|
|
93
.github/workflows/test-external-providers.yml
vendored
Normal file
93
.github/workflows/test-external-providers.yml
vendored
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
name: Test External Providers
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-external-providers:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install Ollama
|
||||||
|
run: |
|
||||||
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
|
||||||
|
- name: Pull Ollama image
|
||||||
|
run: |
|
||||||
|
ollama pull llama3.2:3b-instruct-fp16
|
||||||
|
|
||||||
|
- name: Start Ollama in background
|
||||||
|
run: |
|
||||||
|
nohup ollama run llama3.2:3b-instruct-fp16 --keepalive=30m > ollama.log 2>&1 &
|
||||||
|
|
||||||
|
- name: Set Up Environment and Install Dependencies
|
||||||
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
- name: Install Ollama custom provider
|
||||||
|
run: |
|
||||||
|
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
||||||
|
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
||||||
|
uv pip install tests/external-provider/llama-stack-provider-ollama
|
||||||
|
|
||||||
|
- name: Create provider configuration
|
||||||
|
run: |
|
||||||
|
mkdir -p /tmp/providers.d/remote/inference
|
||||||
|
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
||||||
|
|
||||||
|
- name: Wait for Ollama to start
|
||||||
|
run: |
|
||||||
|
echo "Waiting for Ollama..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
||||||
|
echo "Ollama is running!"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "Ollama failed to start"
|
||||||
|
ollama ps
|
||||||
|
ollama.log
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
- name: Start Llama Stack server in background
|
||||||
|
env:
|
||||||
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
|
- name: Wait for Llama Stack server to be ready
|
||||||
|
run: |
|
||||||
|
echo "Waiting for Llama Stack server..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
|
||||||
|
echo "Llama Stack server is up!"
|
||||||
|
if grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||||
|
echo "Llama Stack server is using custom Ollama provider"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
echo "Llama Stack server is not using custom Ollama provider"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "Llama Stack server failed to start"
|
||||||
|
cat server.log
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
- name: run inference tests
|
||||||
|
run: |
|
||||||
|
uv run pytest -v tests/integration/inference/test_text_inference.py --stack-config="http://localhost:8321" --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2
|
37
CHANGELOG.md
37
CHANGELOG.md
|
@ -1,5 +1,42 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
# v0.2.1
|
||||||
|
Published on: 2025-04-05T23:13:00Z
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.0
|
||||||
|
Published on: 2025-04-05T19:04:29Z
|
||||||
|
|
||||||
|
## Llama 4 Support
|
||||||
|
|
||||||
|
Checkout more at https://www.llama.com
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.1.9
|
||||||
|
Published on: 2025-03-29T00:52:23Z
|
||||||
|
|
||||||
|
### Build and Test Agents
|
||||||
|
* Agents: Entire document context with attachments
|
||||||
|
* RAG: Documentation with sqlite-vec faiss comparison
|
||||||
|
* Getting started: Fixes to getting started notebook.
|
||||||
|
|
||||||
|
### Agent Evals and Model Customization
|
||||||
|
* (**New**) Post-training: Add nemo customizer
|
||||||
|
|
||||||
|
### Better Engineering
|
||||||
|
* Moved sqlite-vec to non-blocking calls
|
||||||
|
* Don't return a payload on file delete
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# v0.1.8
|
# v0.1.8
|
||||||
Published on: 2025-03-24T01:28:50Z
|
Published on: 2025-03-24T01:28:50Z
|
||||||
|
|
||||||
|
|
|
@ -3,7 +3,7 @@
|
||||||
[](https://pypi.org/project/llama_stack/)
|
[](https://pypi.org/project/llama_stack/)
|
||||||
[](https://pypi.org/project/llama-stack/)
|
[](https://pypi.org/project/llama-stack/)
|
||||||
[](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
|
[](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
|
||||||
[](https://discord.gg/llama-stack)
|
[](https://discord.gg/llama-stack)
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
||||||
|
|
||||||
|
@ -11,7 +11,7 @@
|
||||||
|
|
||||||
|
|
||||||
### ✨🎉 Llama 4 Support 🎉✨
|
### ✨🎉 Llama 4 Support 🎉✨
|
||||||
We release [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
||||||
|
|
||||||
You can now run Llama 4 models on Llama Stack.
|
You can now run Llama 4 models on Llama Stack.
|
||||||
|
|
||||||
|
|
876
docs/getting_started_llama4.ipynb
Normal file
876
docs/getting_started_llama4.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -17,7 +17,7 @@ client = LlamaStackAsLibraryClient(
|
||||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||||
)
|
)
|
||||||
await client.initialize()
|
client.initialize()
|
||||||
```
|
```
|
||||||
|
|
||||||
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
||||||
|
|
|
@ -46,6 +46,8 @@ The following models are available by default:
|
||||||
- `accounts/fireworks/models/llama-v3p3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
- `accounts/fireworks/models/llama-v3p3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `accounts/fireworks/models/llama-guard-3-8b (aliases: meta-llama/Llama-Guard-3-8B)`
|
- `accounts/fireworks/models/llama-guard-3-8b (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
- `accounts/fireworks/models/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
- `accounts/fireworks/models/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
|
- `accounts/fireworks/models/llama4-scout-instruct-basic (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
|
- `accounts/fireworks/models/llama4-maverick-instruct-basic (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||||
- `nomic-ai/nomic-embed-text-v1.5 `
|
- `nomic-ai/nomic-embed-text-v1.5 `
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,8 @@ The following models are available by default:
|
||||||
- `groq/llama3-70b-8192 (aliases: meta-llama/Llama-3-70B-Instruct)`
|
- `groq/llama3-70b-8192 (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||||
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
- `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
|
- `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -41,6 +41,80 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Setting up vLLM server
|
## Setting up vLLM server
|
||||||
|
|
||||||
|
Both AMD and NVIDIA GPUs can serve as accelerators for the vLLM server, which acts as both the LLM inference provider and the safety provider.
|
||||||
|
|
||||||
|
### Setting up vLLM server on AMD GPU
|
||||||
|
|
||||||
|
AMD provides two main vLLM container options:
|
||||||
|
- rocm/vllm: Production-ready container
|
||||||
|
- rocm/vllm-dev: Development container with the latest vLLM features
|
||||||
|
|
||||||
|
Please check the [Blog about ROCm vLLM Usage](https://rocm.blogs.amd.com/software-tools-optimization/vllm-container/README.html) to get more details.
|
||||||
|
|
||||||
|
Here is a sample script to start a ROCm vLLM server locally via Docker:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export INFERENCE_PORT=8000
|
||||||
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--ipc=host \
|
||||||
|
--privileged \
|
||||||
|
--shm-size 16g \
|
||||||
|
--device=/dev/kfd \
|
||||||
|
--device=/dev/dri \
|
||||||
|
--group-add video \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--cap-add=CAP_SYS_ADMIN \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--security-opt apparmor=unconfined \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||||
|
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
$VLLM_DIMG \
|
||||||
|
python -m vllm.entrypoints.openai.api_server \
|
||||||
|
--model $INFERENCE_MODEL \
|
||||||
|
--port $INFERENCE_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
|
||||||
|
|
||||||
|
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SAFETY_PORT=8081
|
||||||
|
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
|
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--ipc=host \
|
||||||
|
--privileged \
|
||||||
|
--shm-size 16g \
|
||||||
|
--device=/dev/kfd \
|
||||||
|
--device=/dev/dri \
|
||||||
|
--group-add video \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--cap-add=CAP_SYS_ADMIN \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--security-opt apparmor=unconfined \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||||
|
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
$VLLM_DIMG \
|
||||||
|
python -m vllm.entrypoints.openai.api_server \
|
||||||
|
--model $SAFETY_MODEL \
|
||||||
|
--port $SAFETY_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
### Setting up vLLM server on NVIDIA GPU
|
||||||
|
|
||||||
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
|
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -43,6 +43,7 @@ The following models are available by default:
|
||||||
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
|
- `Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -48,6 +48,8 @@ The following models are available by default:
|
||||||
- `meta-llama/Llama-Guard-3-11B-Vision-Turbo (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
- `meta-llama/Llama-Guard-3-11B-Vision-Turbo (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
- `togethercomputer/m2-bert-80M-8k-retrieval `
|
- `togethercomputer/m2-bert-80M-8k-retrieval `
|
||||||
- `togethercomputer/m2-bert-80M-32k-retrieval `
|
- `togethercomputer/m2-bert-80M-32k-retrieval `
|
||||||
|
- `meta-llama/Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct, together/meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
|
- `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct, together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -12,17 +12,21 @@ as the inference [provider](../providers/index.md#inference) for a Llama Model.
|
||||||
|
|
||||||
## Step 1: Installation and Setup
|
## Step 1: Installation and Setup
|
||||||
|
|
||||||
### i. Install and Start Ollama for Inference
|
### i. Install and Setup Ollama for Inference
|
||||||
|
|
||||||
Install Ollama by following the instructions on the [Ollama website](https://ollama.com/download).
|
Install Ollama by following the instructions on the [Ollama website](https://ollama.com/download).
|
||||||
|
|
||||||
To start Ollama run:
|
Then download a Llama model with Ollama
|
||||||
|
```bash
|
||||||
|
ollama pull llama3.2:3b-instruct-fp16
|
||||||
|
```
|
||||||
|
This will instruct the Ollama service to download the Llama 3.2 3B Instruct model, which we'll use in the rest of this guide.
|
||||||
|
|
||||||
|
Then to start Ollama run:
|
||||||
```bash
|
```bash
|
||||||
ollama run llama3.2:3b --keepalive 60m
|
ollama run llama3.2:3b --keepalive 60m
|
||||||
```
|
```
|
||||||
|
|
||||||
By default, Ollama keeps the model loaded in memory for 5 minutes which can be too short. We set the `--keepalive` flag to 60 minutes to ensure the model remains loaded for sometime.
|
|
||||||
|
|
||||||
### ii. Install `uv` to Manage your Python packages
|
### ii. Install `uv` to Manage your Python packages
|
||||||
|
|
||||||
Install [uv](https://docs.astral.sh/uv/) to setup your virtual environment
|
Install [uv](https://docs.astral.sh/uv/) to setup your virtual environment
|
||||||
|
|
|
@ -1,3 +1,8 @@
|
||||||
|
```{admonition} Llama 4 is here!
|
||||||
|
:class: tip
|
||||||
|
|
||||||
|
Check out [Getting Started with Llama 4](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/getting_started_llama4.ipynb)
|
||||||
|
```
|
||||||
```{admonition} News
|
```{admonition} News
|
||||||
:class: tip
|
:class: tip
|
||||||
|
|
||||||
|
|
234
docs/source/providers/external.md
Normal file
234
docs/source/providers/external.md
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
# External Providers
|
||||||
|
|
||||||
|
Llama Stack supports external providers that live outside of the main codebase. This allows you to:
|
||||||
|
- Create and maintain your own providers independently
|
||||||
|
- Share providers with others without contributing to the main codebase
|
||||||
|
- Keep provider-specific code separate from the core Llama Stack code
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
external_providers_dir: /etc/llama-stack/providers.d/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
The external providers directory should follow this structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
providers.d/
|
||||||
|
remote/
|
||||||
|
inference/
|
||||||
|
custom_ollama.yaml
|
||||||
|
vllm.yaml
|
||||||
|
vector_io/
|
||||||
|
qdrant.yaml
|
||||||
|
safety/
|
||||||
|
llama-guard.yaml
|
||||||
|
inline/
|
||||||
|
inference/
|
||||||
|
custom_ollama.yaml
|
||||||
|
vllm.yaml
|
||||||
|
vector_io/
|
||||||
|
qdrant.yaml
|
||||||
|
safety/
|
||||||
|
llama-guard.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Each YAML file in these directories defines a provider specification for that particular API.
|
||||||
|
|
||||||
|
## Provider Types
|
||||||
|
|
||||||
|
Llama Stack supports two types of external providers:
|
||||||
|
|
||||||
|
1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs)
|
||||||
|
2. **Inline Providers**: Providers that run locally within the Llama Stack process
|
||||||
|
|
||||||
|
## Known External Providers
|
||||||
|
|
||||||
|
Here's a list of known external providers that you can use with Llama Stack:
|
||||||
|
|
||||||
|
| Type | Name | Description | Repository |
|
||||||
|
|------|------|-------------|------------|
|
||||||
|
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||||
|
|
||||||
|
### Remote Provider Specification
|
||||||
|
|
||||||
|
Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
adapter:
|
||||||
|
adapter_type: custom_ollama
|
||||||
|
pip_packages:
|
||||||
|
- ollama
|
||||||
|
- aiohttp
|
||||||
|
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
|
||||||
|
module: llama_stack_ollama_provider
|
||||||
|
api_dependencies: []
|
||||||
|
optional_api_dependencies: []
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Adapter Configuration
|
||||||
|
|
||||||
|
The `adapter` section defines how to load and configure the provider:
|
||||||
|
|
||||||
|
- `adapter_type`: A unique identifier for this adapter
|
||||||
|
- `pip_packages`: List of Python packages required by the provider
|
||||||
|
- `config_class`: The full path to the configuration class
|
||||||
|
- `module`: The Python module containing the provider implementation
|
||||||
|
|
||||||
|
### Inline Provider Specification
|
||||||
|
|
||||||
|
Inline providers run locally within the Llama Stack process. Here's an example for a custom vector store provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
module: llama_stack_vector_provider
|
||||||
|
config_class: llama_stack_vector_provider.config.VectorStoreConfig
|
||||||
|
pip_packages:
|
||||||
|
- faiss-cpu
|
||||||
|
- numpy
|
||||||
|
api_dependencies:
|
||||||
|
- inference
|
||||||
|
optional_api_dependencies:
|
||||||
|
- vector_io
|
||||||
|
provider_data_validator: llama_stack_vector_provider.validator.VectorStoreValidator
|
||||||
|
container_image: custom-vector-store:latest # optional
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Inline Provider Fields
|
||||||
|
|
||||||
|
- `module`: The Python module containing the provider implementation
|
||||||
|
- `config_class`: The full path to the configuration class
|
||||||
|
- `pip_packages`: List of Python packages required by the provider
|
||||||
|
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
|
||||||
|
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
|
||||||
|
- `provider_data_validator`: Optional validator for provider data
|
||||||
|
- `container_image`: Optional container image to use instead of pip packages
|
||||||
|
|
||||||
|
## Required Implementation
|
||||||
|
|
||||||
|
### Remote Providers
|
||||||
|
|
||||||
|
Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments:
|
||||||
|
1. `config`: An instance of the provider's config class
|
||||||
|
2. `deps`: A dictionary of API dependencies
|
||||||
|
|
||||||
|
This function must return an instance of the provider's adapter class that implements the required protocol for the API.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
async def get_adapter_impl(
|
||||||
|
config: OllamaImplConfig, deps: Dict[Api, Any]
|
||||||
|
) -> OllamaInferenceAdapter:
|
||||||
|
return OllamaInferenceAdapter(config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Inline Providers
|
||||||
|
|
||||||
|
Inline providers must expose a `get_provider_impl()` function in their module that takes two arguments:
|
||||||
|
1. `config`: An instance of the provider's config class
|
||||||
|
2. `deps`: A dictionary of API dependencies
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: VectorStoreConfig, deps: Dict[Api, Any]
|
||||||
|
) -> VectorStoreImpl:
|
||||||
|
impl = VectorStoreImpl(config, deps[Api.inference])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
The provider package must be installed on the system. For example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ uv pip show llama-stack-ollama-provider
|
||||||
|
Name: llama-stack-ollama-provider
|
||||||
|
Version: 0.1.0
|
||||||
|
Location: /path/to/venv/lib/python3.10/site-packages
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: Custom Ollama Provider
|
||||||
|
|
||||||
|
Here's a complete example of creating and using a custom Ollama provider:
|
||||||
|
|
||||||
|
1. First, create the provider package:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir -p llama-stack-provider-ollama
|
||||||
|
cd llama-stack-provider-ollama
|
||||||
|
git init
|
||||||
|
uv init
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Edit `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project]
|
||||||
|
name = "llama-stack-provider-ollama"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Ollama provider for Llama Stack"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Create the provider specification:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
|
||||||
|
adapter:
|
||||||
|
adapter_type: custom_ollama
|
||||||
|
pip_packages: ["ollama", "aiohttp"]
|
||||||
|
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||||
|
module: llama_stack_provider_ollama
|
||||||
|
api_dependencies: []
|
||||||
|
optional_api_dependencies: []
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Install the provider:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Configure Llama Stack to use external providers:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
external_providers_dir: /etc/llama-stack/providers.d/
|
||||||
|
```
|
||||||
|
|
||||||
|
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable.
|
||||||
|
|
||||||
|
2. **Version Management**: Keep your provider package versioned and compatible with the Llama Stack version you're using.
|
||||||
|
|
||||||
|
3. **Dependencies**: Only include the minimum required dependencies in your provider package.
|
||||||
|
|
||||||
|
4. **Documentation**: Include clear documentation in your provider package about:
|
||||||
|
- Installation requirements
|
||||||
|
- Configuration options
|
||||||
|
- Usage examples
|
||||||
|
- Any limitations or known issues
|
||||||
|
|
||||||
|
5. **Testing**: Include tests in your provider package to ensure it works correctly with Llama Stack.
|
||||||
|
You can refer to the [integration tests
|
||||||
|
guide](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more
|
||||||
|
information. Execute the test for the Provider type you are developing.
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
If your external provider isn't being loaded:
|
||||||
|
|
||||||
|
1. Check that the `external_providers_dir` path is correct and accessible.
|
||||||
|
2. Verify that the YAML files are properly formatted.
|
||||||
|
3. Ensure all required Python packages are installed.
|
||||||
|
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more
|
||||||
|
information using `LLAMA_STACK_LOGGING=all=debug`.
|
||||||
|
5. Verify that the provider package is installed in your Python environment.
|
|
@ -11,6 +11,10 @@ Providers come in two flavors:
|
||||||
|
|
||||||
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||||
|
|
||||||
|
## External Providers
|
||||||
|
|
||||||
|
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently. See the [External Providers Guide](external) for details.
|
||||||
|
|
||||||
## Agents
|
## Agents
|
||||||
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
||||||
|
|
||||||
|
@ -50,6 +54,7 @@ The following providers (i.e., databases) are available for Vector IO:
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
|
external
|
||||||
vector_io/faiss
|
vector_io/faiss
|
||||||
vector_io/sqlite-vec
|
vector_io/sqlite-vec
|
||||||
vector_io/chromadb
|
vector_io/chromadb
|
||||||
|
|
|
@ -25,15 +25,64 @@ from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
SamplingParams,
|
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
ToolParamDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
register_schema(ToolCall)
|
||||||
|
register_schema(ToolParamDefinition)
|
||||||
|
register_schema(ToolDefinition)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GreedySamplingStrategy(BaseModel):
|
||||||
|
type: Literal["greedy"] = "greedy"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopPSamplingStrategy(BaseModel):
|
||||||
|
type: Literal["top_p"] = "top_p"
|
||||||
|
temperature: Optional[float] = Field(..., gt=0.0)
|
||||||
|
top_p: Optional[float] = 0.95
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopKSamplingStrategy(BaseModel):
|
||||||
|
type: Literal["top_k"] = "top_k"
|
||||||
|
top_k: int = Field(..., ge=1)
|
||||||
|
|
||||||
|
|
||||||
|
SamplingStrategy = Annotated[
|
||||||
|
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SamplingParams(BaseModel):
|
||||||
|
"""Sampling parameters.
|
||||||
|
|
||||||
|
:param strategy: The sampling strategy.
|
||||||
|
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
||||||
|
your prompt plus max_tokens cannot exceed the model's context length.
|
||||||
|
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
||||||
|
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||||
|
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
The returned text will not contain the stop sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = 0
|
||||||
|
repetition_penalty: Optional[float] = 1.0
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -48,18 +97,18 @@ class QuantizationType(Enum):
|
||||||
"""Type of model quantization to run inference with.
|
"""Type of model quantization to run inference with.
|
||||||
|
|
||||||
:cvar bf16: BFloat16 typically this means _no_ quantization
|
:cvar bf16: BFloat16 typically this means _no_ quantization
|
||||||
:cvar fp8: 8-bit floating point quantization
|
:cvar fp8_mixed: 8-bit floating point quantization with mixed precision
|
||||||
:cvar int4: 4-bit integer quantization
|
:cvar int4_mixed: 4-bit integer quantization with mixed precision
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bf16 = "bf16"
|
bf16 = "bf16"
|
||||||
fp8 = "fp8"
|
fp8_mixed = "fp8_mixed"
|
||||||
int4 = "int4"
|
int4_mixed = "int4_mixed"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Fp8QuantizationConfig(BaseModel):
|
class Fp8QuantizationConfig(BaseModel):
|
||||||
type: Literal["fp8"] = "fp8"
|
type: Literal["fp8_mixed"] = "fp8_mixed"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel):
|
||||||
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
|
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["int4"] = "int4"
|
type: Literal["int4_mixed"] = "int4_mixed"
|
||||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -29,8 +29,8 @@ from rich.progress import (
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.models.llama.datatypes import Model
|
|
||||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||||
|
from llama_stack.models.llama.sku_types import Model
|
||||||
|
|
||||||
|
|
||||||
class Download(Subcommand):
|
class Download(Subcommand):
|
||||||
|
|
|
@ -63,17 +63,6 @@ class ModelDescribe(Subcommand):
|
||||||
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||||
]
|
]
|
||||||
|
|
||||||
if model.recommended_sampling_params is not None:
|
|
||||||
sampling_params = model.recommended_sampling_params.model_dump()
|
|
||||||
for k in ("max_tokens", "repetition_penalty"):
|
|
||||||
del sampling_params[k]
|
|
||||||
rows.append(
|
|
||||||
(
|
|
||||||
"Recommended sampling params",
|
|
||||||
json.dumps(sampling_params, indent=4),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print_table(
|
print_table(
|
||||||
rows,
|
rows,
|
||||||
headers,
|
headers,
|
||||||
|
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||||
|
|
||||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
|
|
||||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||||
|
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardModel(BaseModel):
|
class PromptGuardModel(BaseModel):
|
||||||
|
@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
|
||||||
is_instruct_model: bool = False
|
is_instruct_model: bool = False
|
||||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
||||||
recommended_sampling_params: Optional[SamplingParams] = None
|
|
||||||
|
|
||||||
def descriptor(self) -> str:
|
def descriptor(self) -> str:
|
||||||
return self.model_id
|
return self.model_id
|
||||||
|
|
|
@ -312,6 +312,11 @@ a default SQLite store will be used.""",
|
||||||
description="Configuration for the HTTP(S) server",
|
description="Configuration for the HTTP(S) server",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
external_providers_dir: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
|
|
|
@ -4,12 +4,25 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Dict, List
|
import os
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import yaml
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import (
|
||||||
|
AdapterSpec,
|
||||||
|
Api,
|
||||||
|
InlineProviderSpec,
|
||||||
|
ProviderSpec,
|
||||||
|
remote_provider_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> List[Api]:
|
||||||
|
@ -59,11 +72,116 @@ def providable_apis() -> List[Api]:
|
||||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||||
|
|
||||||
|
|
||||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
|
||||||
ret = {}
|
adapter = AdapterSpec(**spec_data["adapter"])
|
||||||
|
spec = remote_provider_spec(
|
||||||
|
api=api,
|
||||||
|
adapter=adapter,
|
||||||
|
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||||
|
)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||||
|
spec = InlineProviderSpec(
|
||||||
|
api=api,
|
||||||
|
provider_type=f"inline::{provider_name}",
|
||||||
|
pip_packages=spec_data.get("pip_packages", []),
|
||||||
|
module=spec_data["module"],
|
||||||
|
config_class=spec_data["config_class"],
|
||||||
|
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||||
|
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
||||||
|
provider_data_validator=spec_data.get("provider_data_validator"),
|
||||||
|
container_image=spec_data.get("container_image"),
|
||||||
|
)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
|
"""Get the provider registry, optionally including external providers.
|
||||||
|
|
||||||
|
This function loads both built-in providers and external providers from YAML files.
|
||||||
|
External providers are loaded from a directory structure like:
|
||||||
|
|
||||||
|
providers.d/
|
||||||
|
remote/
|
||||||
|
inference/
|
||||||
|
custom_ollama.yaml
|
||||||
|
vllm.yaml
|
||||||
|
vector_io/
|
||||||
|
qdrant.yaml
|
||||||
|
safety/
|
||||||
|
llama-guard.yaml
|
||||||
|
inline/
|
||||||
|
inference/
|
||||||
|
custom_ollama.yaml
|
||||||
|
vllm.yaml
|
||||||
|
vector_io/
|
||||||
|
qdrant.yaml
|
||||||
|
safety/
|
||||||
|
llama-guard.yaml
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional StackRunConfig containing the external providers directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping APIs to their available providers
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the external providers directory doesn't exist
|
||||||
|
ValueError: If any provider spec is invalid
|
||||||
|
"""
|
||||||
|
|
||||||
|
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
logger.debug(f"Importing module {name}")
|
||||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
try:
|
||||||
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
|
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Failed to import module {name}: {e}")
|
||||||
|
|
||||||
|
if config and config.external_providers_dir:
|
||||||
|
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||||
|
if not os.path.exists(external_providers_dir):
|
||||||
|
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||||
|
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||||
|
|
||||||
|
for api in providable_apis():
|
||||||
|
api_name = api.name.lower()
|
||||||
|
|
||||||
|
# Process both remote and inline providers
|
||||||
|
for provider_type in ["remote", "inline"]:
|
||||||
|
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||||
|
if not os.path.exists(api_dir):
|
||||||
|
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Look for provider spec files in the API directory
|
||||||
|
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||||
|
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||||
|
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(spec_path) as f:
|
||||||
|
spec_data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if provider_type == "remote":
|
||||||
|
spec = _load_remote_provider_spec(spec_data, api)
|
||||||
|
provider_type_key = f"remote::{provider_name}"
|
||||||
|
else:
|
||||||
|
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||||
|
provider_type_key = f"inline::{provider_name}"
|
||||||
|
|
||||||
|
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||||
|
if provider_type_key in ret[api]:
|
||||||
|
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||||
|
ret[api][provider_type_key] = spec
|
||||||
|
except yaml.YAMLError as yaml_err:
|
||||||
|
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||||
|
raise yaml_err
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||||
|
raise e
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -351,6 +351,7 @@ async def instantiate_provider(
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
|
|
||||||
|
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
args = []
|
args = []
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
|
|
|
@ -608,8 +608,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
if tool_group is None:
|
if tool_group is None:
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
tools = (await self.list_tools(toolgroup_id)).data
|
tools = await self.list_tools(toolgroup_id)
|
||||||
for tool in tools:
|
for tool in getattr(tools, "data", []):
|
||||||
await self.unregister_object(tool)
|
await self.unregister_object(tool)
|
||||||
await self.unregister_object(tool_group)
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
|
|
|
@ -218,7 +218,7 @@ async def construct_stack(
|
||||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
|
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||||
await register_resources(run_config, impls)
|
await register_resources(run_config, impls)
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# More info on playground configuration can be found here:
|
# More info on playground configuration can be found here:
|
||||||
# https://llama-stack.readthedocs.io/en/latest/playground
|
# https://llama-stack.readthedocs.io/en/latest/playground
|
||||||
|
|
||||||
FROM python:3.9-slim
|
FROM python:3.12-slim
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY . /app/
|
COPY . /app/
|
||||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||||
|
|
|
@ -24,6 +24,7 @@ def main():
|
||||||
# Playground pages
|
# Playground pages
|
||||||
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||||
|
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
|
||||||
|
|
||||||
# Distribution pages
|
# Distribution pages
|
||||||
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
||||||
|
@ -39,6 +40,7 @@ def main():
|
||||||
"Playground": [
|
"Playground": [
|
||||||
chat_page,
|
chat_page,
|
||||||
rag_page,
|
rag_page,
|
||||||
|
tool_page,
|
||||||
application_evaluation_page,
|
application_evaluation_page,
|
||||||
native_evaluation_page,
|
native_evaluation_page,
|
||||||
],
|
],
|
||||||
|
|
|
@ -4,6 +4,8 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||||
|
|
||||||
|
@ -102,8 +104,8 @@ def rag_chat_page():
|
||||||
|
|
||||||
# Add clear chat button to sidebar
|
# Add clear chat button to sidebar
|
||||||
if st.button("Clear Chat", use_container_width=True):
|
if st.button("Clear Chat", use_container_width=True):
|
||||||
st.session_state.messages = []
|
st.session_state.clear()
|
||||||
st.rerun()
|
st.cache_resource.clear()
|
||||||
|
|
||||||
# Chat Interface
|
# Chat Interface
|
||||||
if "messages" not in st.session_state:
|
if "messages" not in st.session_state:
|
||||||
|
@ -123,23 +125,31 @@ def rag_chat_page():
|
||||||
else:
|
else:
|
||||||
strategy = {"type": "greedy"}
|
strategy = {"type": "greedy"}
|
||||||
|
|
||||||
agent = Agent(
|
@st.cache_resource
|
||||||
llama_stack_api.client,
|
def create_agent():
|
||||||
model=selected_model,
|
return Agent(
|
||||||
instructions=system_prompt,
|
llama_stack_api.client,
|
||||||
sampling_params={
|
model=selected_model,
|
||||||
"strategy": strategy,
|
instructions=system_prompt,
|
||||||
},
|
sampling_params={
|
||||||
tools=[
|
"strategy": strategy,
|
||||||
dict(
|
},
|
||||||
name="builtin::rag/knowledge_search",
|
tools=[
|
||||||
args={
|
dict(
|
||||||
"vector_db_ids": list(selected_vector_dbs),
|
name="builtin::rag/knowledge_search",
|
||||||
},
|
args={
|
||||||
)
|
"vector_db_ids": list(selected_vector_dbs),
|
||||||
],
|
},
|
||||||
)
|
)
|
||||||
session_id = agent.create_session("rag-session")
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = create_agent()
|
||||||
|
|
||||||
|
if "agent_session_id" not in st.session_state:
|
||||||
|
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
|
||||||
|
|
||||||
|
session_id = st.session_state["agent_session_id"]
|
||||||
|
|
||||||
# Chat input
|
# Chat input
|
||||||
if prompt := st.chat_input("Ask a question about your documents"):
|
if prompt := st.chat_input("Ask a question about your documents"):
|
||||||
|
|
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from llama_stack_client import Agent
|
||||||
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
|
def tool_chat_page():
|
||||||
|
st.title("🛠 Tools")
|
||||||
|
|
||||||
|
client = llama_stack_api.client
|
||||||
|
models = client.models.list()
|
||||||
|
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
|
||||||
|
|
||||||
|
tool_groups = client.toolgroups.list()
|
||||||
|
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||||
|
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||||
|
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||||
|
|
||||||
|
def reset_agent():
|
||||||
|
st.session_state.clear()
|
||||||
|
st.cache_resource.clear()
|
||||||
|
|
||||||
|
with st.sidebar:
|
||||||
|
st.subheader("Model")
|
||||||
|
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
|
||||||
|
|
||||||
|
st.subheader("Builtin Tools")
|
||||||
|
toolgroup_selection = st.pills(
|
||||||
|
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
st.subheader("MCP Servers")
|
||||||
|
mcp_selection = st.pills(
|
||||||
|
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
toolgroup_selection.extend(mcp_selection)
|
||||||
|
|
||||||
|
active_tool_list = []
|
||||||
|
for toolgroup_id in toolgroup_selection:
|
||||||
|
active_tool_list.extend(
|
||||||
|
[
|
||||||
|
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
||||||
|
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
||||||
|
st.json(active_tool_list)
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def create_agent():
|
||||||
|
return Agent(
|
||||||
|
client,
|
||||||
|
model=model,
|
||||||
|
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||||
|
tools=toolgroup_selection,
|
||||||
|
sampling_params={
|
||||||
|
"strategy": {"type": "greedy"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = create_agent()
|
||||||
|
|
||||||
|
if "agent_session_id" not in st.session_state:
|
||||||
|
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
|
||||||
|
|
||||||
|
session_id = st.session_state["agent_session_id"]
|
||||||
|
|
||||||
|
if "messages" not in st.session_state:
|
||||||
|
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||||
|
|
||||||
|
for msg in st.session_state.messages:
|
||||||
|
with st.chat_message(msg["role"]):
|
||||||
|
st.markdown(msg["content"])
|
||||||
|
|
||||||
|
if prompt := st.chat_input(placeholder=""):
|
||||||
|
with st.chat_message("user"):
|
||||||
|
st.markdown(prompt)
|
||||||
|
|
||||||
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
turn_response = agent.create_turn(
|
||||||
|
session_id=session_id,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def response_generator(turn_response):
|
||||||
|
for response in turn_response:
|
||||||
|
if hasattr(response.event, "payload"):
|
||||||
|
print(response.event.payload)
|
||||||
|
if response.event.payload.event_type == "step_progress":
|
||||||
|
if hasattr(response.event.payload.delta, "text"):
|
||||||
|
yield response.event.payload.delta.text
|
||||||
|
if response.event.payload.event_type == "step_complete":
|
||||||
|
if response.event.payload.step_details.step_type == "tool_execution":
|
||||||
|
yield " 🛠 "
|
||||||
|
else:
|
||||||
|
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||||
|
|
||||||
|
with st.chat_message("assistant"):
|
||||||
|
response = st.write_stream(response_generator(turn_response))
|
||||||
|
|
||||||
|
st.session_state.messages.append({"role": "assistant", "content": response})
|
||||||
|
|
||||||
|
|
||||||
|
tool_chat_page()
|
|
@ -2,3 +2,4 @@ streamlit
|
||||||
pandas
|
pandas
|
||||||
llama-stack-client>=0.0.55
|
llama-stack-client>=0.0.55
|
||||||
streamlit-option-menu
|
streamlit-option-menu
|
||||||
|
llama-stack>=0.1.9
|
||||||
|
|
|
@ -29,6 +29,11 @@ def preserve_contexts_async_generator(
|
||||||
context_var.set(initial_context_values[context_var.name])
|
context_var.set(initial_context_values[context_var.name])
|
||||||
|
|
||||||
item = await gen.__anext__()
|
item = await gen.__anext__()
|
||||||
|
|
||||||
|
# Update our tracked values with any changes made during this iteration
|
||||||
|
for context_var in context_vars:
|
||||||
|
initial_context_values[context_var.name] = context_var.get()
|
||||||
|
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
|
|
164
llama_stack/models/llama/checkpoint.py
Normal file
164
llama_stack/models/llama/checkpoint.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
||||||
|
|
||||||
|
|
||||||
|
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
|
||||||
|
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
||||||
|
if new_mp_size % old_mp_size == 0:
|
||||||
|
# Read old MP shard and split it into smaller ones
|
||||||
|
return [new_mp_rank * old_mp_size // new_mp_size]
|
||||||
|
elif old_mp_size % new_mp_size == 0:
|
||||||
|
# Merge old MP shards into a single one
|
||||||
|
mp_factor = old_mp_size // new_mp_size
|
||||||
|
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Either old MP size or new MP size should be a multiple of the other: "
|
||||||
|
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_reshard_state_dict(
|
||||||
|
ckpt_paths: List[Path],
|
||||||
|
n_kv_heads: int,
|
||||||
|
moe_num_experts: Optional[int] = None,
|
||||||
|
map_location: Union[str, torch.device] = "cpu",
|
||||||
|
mmap: bool = True,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
if str(map_location) == "cpu":
|
||||||
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
|
||||||
|
ckpt_paths = np.array(sorted(ckpt_paths))
|
||||||
|
|
||||||
|
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
||||||
|
old_mp_size = len(ckpt_paths)
|
||||||
|
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
||||||
|
|
||||||
|
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
|
||||||
|
paths = ckpt_paths[old_mp_ranks] # type: ignore
|
||||||
|
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
||||||
|
|
||||||
|
if new_mp_size == old_mp_size:
|
||||||
|
return state_dicts[0] # type: ignore
|
||||||
|
|
||||||
|
if moe_num_experts is not None:
|
||||||
|
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
|
||||||
|
|
||||||
|
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
|
||||||
|
return reshard_mp(
|
||||||
|
state_dicts,
|
||||||
|
size=max(new_mp_size // old_mp_size, 1),
|
||||||
|
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
|
||||||
|
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_WEIGHT_ROW_KEY = {
|
||||||
|
"feed_forward.w2",
|
||||||
|
"feed_forward.mlp.fc2",
|
||||||
|
"attention.wo",
|
||||||
|
"feed_forward.mlp.fc2_weight",
|
||||||
|
"feed_forward.w_out_shared_DF.weight",
|
||||||
|
"attn.wo.weight",
|
||||||
|
"mlp.c_proj.weight",
|
||||||
|
}
|
||||||
|
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
|
||||||
|
|
||||||
|
_WEIGHT_COLUMN_KEY = {
|
||||||
|
"output",
|
||||||
|
"feed_forward.(w1|w3)",
|
||||||
|
"feed_forward.mlp.(fc1|fc3)",
|
||||||
|
"feed_forward.mlp.fc1_weight",
|
||||||
|
"attention.(wk|wq|wv|wqkv).weight",
|
||||||
|
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
|
||||||
|
"attn.(wk|wq|wv).weight",
|
||||||
|
"attn.(wk|wq|wv).bias",
|
||||||
|
"mlp.c_fc.weight",
|
||||||
|
"mlp.c_fc.bias",
|
||||||
|
"conv1._linear.weight",
|
||||||
|
"tok_embeddings.weight",
|
||||||
|
"vision_projection.weight",
|
||||||
|
}
|
||||||
|
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
|
||||||
|
|
||||||
|
|
||||||
|
def reshard_mp(
|
||||||
|
state_dicts: List[Dict[str, torch.Tensor]],
|
||||||
|
size: int,
|
||||||
|
rank: int,
|
||||||
|
repeat_qk_qv: int = 1,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Reshard a list of state dicts into a single state dict given a change in MP size.
|
||||||
|
If the list has more than one state dict, we concatenate the values of the same
|
||||||
|
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
|
||||||
|
if len(tensors) > 1:
|
||||||
|
return torch.cat(tensors, dim=dim)
|
||||||
|
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
||||||
|
|
||||||
|
def process_key(key: str) -> torch.Tensor:
|
||||||
|
if row_regex.search(key):
|
||||||
|
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
|
||||||
|
elif column_regex.search(key):
|
||||||
|
if "w13" in key or "fc1_weight" in key:
|
||||||
|
dims = state_dicts[0][key].size()
|
||||||
|
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
|
||||||
|
return concat_or_chunk(values, dim=1).flatten(0, 1)
|
||||||
|
elif "qkv" in key:
|
||||||
|
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
|
||||||
|
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
|
||||||
|
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
|
||||||
|
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
|
||||||
|
elif "wk.weight" in key or "wv.weight" in key:
|
||||||
|
# Support MP > #kv_head
|
||||||
|
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
|
||||||
|
elif key == "output.bias" or key == "fc.weight":
|
||||||
|
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||||
|
elif "w_" in key:
|
||||||
|
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
|
||||||
|
else:
|
||||||
|
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||||
|
else:
|
||||||
|
return state_dicts[0][key].clone()
|
||||||
|
|
||||||
|
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
|
||||||
|
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||||
|
|
||||||
|
column_regex = re.compile("|".join(column_keys))
|
||||||
|
row_regex = re.compile("|".join(row_keys))
|
||||||
|
|
||||||
|
output: Dict[str, torch.Tensor] = {}
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
# Note: only processes keys in the first state dict.
|
||||||
|
# Assumes keys are the same across all state dicts.
|
||||||
|
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
|
||||||
|
for future in concurrent.futures.as_completed(mappings):
|
||||||
|
output[mappings[future]] = future.result()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
|
||||||
|
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||||
|
routed_regex = re.compile("|".join(routed_keys))
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
for key in keys:
|
||||||
|
if routed_regex.search(key):
|
||||||
|
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
|
||||||
|
return state_dict
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
|
||||||
|
|
||||||
# The goal is that these set of types are relevant for all Llama models.
|
# The goal is that these set of types are relevant for all Llama models.
|
||||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||||
# the llama3 series of models.
|
# the llama3 series of models.
|
||||||
|
@ -98,6 +89,29 @@ class StopReason(Enum):
|
||||||
out_of_tokens = "out_of_tokens"
|
out_of_tokens = "out_of_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParamDefinition(BaseModel):
|
||||||
|
param_type: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
required: Optional[bool] = True
|
||||||
|
default: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolDefinition(BaseModel):
|
||||||
|
tool_name: Union[BuiltinTool, str]
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||||
|
|
||||||
|
@field_validator("tool_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinTool(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class RawMediaItem(BaseModel):
|
class RawMediaItem(BaseModel):
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image"] = "image"
|
||||||
data: bytes | BytesIO
|
data: bytes | BytesIO
|
||||||
|
@ -140,292 +154,25 @@ class RawMessage(BaseModel):
|
||||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
register_schema(ToolCall)
|
class GenerationResult(BaseModel):
|
||||||
|
token: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[List[float]] = None
|
||||||
|
|
||||||
|
source: Literal["input"] | Literal["output"]
|
||||||
|
|
||||||
|
# index within the batch
|
||||||
|
batch_idx: int
|
||||||
|
# whether generation for this item is already finished. note that tokens can
|
||||||
|
# get returned even afterwards since other items in the batch can still be generating tokens
|
||||||
|
finished: bool
|
||||||
|
# because a batch is parallel processed, useful decoding for one item can correspond to processing
|
||||||
|
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
|
||||||
|
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
|
||||||
|
ignore_token: bool
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class QuantizationMode(str, Enum):
|
||||||
class ToolParamDefinition(BaseModel):
|
none = "none"
|
||||||
param_type: str
|
fp8_mixed = "fp8_mixed"
|
||||||
description: Optional[str] = None
|
int4_mixed = "int4_mixed"
|
||||||
required: Optional[bool] = True
|
|
||||||
default: Optional[Any] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolDefinition(BaseModel):
|
|
||||||
tool_name: Union[BuiltinTool, str]
|
|
||||||
description: Optional[str] = None
|
|
||||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_field(cls, v):
|
|
||||||
if isinstance(v, str):
|
|
||||||
try:
|
|
||||||
return BuiltinTool(v)
|
|
||||||
except ValueError:
|
|
||||||
return v
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class GreedySamplingStrategy(BaseModel):
|
|
||||||
type: Literal["greedy"] = "greedy"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TopPSamplingStrategy(BaseModel):
|
|
||||||
type: Literal["top_p"] = "top_p"
|
|
||||||
temperature: Optional[float] = Field(..., gt=0.0)
|
|
||||||
top_p: Optional[float] = 0.95
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TopKSamplingStrategy(BaseModel):
|
|
||||||
type: Literal["top_k"] = "top_k"
|
|
||||||
top_k: int = Field(..., ge=1)
|
|
||||||
|
|
||||||
|
|
||||||
SamplingStrategy = Annotated[
|
|
||||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SamplingParams(BaseModel):
|
|
||||||
"""Sampling parameters.
|
|
||||||
|
|
||||||
:param strategy: The sampling strategy.
|
|
||||||
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
|
||||||
your prompt plus max_tokens cannot exceed the model's context length.
|
|
||||||
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
|
||||||
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
|
||||||
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
|
||||||
The returned text will not contain the stop sequence.
|
|
||||||
"""
|
|
||||||
|
|
||||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
|
||||||
|
|
||||||
max_tokens: Optional[int] = 0
|
|
||||||
repetition_penalty: Optional[float] = 1.0
|
|
||||||
stop: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointQuantizationFormat(Enum):
|
|
||||||
# default format
|
|
||||||
bf16 = "bf16"
|
|
||||||
|
|
||||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
|
||||||
fp8_mixed = "fp8-mixed"
|
|
||||||
|
|
||||||
int8 = "int8"
|
|
||||||
|
|
||||||
int4 = "int4"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelFamily(Enum):
|
|
||||||
llama2 = "llama2"
|
|
||||||
llama3 = "llama3"
|
|
||||||
llama3_1 = "llama3_1"
|
|
||||||
llama3_2 = "llama3_2"
|
|
||||||
llama3_3 = "llama3_3"
|
|
||||||
llama4 = "llama4"
|
|
||||||
safety = "safety"
|
|
||||||
|
|
||||||
|
|
||||||
class CoreModelId(Enum):
|
|
||||||
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
|
||||||
|
|
||||||
# Llama 2 family
|
|
||||||
llama2_7b = "Llama-2-7b"
|
|
||||||
llama2_13b = "Llama-2-13b"
|
|
||||||
llama2_70b = "Llama-2-70b"
|
|
||||||
llama2_7b_chat = "Llama-2-7b-chat"
|
|
||||||
llama2_13b_chat = "Llama-2-13b-chat"
|
|
||||||
llama2_70b_chat = "Llama-2-70b-chat"
|
|
||||||
|
|
||||||
# Llama 3 family
|
|
||||||
llama3_8b = "Llama-3-8B"
|
|
||||||
llama3_70b = "Llama-3-70B"
|
|
||||||
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
|
||||||
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
|
||||||
|
|
||||||
# Llama 3.1 family
|
|
||||||
llama3_1_8b = "Llama3.1-8B"
|
|
||||||
llama3_1_70b = "Llama3.1-70B"
|
|
||||||
llama3_1_405b = "Llama3.1-405B"
|
|
||||||
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
|
||||||
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
|
||||||
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
|
||||||
|
|
||||||
# Llama 3.2 family
|
|
||||||
llama3_2_1b = "Llama3.2-1B"
|
|
||||||
llama3_2_3b = "Llama3.2-3B"
|
|
||||||
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
|
||||||
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
|
||||||
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
|
||||||
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
|
||||||
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
|
||||||
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
|
||||||
|
|
||||||
# Llama 3.3 family
|
|
||||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
|
||||||
|
|
||||||
# Llama 4 family
|
|
||||||
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
|
|
||||||
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
|
|
||||||
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
|
|
||||||
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
|
|
||||||
|
|
||||||
# Safety models
|
|
||||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
|
||||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
|
||||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
|
||||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
|
||||||
|
|
||||||
|
|
||||||
def is_multimodal(model_id) -> bool:
|
|
||||||
if model_id in [
|
|
||||||
CoreModelId.llama3_2_11b_vision,
|
|
||||||
CoreModelId.llama3_2_90b_vision,
|
|
||||||
CoreModelId.llama3_2_11b_vision_instruct,
|
|
||||||
CoreModelId.llama3_2_90b_vision_instruct,
|
|
||||||
]:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def model_family(model_id) -> ModelFamily:
|
|
||||||
if model_id in [
|
|
||||||
CoreModelId.llama2_7b,
|
|
||||||
CoreModelId.llama2_13b,
|
|
||||||
CoreModelId.llama2_70b,
|
|
||||||
CoreModelId.llama2_7b_chat,
|
|
||||||
CoreModelId.llama2_13b_chat,
|
|
||||||
CoreModelId.llama2_70b_chat,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama2
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama3_8b,
|
|
||||||
CoreModelId.llama3_70b,
|
|
||||||
CoreModelId.llama3_8b_instruct,
|
|
||||||
CoreModelId.llama3_70b_instruct,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama3
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama3_1_8b,
|
|
||||||
CoreModelId.llama3_1_70b,
|
|
||||||
CoreModelId.llama3_1_405b,
|
|
||||||
CoreModelId.llama3_1_8b_instruct,
|
|
||||||
CoreModelId.llama3_1_70b_instruct,
|
|
||||||
CoreModelId.llama3_1_405b_instruct,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama3_1
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama3_2_1b,
|
|
||||||
CoreModelId.llama3_2_3b,
|
|
||||||
CoreModelId.llama3_2_1b_instruct,
|
|
||||||
CoreModelId.llama3_2_3b_instruct,
|
|
||||||
CoreModelId.llama3_2_11b_vision,
|
|
||||||
CoreModelId.llama3_2_90b_vision,
|
|
||||||
CoreModelId.llama3_2_11b_vision_instruct,
|
|
||||||
CoreModelId.llama3_2_90b_vision_instruct,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama3_2
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama3_3_70b_instruct,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama3_3
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama4_scout_17b_16e,
|
|
||||||
CoreModelId.llama4_scout_17b_16e_instruct,
|
|
||||||
CoreModelId.llama4_maverick_17b_128e,
|
|
||||||
CoreModelId.llama4_maverick_17b_128e_instruct,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama4
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama_guard_3_8b,
|
|
||||||
CoreModelId.llama_guard_2_8b,
|
|
||||||
CoreModelId.llama_guard_3_11b_vision,
|
|
||||||
CoreModelId.llama_guard_3_1b,
|
|
||||||
]:
|
|
||||||
return ModelFamily.safety
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown model family for {model_id}")
|
|
||||||
|
|
||||||
|
|
||||||
class Model(BaseModel):
|
|
||||||
core_model_id: CoreModelId
|
|
||||||
description: str
|
|
||||||
huggingface_repo: Optional[str] = None
|
|
||||||
recommended_sampling_params: Optional[SamplingParams] = None
|
|
||||||
arch_args: Dict[str, Any]
|
|
||||||
variant: str = ""
|
|
||||||
|
|
||||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
|
||||||
pth_file_count: int
|
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
# silence pydantic until we remove the `model_` fields
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_family(self) -> ModelFamily:
|
|
||||||
return model_family(self.core_model_id)
|
|
||||||
|
|
||||||
# The SKU is uniquely identified by (model_id, variant) combo
|
|
||||||
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
|
||||||
if not self.variant:
|
|
||||||
return self.core_model_id.value
|
|
||||||
return f"{self.core_model_id.value}:{self.variant}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_instruct_model(self) -> bool:
|
|
||||||
return "instruct" in self.id.name
|
|
||||||
|
|
||||||
# Featured models are shown in the non-exhaustive model list
|
|
||||||
@property
|
|
||||||
def is_featured(self) -> bool:
|
|
||||||
return self.model_family in [
|
|
||||||
ModelFamily.llama3_1,
|
|
||||||
ModelFamily.llama3_2,
|
|
||||||
ModelFamily.llama3_3,
|
|
||||||
ModelFamily.llama4,
|
|
||||||
ModelFamily.safety,
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_seq_length(self) -> int:
|
|
||||||
if self.model_family == ModelFamily.llama2:
|
|
||||||
return 4096
|
|
||||||
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
|
||||||
return 4096
|
|
||||||
elif self.model_family == ModelFamily.llama3:
|
|
||||||
return 8192
|
|
||||||
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
|
||||||
return 131072
|
|
||||||
elif self.model_family == ModelFamily.llama3_2:
|
|
||||||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
|
||||||
return 8192
|
|
||||||
return 131072
|
|
||||||
elif self.model_family == ModelFamily.llama4:
|
|
||||||
if self.core_model_id in {
|
|
||||||
CoreModelId.llama4_scout_17b_16e,
|
|
||||||
CoreModelId.llama4_maverick_17b_128e,
|
|
||||||
}:
|
|
||||||
return 262144
|
|
||||||
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
|
|
||||||
return 10485760
|
|
||||||
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
|
||||||
return 1048576
|
|
||||||
elif self.core_model_id in [
|
|
||||||
CoreModelId.llama_guard_3_8b,
|
|
||||||
CoreModelId.llama_guard_3_11b_vision,
|
|
||||||
CoreModelId.llama_guard_3_1b,
|
|
||||||
]:
|
|
||||||
return 131072
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
|
||||||
|
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -19,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from ..datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawContent,
|
RawContent,
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
|
@ -30,7 +23,6 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .tool_utils import ToolUtils
|
from .tool_utils import ToolUtils
|
||||||
|
|
||||||
|
|
366
llama_stack/models/llama/llama3/generation.py
Normal file
366
llama_stack/models/llama/llama3/generation.py
Normal file
|
@ -0,0 +1,366 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Generator, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
|
initialize_model_parallel,
|
||||||
|
model_parallel_is_initialized,
|
||||||
|
)
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from ..checkpoint import maybe_reshard_state_dict
|
||||||
|
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||||
|
from .args import ModelArgs
|
||||||
|
from .chat_format import ChatFormat, LLMInput
|
||||||
|
from .model import Transformer
|
||||||
|
from .multimodal.model import CrossAttentionTransformer
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Llama3:
|
||||||
|
@staticmethod
|
||||||
|
def build(
|
||||||
|
ckpt_dir: str,
|
||||||
|
max_seq_len: int,
|
||||||
|
max_batch_size: int,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
quantization_mode: Optional[QuantizationMode] = None,
|
||||||
|
seed: int = 1,
|
||||||
|
device: str = "cuda",
|
||||||
|
):
|
||||||
|
device = torch.device(device)
|
||||||
|
if (
|
||||||
|
device.type == "cuda"
|
||||||
|
and not torch.cuda.is_available()
|
||||||
|
or device.type == "xpu"
|
||||||
|
and not torch.xpu.is_available()
|
||||||
|
):
|
||||||
|
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
|
||||||
|
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.distributed.init_process_group("nccl")
|
||||||
|
else:
|
||||||
|
torch.distributed.init_process_group("gloo")
|
||||||
|
|
||||||
|
if not model_parallel_is_initialized():
|
||||||
|
if world_size is None:
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
initialize_model_parallel(world_size)
|
||||||
|
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
elif device.type == "xpu":
|
||||||
|
torch.xpu.set_device(local_rank)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
if local_rank > 0:
|
||||||
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
|
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
|
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||||
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
model_args: ModelArgs = ModelArgs(
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
|
||||||
|
state_dict = maybe_reshard_state_dict(
|
||||||
|
ckpt_paths,
|
||||||
|
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_args.vocab_size == tokenizer.n_words
|
||||||
|
|
||||||
|
def build_model():
|
||||||
|
if model_args.vision_chunk_size > 0:
|
||||||
|
model = CrossAttentionTransformer(model_args)
|
||||||
|
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
|
||||||
|
else:
|
||||||
|
model = Transformer(model_args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||||
|
from .quantization.loader import convert_to_quantized_model
|
||||||
|
|
||||||
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
model = build_model()
|
||||||
|
print("Loading state dict...")
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
print("Done...")
|
||||||
|
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
else:
|
||||||
|
print(f"Setting default device to {device}")
|
||||||
|
if device.type == "cuda":
|
||||||
|
if torch.cuda.is_bf16_supported():
|
||||||
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
||||||
|
elif device.type == "xpu":
|
||||||
|
if torch.xpu.is_bf16_supported():
|
||||||
|
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
|
||||||
|
|
||||||
|
model = build_model()
|
||||||
|
print("Loading state dict...")
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
model.to(device)
|
||||||
|
print("Done...")
|
||||||
|
|
||||||
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
|
|
||||||
|
return Llama3(model, tokenizer, model_args)
|
||||||
|
|
||||||
|
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||||
|
self.args = args
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
model_inputs: List[LLMInput],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
echo: bool = False,
|
||||||
|
print_model_input: bool = False,
|
||||||
|
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||||
|
) -> Generator[List[GenerationResult], None, None]:
|
||||||
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
params = self.model.params
|
||||||
|
|
||||||
|
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||||
|
if print_model_input:
|
||||||
|
for inp in model_inputs:
|
||||||
|
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||||
|
cprint(
|
||||||
|
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
prompt_tokens = [inp.tokens for inp in model_inputs]
|
||||||
|
|
||||||
|
bsz = len(model_inputs)
|
||||||
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||||
|
|
||||||
|
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||||
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||||
|
|
||||||
|
if max_prompt_len >= params.max_seq_len:
|
||||||
|
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
||||||
|
return
|
||||||
|
|
||||||
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||||
|
|
||||||
|
pad_id = self.tokenizer.pad_id
|
||||||
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
||||||
|
for k, t in enumerate(prompt_tokens):
|
||||||
|
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
||||||
|
if logprobs:
|
||||||
|
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||||
|
|
||||||
|
is_vision = not isinstance(self.model, Transformer)
|
||||||
|
if is_vision:
|
||||||
|
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
|
||||||
|
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
|
||||||
|
|
||||||
|
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||||
|
batch_images=images,
|
||||||
|
batch_masks=mask,
|
||||||
|
total_len=total_len,
|
||||||
|
device=tokens.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
eos_reached = torch.tensor([False] * bsz)
|
||||||
|
input_text_mask = tokens != pad_id
|
||||||
|
|
||||||
|
if echo:
|
||||||
|
for i in range(max_prompt_len):
|
||||||
|
results = []
|
||||||
|
for j, t in enumerate(tokens[:, i]):
|
||||||
|
results.append(
|
||||||
|
GenerationResult(
|
||||||
|
token=t.item(),
|
||||||
|
text=self.tokenizer.decode([t.item()]),
|
||||||
|
source="input",
|
||||||
|
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||||
|
batch_idx=j,
|
||||||
|
finished=False,
|
||||||
|
ignore_token=t.item() == pad_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield results
|
||||||
|
|
||||||
|
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||||
|
|
||||||
|
prev_pos = 0
|
||||||
|
for cur_pos in range(min_prompt_len, total_len):
|
||||||
|
if is_vision:
|
||||||
|
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||||
|
text_only_inference = all(inp.vision is None for inp in model_inputs)
|
||||||
|
logits = self.model.forward(
|
||||||
|
position_ids,
|
||||||
|
tokens,
|
||||||
|
cross_attention_masks,
|
||||||
|
full_text_row_masked_out_mask,
|
||||||
|
xattn_caches,
|
||||||
|
text_only_inference,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
|
||||||
|
if logits_processor is not None:
|
||||||
|
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||||
|
|
||||||
|
if temperature > 0:
|
||||||
|
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||||
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||||
|
|
||||||
|
next_token = next_token.reshape(-1)
|
||||||
|
# only replace token if prompt has already been generated
|
||||||
|
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||||
|
tokens[:, cur_pos] = next_token
|
||||||
|
|
||||||
|
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||||
|
if is_vision:
|
||||||
|
# the logits space (num_classes) is designed to never contain a media_token
|
||||||
|
# however our input token stream does contain them. we need to nuke them here
|
||||||
|
# or else the CUDA kernels will crash with an illegal memory access
|
||||||
|
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
||||||
|
masks = [target.eq(t) for t in vision_tokens]
|
||||||
|
if len(masks) > 1:
|
||||||
|
mask = torch.logical_or(*masks)
|
||||||
|
else:
|
||||||
|
mask = masks[0]
|
||||||
|
target[mask] = 0
|
||||||
|
|
||||||
|
if logprobs:
|
||||||
|
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||||
|
input=logits.transpose(1, 2),
|
||||||
|
target=target,
|
||||||
|
reduction="none",
|
||||||
|
ignore_index=pad_id,
|
||||||
|
)
|
||||||
|
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||||
|
results = []
|
||||||
|
for idx, t in enumerate(next_token):
|
||||||
|
results.append(
|
||||||
|
GenerationResult(
|
||||||
|
token=t.item(),
|
||||||
|
text=self.tokenizer.decode([t.item()]),
|
||||||
|
source="output",
|
||||||
|
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||||
|
batch_idx=idx,
|
||||||
|
finished=eos_reached[idx],
|
||||||
|
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield results
|
||||||
|
|
||||||
|
prev_pos = cur_pos
|
||||||
|
if all(eos_reached):
|
||||||
|
break
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
contents: List[RawContent],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
echo: bool = False,
|
||||||
|
) -> Generator[List[GenerationResult], None, None]:
|
||||||
|
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||||
|
for result in self.generate(
|
||||||
|
model_inputs=model_inputs,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
logprobs=logprobs,
|
||||||
|
echo=echo,
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
if all(r.finished for r in result):
|
||||||
|
break
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
self,
|
||||||
|
messages_batch: List[List[RawMessage]],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
|
echo: bool = False,
|
||||||
|
) -> Generator[List[GenerationResult], None, None]:
|
||||||
|
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||||
|
for result in self.generate(
|
||||||
|
model_inputs=model_inputs,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
logprobs=logprobs,
|
||||||
|
echo=echo,
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
if all(r.finished for r in result):
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs, p):
|
||||||
|
"""
|
||||||
|
Perform top-p (nucleus) sampling on a probability distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor): Probability distribution tensor.
|
||||||
|
p (float): Probability threshold for top-p sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Sampled token indices.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||||
|
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||||
|
"""
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort[mask] = 0.0
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
|
@ -16,7 +16,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from ..datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from . import template_data
|
from . import template_data
|
||||||
from .chat_format import ChatFormat
|
from .chat_format import ChatFormat
|
||||||
from .prompt_templates import (
|
from .prompt_templates import (
|
||||||
|
|
|
@ -4,16 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
@ -29,6 +19,10 @@ from torch import nn
|
||||||
|
|
||||||
from .args import ModelArgs
|
from .args import ModelArgs
|
||||||
|
|
||||||
|
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
|
||||||
|
# dependencies. These dependencies are not part of the default dependencies
|
||||||
|
# (requirements.txt) of the `llama-models` package.
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
@ -111,9 +105,9 @@ class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
world_size = fs_init.get_model_parallel_world_size()
|
||||||
self.n_local_heads = args.n_heads // model_parallel_size
|
self.n_local_heads = args.n_heads // world_size
|
||||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
self.head_dim = args.dim // args.n_heads
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
|
|
@ -4,16 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
|
||||||
n_heads,
|
n_heads,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
world_size = fs_init.get_model_parallel_world_size()
|
||||||
qkvo_replication = 1
|
qkvo_replication = 1
|
||||||
if model_parallel_size > 16:
|
if world_size > 16:
|
||||||
qkvo_replication = model_parallel_size // 8
|
qkvo_replication = world_size // 8
|
||||||
|
|
||||||
self.n_kv_heads = n_heads
|
self.n_kv_heads = n_heads
|
||||||
self.n_local_heads = n_heads * qkvo_replication // model_parallel_size
|
self.n_local_heads = n_heads * qkvo_replication // world_size
|
||||||
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
|
||||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
self.head_dim = dim // n_heads
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
|
@ -536,16 +526,16 @@ class Attention(nn.Module):
|
||||||
cache_v (torch.Tensor): Cached values for attention.
|
cache_v (torch.Tensor): Cached values for attention.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
world_size = fs_init.get_model_parallel_world_size()
|
||||||
replication_factor = 1
|
replication_factor = 1
|
||||||
if model_parallel_size > 8:
|
if world_size > 8:
|
||||||
replication_factor = model_parallel_size // MP_SCALE
|
replication_factor = world_size // MP_SCALE
|
||||||
|
|
||||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||||
self.n_kv_heads *= replication_factor
|
self.n_kv_heads *= replication_factor
|
||||||
|
|
||||||
self.n_local_heads = args.n_heads // model_parallel_size
|
self.n_local_heads = args.n_heads // world_size
|
||||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
self.head_dim = args.dim // args.n_heads
|
self.head_dim = args.dim // args.n_heads
|
||||||
self.max_seq_len = args.max_seq_len
|
self.max_seq_len = args.max_seq_len
|
||||||
|
@ -587,13 +577,11 @@ class Attention(nn.Module):
|
||||||
self.n_local_kv_heads,
|
self.n_local_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
device = next(self.parameters()).device
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"key_cache",
|
"key_cache",
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
cache_shape,
|
cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
|
||||||
),
|
),
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
|
@ -602,7 +590,6 @@ class Attention(nn.Module):
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
cache_shape,
|
cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
|
||||||
),
|
),
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
|
@ -614,6 +601,9 @@ class Attention(nn.Module):
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
position_ids: torch.LongTensor,
|
position_ids: torch.LongTensor,
|
||||||
):
|
):
|
||||||
|
self.key_cache = self.key_cache.to(x.device)
|
||||||
|
self.value_cache = self.value_cache.to(x.device)
|
||||||
|
|
||||||
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
|
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
|
||||||
|
|
||||||
bs, slen, _ = xq.shape
|
bs, slen, _ = xq.shape
|
||||||
|
@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_parallel_size = fs_init.get_model_parallel_world_size()
|
self.world_size = fs_init.get_model_parallel_world_size()
|
||||||
replication_factor = 1
|
replication_factor = 1
|
||||||
if self.model_parallel_size > 8:
|
if self.world_size > 8:
|
||||||
replication_factor = self.model_parallel_size // MP_SCALE
|
replication_factor = self.world_size // MP_SCALE
|
||||||
n_kv_heads *= replication_factor
|
n_kv_heads *= replication_factor
|
||||||
|
|
||||||
assert n_heads % n_kv_heads == 0
|
assert n_heads % n_kv_heads == 0
|
||||||
|
@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
|
||||||
# trunk LLM (i.e., group query attention) -- @dubeya
|
# trunk LLM (i.e., group query attention) -- @dubeya
|
||||||
# local heads
|
# local heads
|
||||||
assert self.n_heads % self.n_kv_heads == 0
|
assert self.n_heads % self.n_kv_heads == 0
|
||||||
assert self.n_heads % self.model_parallel_size == 0
|
assert self.n_heads % self.world_size == 0
|
||||||
assert self.n_kv_heads % self.model_parallel_size == 0
|
assert self.n_kv_heads % self.world_size == 0
|
||||||
self.n_local_heads = self.n_heads // self.model_parallel_size
|
self.n_local_heads = self.n_heads // self.world_size
|
||||||
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
||||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
|
|
||||||
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
|
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
|
||||||
self.image_res = args.vision_chunk_size
|
self.image_res = args.vision_chunk_size
|
||||||
self.max_num_chunks = args.vision_max_num_chunks
|
self.max_num_chunks = args.vision_max_num_chunks
|
||||||
if return_intermediate is not None:
|
if return_intermediate is not None:
|
||||||
return_intermediate = [int(level) for level in return_intermediate.split(",")]
|
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
|
||||||
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
|
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
|
||||||
self.patch_size = 14
|
self.patch_size = 14
|
||||||
self.vision_encoder = VisionEncoder(
|
self.vision_encoder = VisionEncoder(
|
||||||
|
@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, args: ModelArgs) -> None:
|
def __init__(self, args: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_parallel_size = fs_init.get_model_parallel_world_size()
|
self.world_size = fs_init.get_model_parallel_world_size()
|
||||||
assert args.vocab_size > 0
|
assert args.vocab_size > 0
|
||||||
self.vocab_size = args.vocab_size
|
self.vocab_size = args.vocab_size
|
||||||
self.n_layers = args.n_layers
|
self.n_layers = args.n_layers
|
||||||
self.dim = args.dim
|
self.dim = args.dim
|
||||||
self.head_dim = args.dim // args.n_heads
|
self.head_dim = args.dim // args.n_heads
|
||||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||||
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
||||||
assert self.vocab_size % self.model_parallel_size == 0
|
assert self.vocab_size % self.world_size == 0
|
||||||
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
||||||
self.pos_embeddings = None
|
self.pos_embeddings = None
|
||||||
# final norm layer (not necessary for post-norm)
|
# final norm layer (not necessary for post-norm)
|
||||||
|
@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
text_only_inference: bool = False,
|
text_only_inference: bool = False,
|
||||||
):
|
):
|
||||||
assert self.cache_is_setup, "Please set up cache before calling forward"
|
assert self.cache_is_setup, "Please set up cache before calling forward"
|
||||||
|
self.mask_cache = self.mask_cache.to(h.device)
|
||||||
|
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||||
mask = self.mask_cache.index_select(2, position_ids)
|
mask = self.mask_cache.index_select(2, position_ids)
|
||||||
freqs_cis = self.freqs_cis.index_select(0, position_ids)
|
freqs_cis = self.freqs_cis.index_select(0, position_ids)
|
||||||
|
|
||||||
|
@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
output = gather_from_tensor_model_parallel_region(output)
|
output = gather_from_tensor_model_parallel_region(output)
|
||||||
return output.float()
|
return output.float()
|
||||||
|
|
||||||
def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16):
|
def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
|
||||||
# Set up the text kv caches
|
# Set up the text kv caches
|
||||||
device = next(self.parameters()).device
|
|
||||||
ones = torch.ones(
|
ones = torch.ones(
|
||||||
(self.max_seq_len, self.max_seq_len),
|
(self.max_seq_len, self.max_seq_len),
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
|
@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
|
|
||||||
return (
|
return (
|
||||||
cross_attention_masks.to(device=text_device, dtype=text_dtype),
|
cross_attention_masks.to(device=text_device, dtype=text_dtype),
|
||||||
full_text_row_masked_out_mask,
|
full_text_row_masked_out_mask.to(device=text_device),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
max_num_chunks=args.vision_max_num_chunks,
|
max_num_chunks=args.vision_max_num_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
|
def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
|
||||||
self.text_model.setup_cache(max_batch_size, dtype)
|
self.text_model.setup_cache(max_batch_size, device, dtype)
|
||||||
|
|
||||||
def compute_vision_tokens_masks(
|
def compute_vision_tokens_masks(
|
||||||
self,
|
self,
|
||||||
batch_images: List[List[PIL_Image.Image]],
|
batch_images: List[List[PIL_Image.Image]],
|
||||||
batch_masks: List[List[List[int]]],
|
batch_masks: List[List[List[int]]],
|
||||||
total_len: int,
|
total_len: int,
|
||||||
|
device: torch.device,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
skip_vision_encoder = False
|
skip_vision_encoder = False
|
||||||
|
|
||||||
|
@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
image_res=self.params.vision_chunk_size,
|
image_res=self.params.vision_chunk_size,
|
||||||
max_num_images=max_num_images,
|
max_num_images=max_num_images,
|
||||||
)
|
)
|
||||||
|
stacked_images = stacked_images.to(device=device)
|
||||||
|
|
||||||
if skip_vision_encoder:
|
if skip_vision_encoder:
|
||||||
vision_tokens = torch.zeros(
|
vision_tokens = torch.zeros(
|
||||||
|
@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vision_tokens = self.vision_model(stacked_images, aspect_ratios)
|
vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
|
||||||
|
|
||||||
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
|
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
|
||||||
xattn_caches = torch.stack(
|
xattn_caches = torch.stack(
|
|
@ -15,7 +15,7 @@ import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.apis.inference import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
|
@ -279,6 +279,10 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
You can answer general questions or invoke tools when necessary.
|
||||||
|
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
|
|
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -4,9 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
||||||
|
|
||||||
# type: ignore
|
# type: ignore
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional, cast
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
@ -18,22 +15,15 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from ...datatypes import QuantizationMode
|
||||||
from llama_stack.log import get_logger
|
from ...quantize_impls import (
|
||||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
|
||||||
from llama_stack.providers.inline.inference.meta_reference.quantize_impls import (
|
|
||||||
Fp8ScaledWeights,
|
Fp8ScaledWeights,
|
||||||
ffn_swiglu,
|
ffn_swiglu,
|
||||||
load_fp8,
|
load_fp8,
|
||||||
quantize_fp8,
|
quantize_fp8,
|
||||||
)
|
)
|
||||||
|
|
||||||
from ...config import MetaReferenceQuantizedInferenceConfig
|
|
||||||
from ..args import ModelArgs
|
|
||||||
from ..model import Transformer, TransformerBlock
|
from ..model import Transformer, TransformerBlock
|
||||||
|
from ..multimodal.model import CrossAttentionTransformer
|
||||||
log = get_logger(__name__, category="quantization")
|
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper(
|
def swiglu_wrapper(
|
||||||
|
@ -44,30 +34,34 @@ def swiglu_wrapper(
|
||||||
return reduce_from_model_parallel_region(out)
|
return reduce_from_model_parallel_region(out)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_quantized_model(
|
||||||
|
model: Transformer | CrossAttentionTransformer,
|
||||||
|
checkpoint_dir: str,
|
||||||
|
quantization_mode: Optional[str] = None,
|
||||||
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Transformer | CrossAttentionTransformer:
|
||||||
|
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||||
|
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
||||||
|
elif quantization_mode == QuantizationMode.int4_mixed:
|
||||||
|
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
|
||||||
|
|
||||||
|
|
||||||
def convert_to_fp8_quantized_model(
|
def convert_to_fp8_quantized_model(
|
||||||
model: Transformer,
|
model: Transformer,
|
||||||
config: MetaReferenceQuantizedInferenceConfig,
|
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
) -> Transformer:
|
) -> Transformer:
|
||||||
if config.quantization.type == QuantizationType.bf16.value:
|
|
||||||
return model
|
|
||||||
|
|
||||||
elif config.quantization.type != QuantizationType.fp8.value:
|
|
||||||
raise ValueError("Only FP8 quantization is supported")
|
|
||||||
|
|
||||||
assert config.model is not None, "Model must be specified for quantized inference"
|
|
||||||
llama_model = resolve_model(config.model)
|
|
||||||
assert llama_model is not None, f"Model {config.model} not found"
|
|
||||||
|
|
||||||
# Move weights to GPU with quantization
|
# Move weights to GPU with quantization
|
||||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||||
log.info("Loading fp8 scales...")
|
if os.path.isfile(fp8_scales_path):
|
||||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
print("Loading fp8 scales...")
|
||||||
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
|
||||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||||
|
|
||||||
for block in model.layers:
|
for _, block in model.named_modules():
|
||||||
if isinstance(block, TransformerBlock):
|
if isinstance(block, TransformerBlock):
|
||||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
continue
|
continue
|
||||||
|
@ -81,8 +75,8 @@ def convert_to_fp8_quantized_model(
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("Quantizing fp8 weights from bf16...")
|
print("Quantizing fp8 weights from bf16...")
|
||||||
for block in model.layers:
|
for _, block in model.named_modules():
|
||||||
if isinstance(block, TransformerBlock):
|
if isinstance(block, TransformerBlock):
|
||||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
continue
|
continue
|
||||||
|
@ -92,12 +86,12 @@ def convert_to_fp8_quantized_model(
|
||||||
param.weight = quantize_fp8(
|
param.weight = quantize_fp8(
|
||||||
param.weight,
|
param.weight,
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
output_device=torch.device("cuda"),
|
output_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, parameter in model.named_parameters():
|
for _, parameter in model.named_parameters():
|
||||||
if not isinstance(parameter, Fp8ScaledWeights):
|
if not isinstance(parameter, Fp8ScaledWeights):
|
||||||
parameter.data = parameter.to(device="cuda")
|
parameter.data = parameter.to(device=device)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -290,12 +284,12 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||||
|
|
||||||
|
|
||||||
def convert_to_int4_quantized_model(
|
def convert_to_int4_quantized_model(
|
||||||
model: Transformer,
|
model: Transformer | CrossAttentionTransformer,
|
||||||
model_args: ModelArgs,
|
checkpoint_dir: str,
|
||||||
config: MetaReferenceQuantizedInferenceConfig,
|
device: Optional[torch.device] = None,
|
||||||
) -> Transformer:
|
) -> Transformer | CrossAttentionTransformer:
|
||||||
"""Convert the model to int4 quantized model."""
|
"""Convert the model to int4 quantized model."""
|
||||||
|
model_args = model.params
|
||||||
assert model_args.quantization_args is not None, "Quantization args must be specified."
|
assert model_args.quantization_args is not None, "Quantization args must be specified."
|
||||||
quantization_args = model_args.quantization_args
|
quantization_args = model_args.quantization_args
|
||||||
if quantization_args.scheme is None:
|
if quantization_args.scheme is None:
|
||||||
|
@ -319,5 +313,4 @@ def convert_to_int4_quantized_model(
|
||||||
lora_scale = model_args.lora_args.scale
|
lora_scale = model_args.lora_args.scale
|
||||||
|
|
||||||
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
|
||||||
return cast(Transformer, model.to(device))
|
|
|
@ -12,8 +12,7 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
from ..datatypes import BuiltinTool, StopReason, ToolCall
|
||||||
|
|
||||||
from .prompt_templates import (
|
from .prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
JsonCustomToolGenerator,
|
JsonCustomToolGenerator,
|
||||||
|
|
|
@ -4,16 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
|
@ -16,7 +16,8 @@ import re
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
|
||||||
|
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
|
|
@ -3,10 +3,3 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
|
@ -4,12 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
|
@ -13,7 +13,7 @@ import torch
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
# TODO: either fork these or move them to the common package
|
# TODO: either fork these or move them to the common package
|
||||||
from llama_stack.models.llama.datatypes import (
|
from ..datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawContent,
|
RawContent,
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
|
@ -24,16 +24,10 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
|
from ..llama3.tool_utils import ToolUtils
|
||||||
from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs
|
from .args import VisionArgs
|
||||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
from .datatypes import LLMInput
|
||||||
LLMInput,
|
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
|
||||||
)
|
|
||||||
from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import (
|
|
||||||
ResizeNormalizeImageTransform,
|
|
||||||
VariableSizeImageTransform,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
@ -54,7 +48,7 @@ class TransformedImage:
|
||||||
aspect_ratio: Tuple[int, int]
|
aspect_ratio: Tuple[int, int]
|
||||||
|
|
||||||
|
|
||||||
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||||
if image.mode == "RGBA":
|
if image.mode == "RGBA":
|
||||||
image.load() # for png.split()
|
image.load() # for png.split()
|
||||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||||
|
@ -171,7 +165,7 @@ class ChatFormat:
|
||||||
|
|
||||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||||
image = PIL_Image.open(bytes_io)
|
image = PIL_Image.open(bytes_io)
|
||||||
image = convert_rgba_to_rgb(image)
|
image = convert_image_to_rgb(image)
|
||||||
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
||||||
|
|
||||||
if image_tiles.shape[0] > 1:
|
if image_tiles.shape[0] > 1:
|
||||||
|
@ -216,9 +210,12 @@ class ChatFormat:
|
||||||
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||||
_process_content(content)
|
_process_content(content)
|
||||||
|
|
||||||
|
# Tool calls and Tool Response messages should be eom
|
||||||
eom = False
|
eom = False
|
||||||
if message.role == "assistant":
|
if message.role == "assistant":
|
||||||
eom = message.stop_reason == StopReason.end_of_message
|
eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
|
||||||
|
elif message.role == "tool":
|
||||||
|
eom = True
|
||||||
|
|
||||||
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
|
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
|
||||||
return tokens, images
|
return tokens, images
|
||||||
|
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import List, Optional, Union
|
from typing import List, Optional, Union
|
||||||
|
|
|
@ -10,40 +10,28 @@ import json
|
||||||
import os
|
import os
|
||||||
import sys
|
import sys
|
||||||
import time
|
import time
|
||||||
from enum import Enum
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Callable, Generator, List, Optional
|
from typing import Callable, Generator, List, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
import torch.nn.functional as F
|
import torch.nn.functional as F
|
||||||
from fairscale.nn.model_parallel.initialize import (
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
get_model_parallel_rank,
|
|
||||||
initialize_model_parallel,
|
initialize_model_parallel,
|
||||||
model_parallel_is_initialized,
|
model_parallel_is_initialized,
|
||||||
)
|
)
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.models.llama.llama4.chat_format import (
|
from ..checkpoint import maybe_reshard_state_dict
|
||||||
ChatFormat,
|
from ..datatypes import GenerationResult, QuantizationMode
|
||||||
RawContent,
|
|
||||||
RawMessage,
|
|
||||||
)
|
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
|
||||||
|
|
||||||
from ..common import TokenResult
|
|
||||||
from .args import ModelArgs
|
from .args import ModelArgs
|
||||||
|
from .chat_format import ChatFormat, RawContent, RawMessage
|
||||||
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
|
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
|
||||||
from .model import Transformer
|
from .model import Transformer
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
|
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
|
||||||
|
|
||||||
|
|
||||||
class QuantizationMode(str, Enum):
|
|
||||||
none = "none"
|
|
||||||
fp8_mixed = "fp8_mixed"
|
|
||||||
int4_mixed = "int4_mixed"
|
|
||||||
|
|
||||||
|
|
||||||
class Llama4:
|
class Llama4:
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def build(
|
def build(
|
||||||
|
@ -51,7 +39,7 @@ class Llama4:
|
||||||
max_seq_len: int,
|
max_seq_len: int,
|
||||||
max_batch_size: int,
|
max_batch_size: int,
|
||||||
world_size: Optional[int] = None,
|
world_size: Optional[int] = None,
|
||||||
quantization_mode: Optional[str] = None,
|
quantization_mode: Optional[QuantizationMode] = None,
|
||||||
seed: int = 1,
|
seed: int = 1,
|
||||||
):
|
):
|
||||||
if not torch.distributed.is_initialized():
|
if not torch.distributed.is_initialized():
|
||||||
|
@ -72,11 +60,9 @@ class Llama4:
|
||||||
|
|
||||||
start_time = time.time()
|
start_time = time.time()
|
||||||
|
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
assert world_size == len(checkpoints), (
|
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||||
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
|
|
||||||
)
|
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
params = json.loads(f.read())
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
@ -93,10 +79,11 @@ class Llama4:
|
||||||
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
|
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
|
||||||
print("Model args:\n", model_args.model_dump_json(indent=2))
|
print("Model args:\n", model_args.model_dump_json(indent=2))
|
||||||
|
|
||||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
state_dict = maybe_reshard_state_dict(
|
||||||
print(f"Loading checkpoint from {ckpt_dir}...")
|
ckpt_paths,
|
||||||
with open(ckpt_path, "rb") as f:
|
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||||
checkpoint = torch.load(f, map_location="cpu", weights_only=True)
|
moe_num_experts=model_args.moe_args.num_experts,
|
||||||
|
)
|
||||||
print("Loaded checkpoint")
|
print("Loaded checkpoint")
|
||||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||||
from .quantization.loader import convert_to_quantized_model
|
from .quantization.loader import convert_to_quantized_model
|
||||||
|
@ -104,9 +91,9 @@ class Llama4:
|
||||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
print("Loading state dict...")
|
print("Loading state dict...")
|
||||||
model.load_state_dict(checkpoint, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
print("Done...")
|
print("Done...")
|
||||||
model = convert_to_quantized_model(model, ckpt_dir)
|
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
|
||||||
else:
|
else:
|
||||||
if torch.cuda.is_bf16_supported():
|
if torch.cuda.is_bf16_supported():
|
||||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
@ -115,7 +102,7 @@ class Llama4:
|
||||||
|
|
||||||
model = Transformer(model_args)
|
model = Transformer(model_args)
|
||||||
print("Loading state dict...")
|
print("Loading state dict...")
|
||||||
model.load_state_dict(checkpoint, strict=False)
|
model.load_state_dict(state_dict, strict=False)
|
||||||
print("Done...")
|
print("Done...")
|
||||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
|
|
||||||
|
@ -130,7 +117,7 @@ class Llama4:
|
||||||
@torch.inference_mode()
|
@torch.inference_mode()
|
||||||
def generate(
|
def generate(
|
||||||
self,
|
self,
|
||||||
llm_input: LLMInput,
|
llm_inputs: List[LLMInput],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
|
@ -138,22 +125,20 @@ class Llama4:
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
print_model_input: bool = False,
|
print_model_input: bool = False,
|
||||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||||
) -> Generator:
|
) -> Generator[List[GenerationResult], None, None]:
|
||||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||||
max_gen_len = self.model.args.max_seq_len - 1
|
max_gen_len = self.model.args.max_seq_len - 1
|
||||||
|
|
||||||
params = self.model.args
|
params = self.model.args
|
||||||
|
|
||||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||||
if print_model_input and get_model_parallel_rank() == 0:
|
if print_model_input:
|
||||||
tokens_to_print = list(llm_input.tokens)
|
cprint("Input to model:\n", "yellow")
|
||||||
cprint(
|
for inp in llm_inputs:
|
||||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
cprint(self.tokenizer.decode(inp.tokens), "grey")
|
||||||
"red",
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||||
)
|
|
||||||
prompt_tokens = [llm_input.tokens]
|
|
||||||
|
|
||||||
bsz = 1
|
bsz = len(llm_inputs)
|
||||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||||
|
|
||||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||||
|
@ -176,24 +161,33 @@ class Llama4:
|
||||||
input_text_mask = tokens != pad_id
|
input_text_mask = tokens != pad_id
|
||||||
|
|
||||||
if echo:
|
if echo:
|
||||||
for i, t in enumerate(llm_input.tokens):
|
for i in range(max_prompt_len):
|
||||||
yield TokenResult(
|
results = []
|
||||||
token=t,
|
for j, t in enumerate(tokens[:, i]):
|
||||||
text=self.tokenizer.decode([t]),
|
results.append(
|
||||||
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
|
GenerationResult(
|
||||||
)
|
token=t.item(),
|
||||||
|
text=self.tokenizer.decode([t.item()]),
|
||||||
|
source="input",
|
||||||
|
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||||
|
batch_idx=j,
|
||||||
|
finished=False,
|
||||||
|
ignore_token=t.item() == pad_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield results
|
||||||
|
|
||||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
||||||
|
|
||||||
prev_pos = 0
|
prev_pos = 0
|
||||||
for cur_pos in range(min_prompt_len, total_len):
|
for cur_pos in range(min_prompt_len, total_len):
|
||||||
image_embedding = None
|
image_embedding = None
|
||||||
if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0:
|
if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
|
||||||
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
|
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
|
||||||
image_mask = image_mask.unsqueeze(-1)
|
image_mask = image_mask.unsqueeze(-1)
|
||||||
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
|
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
|
||||||
|
|
||||||
image_batch = [llm_input.images]
|
image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
|
||||||
image_embedding = MaskedEmbedding(
|
image_embedding = MaskedEmbedding(
|
||||||
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
|
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
|
||||||
mask=image_mask,
|
mask=image_mask,
|
||||||
|
@ -229,11 +223,21 @@ class Llama4:
|
||||||
ignore_index=pad_id,
|
ignore_index=pad_id,
|
||||||
)
|
)
|
||||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||||
yield TokenResult(
|
|
||||||
token=next_token[0].item(),
|
results = []
|
||||||
text=self.tokenizer.decode(next_token.tolist()),
|
for idx, t in enumerate(next_token):
|
||||||
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
|
results.append(
|
||||||
)
|
GenerationResult(
|
||||||
|
token=t.item(),
|
||||||
|
text=self.tokenizer.decode([t.item()]),
|
||||||
|
source="output",
|
||||||
|
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||||
|
batch_idx=idx,
|
||||||
|
finished=eos_reached[idx],
|
||||||
|
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield results
|
||||||
|
|
||||||
prev_pos = cur_pos
|
prev_pos = cur_pos
|
||||||
if all(eos_reached):
|
if all(eos_reached):
|
||||||
|
@ -241,68 +245,47 @@ class Llama4:
|
||||||
|
|
||||||
def completion(
|
def completion(
|
||||||
self,
|
self,
|
||||||
content: RawContent,
|
contents: List[RawContent],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
) -> Generator:
|
) -> Generator[List[GenerationResult], None, None]:
|
||||||
llm_input = self.formatter.encode_content(content)
|
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
llm_input=llm_input,
|
llm_inputs=llm_inputs,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
):
|
):
|
||||||
if result.token in self.tokenizer.stop_tokens:
|
|
||||||
break
|
|
||||||
yield result
|
yield result
|
||||||
|
if all(r.finished for r in result):
|
||||||
|
break
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
messages: List[RawMessage],
|
messages_batch: List[List[RawMessage]],
|
||||||
temperature: float = 0.6,
|
temperature: float = 0.6,
|
||||||
top_p: float = 0.9,
|
top_p: float = 0.9,
|
||||||
max_gen_len: Optional[int] = None,
|
max_gen_len: Optional[int] = None,
|
||||||
logprobs: bool = False,
|
logprobs: bool = False,
|
||||||
echo: bool = False,
|
echo: bool = False,
|
||||||
) -> Generator:
|
) -> Generator[List[GenerationResult], None, None]:
|
||||||
llm_input = self.formatter.encode_dialog_prompt(messages)
|
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||||
for result in self.generate(
|
for result in self.generate(
|
||||||
llm_input=llm_input,
|
llm_inputs=llm_inputs,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
logprobs=logprobs,
|
logprobs=logprobs,
|
||||||
echo=echo,
|
echo=echo,
|
||||||
):
|
):
|
||||||
if result.token in self.tokenizer.stop_tokens:
|
|
||||||
break
|
|
||||||
yield result
|
yield result
|
||||||
|
if all(r.finished for r in result):
|
||||||
def chat_completion_raw(
|
break
|
||||||
self,
|
|
||||||
messages: List[RawMessage],
|
|
||||||
temperature: float = 0.6,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
max_gen_len: Optional[int] = None,
|
|
||||||
logprobs: bool = False,
|
|
||||||
):
|
|
||||||
llm_input = self.formatter.encode_dialog_prompt(messages)
|
|
||||||
output_tokens = []
|
|
||||||
for result in self.generate(
|
|
||||||
llm_input=llm_input,
|
|
||||||
temperature=temperature,
|
|
||||||
top_p=top_p,
|
|
||||||
max_gen_len=max_gen_len,
|
|
||||||
logprobs=logprobs,
|
|
||||||
):
|
|
||||||
output_tokens.append(result.token)
|
|
||||||
|
|
||||||
return llm_input.tokens, output_tokens
|
|
||||||
|
|
||||||
|
|
||||||
def sample_top_p(probs, p):
|
def sample_top_p(probs, p):
|
|
@ -4,16 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Dict, List, Optional, Tuple
|
from typing import Any, Dict, List, Optional, Tuple
|
||||||
|
|
||||||
|
@ -184,7 +174,6 @@ class Attention(nn.Module):
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
).cuda()
|
).cuda()
|
||||||
|
|
||||||
self.qk_norm = None
|
self.qk_norm = None
|
||||||
if self.use_qk_norm:
|
if self.use_qk_norm:
|
||||||
self.qk_norm = L2Norm(args.norm_eps)
|
self.qk_norm = L2Norm(args.norm_eps)
|
|
@ -100,31 +100,21 @@ class Experts(nn.Module):
|
||||||
|
|
||||||
class MoE(torch.nn.Module):
|
class MoE(torch.nn.Module):
|
||||||
"""
|
"""
|
||||||
This EC implementation is modified from the original EC module.
|
|
||||||
We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding.
|
|
||||||
This module supports 3 sharding methods of the experts:
|
|
||||||
- tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding.
|
|
||||||
- tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded.
|
|
||||||
- dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding.
|
|
||||||
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
|
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
|
||||||
Several commonly used annotations include:
|
Several commonly used annotations include:
|
||||||
- a: bsz*slen
|
- a: bsz*slen
|
||||||
- E: number of experts
|
- E: number of experts
|
||||||
- e: number of local experts per ep (n_experts/ep)
|
- e: number of local experts per ep (n_experts/ep)
|
||||||
- et: number of local experts per tp (n_experts/tp)
|
|
||||||
- D: hidden dimension
|
- D: hidden dimension
|
||||||
- d: D/tp
|
- d: D/tp
|
||||||
- F: model dimension
|
- F: model dimension
|
||||||
- f: F/tp (used in column/row-parallel linear)
|
|
||||||
- G: number of tokens per expert (a * capacity_factor / E)
|
- G: number of tokens per expert (a * capacity_factor / E)
|
||||||
- g: number of tokens per expert per TP rank (i.e., G/TP)
|
- g: number of tokens per expert per TP rank (i.e., G/TP)
|
||||||
- GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False)
|
|
||||||
- gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True)
|
|
||||||
|
|
||||||
Examples:
|
Examples:
|
||||||
x_aD [a, D]
|
x_aD [a, D]
|
||||||
routed_in_etG_D [et*G, D]
|
routed_in_etG_D [et*G, D]
|
||||||
x_eGGD: [e, GG, D]
|
x_eGD: [e, G, D]
|
||||||
"""
|
"""
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -207,13 +197,13 @@ class MoE(torch.nn.Module):
|
||||||
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
|
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
|
||||||
|
|
||||||
out_aD = self.shared_expert(x_aD)
|
out_aD = self.shared_expert(x_aD)
|
||||||
routed_out_egg_D = self.experts(routed_in_EG_D.detach())
|
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
|
||||||
|
|
||||||
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
|
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
|
||||||
out_aD.scatter_add_(
|
out_aD.scatter_add_(
|
||||||
dim=0,
|
dim=0,
|
||||||
index=router_indices_EG_D,
|
index=router_indices_EG_D,
|
||||||
src=routed_out_egg_D.view(-1, D),
|
src=routed_out_eg_D.view(-1, D),
|
||||||
)
|
)
|
||||||
out_aD = reduce_from_model_parallel_region(out_aD)
|
out_aD = reduce_from_model_parallel_region(out_aD)
|
||||||
return out_aD.view(-1, slen, D)
|
return out_aD.view(-1, slen, D)
|
|
@ -4,20 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem
|
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||||
from llama_stack.models.llama.prompt_format import (
|
from ..prompt_format import (
|
||||||
Llama4UseCase,
|
Llama4UseCase,
|
||||||
TextCompletionContent,
|
TextCompletionContent,
|
||||||
UseCase,
|
UseCase,
|
||||||
|
|
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -6,20 +6,29 @@
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional
|
from typing import Callable, Optional
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
from torch import Tensor
|
from torch import Tensor, nn
|
||||||
from torch.nn import functional as F
|
from torch.nn import functional as F
|
||||||
|
|
||||||
from ..generation import QuantizationMode
|
from ...datatypes import QuantizationMode
|
||||||
from ..model import Transformer, TransformerBlock
|
from ..model import Transformer, TransformerBlock
|
||||||
from ..moe import MoE
|
from ..moe import MoE
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
def swiglu_wrapper_no_reduce(
|
||||||
|
self,
|
||||||
|
x: Tensor,
|
||||||
|
):
|
||||||
|
from ...quantize_impls import ffn_swiglu
|
||||||
|
|
||||||
|
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||||
|
|
||||||
|
|
||||||
def experts_batched_swiglu_wrapper(
|
def experts_batched_swiglu_wrapper(
|
||||||
self,
|
self,
|
||||||
x: Tensor, # (e, g, D)
|
x: Tensor, # (e, g, D)
|
||||||
|
@ -51,24 +60,30 @@ def convert_to_quantized_model(
|
||||||
|
|
||||||
rank = get_model_parallel_rank()
|
rank = get_model_parallel_rank()
|
||||||
|
|
||||||
|
def should_quantize_block(block: nn.Module) -> bool:
|
||||||
|
if not isinstance(block, TransformerBlock):
|
||||||
|
return False
|
||||||
|
|
||||||
|
is_moe = isinstance(block.feed_forward, MoE)
|
||||||
|
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||||
|
# skip quantization on first and last layers
|
||||||
|
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
|
||||||
|
|
||||||
|
return is_moe
|
||||||
|
|
||||||
use_rich_progress = use_rich_progress and rank == 0
|
use_rich_progress = use_rich_progress and rank == 0
|
||||||
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model)
|
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
|
||||||
if quantization_mode == QuantizationMode.int4_mixed:
|
if quantization_mode == QuantizationMode.int4_mixed:
|
||||||
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
|
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
|
||||||
int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt")
|
|
||||||
if os.path.isfile(int4_scales_path):
|
if os.path.isfile(int4_scales_path):
|
||||||
log_status(f"Rank {rank}: Loading int4 scales")
|
log_status(f"Rank {rank}: Loading int4 scales")
|
||||||
int4_scales = torch.load(int4_scales_path, weights_only=True)
|
int4_scales = torch.load(int4_scales_path, weights_only=True)
|
||||||
int4_zero_points = torch.load(int4_zero_points_path, weights_only=True)
|
|
||||||
|
|
||||||
def apply_quantization(key, weight):
|
def apply_quantization(key, weight):
|
||||||
scale = int4_scales[key]
|
scale = int4_scales[key]
|
||||||
zero_point = int4_zero_points[key]
|
|
||||||
return load_int4(
|
return load_int4(
|
||||||
weight,
|
weight,
|
||||||
scale,
|
scale,
|
||||||
zero_point,
|
|
||||||
fp8_activation_scale_ub,
|
|
||||||
output_device=torch.device("cuda"),
|
output_device=torch.device("cuda"),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -77,6 +92,7 @@ def convert_to_quantized_model(
|
||||||
|
|
||||||
def apply_quantization(_, weight):
|
def apply_quantization(_, weight):
|
||||||
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||||
|
|
||||||
else:
|
else:
|
||||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
||||||
if os.path.isfile(fp8_scales_path):
|
if os.path.isfile(fp8_scales_path):
|
||||||
|
@ -104,33 +120,38 @@ def convert_to_quantized_model(
|
||||||
progress.start()
|
progress.start()
|
||||||
|
|
||||||
for _, block in model.named_modules():
|
for _, block in model.named_modules():
|
||||||
if isinstance(block, TransformerBlock):
|
if not should_quantize_block(block):
|
||||||
# Skip quantization on first and last layers
|
continue
|
||||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
|
||||||
continue
|
|
||||||
|
|
||||||
# Skip quantization on dense layers
|
update_status(f"Rank {rank} - Layer {block.layer_id}")
|
||||||
if not isinstance(block.feed_forward, MoE):
|
|
||||||
continue
|
|
||||||
|
|
||||||
update_status(f"Rank {rank} - Layer {block.layer_id}")
|
# Quantize only routed experts, not shared
|
||||||
|
prefix = f"layers.{block.layer_id}.feed_forward"
|
||||||
|
moe = block.feed_forward
|
||||||
|
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
|
||||||
|
|
||||||
# Quantize only routed experts, not shared
|
for key in ("w1", "w3", "w2"):
|
||||||
prefix = f"layers.{block.layer_id}.feed_forward"
|
param = getattr(moe.experts, key)
|
||||||
moe = block.feed_forward
|
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
|
||||||
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
|
setattr(
|
||||||
|
moe.experts,
|
||||||
|
key,
|
||||||
|
apply_quantization(
|
||||||
|
f"{prefix}.experts.{key}",
|
||||||
|
param.transpose(1, 2).contiguous(),
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
if quantization_mode == QuantizationMode.int4_mixed:
|
||||||
|
# Quantize shared experts
|
||||||
|
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
|
||||||
for key in ("w1", "w3", "w2"):
|
for key in ("w1", "w3", "w2"):
|
||||||
param = getattr(moe.experts, key)
|
param = getattr(moe.shared_expert, key)
|
||||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
|
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
|
||||||
setattr(
|
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
|
||||||
moe.experts,
|
|
||||||
key,
|
|
||||||
apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()),
|
|
||||||
)
|
|
||||||
|
|
||||||
processed_blocks += 1
|
processed_blocks += 1
|
||||||
update_status(message=None, completed=processed_blocks)
|
update_status(message=None, completed=processed_blocks)
|
||||||
|
|
||||||
update_status(f"Rank {rank} - Moving parameters to CUDA")
|
update_status(f"Rank {rank} - Moving parameters to CUDA")
|
||||||
|
|
||||||
|
@ -149,7 +170,12 @@ def convert_to_quantized_model(
|
||||||
|
|
||||||
|
|
||||||
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
|
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
|
||||||
def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
|
def logging_callbacks(
|
||||||
|
use_rich_progress: bool,
|
||||||
|
rank: int,
|
||||||
|
model: Transformer,
|
||||||
|
should_quantize_block: Callable[[nn.Module], bool],
|
||||||
|
):
|
||||||
console = None
|
console = None
|
||||||
if use_rich_progress:
|
if use_rich_progress:
|
||||||
from rich.console import Console
|
from rich.console import Console
|
||||||
|
@ -162,15 +188,7 @@ def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
|
||||||
elif rank == 0: # Only log from rank 0 for non-rich logging
|
elif rank == 0: # Only log from rank 0 for non-rich logging
|
||||||
log.info(message)
|
log.info(message)
|
||||||
|
|
||||||
total_blocks = sum(
|
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
|
||||||
1
|
|
||||||
for _, block in model.named_modules()
|
|
||||||
if (
|
|
||||||
isinstance(block, TransformerBlock)
|
|
||||||
and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
|
|
||||||
and isinstance(block.feed_forward, MoE)
|
|
||||||
)
|
|
||||||
)
|
|
||||||
progress = None
|
progress = None
|
||||||
if use_rich_progress:
|
if use_rich_progress:
|
||||||
from rich.progress import (
|
from rich.progress import (
|
|
@ -4,9 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
@ -59,11 +56,11 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
|
||||||
"<|text_post_train_reserved_special_token_3|>",
|
"<|text_post_train_reserved_special_token_3|>",
|
||||||
"<|text_post_train_reserved_special_token_4|>",
|
"<|text_post_train_reserved_special_token_4|>",
|
||||||
"<|text_post_train_reserved_special_token_5|>",
|
"<|text_post_train_reserved_special_token_5|>",
|
||||||
"<|python_start|>",
|
"<|text_post_train_reserved_special_token_6|>",
|
||||||
"<|python_end|>",
|
"<|text_post_train_reserved_special_token_7|>",
|
||||||
"<|finetune_right_pad|>",
|
"<|finetune_right_pad|>",
|
||||||
] + get_reserved_special_tokens(
|
] + get_reserved_special_tokens(
|
||||||
"text_post_train", 61, 6
|
"text_post_train", 61, 8
|
||||||
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
|
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
|
||||||
|
|
||||||
# 200080, ..., 201133
|
# 200080, ..., 201133
|
||||||
|
@ -85,8 +82,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [
|
||||||
"vision", 1041, 7
|
"vision", 1041, 7
|
||||||
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
|
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
|
||||||
|
|
||||||
|
# 201134, ..., 201143
|
||||||
|
LLAMA4_REASONING_SPECIAL_TOKENS = [
|
||||||
|
"<|reasoning_reserved_special_token_0|>",
|
||||||
|
"<|reasoning_reserved_special_token_1|>",
|
||||||
|
"<|reasoning_reserved_special_token_2|>",
|
||||||
|
"<|reasoning_reserved_special_token_3|>",
|
||||||
|
"<|reasoning_reserved_special_token_4|>",
|
||||||
|
"<|reasoning_reserved_special_token_5|>",
|
||||||
|
"<|reasoning_reserved_special_token_6|>",
|
||||||
|
"<|reasoning_reserved_special_token_7|>",
|
||||||
|
"<|reasoning_thinking_start|>",
|
||||||
|
"<|reasoning_thinking_end|>",
|
||||||
|
]
|
||||||
|
|
||||||
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
|
LLAMA4_SPECIAL_TOKENS = (
|
||||||
|
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
|
||||||
|
)
|
||||||
|
|
||||||
BASIC_SPECIAL_TOKENS = [
|
BASIC_SPECIAL_TOKENS = [
|
||||||
"<|begin_of_text|>",
|
"<|begin_of_text|>",
|
||||||
|
@ -155,6 +167,9 @@ class Tokenizer:
|
||||||
self.eot_id: int = self.special_tokens["<|eot|>"]
|
self.eot_id: int = self.special_tokens["<|eot|>"]
|
||||||
self.eom_id: int = self.special_tokens["<|eom|>"]
|
self.eom_id: int = self.special_tokens["<|eom|>"]
|
||||||
|
|
||||||
|
self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
|
||||||
|
self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
|
||||||
|
|
||||||
self.stop_tokens = [
|
self.stop_tokens = [
|
||||||
self.eos_id,
|
self.eos_id,
|
||||||
self.special_tokens["<|eom|>"],
|
self.special_tokens["<|eom|>"],
|
||||||
|
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Any, Callable, Dict, List
|
from typing import Any, Callable, Dict, List
|
||||||
|
|
|
@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
|
||||||
LLMInput,
|
|
||||||
)
|
|
||||||
|
|
||||||
from .llama3.interface import LLama31Interface
|
from .llama3.interface import LLama31Interface
|
||||||
from .llama3.template_data import (
|
from .llama3.template_data import (
|
||||||
|
@ -76,21 +73,22 @@ class UseCase(BaseModel):
|
||||||
text += dialog
|
text += dialog
|
||||||
text += "\n\n"
|
text += "\n\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elif isinstance(dialog, TextCompletionContent):
|
|
||||||
input_tokens, output_tokens = generator.text_completion_raw(
|
|
||||||
dialog.content,
|
|
||||||
temperature=0.1,
|
|
||||||
top_p=0.95,
|
|
||||||
max_gen_len=64,
|
|
||||||
)
|
|
||||||
else:
|
else:
|
||||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
batch = [dialog]
|
||||||
dialog,
|
method = (
|
||||||
temperature=0.0,
|
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||||
top_p=0.95,
|
|
||||||
max_gen_len=self.max_gen_len,
|
|
||||||
)
|
)
|
||||||
|
input_tokens = []
|
||||||
|
output_tokens = []
|
||||||
|
for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95):
|
||||||
|
result = token_results[0]
|
||||||
|
if result.source == "input":
|
||||||
|
input_tokens.append(result.token)
|
||||||
|
else:
|
||||||
|
output_tokens.append(result.token)
|
||||||
|
|
||||||
|
if result.finished:
|
||||||
|
break
|
||||||
text += "##### Input Prompt Format\n"
|
text += "##### Input Prompt Format\n"
|
||||||
|
|
||||||
# FIXME: This is added to undo the hack in chat_formatter where
|
# FIXME: This is added to undo the hack in chat_formatter where
|
||||||
|
@ -126,27 +124,27 @@ class Llama4UseCase(UseCase):
|
||||||
|
|
||||||
text = ""
|
text = ""
|
||||||
tokenizer = Tokenizer.get_instance()
|
tokenizer = Tokenizer.get_instance()
|
||||||
temperature = 0.0
|
|
||||||
for dialog in self.dialogs:
|
for dialog in self.dialogs:
|
||||||
if isinstance(dialog, str):
|
if isinstance(dialog, str):
|
||||||
text += dialog
|
text += dialog
|
||||||
text += "\n\n"
|
text += "\n\n"
|
||||||
continue
|
continue
|
||||||
|
|
||||||
elif isinstance(dialog, TextCompletionContent):
|
|
||||||
# TODO pass the raw input and do the encoding in the text completion function
|
|
||||||
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
|
|
||||||
llm_input = LLMInput(tokens=input_tokens)
|
|
||||||
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
|
|
||||||
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
|
|
||||||
)
|
|
||||||
|
|
||||||
else:
|
else:
|
||||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
batch = [dialog]
|
||||||
dialog,
|
method = (
|
||||||
temperature=temperature,
|
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||||
max_gen_len=self.max_gen_len,
|
|
||||||
)
|
)
|
||||||
|
input_tokens = []
|
||||||
|
output_tokens = []
|
||||||
|
for token_results in method(batch, echo=True, temperature=0.0):
|
||||||
|
result = token_results[0]
|
||||||
|
if result.source == "input":
|
||||||
|
input_tokens.append(result.token)
|
||||||
|
else:
|
||||||
|
output_tokens.append(result.token)
|
||||||
|
|
||||||
|
if result.finished:
|
||||||
|
break
|
||||||
|
|
||||||
text += "##### Input Prompt Format\n"
|
text += "##### Input Prompt Format\n"
|
||||||
text += _code_block(tokenizer.decode(input_tokens))
|
text += _code_block(tokenizer.decode(input_tokens))
|
||||||
|
|
|
@ -4,24 +4,15 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from functools import lru_cache
|
from functools import lru_cache
|
||||||
from typing import List, Optional
|
from typing import List, Optional
|
||||||
|
|
||||||
from .datatypes import (
|
from .sku_types import (
|
||||||
CheckpointQuantizationFormat,
|
CheckpointQuantizationFormat,
|
||||||
CoreModelId,
|
CoreModelId,
|
||||||
Model,
|
Model,
|
||||||
ModelFamily,
|
ModelFamily,
|
||||||
SamplingParams,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
LLAMA2_VOCAB_SIZE = 32000
|
LLAMA2_VOCAB_SIZE = 32000
|
||||||
|
@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]:
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
def recommended_sampling_params() -> SamplingParams:
|
|
||||||
return SamplingParams(
|
|
||||||
strategy=TopPSamplingStrategy(
|
|
||||||
temperature=1.0,
|
|
||||||
top_p=0.9,
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def llama2_family() -> List[Model]:
|
def llama2_family() -> List[Model]:
|
||||||
return [
|
return [
|
||||||
*llama2_base_models(),
|
*llama2_base_models(),
|
||||||
|
@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama2_7b,
|
core_model_id=CoreModelId.llama2_7b,
|
||||||
description="Llama 2 7b model",
|
description="Llama 2 7b model",
|
||||||
huggingface_repo="meta-llama/Llama-2-7b",
|
huggingface_repo="meta-llama/Llama-2-7b",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 4096,
|
"dim": 4096,
|
||||||
"n_layers": 32,
|
"n_layers": 32,
|
||||||
|
@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama2_13b,
|
core_model_id=CoreModelId.llama2_13b,
|
||||||
description="Llama 2 13b model",
|
description="Llama 2 13b model",
|
||||||
huggingface_repo="meta-llama/Llama-2-13b",
|
huggingface_repo="meta-llama/Llama-2-13b",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 5120,
|
"dim": 5120,
|
||||||
"n_layers": 40,
|
"n_layers": 40,
|
||||||
|
@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama2_70b,
|
core_model_id=CoreModelId.llama2_70b,
|
||||||
description="Llama 2 70b model",
|
description="Llama 2 70b model",
|
||||||
huggingface_repo="meta-llama/Llama-2-70b",
|
huggingface_repo="meta-llama/Llama-2-70b",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 8192,
|
"dim": 8192,
|
||||||
"n_layers": 80,
|
"n_layers": 80,
|
||||||
|
@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_70b,
|
core_model_id=CoreModelId.llama3_70b,
|
||||||
description="Llama 3 70b model",
|
description="Llama 3 70b model",
|
||||||
huggingface_repo="meta-llama/Llama-3-70B",
|
huggingface_repo="meta-llama/Llama-3-70B",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 8192,
|
"dim": 8192,
|
||||||
"n_layers": 80,
|
"n_layers": 80,
|
||||||
|
@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_1_8b,
|
core_model_id=CoreModelId.llama3_1_8b,
|
||||||
description="Llama 3.1 8b model",
|
description="Llama 3.1 8b model",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-8B",
|
huggingface_repo="meta-llama/Llama-3.1-8B",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 4096,
|
"dim": 4096,
|
||||||
"n_layers": 32,
|
"n_layers": 32,
|
||||||
|
@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_1_70b,
|
core_model_id=CoreModelId.llama3_1_70b,
|
||||||
description="Llama 3.1 70b model",
|
description="Llama 3.1 70b model",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-70B",
|
huggingface_repo="meta-llama/Llama-3.1-70B",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 8192,
|
"dim": 8192,
|
||||||
"n_layers": 80,
|
"n_layers": 80,
|
||||||
|
@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
|
||||||
variant="bf16-mp8",
|
variant="bf16-mp8",
|
||||||
description="Llama 3.1 405b model (BF16 weights)",
|
description="Llama 3.1 405b model (BF16 weights)",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-405B",
|
huggingface_repo="meta-llama/Llama-3.1-405B",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 16384,
|
"dim": 16384,
|
||||||
"n_layers": 126,
|
"n_layers": 126,
|
||||||
|
@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
|
||||||
description="Llama 3.1 405b model (FP8 quantized)",
|
description="Llama 3.1 405b model (FP8 quantized)",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
|
huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
|
||||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 16384,
|
"dim": 16384,
|
||||||
"n_layers": 126,
|
"n_layers": 126,
|
||||||
|
@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
|
||||||
variant="bf16-mp16",
|
variant="bf16-mp16",
|
||||||
description="Llama 3.1 405b model (BF16 weights for mp16)",
|
description="Llama 3.1 405b model (BF16 weights for mp16)",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-405B",
|
huggingface_repo="meta-llama/Llama-3.1-405B",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 16384,
|
"dim": 16384,
|
||||||
"n_layers": 126,
|
"n_layers": 126,
|
||||||
|
@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_2_1b,
|
core_model_id=CoreModelId.llama3_2_1b,
|
||||||
description="Llama 3.2 1b model",
|
description="Llama 3.2 1b model",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-1B",
|
huggingface_repo="meta-llama/Llama-3.2-1B",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 2048,
|
"dim": 2048,
|
||||||
"n_layers": 16,
|
"n_layers": 16,
|
||||||
|
@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_2_3b,
|
core_model_id=CoreModelId.llama3_2_3b,
|
||||||
description="Llama 3.2 3b model",
|
description="Llama 3.2 3b model",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-3B",
|
huggingface_repo="meta-llama/Llama-3.2-3B",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 3072,
|
"dim": 3072,
|
||||||
"n_layers": 28,
|
"n_layers": 28,
|
||||||
|
@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_2_11b_vision,
|
core_model_id=CoreModelId.llama3_2_11b_vision,
|
||||||
description="Llama 3.2 11b vision model",
|
description="Llama 3.2 11b vision model",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
|
huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 4096,
|
"dim": 4096,
|
||||||
"n_layers": 32,
|
"n_layers": 32,
|
||||||
|
@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_2_90b_vision,
|
core_model_id=CoreModelId.llama3_2_90b_vision,
|
||||||
description="Llama 3.2 90b vision model",
|
description="Llama 3.2 90b vision model",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
|
huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 8192,
|
"dim": 8192,
|
||||||
"n_layers": 80,
|
"n_layers": 80,
|
||||||
|
@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama2_7b_chat,
|
core_model_id=CoreModelId.llama2_7b_chat,
|
||||||
description="Llama 2 7b chat model",
|
description="Llama 2 7b chat model",
|
||||||
huggingface_repo="meta-llama/Llama-2-7b-chat",
|
huggingface_repo="meta-llama/Llama-2-7b-chat",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 4096,
|
"dim": 4096,
|
||||||
"n_layers": 32,
|
"n_layers": 32,
|
||||||
|
@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama2_13b_chat,
|
core_model_id=CoreModelId.llama2_13b_chat,
|
||||||
description="Llama 2 13b chat model",
|
description="Llama 2 13b chat model",
|
||||||
huggingface_repo="meta-llama/Llama-2-13b-chat",
|
huggingface_repo="meta-llama/Llama-2-13b-chat",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 5120,
|
"dim": 5120,
|
||||||
"n_layers": 40,
|
"n_layers": 40,
|
||||||
|
@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama2_70b_chat,
|
core_model_id=CoreModelId.llama2_70b_chat,
|
||||||
description="Llama 2 70b chat model",
|
description="Llama 2 70b chat model",
|
||||||
huggingface_repo="meta-llama/Llama-2-70b-chat",
|
huggingface_repo="meta-llama/Llama-2-70b-chat",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 8192,
|
"dim": 8192,
|
||||||
"n_layers": 80,
|
"n_layers": 80,
|
||||||
|
@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_8b_instruct,
|
core_model_id=CoreModelId.llama3_8b_instruct,
|
||||||
description="Llama 3 8b instruct model",
|
description="Llama 3 8b instruct model",
|
||||||
huggingface_repo="meta-llama/Llama-3-8B-Instruct",
|
huggingface_repo="meta-llama/Llama-3-8B-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 4096,
|
"dim": 4096,
|
||||||
"n_layers": 32,
|
"n_layers": 32,
|
||||||
|
@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_70b_instruct,
|
core_model_id=CoreModelId.llama3_70b_instruct,
|
||||||
description="Llama 3 70b instruct model",
|
description="Llama 3 70b instruct model",
|
||||||
huggingface_repo="meta-llama/Llama-3-70B-Instruct",
|
huggingface_repo="meta-llama/Llama-3-70B-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 8192,
|
"dim": 8192,
|
||||||
"n_layers": 80,
|
"n_layers": 80,
|
||||||
|
@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_1_8b_instruct,
|
core_model_id=CoreModelId.llama3_1_8b_instruct,
|
||||||
description="Llama 3.1 8b instruct model",
|
description="Llama 3.1 8b instruct model",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
|
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 4096,
|
"dim": 4096,
|
||||||
"n_layers": 32,
|
"n_layers": 32,
|
||||||
|
@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_1_70b_instruct,
|
core_model_id=CoreModelId.llama3_1_70b_instruct,
|
||||||
description="Llama 3.1 70b instruct model",
|
description="Llama 3.1 70b instruct model",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
|
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 8192,
|
"dim": 8192,
|
||||||
"n_layers": 80,
|
"n_layers": 80,
|
||||||
|
@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
||||||
variant="bf16-mp8",
|
variant="bf16-mp8",
|
||||||
description="Llama 3.1 405b instruct model (BF16 weights)",
|
description="Llama 3.1 405b instruct model (BF16 weights)",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 16384,
|
"dim": 16384,
|
||||||
"n_layers": 126,
|
"n_layers": 126,
|
||||||
|
@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
||||||
description="Llama 3.1 405b instruct model (FP8 quantized)",
|
description="Llama 3.1 405b instruct model (FP8 quantized)",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
|
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 16384,
|
"dim": 16384,
|
||||||
"n_layers": 126,
|
"n_layers": 126,
|
||||||
|
@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
||||||
variant="bf16-mp16",
|
variant="bf16-mp16",
|
||||||
description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
|
description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
|
||||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 16384,
|
"dim": 16384,
|
||||||
"n_layers": 126,
|
"n_layers": 126,
|
||||||
|
@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
||||||
quantization_format=CheckpointQuantizationFormat.int4,
|
quantization_format=CheckpointQuantizationFormat.int4,
|
||||||
description="Llama 3.2 1b INT4 quantized LoRA",
|
description="Llama 3.2 1b INT4 quantized LoRA",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
|
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
**arch_args_1b(),
|
**arch_args_1b(),
|
||||||
"quantization_args": {
|
"quantization_args": {
|
||||||
|
@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
||||||
quantization_format=CheckpointQuantizationFormat.int4,
|
quantization_format=CheckpointQuantizationFormat.int4,
|
||||||
description="Llama 3.2 1b INT4 quantized SpinQuant",
|
description="Llama 3.2 1b INT4 quantized SpinQuant",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
|
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
**arch_args_1b(),
|
**arch_args_1b(),
|
||||||
"quantization_args": {
|
"quantization_args": {
|
||||||
|
@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
||||||
quantization_format=CheckpointQuantizationFormat.int4,
|
quantization_format=CheckpointQuantizationFormat.int4,
|
||||||
description="Llama 3.2 3b INT4 quantized LoRA",
|
description="Llama 3.2 3b INT4 quantized LoRA",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
|
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
**arch_args_3b(),
|
**arch_args_3b(),
|
||||||
"quantization_args": {
|
"quantization_args": {
|
||||||
|
@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
||||||
quantization_format=CheckpointQuantizationFormat.int4,
|
quantization_format=CheckpointQuantizationFormat.int4,
|
||||||
description="Llama 3.2 3b INT4 quantized SpinQuant",
|
description="Llama 3.2 3b INT4 quantized SpinQuant",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
|
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
**arch_args_3b(),
|
**arch_args_3b(),
|
||||||
"quantization_args": {
|
"quantization_args": {
|
||||||
|
@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
||||||
description="Llama 3.2 1b instruct model",
|
description="Llama 3.2 1b instruct model",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
|
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args=arch_args_1b(),
|
arch_args=arch_args_1b(),
|
||||||
pth_file_count=1,
|
pth_file_count=1,
|
||||||
),
|
),
|
||||||
|
@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_2_3b_instruct,
|
core_model_id=CoreModelId.llama3_2_3b_instruct,
|
||||||
description="Llama 3.2 3b instruct model",
|
description="Llama 3.2 3b instruct model",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
|
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args=arch_args_3b(),
|
arch_args=arch_args_3b(),
|
||||||
pth_file_count=1,
|
pth_file_count=1,
|
||||||
),
|
),
|
||||||
|
@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
|
core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
|
||||||
description="Llama 3.2 11b vision instruct model",
|
description="Llama 3.2 11b vision instruct model",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 4096,
|
"dim": 4096,
|
||||||
"n_layers": 32,
|
"n_layers": 32,
|
||||||
|
@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
|
core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
|
||||||
description="Llama 3.2 90b vision instruct model",
|
description="Llama 3.2 90b vision instruct model",
|
||||||
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
|
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 8192,
|
"dim": 8192,
|
||||||
"n_layers": 80,
|
"n_layers": 80,
|
||||||
|
@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama3_3_70b_instruct,
|
core_model_id=CoreModelId.llama3_3_70b_instruct,
|
||||||
description="Llama 3.3 70b instruct",
|
description="Llama 3.3 70b instruct",
|
||||||
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
|
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 8192,
|
"dim": 8192,
|
||||||
"n_layers": 80,
|
"n_layers": 80,
|
||||||
|
@ -846,7 +796,6 @@ def safety_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama_guard_3_11b_vision,
|
core_model_id=CoreModelId.llama_guard_3_11b_vision,
|
||||||
description="Llama Guard v3 11b vision system safety model",
|
description="Llama Guard v3 11b vision system safety model",
|
||||||
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
|
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 4096,
|
"dim": 4096,
|
||||||
"n_layers": 32,
|
"n_layers": 32,
|
||||||
|
@ -870,7 +819,6 @@ def safety_models() -> List[Model]:
|
||||||
description="Llama Guard v3 1b 'int4' quantized system safety model",
|
description="Llama Guard v3 1b 'int4' quantized system safety model",
|
||||||
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
|
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
|
||||||
quantization_format=CheckpointQuantizationFormat.int4,
|
quantization_format=CheckpointQuantizationFormat.int4,
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 2048,
|
"dim": 2048,
|
||||||
"n_layers": 12,
|
"n_layers": 12,
|
||||||
|
@ -888,7 +836,6 @@ def safety_models() -> List[Model]:
|
||||||
core_model_id=CoreModelId.llama_guard_3_1b,
|
core_model_id=CoreModelId.llama_guard_3_1b,
|
||||||
description="Llama Guard v3 1b system safety model",
|
description="Llama Guard v3 1b system safety model",
|
||||||
huggingface_repo="meta-llama/Llama-Guard-3-1B",
|
huggingface_repo="meta-llama/Llama-Guard-3-1B",
|
||||||
recommended_sampling_params=recommended_sampling_params(),
|
|
||||||
arch_args={
|
arch_args={
|
||||||
"dim": 2048,
|
"dim": 2048,
|
||||||
"n_layers": 16,
|
"n_layers": 16,
|
||||||
|
|
229
llama_stack/models/llama/sku_types.py
Normal file
229
llama_stack/models/llama/sku_types.py
Normal file
|
@ -0,0 +1,229 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from enum import Enum
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
|
|
||||||
|
class CheckpointQuantizationFormat(Enum):
|
||||||
|
# default format
|
||||||
|
bf16 = "bf16"
|
||||||
|
|
||||||
|
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||||
|
fp8_mixed = "fp8-mixed"
|
||||||
|
|
||||||
|
int8 = "int8"
|
||||||
|
|
||||||
|
int4 = "int4"
|
||||||
|
|
||||||
|
|
||||||
|
class ModelFamily(Enum):
|
||||||
|
llama2 = "llama2"
|
||||||
|
llama3 = "llama3"
|
||||||
|
llama3_1 = "llama3_1"
|
||||||
|
llama3_2 = "llama3_2"
|
||||||
|
llama3_3 = "llama3_3"
|
||||||
|
llama4 = "llama4"
|
||||||
|
safety = "safety"
|
||||||
|
|
||||||
|
|
||||||
|
class CoreModelId(Enum):
|
||||||
|
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||||
|
|
||||||
|
# Llama 2 family
|
||||||
|
llama2_7b = "Llama-2-7b"
|
||||||
|
llama2_13b = "Llama-2-13b"
|
||||||
|
llama2_70b = "Llama-2-70b"
|
||||||
|
llama2_7b_chat = "Llama-2-7b-chat"
|
||||||
|
llama2_13b_chat = "Llama-2-13b-chat"
|
||||||
|
llama2_70b_chat = "Llama-2-70b-chat"
|
||||||
|
|
||||||
|
# Llama 3 family
|
||||||
|
llama3_8b = "Llama-3-8B"
|
||||||
|
llama3_70b = "Llama-3-70B"
|
||||||
|
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||||
|
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||||
|
|
||||||
|
# Llama 3.1 family
|
||||||
|
llama3_1_8b = "Llama3.1-8B"
|
||||||
|
llama3_1_70b = "Llama3.1-70B"
|
||||||
|
llama3_1_405b = "Llama3.1-405B"
|
||||||
|
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||||
|
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||||
|
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||||
|
|
||||||
|
# Llama 3.2 family
|
||||||
|
llama3_2_1b = "Llama3.2-1B"
|
||||||
|
llama3_2_3b = "Llama3.2-3B"
|
||||||
|
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||||
|
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||||
|
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||||
|
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||||
|
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||||
|
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||||
|
|
||||||
|
# Llama 3.3 family
|
||||||
|
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||||
|
|
||||||
|
# Llama 4 family
|
||||||
|
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
|
||||||
|
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
|
||||||
|
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
|
||||||
|
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
|
||||||
|
|
||||||
|
# Safety models
|
||||||
|
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||||
|
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||||
|
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||||
|
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||||
|
|
||||||
|
|
||||||
|
def is_multimodal(model_id) -> bool:
|
||||||
|
if model_id in [
|
||||||
|
CoreModelId.llama3_2_11b_vision,
|
||||||
|
CoreModelId.llama3_2_90b_vision,
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct,
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct,
|
||||||
|
]:
|
||||||
|
return True
|
||||||
|
else:
|
||||||
|
return False
|
||||||
|
|
||||||
|
|
||||||
|
def model_family(model_id) -> ModelFamily:
|
||||||
|
if model_id in [
|
||||||
|
CoreModelId.llama2_7b,
|
||||||
|
CoreModelId.llama2_13b,
|
||||||
|
CoreModelId.llama2_70b,
|
||||||
|
CoreModelId.llama2_7b_chat,
|
||||||
|
CoreModelId.llama2_13b_chat,
|
||||||
|
CoreModelId.llama2_70b_chat,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama2
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_8b,
|
||||||
|
CoreModelId.llama3_70b,
|
||||||
|
CoreModelId.llama3_8b_instruct,
|
||||||
|
CoreModelId.llama3_70b_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_1_8b,
|
||||||
|
CoreModelId.llama3_1_70b,
|
||||||
|
CoreModelId.llama3_1_405b,
|
||||||
|
CoreModelId.llama3_1_8b_instruct,
|
||||||
|
CoreModelId.llama3_1_70b_instruct,
|
||||||
|
CoreModelId.llama3_1_405b_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3_1
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_2_1b,
|
||||||
|
CoreModelId.llama3_2_3b,
|
||||||
|
CoreModelId.llama3_2_1b_instruct,
|
||||||
|
CoreModelId.llama3_2_3b_instruct,
|
||||||
|
CoreModelId.llama3_2_11b_vision,
|
||||||
|
CoreModelId.llama3_2_90b_vision,
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct,
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3_2
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama3_3_70b_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama3_3
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama4_scout_17b_16e,
|
||||||
|
CoreModelId.llama4_scout_17b_16e_instruct,
|
||||||
|
CoreModelId.llama4_maverick_17b_128e,
|
||||||
|
CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||||
|
]:
|
||||||
|
return ModelFamily.llama4
|
||||||
|
elif model_id in [
|
||||||
|
CoreModelId.llama_guard_3_8b,
|
||||||
|
CoreModelId.llama_guard_2_8b,
|
||||||
|
CoreModelId.llama_guard_3_11b_vision,
|
||||||
|
CoreModelId.llama_guard_3_1b,
|
||||||
|
]:
|
||||||
|
return ModelFamily.safety
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown model family for {model_id}")
|
||||||
|
|
||||||
|
|
||||||
|
class Model(BaseModel):
|
||||||
|
core_model_id: CoreModelId
|
||||||
|
description: str
|
||||||
|
huggingface_repo: Optional[str] = None
|
||||||
|
arch_args: Dict[str, Any]
|
||||||
|
variant: str = ""
|
||||||
|
|
||||||
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
|
pth_file_count: int
|
||||||
|
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||||
|
|
||||||
|
# silence pydantic until we remove the `model_` fields
|
||||||
|
model_config = ConfigDict(protected_namespaces=())
|
||||||
|
|
||||||
|
@property
|
||||||
|
def model_family(self) -> ModelFamily:
|
||||||
|
return model_family(self.core_model_id)
|
||||||
|
|
||||||
|
# The SKU is uniquely identified by (model_id, variant) combo
|
||||||
|
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||||
|
if not self.variant:
|
||||||
|
return self.core_model_id.value
|
||||||
|
return f"{self.core_model_id.value}:{self.variant}"
|
||||||
|
|
||||||
|
@property
|
||||||
|
def is_instruct_model(self) -> bool:
|
||||||
|
return "instruct" in self.core_model_id.value
|
||||||
|
|
||||||
|
# Featured models are shown in the non-exhaustive model list
|
||||||
|
@property
|
||||||
|
def is_featured(self) -> bool:
|
||||||
|
return self.model_family in [
|
||||||
|
ModelFamily.llama3_1,
|
||||||
|
ModelFamily.llama3_2,
|
||||||
|
ModelFamily.llama3_3,
|
||||||
|
ModelFamily.llama4,
|
||||||
|
ModelFamily.safety,
|
||||||
|
]
|
||||||
|
|
||||||
|
@property
|
||||||
|
def max_seq_length(self) -> int:
|
||||||
|
if self.model_family == ModelFamily.llama2:
|
||||||
|
return 4096
|
||||||
|
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||||
|
return 4096
|
||||||
|
elif self.model_family == ModelFamily.llama3:
|
||||||
|
return 8192
|
||||||
|
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||||
|
return 131072
|
||||||
|
elif self.model_family == ModelFamily.llama3_2:
|
||||||
|
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||||
|
return 8192
|
||||||
|
return 131072
|
||||||
|
elif self.model_family == ModelFamily.llama4:
|
||||||
|
if self.core_model_id in {
|
||||||
|
CoreModelId.llama4_scout_17b_16e,
|
||||||
|
CoreModelId.llama4_maverick_17b_128e,
|
||||||
|
}:
|
||||||
|
return 262144
|
||||||
|
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
|
||||||
|
return 10485760
|
||||||
|
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
||||||
|
return 1048576
|
||||||
|
|
||||||
|
raise AssertionError(f"Unexpected core model id: {self.core_model_id}")
|
||||||
|
elif self.core_model_id in [
|
||||||
|
CoreModelId.llama_guard_3_8b,
|
||||||
|
CoreModelId.llama_guard_3_11b_vision,
|
||||||
|
CoreModelId.llama_guard_3_1b,
|
||||||
|
]:
|
||||||
|
return 131072
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
|
@ -52,6 +52,7 @@ from llama_stack.apis.inference import (
|
||||||
StopReason,
|
StopReason,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
ToolParamDefinition,
|
||||||
ToolResponse,
|
ToolResponse,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
|
@ -63,7 +64,6 @@ from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolParamDefinition,
|
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.kvstore import KVStore
|
from llama_stack.providers.utils.kvstore import KVStore
|
||||||
from llama_stack.providers.utils.telemetry import tracing
|
from llama_stack.providers.utils.telemetry import tracing
|
||||||
|
@ -89,7 +89,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
self,
|
self,
|
||||||
agent_id: str,
|
agent_id: str,
|
||||||
agent_config: AgentConfig,
|
agent_config: AgentConfig,
|
||||||
tempdir: str,
|
|
||||||
inference_api: Inference,
|
inference_api: Inference,
|
||||||
safety_api: Safety,
|
safety_api: Safety,
|
||||||
tool_runtime_api: ToolRuntime,
|
tool_runtime_api: ToolRuntime,
|
||||||
|
@ -99,7 +98,6 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
):
|
):
|
||||||
self.agent_id = agent_id
|
self.agent_id = agent_id
|
||||||
self.agent_config = agent_config
|
self.agent_config = agent_config
|
||||||
self.tempdir = tempdir
|
|
||||||
self.inference_api = inference_api
|
self.inference_api = inference_api
|
||||||
self.safety_api = safety_api
|
self.safety_api = safety_api
|
||||||
self.vector_io_api = vector_io_api
|
self.vector_io_api = vector_io_api
|
||||||
|
|
|
@ -7,7 +7,6 @@
|
||||||
import json
|
import json
|
||||||
import logging
|
import logging
|
||||||
import shutil
|
import shutil
|
||||||
import tempfile
|
|
||||||
import uuid
|
import uuid
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
@ -64,7 +63,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
self.tool_groups_api = tool_groups_api
|
self.tool_groups_api = tool_groups_api
|
||||||
|
|
||||||
self.in_memory_store = InmemoryKVStoreImpl()
|
self.in_memory_store = InmemoryKVStoreImpl()
|
||||||
self.tempdir = tempfile.mkdtemp()
|
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||||
|
@ -107,7 +105,6 @@ class MetaReferenceAgentsImpl(Agents):
|
||||||
return ChatAgent(
|
return ChatAgent(
|
||||||
agent_id=agent_id,
|
agent_id=agent_id,
|
||||||
agent_config=agent_config,
|
agent_config=agent_config,
|
||||||
tempdir=self.tempdir,
|
|
||||||
inference_api=self.inference_api,
|
inference_api=self.inference_api,
|
||||||
safety_api=self.safety_api,
|
safety_api=self.safety_api,
|
||||||
vector_io_api=self.vector_io_api,
|
vector_io_api=self.vector_io_api,
|
||||||
|
|
|
@ -4,13 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, Union
|
from typing import Any, Dict
|
||||||
|
|
||||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
|
|
||||||
|
|
||||||
async def get_provider_impl(
|
async def get_provider_impl(
|
||||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
config: MetaReferenceInferenceConfig,
|
||||||
_deps: Dict[str, Any],
|
_deps: Dict[str, Any],
|
||||||
):
|
):
|
||||||
from .inference import MetaReferenceInferenceImpl
|
from .inference import MetaReferenceInferenceImpl
|
||||||
|
|
|
@ -5,19 +5,10 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List, Optional
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
|
||||||
|
|
||||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||||
|
|
||||||
|
|
||||||
class TokenResult(BaseModel):
|
|
||||||
token: int
|
|
||||||
text: str
|
|
||||||
logprobs: Optional[List[float]] = None
|
|
||||||
|
|
||||||
|
|
||||||
def model_checkpoint_dir(model_id) -> str:
|
def model_checkpoint_dir(model_id) -> str:
|
||||||
checkpoint_dir = Path(model_local_dir(model_id))
|
checkpoint_dir = Path(model_local_dir(model_id))
|
||||||
|
|
||||||
|
|
|
@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
torch_seed: Optional[int] = None
|
torch_seed: Optional[int] = None
|
||||||
max_seq_len: int = 4096
|
max_seq_len: int = 4096
|
||||||
max_batch_size: int = 1
|
max_batch_size: int = 1
|
||||||
|
model_parallel_size: Optional[int] = None
|
||||||
|
|
||||||
# when this is False, we assume that the distributed process group is setup by someone
|
# when this is False, we assume that the distributed process group is setup by someone
|
||||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||||
|
@ -31,6 +32,8 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
# can override by specifying the directory explicitly
|
# can override by specifying the directory explicitly
|
||||||
checkpoint_dir: Optional[str] = None
|
checkpoint_dir: Optional[str] = None
|
||||||
|
|
||||||
|
quantization: Optional[QuantizationConfig] = None
|
||||||
|
|
||||||
@field_validator("model")
|
@field_validator("model")
|
||||||
@classmethod
|
@classmethod
|
||||||
def validate_model(cls, model: str) -> str:
|
def validate_model(cls, model: str) -> str:
|
||||||
|
@ -47,27 +50,16 @@ class MetaReferenceInferenceConfig(BaseModel):
|
||||||
cls,
|
cls,
|
||||||
model: str = "Llama3.2-3B-Instruct",
|
model: str = "Llama3.2-3B-Instruct",
|
||||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||||
|
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||||
|
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
||||||
**kwargs,
|
**kwargs,
|
||||||
) -> Dict[str, Any]:
|
) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"model": model,
|
"model": model,
|
||||||
"max_seq_len": 4096,
|
"max_seq_len": 4096,
|
||||||
"checkpoint_dir": checkpoint_dir,
|
"checkpoint_dir": checkpoint_dir,
|
||||||
|
"quantization": {
|
||||||
|
"type": quantization_type,
|
||||||
|
},
|
||||||
|
"model_parallel_size": model_parallel_size,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
|
|
||||||
quantization: QuantizationConfig
|
|
||||||
|
|
||||||
@classmethod
|
|
||||||
def sample_run_config(
|
|
||||||
cls,
|
|
||||||
model: str = "Llama3.2-3B-Instruct",
|
|
||||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
|
||||||
**kwargs,
|
|
||||||
) -> Dict[str, Any]:
|
|
||||||
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
|
|
||||||
config["quantization"] = {
|
|
||||||
"type": "fp8",
|
|
||||||
}
|
|
||||||
return config
|
|
||||||
|
|
|
@ -11,19 +11,18 @@ import torch
|
||||||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
Fp8QuantizationConfig,
|
GreedySamplingStrategy,
|
||||||
Int4QuantizationConfig,
|
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
)
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
|
||||||
GreedySamplingStrategy,
|
|
||||||
Model,
|
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.datatypes import QuantizationMode
|
||||||
|
from llama_stack.models.llama.llama3.generation import Llama3
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
|
from llama_stack.models.llama.llama4.generation import Llama4
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||||
|
from llama_stack.models.llama.sku_types import Model
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
|
@ -31,10 +30,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from .common import model_checkpoint_dir
|
from .common import model_checkpoint_dir
|
||||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
from .config import MetaReferenceInferenceConfig
|
||||||
from .inference import resolve_model
|
from .inference import resolve_model
|
||||||
from .llama3.generation import Llama3
|
|
||||||
from .llama4.generation import Llama4
|
|
||||||
|
|
||||||
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
||||||
|
|
||||||
|
@ -116,10 +113,11 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
||||||
return get_default_tool_prompt_format(request.model)
|
return get_default_tool_prompt_format(request.model)
|
||||||
|
|
||||||
|
|
||||||
|
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
|
||||||
class Llama4Generator:
|
class Llama4Generator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
config: MetaReferenceInferenceConfig,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
llama_model: Model,
|
llama_model: Model,
|
||||||
):
|
):
|
||||||
|
@ -134,11 +132,13 @@ class Llama4Generator:
|
||||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||||
|
|
||||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
if config.quantization:
|
||||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
if config.quantization.type == "fp8_mixed":
|
||||||
quantization_mode = "fp8_mixed"
|
quantization_mode = QuantizationMode.fp8_mixed
|
||||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
elif config.quantization.type == "int4_mixed":
|
||||||
quantization_mode = "int4_mixed"
|
quantization_mode = QuantizationMode.int4_mixed
|
||||||
|
elif config.quantization.type == "bf16":
|
||||||
|
quantization_mode = None
|
||||||
else:
|
else:
|
||||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||||
else:
|
else:
|
||||||
|
@ -148,7 +148,7 @@ class Llama4Generator:
|
||||||
ckpt_dir=ckpt_dir,
|
ckpt_dir=ckpt_dir,
|
||||||
max_seq_len=config.max_seq_len,
|
max_seq_len=config.max_seq_len,
|
||||||
max_batch_size=config.max_batch_size,
|
max_batch_size=config.max_batch_size,
|
||||||
world_size=llama_model.pth_file_count,
|
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||||
quantization_mode=quantization_mode,
|
quantization_mode=quantization_mode,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -166,8 +166,8 @@ class Llama4Generator:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
llm_input=self.formatter.encode_content(request.content),
|
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
@ -178,7 +178,8 @@ class Llama4Generator:
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
request.response_format,
|
||||||
),
|
),
|
||||||
)
|
):
|
||||||
|
yield result[0]
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -190,8 +191,8 @@ class Llama4Generator:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
@ -202,20 +203,46 @@ class Llama4Generator:
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
request.response_format,
|
||||||
),
|
),
|
||||||
)
|
):
|
||||||
|
yield result[0]
|
||||||
|
|
||||||
|
|
||||||
class Llama3Generator:
|
class Llama3Generator:
|
||||||
def __init__(
|
def __init__(
|
||||||
self,
|
self,
|
||||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
config: MetaReferenceInferenceConfig,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
llama_model: Model,
|
llama_model: Model,
|
||||||
):
|
):
|
||||||
|
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||||
|
ckpt_dir = config.checkpoint_dir
|
||||||
|
else:
|
||||||
|
resolved_model = resolve_model(model_id)
|
||||||
|
if resolved_model is None:
|
||||||
|
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||||
|
ckpt_dir = model_checkpoint_dir(model_id)
|
||||||
|
else:
|
||||||
|
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||||
|
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||||
|
|
||||||
|
if config.quantization:
|
||||||
|
if config.quantization.type == "fp8_mixed":
|
||||||
|
quantization_mode = QuantizationMode.fp8_mixed
|
||||||
|
elif config.quantization.type == "int4_mixed":
|
||||||
|
quantization_mode = QuantizationMode.int4_mixed
|
||||||
|
elif config.quantization.type == "bf16":
|
||||||
|
quantization_mode = None
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||||
|
else:
|
||||||
|
quantization_mode = None
|
||||||
|
|
||||||
self.inner_generator = Llama3.build(
|
self.inner_generator = Llama3.build(
|
||||||
config=config,
|
ckpt_dir=ckpt_dir,
|
||||||
model_id=model_id,
|
max_seq_len=config.max_seq_len,
|
||||||
llama_model=llama_model,
|
max_batch_size=config.max_batch_size,
|
||||||
|
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||||
|
quantization_mode=quantization_mode,
|
||||||
)
|
)
|
||||||
self.tokenizer = self.inner_generator.tokenizer
|
self.tokenizer = self.inner_generator.tokenizer
|
||||||
self.args = self.inner_generator.args
|
self.args = self.inner_generator.args
|
||||||
|
@ -231,8 +258,8 @@ class Llama3Generator:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
model_input=self.formatter.encode_content(request.content),
|
model_inputs=[self.formatter.encode_content(request.content)],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
@ -243,7 +270,8 @@ class Llama3Generator:
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
request.response_format,
|
||||||
),
|
),
|
||||||
)
|
):
|
||||||
|
yield result[0]
|
||||||
|
|
||||||
def chat_completion(
|
def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -255,8 +283,8 @@ class Llama3Generator:
|
||||||
max_gen_len = self.args.max_seq_len - 1
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
|
||||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||||
yield from self.inner_generator.generate(
|
for result in self.inner_generator.generate(
|
||||||
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||||
max_gen_len=max_gen_len,
|
max_gen_len=max_gen_len,
|
||||||
temperature=temperature,
|
temperature=temperature,
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
|
@ -267,4 +295,5 @@ class Llama3Generator:
|
||||||
self.args.vocab_size,
|
self.args.vocab_size,
|
||||||
request.response_format,
|
request.response_format,
|
||||||
),
|
),
|
||||||
)
|
):
|
||||||
|
yield result[0]
|
||||||
|
|
|
@ -6,8 +6,11 @@
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
|
import os
|
||||||
from typing import AsyncGenerator, List, Optional, Union
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
|
@ -28,23 +31,21 @@ from llama_stack.apis.inference import (
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
ResponseFormat,
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
StopReason,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
|
||||||
from llama_stack.apis.models import Model, ModelType
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
|
||||||
ModelFamily,
|
|
||||||
SamplingParams,
|
|
||||||
StopReason,
|
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
from llama_stack.models.llama.sku_types import ModelFamily
|
||||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||||
SentenceTransformerEmbeddingMixin,
|
SentenceTransformerEmbeddingMixin,
|
||||||
|
@ -148,7 +149,7 @@ class MetaReferenceInferenceImpl(
|
||||||
|
|
||||||
if self.config.create_distributed_process_group:
|
if self.config.create_distributed_process_group:
|
||||||
self.generator = LlamaModelParallelGenerator(
|
self.generator = LlamaModelParallelGenerator(
|
||||||
model_parallel_size=llama_model.pth_file_count,
|
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||||
builder_fn=builder_fn,
|
builder_fn=builder_fn,
|
||||||
builder_params=builder_params,
|
builder_params=builder_params,
|
||||||
formatter=(
|
formatter=(
|
||||||
|
@ -338,6 +339,9 @@ class MetaReferenceInferenceImpl(
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(request):
|
for token_result in self.generator.chat_completion(request):
|
||||||
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||||
|
cprint(token_result.text, "cyan", end="")
|
||||||
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if token_result.token == tokenizer.eot_id:
|
if token_result.token == tokenizer.eot_id:
|
||||||
|
@ -386,6 +390,9 @@ class MetaReferenceInferenceImpl(
|
||||||
ipython = False
|
ipython = False
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(request):
|
for token_result in self.generator.chat_completion(request):
|
||||||
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||||
|
cprint(token_result.text, "cyan", end="")
|
||||||
|
|
||||||
tokens.append(token_result.token)
|
tokens.append(token_result.token)
|
||||||
|
|
||||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||||
|
|
|
@ -1,346 +0,0 @@
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# the root directory of this source tree.
|
|
||||||
|
|
||||||
|
|
||||||
import json
|
|
||||||
import os
|
|
||||||
import sys
|
|
||||||
import time
|
|
||||||
from pathlib import Path
|
|
||||||
from typing import Callable, Generator, Optional, Union
|
|
||||||
|
|
||||||
import torch
|
|
||||||
import torch.nn.functional as F
|
|
||||||
from fairscale.nn.model_parallel.initialize import (
|
|
||||||
get_model_parallel_rank,
|
|
||||||
initialize_model_parallel,
|
|
||||||
model_parallel_is_initialized,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
|
||||||
Fp8QuantizationConfig,
|
|
||||||
Int4QuantizationConfig,
|
|
||||||
)
|
|
||||||
from llama_stack.log import get_logger
|
|
||||||
from llama_stack.models.llama.datatypes import Model
|
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
|
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
|
||||||
|
|
||||||
from ..common import TokenResult, model_checkpoint_dir
|
|
||||||
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
|
||||||
from .args import ModelArgs
|
|
||||||
from .model import Transformer
|
|
||||||
from .multimodal.model import CrossAttentionTransformer
|
|
||||||
|
|
||||||
log = get_logger(__name__, category="inference")
|
|
||||||
|
|
||||||
|
|
||||||
class Llama3:
|
|
||||||
@staticmethod
|
|
||||||
def build(
|
|
||||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
|
||||||
model_id: str,
|
|
||||||
llama_model: Model,
|
|
||||||
):
|
|
||||||
"""
|
|
||||||
Build a Llama instance by initializing and loading a model checkpoint.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
This method initializes the distributed process group, sets the device to CUDA,
|
|
||||||
and loads the pre-trained model and tokenizer.
|
|
||||||
"""
|
|
||||||
if "DEVICE" in os.environ:
|
|
||||||
device = os.environ.get("DEVICE")
|
|
||||||
if device == "cuda":
|
|
||||||
assert torch.cuda.is_available(), "PyTorch CUDA backend not available"
|
|
||||||
if device == "xpu":
|
|
||||||
assert torch.xpu.is_available(), "PyTorch XPU backend not available"
|
|
||||||
else:
|
|
||||||
if torch.cuda.is_available():
|
|
||||||
device = "cuda"
|
|
||||||
elif torch.xpu.is_available():
|
|
||||||
device = "xpu"
|
|
||||||
else:
|
|
||||||
device = "cpu"
|
|
||||||
log.info(f"Using {device} device")
|
|
||||||
|
|
||||||
llama_model_id = llama_model.core_model_id.value
|
|
||||||
if not torch.distributed.is_initialized():
|
|
||||||
if device == "cuda":
|
|
||||||
torch.distributed.init_process_group("nccl")
|
|
||||||
else:
|
|
||||||
torch.distributed.init_process_group("gloo")
|
|
||||||
|
|
||||||
model_parallel_size = llama_model.pth_file_count
|
|
||||||
|
|
||||||
if not model_parallel_is_initialized():
|
|
||||||
initialize_model_parallel(model_parallel_size)
|
|
||||||
|
|
||||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
|
||||||
if device == "cuda":
|
|
||||||
torch.cuda.set_device(local_rank)
|
|
||||||
elif device == "xpu":
|
|
||||||
torch.xpu.set_device(local_rank)
|
|
||||||
|
|
||||||
# seed must be the same in all processes
|
|
||||||
if config.torch_seed is not None:
|
|
||||||
torch.manual_seed(config.torch_seed)
|
|
||||||
|
|
||||||
if local_rank > 0:
|
|
||||||
sys.stdout = open(os.devnull, "w")
|
|
||||||
|
|
||||||
start_time = time.time()
|
|
||||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
|
||||||
ckpt_dir = config.checkpoint_dir
|
|
||||||
else:
|
|
||||||
resolved_model = resolve_model(model_id)
|
|
||||||
if resolved_model is None:
|
|
||||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
|
||||||
ckpt_dir = model_checkpoint_dir(model_id)
|
|
||||||
else:
|
|
||||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
|
||||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
|
||||||
|
|
||||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
|
||||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
|
||||||
assert model_parallel_size == len(checkpoints), (
|
|
||||||
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
|
||||||
)
|
|
||||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
|
||||||
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
|
||||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
|
||||||
params = json.loads(f.read())
|
|
||||||
|
|
||||||
if "model" in params:
|
|
||||||
params = params["model"]
|
|
||||||
|
|
||||||
model_args: ModelArgs = ModelArgs(
|
|
||||||
max_seq_len=config.max_seq_len,
|
|
||||||
max_batch_size=config.max_batch_size,
|
|
||||||
**params,
|
|
||||||
)
|
|
||||||
|
|
||||||
tokenizer = Tokenizer.get_instance()
|
|
||||||
assert model_args.vocab_size == tokenizer.n_words, (
|
|
||||||
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
|
||||||
)
|
|
||||||
|
|
||||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
|
||||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
|
||||||
from .quantization.loader import convert_to_fp8_quantized_model
|
|
||||||
|
|
||||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
|
||||||
# unexpected (fp32, e.g.) datatype
|
|
||||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
|
||||||
if model_args.vision_chunk_size > 0:
|
|
||||||
model = CrossAttentionTransformer(model_args)
|
|
||||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
|
||||||
else:
|
|
||||||
model = Transformer(model_args)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
|
||||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
|
||||||
from .quantization.loader import convert_to_int4_quantized_model
|
|
||||||
|
|
||||||
model = Transformer(model_args)
|
|
||||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
|
||||||
model.load_state_dict(state_dict, strict=True)
|
|
||||||
|
|
||||||
if model_args.quantization_args is not None and model_args.quantization_args.spinquant:
|
|
||||||
# Add a wrapper for adding hadamard transform for spinquant.
|
|
||||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
|
||||||
# loading the state dict.
|
|
||||||
from ..hadamard_utils import (
|
|
||||||
add_hadamard_transform_for_spinquant,
|
|
||||||
)
|
|
||||||
|
|
||||||
add_hadamard_transform_for_spinquant(model)
|
|
||||||
else:
|
|
||||||
raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.")
|
|
||||||
else:
|
|
||||||
if device == "cuda":
|
|
||||||
if torch.cuda.is_bf16_supported():
|
|
||||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
|
||||||
else:
|
|
||||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
|
||||||
else:
|
|
||||||
torch.set_default_device(device)
|
|
||||||
if device == "xpu" and torch.xpu.is_bf16_supported():
|
|
||||||
torch.set_default_dtype(torch.bfloat16)
|
|
||||||
else:
|
|
||||||
torch.set_default_dtype(torch.half)
|
|
||||||
if model_args.vision_chunk_size > 0:
|
|
||||||
model = CrossAttentionTransformer(model_args)
|
|
||||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
|
||||||
else:
|
|
||||||
model = Transformer(model_args)
|
|
||||||
model.load_state_dict(state_dict, strict=False)
|
|
||||||
|
|
||||||
model.to(device)
|
|
||||||
|
|
||||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
|
||||||
return Llama3(model, tokenizer, model_args, llama_model_id)
|
|
||||||
|
|
||||||
def __init__(
|
|
||||||
self,
|
|
||||||
model: Transformer,
|
|
||||||
tokenizer: Tokenizer,
|
|
||||||
args: ModelArgs,
|
|
||||||
llama_model: str,
|
|
||||||
):
|
|
||||||
self.args = args
|
|
||||||
self.model = model
|
|
||||||
self.tokenizer = tokenizer
|
|
||||||
self.formatter = ChatFormat(tokenizer)
|
|
||||||
self.llama_model = llama_model
|
|
||||||
|
|
||||||
@torch.inference_mode()
|
|
||||||
def generate(
|
|
||||||
self,
|
|
||||||
model_input: LLMInput,
|
|
||||||
max_gen_len: int,
|
|
||||||
temperature: float = 0.6,
|
|
||||||
top_p: float = 0.9,
|
|
||||||
logprobs: bool = False,
|
|
||||||
echo: bool = False,
|
|
||||||
print_input_tokens: bool = False,
|
|
||||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
|
||||||
) -> Generator:
|
|
||||||
params = self.model.params
|
|
||||||
|
|
||||||
if print_input_tokens:
|
|
||||||
input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
|
|
||||||
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
|
|
||||||
prompt_tokens = [model_input.tokens]
|
|
||||||
|
|
||||||
bsz = 1
|
|
||||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
|
||||||
|
|
||||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
|
||||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
|
||||||
|
|
||||||
if max_prompt_len >= params.max_seq_len:
|
|
||||||
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
|
|
||||||
return
|
|
||||||
|
|
||||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
|
||||||
|
|
||||||
is_vision = isinstance(self.model, CrossAttentionTransformer)
|
|
||||||
if is_vision:
|
|
||||||
images = model_input.vision.images if model_input.vision is not None else []
|
|
||||||
mask = model_input.vision.mask if model_input.vision is not None else []
|
|
||||||
|
|
||||||
# the method works for bsz > 1 so add a batch dimension
|
|
||||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
|
||||||
batch_images=[images],
|
|
||||||
batch_masks=[mask],
|
|
||||||
total_len=total_len,
|
|
||||||
)
|
|
||||||
|
|
||||||
pad_id = self.tokenizer.pad_id
|
|
||||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
|
||||||
for k, t in enumerate(prompt_tokens):
|
|
||||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
|
||||||
if logprobs:
|
|
||||||
token_logprobs = torch.zeros_like(tokens)
|
|
||||||
|
|
||||||
prev_pos = 0
|
|
||||||
eos_reached = torch.tensor([False] * bsz)
|
|
||||||
input_text_mask = tokens != pad_id
|
|
||||||
if min_prompt_len == total_len:
|
|
||||||
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
|
|
||||||
logits = self.model.forward(tokens, prev_pos)
|
|
||||||
token_logprobs = -F.cross_entropy(
|
|
||||||
input=logits.transpose(1, 2),
|
|
||||||
target=tokens,
|
|
||||||
reduction="none",
|
|
||||||
ignore_index=pad_id,
|
|
||||||
)
|
|
||||||
|
|
||||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
|
||||||
for cur_pos in range(min_prompt_len, total_len):
|
|
||||||
if is_vision:
|
|
||||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
|
||||||
logits = self.model.forward(
|
|
||||||
position_ids,
|
|
||||||
tokens,
|
|
||||||
cross_attention_masks,
|
|
||||||
full_text_row_masked_out_mask,
|
|
||||||
xattn_caches,
|
|
||||||
)
|
|
||||||
else:
|
|
||||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
|
||||||
|
|
||||||
if logits_processor is not None:
|
|
||||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
|
||||||
|
|
||||||
if temperature > 0:
|
|
||||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
|
||||||
next_token = sample_top_p(probs, top_p)
|
|
||||||
else:
|
|
||||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
|
||||||
|
|
||||||
next_token = next_token.reshape(-1)
|
|
||||||
# only replace token if prompt has already been generated
|
|
||||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
|
||||||
tokens[:, cur_pos] = next_token
|
|
||||||
|
|
||||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
|
||||||
if is_vision:
|
|
||||||
# the logits space (num_classes) is designed to never contain a media_token
|
|
||||||
# however our input token stream does contain them. we need to nuke them here
|
|
||||||
# or else the CUDA kernels will crash with an illegal memory access
|
|
||||||
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
|
||||||
masks = [target.eq(t) for t in vision_tokens]
|
|
||||||
if len(masks) > 1:
|
|
||||||
mask = torch.logical_or(*masks)
|
|
||||||
else:
|
|
||||||
mask = masks[0]
|
|
||||||
target[mask] = 0
|
|
||||||
|
|
||||||
if logprobs:
|
|
||||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
|
||||||
input=logits.transpose(1, 2),
|
|
||||||
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
|
||||||
reduction="none",
|
|
||||||
ignore_index=pad_id,
|
|
||||||
)
|
|
||||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
|
||||||
yield TokenResult(
|
|
||||||
token=next_token[0].item(),
|
|
||||||
text=self.tokenizer.decode(next_token.tolist()),
|
|
||||||
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
|
|
||||||
)
|
|
||||||
|
|
||||||
prev_pos = cur_pos
|
|
||||||
if all(eos_reached):
|
|
||||||
break
|
|
||||||
|
|
||||||
|
|
||||||
def sample_top_p(probs, p):
|
|
||||||
"""
|
|
||||||
Perform top-p (nucleus) sampling on a probability distribution.
|
|
||||||
|
|
||||||
Args:
|
|
||||||
probs (torch.Tensor): Probability distribution tensor.
|
|
||||||
p (float): Probability threshold for top-p sampling.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
torch.Tensor: Sampled token indices.
|
|
||||||
|
|
||||||
Note:
|
|
||||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
|
||||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
|
||||||
"""
|
|
||||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
|
||||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
|
||||||
mask = probs_sum - probs_sort > p
|
|
||||||
probs_sort[mask] = 0.0
|
|
||||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
|
||||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
|
||||||
next_token = torch.gather(probs_idx, -1, next_token)
|
|
||||||
return next_token
|
|
|
@ -32,13 +32,12 @@ from pydantic import BaseModel, Field
|
||||||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
from llama_stack.models.llama.datatypes import GenerationResult
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
ChatCompletionRequestWithRawContent,
|
ChatCompletionRequestWithRawContent,
|
||||||
CompletionRequestWithRawContent,
|
CompletionRequestWithRawContent,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .common import TokenResult
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
@ -75,7 +74,7 @@ class TaskRequest(BaseModel):
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||||
result: TokenResult
|
result: GenerationResult
|
||||||
|
|
||||||
|
|
||||||
class ExceptionResponse(BaseModel):
|
class ExceptionResponse(BaseModel):
|
||||||
|
|
|
@ -14,9 +14,10 @@ from llama_stack.apis.inference import (
|
||||||
JsonSchemaResponseFormat,
|
JsonSchemaResponseFormat,
|
||||||
Message,
|
Message,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
|
ToolDefinition,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||||
from llama_stack.providers.utils.inference.openai_compat import (
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
convert_message_to_openai_dict,
|
convert_message_to_openai_dict,
|
||||||
get_sampling_options,
|
get_sampling_options,
|
||||||
|
|
|
@ -46,6 +46,8 @@ from llama_stack.apis.inference import (
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
TopKSamplingStrategy,
|
||||||
|
TopPSamplingStrategy,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -55,8 +57,6 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
TopKSamplingStrategy,
|
|
||||||
TopPSamplingStrategy,
|
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
|
|
|
@ -22,8 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
|
||||||
from torchtune.modules.transforms import Transform
|
from torchtune.modules.transforms import Transform
|
||||||
|
|
||||||
from llama_stack.apis.post_training import DatasetFormat
|
from llama_stack.apis.post_training import DatasetFormat
|
||||||
from llama_stack.models.llama.datatypes import Model
|
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
|
from llama_stack.models.llama.sku_types import Model
|
||||||
|
|
||||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||||
|
|
|
@ -23,7 +23,8 @@ from llama_stack.apis.safety import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.shields import Shield
|
from llama_stack.apis.shields import Shield
|
||||||
from llama_stack.distribution.datatypes import Api
|
from llama_stack.distribution.datatypes import Api
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId, Role
|
from llama_stack.models.llama.datatypes import Role
|
||||||
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
|
|
|
@ -24,6 +24,8 @@ META_REFERENCE_DEPS = [
|
||||||
"zmq",
|
"zmq",
|
||||||
"lm-format-enforcer",
|
"lm-format-enforcer",
|
||||||
"sentence-transformers",
|
"sentence-transformers",
|
||||||
|
"torchao==0.5.0",
|
||||||
|
"fbgemm-gpu-genai==1.1.2",
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
@ -36,13 +38,6 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
module="llama_stack.providers.inline.inference.meta_reference",
|
module="llama_stack.providers.inline.inference.meta_reference",
|
||||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
||||||
),
|
),
|
||||||
InlineProviderSpec(
|
|
||||||
api=Api.inference,
|
|
||||||
provider_type="inline::meta-reference-quantized",
|
|
||||||
pip_packages=META_REFERENCE_DEPS + ["fbgemm-gpu", "torchao==0.5.0"],
|
|
||||||
module="llama_stack.providers.inline.inference.meta_reference",
|
|
||||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
|
|
||||||
),
|
|
||||||
InlineProviderSpec(
|
InlineProviderSpec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
provider_type="inline::vllm",
|
provider_type="inline::vllm",
|
||||||
|
@ -222,6 +217,56 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="fireworks-openai-compat",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.inference.fireworks_openai_compat",
|
||||||
|
config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="together-openai-compat",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.inference.together_openai_compat",
|
||||||
|
config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="groq-openai-compat",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.inference.groq_openai_compat",
|
||||||
|
config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="sambanova-openai-compat",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.inference.sambanova_openai_compat",
|
||||||
|
config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="cerebras-openai-compat",
|
||||||
|
pip_packages=["litellm"],
|
||||||
|
module="llama_stack.providers.remote.inference.cerebras_openai_compat",
|
||||||
|
config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
remote_provider_spec(
|
remote_provider_spec(
|
||||||
api=Api.inference,
|
api=Api.inference,
|
||||||
adapter=AdapterSpec(
|
adapter=AdapterSpec(
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_hf_repo_model_entry,
|
build_hf_repo_model_entry,
|
||||||
)
|
)
|
||||||
|
|
|
@ -28,8 +28,8 @@ from llama_stack.apis.inference import (
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
|
TopKSamplingStrategy,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import TopKSamplingStrategy
|
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
|
|
@ -4,7 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
build_hf_repo_model_entry,
|
build_hf_repo_model_entry,
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import CerebrasCompatConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so the import is used only when it is needed
|
||||||
|
from .cerebras import CerebrasCompatInferenceAdapter
|
||||||
|
|
||||||
|
adapter = CerebrasCompatInferenceAdapter(config)
|
||||||
|
return adapter
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
||||||
|
from ..cerebras.models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
_config: CerebrasCompatConfig
|
||||||
|
|
||||||
|
def __init__(self, config: CerebrasCompatConfig):
|
||||||
|
LiteLLMOpenAIMixin.__init__(
|
||||||
|
self,
|
||||||
|
model_entries=MODEL_ENTRIES,
|
||||||
|
api_key_from_config=config.api_key,
|
||||||
|
provider_data_api_key_field="cerebras_api_key",
|
||||||
|
openai_compat_api_base=config.openai_compat_api_base,
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
await super().initialize()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
await super().shutdown()
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class CerebrasProviderDataValidator(BaseModel):
|
||||||
|
cerebras_api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for Cerebras models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class CerebrasCompatConfig(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Cerebras API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_compat_api_base: str = Field(
|
||||||
|
default="https://api.cerebras.ai/v1",
|
||||||
|
description="The URL for the Cerebras API server",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"openai_compat_api_base": "https://api.cerebras.ai/v1",
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
build_hf_repo_model_entry,
|
build_hf_repo_model_entry,
|
||||||
|
|
|
@ -5,7 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from llama_stack.apis.models.models import ModelType
|
from llama_stack.apis.models.models import ModelType
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ProviderModelEntry,
|
ProviderModelEntry,
|
||||||
build_hf_repo_model_entry,
|
build_hf_repo_model_entry,
|
||||||
|
@ -48,6 +48,14 @@ MODEL_ENTRIES = [
|
||||||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||||
CoreModelId.llama_guard_3_11b_vision.value,
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
),
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||||
|
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"accounts/fireworks/models/llama4-maverick-instruct-basic",
|
||||||
|
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||||
|
),
|
||||||
ProviderModelEntry(
|
ProviderModelEntry(
|
||||||
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
|
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||||
model_type=ModelType.embedding,
|
model_type=ModelType.embedding,
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import FireworksCompatConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so the import is used only when it is needed
|
||||||
|
from .fireworks import FireworksCompatInferenceAdapter
|
||||||
|
|
||||||
|
adapter = FireworksCompatInferenceAdapter(config)
|
||||||
|
return adapter
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class FireworksProviderDataValidator(BaseModel):
|
||||||
|
fireworks_api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for Fireworks models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class FireworksCompatConfig(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Fireworks API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_compat_api_base: str = Field(
|
||||||
|
default="https://api.fireworks.ai/inference/v1",
|
||||||
|
description="The URL for the Fireworks API server",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
||||||
|
from ..fireworks.models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
_config: FireworksCompatConfig
|
||||||
|
|
||||||
|
def __init__(self, config: FireworksCompatConfig):
|
||||||
|
LiteLLMOpenAIMixin.__init__(
|
||||||
|
self,
|
||||||
|
model_entries=MODEL_ENTRIES,
|
||||||
|
api_key_from_config=config.api_key,
|
||||||
|
provider_data_api_key_field="fireworks_api_key",
|
||||||
|
openai_compat_api_base=config.openai_compat_api_base,
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
await super().initialize()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
await super().shutdown()
|
|
@ -35,4 +35,12 @@ MODEL_ENTRIES = [
|
||||||
"groq/llama-3.2-3b-preview",
|
"groq/llama-3.2-3b-preview",
|
||||||
CoreModelId.llama3_2_3b_instruct.value,
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
),
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"groq/llama-4-scout-17b-16e-instruct",
|
||||||
|
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"groq/llama-4-maverick-17b-128e-instruct",
|
||||||
|
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -0,0 +1,17 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import GroqCompatConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so the import is used only when it is needed
|
||||||
|
from .groq import GroqCompatInferenceAdapter
|
||||||
|
|
||||||
|
adapter = GroqCompatInferenceAdapter(config)
|
||||||
|
return adapter
|
|
@ -0,0 +1,38 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class GroqProviderDataValidator(BaseModel):
|
||||||
|
groq_api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="API key for Groq models",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GroqCompatConfig(BaseModel):
|
||||||
|
api_key: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="The Groq API key",
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_compat_api_base: str = Field(
|
||||||
|
default="https://api.groq.com/openai/v1",
|
||||||
|
description="The URL for the Groq API server",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"openai_compat_api_base": "https://api.groq.com/openai/v1",
|
||||||
|
"api_key": api_key,
|
||||||
|
}
|
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig
|
||||||
|
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||||
|
|
||||||
|
from ..groq.models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||||
|
_config: GroqCompatConfig
|
||||||
|
|
||||||
|
def __init__(self, config: GroqCompatConfig):
|
||||||
|
LiteLLMOpenAIMixin.__init__(
|
||||||
|
self,
|
||||||
|
model_entries=MODEL_ENTRIES,
|
||||||
|
api_key_from_config=config.api_key,
|
||||||
|
provider_data_api_key_field="groq_api_key",
|
||||||
|
openai_compat_api_base=config.openai_compat_api_base,
|
||||||
|
)
|
||||||
|
self.config = config
|
||||||
|
|
||||||
|
async def initialize(self):
|
||||||
|
await super().initialize()
|
||||||
|
|
||||||
|
async def shutdown(self):
|
||||||
|
await super().shutdown()
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue