mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
Merge branch 'main' into docs-4
This commit is contained in:
commit
3366937765
168 changed files with 14921 additions and 1625 deletions
|
@ -320,7 +320,7 @@ jobs:
|
|||
- name: "PR - Update comment"
|
||||
id: pr_update_comment
|
||||
if: github.event_name == 'pull_request_target'
|
||||
uses: thollander/actions-comment-pull-request@65f9e5c9a1f2cd378bd74b2e057c9736982a8e74 # v3.0.1
|
||||
uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b # v3.0.1
|
||||
with:
|
||||
filePath: test-summary.md
|
||||
|
||||
|
|
93
.github/workflows/test-external-providers.yml
vendored
Normal file
93
.github/workflows/test-external-providers.yml
vendored
Normal file
|
@ -0,0 +1,93 @@
|
|||
name: Test External Providers
|
||||
|
||||
on:
|
||||
push:
|
||||
branches: [ main ]
|
||||
pull_request:
|
||||
branches: [ main ]
|
||||
|
||||
jobs:
|
||||
test-external-providers:
|
||||
runs-on: ubuntu-latest
|
||||
steps:
|
||||
- name: Checkout repository
|
||||
uses: actions/checkout@v4
|
||||
|
||||
- name: Install uv
|
||||
uses: astral-sh/setup-uv@v5
|
||||
with:
|
||||
python-version: "3.10"
|
||||
|
||||
- name: Install Ollama
|
||||
run: |
|
||||
curl -fsSL https://ollama.com/install.sh | sh
|
||||
|
||||
- name: Pull Ollama image
|
||||
run: |
|
||||
ollama pull llama3.2:3b-instruct-fp16
|
||||
|
||||
- name: Start Ollama in background
|
||||
run: |
|
||||
nohup ollama run llama3.2:3b-instruct-fp16 --keepalive=30m > ollama.log 2>&1 &
|
||||
|
||||
- name: Set Up Environment and Install Dependencies
|
||||
run: |
|
||||
uv sync --extra dev --extra test
|
||||
uv pip install -e .
|
||||
|
||||
- name: Install Ollama custom provider
|
||||
run: |
|
||||
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
||||
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
||||
uv pip install tests/external-provider/llama-stack-provider-ollama
|
||||
|
||||
- name: Create provider configuration
|
||||
run: |
|
||||
mkdir -p /tmp/providers.d/remote/inference
|
||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
||||
|
||||
- name: Wait for Ollama to start
|
||||
run: |
|
||||
echo "Waiting for Ollama..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
||||
echo "Ollama is running!"
|
||||
exit 0
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Ollama failed to start"
|
||||
ollama ps
|
||||
ollama.log
|
||||
exit 1
|
||||
|
||||
- name: Start Llama Stack server in background
|
||||
env:
|
||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||
run: |
|
||||
source .venv/bin/activate
|
||||
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
|
||||
|
||||
- name: Wait for Llama Stack server to be ready
|
||||
run: |
|
||||
echo "Waiting for Llama Stack server..."
|
||||
for i in {1..30}; do
|
||||
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
|
||||
echo "Llama Stack server is up!"
|
||||
if grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||
echo "Llama Stack server is using custom Ollama provider"
|
||||
exit 0
|
||||
else
|
||||
echo "Llama Stack server is not using custom Ollama provider"
|
||||
exit 1
|
||||
fi
|
||||
fi
|
||||
sleep 1
|
||||
done
|
||||
echo "Llama Stack server failed to start"
|
||||
cat server.log
|
||||
exit 1
|
||||
|
||||
- name: run inference tests
|
||||
run: |
|
||||
uv run pytest -v tests/integration/inference/test_text_inference.py --stack-config="http://localhost:8321" --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2
|
37
CHANGELOG.md
37
CHANGELOG.md
|
@ -1,5 +1,42 @@
|
|||
# Changelog
|
||||
|
||||
# v0.2.1
|
||||
Published on: 2025-04-05T23:13:00Z
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.2.0
|
||||
Published on: 2025-04-05T19:04:29Z
|
||||
|
||||
## Llama 4 Support
|
||||
|
||||
Checkout more at https://www.llama.com
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.9
|
||||
Published on: 2025-03-29T00:52:23Z
|
||||
|
||||
### Build and Test Agents
|
||||
* Agents: Entire document context with attachments
|
||||
* RAG: Documentation with sqlite-vec faiss comparison
|
||||
* Getting started: Fixes to getting started notebook.
|
||||
|
||||
### Agent Evals and Model Customization
|
||||
* (**New**) Post-training: Add nemo customizer
|
||||
|
||||
### Better Engineering
|
||||
* Moved sqlite-vec to non-blocking calls
|
||||
* Don't return a payload on file delete
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.8
|
||||
Published on: 2025-03-24T01:28:50Z
|
||||
|
||||
|
|
|
@ -3,7 +3,7 @@
|
|||
[](https://pypi.org/project/llama_stack/)
|
||||
[](https://pypi.org/project/llama-stack/)
|
||||
[](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
|
||||
[](https://discord.gg/llama-stack)
|
||||
[](https://discord.gg/llama-stack)
|
||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
||||
|
||||
|
@ -11,7 +11,7 @@
|
|||
|
||||
|
||||
### ✨🎉 Llama 4 Support 🎉✨
|
||||
We release [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
||||
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
||||
|
||||
You can now run Llama 4 models on Llama Stack.
|
||||
|
||||
|
|
876
docs/getting_started_llama4.ipynb
Normal file
876
docs/getting_started_llama4.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -17,7 +17,7 @@ client = LlamaStackAsLibraryClient(
|
|||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||
)
|
||||
await client.initialize()
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
||||
|
|
|
@ -46,6 +46,8 @@ The following models are available by default:
|
|||
- `accounts/fireworks/models/llama-v3p3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||
- `accounts/fireworks/models/llama-guard-3-8b (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||
- `accounts/fireworks/models/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||
- `accounts/fireworks/models/llama4-scout-instruct-basic (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||
- `accounts/fireworks/models/llama4-maverick-instruct-basic (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||
- `nomic-ai/nomic-embed-text-v1.5 `
|
||||
|
||||
|
||||
|
|
|
@ -42,6 +42,8 @@ The following models are available by default:
|
|||
- `groq/llama3-70b-8192 (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||
- `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||
- `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
|
|
@ -41,6 +41,80 @@ The following environment variables can be configured:
|
|||
|
||||
## Setting up vLLM server
|
||||
|
||||
Both AMD and NVIDIA GPUs can serve as accelerators for the vLLM server, which acts as both the LLM inference provider and the safety provider.
|
||||
|
||||
### Setting up vLLM server on AMD GPU
|
||||
|
||||
AMD provides two main vLLM container options:
|
||||
- rocm/vllm: Production-ready container
|
||||
- rocm/vllm-dev: Development container with the latest vLLM features
|
||||
|
||||
Please check the [Blog about ROCm vLLM Usage](https://rocm.blogs.amd.com/software-tools-optimization/vllm-container/README.html) to get more details.
|
||||
|
||||
Here is a sample script to start a ROCm vLLM server locally via Docker:
|
||||
|
||||
```bash
|
||||
export INFERENCE_PORT=8000
|
||||
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||
export CUDA_VISIBLE_DEVICES=0
|
||||
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--ipc=host \
|
||||
--privileged \
|
||||
--shm-size 16g \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
--group-add video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--cap-add=CAP_SYS_ADMIN \
|
||||
--security-opt seccomp=unconfined \
|
||||
--security-opt apparmor=unconfined \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
$VLLM_DIMG \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model $INFERENCE_MODEL \
|
||||
--port $INFERENCE_PORT
|
||||
```
|
||||
|
||||
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
|
||||
|
||||
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||
|
||||
```bash
|
||||
export SAFETY_PORT=8081
|
||||
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||
export CUDA_VISIBLE_DEVICES=1
|
||||
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||
|
||||
docker run \
|
||||
--pull always \
|
||||
--ipc=host \
|
||||
--privileged \
|
||||
--shm-size 16g \
|
||||
--device=/dev/kfd \
|
||||
--device=/dev/dri \
|
||||
--group-add video \
|
||||
--cap-add=SYS_PTRACE \
|
||||
--cap-add=CAP_SYS_ADMIN \
|
||||
--security-opt seccomp=unconfined \
|
||||
--security-opt apparmor=unconfined \
|
||||
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||
$VLLM_DIMG \
|
||||
python -m vllm.entrypoints.openai.api_server \
|
||||
--model $SAFETY_MODEL \
|
||||
--port $SAFETY_PORT
|
||||
```
|
||||
|
||||
### Setting up vLLM server on NVIDIA GPU
|
||||
|
||||
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
|
||||
|
||||
```bash
|
||||
|
|
|
@ -43,6 +43,7 @@ The following models are available by default:
|
|||
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||
- `Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
|
|
@ -48,6 +48,8 @@ The following models are available by default:
|
|||
- `meta-llama/Llama-Guard-3-11B-Vision-Turbo (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||
- `togethercomputer/m2-bert-80M-8k-retrieval `
|
||||
- `togethercomputer/m2-bert-80M-32k-retrieval `
|
||||
- `meta-llama/Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct, together/meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||
- `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct, together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)`
|
||||
|
||||
|
||||
### Prerequisite: API Keys
|
||||
|
|
|
@ -12,17 +12,21 @@ as the inference [provider](../providers/index.md#inference) for a Llama Model.
|
|||
|
||||
## Step 1: Installation and Setup
|
||||
|
||||
### i. Install and Start Ollama for Inference
|
||||
### i. Install and Setup Ollama for Inference
|
||||
|
||||
Install Ollama by following the instructions on the [Ollama website](https://ollama.com/download).
|
||||
|
||||
To start Ollama run:
|
||||
Then download a Llama model with Ollama
|
||||
```bash
|
||||
ollama pull llama3.2:3b-instruct-fp16
|
||||
```
|
||||
This will instruct the Ollama service to download the Llama 3.2 3B Instruct model, which we'll use in the rest of this guide.
|
||||
|
||||
Then to start Ollama run:
|
||||
```bash
|
||||
ollama run llama3.2:3b --keepalive 60m
|
||||
```
|
||||
|
||||
By default, Ollama keeps the model loaded in memory for 5 minutes which can be too short. We set the `--keepalive` flag to 60 minutes to ensure the model remains loaded for sometime.
|
||||
|
||||
### ii. Install `uv` to Manage your Python packages
|
||||
|
||||
Install [uv](https://docs.astral.sh/uv/) to setup your virtual environment
|
||||
|
|
|
@ -1,3 +1,8 @@
|
|||
```{admonition} Llama 4 is here!
|
||||
:class: tip
|
||||
|
||||
Check out [Getting Started with Llama 4](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/getting_started_llama4.ipynb)
|
||||
```
|
||||
```{admonition} News
|
||||
:class: tip
|
||||
|
||||
|
|
234
docs/source/providers/external.md
Normal file
234
docs/source/providers/external.md
Normal file
|
@ -0,0 +1,234 @@
|
|||
# External Providers
|
||||
|
||||
Llama Stack supports external providers that live outside of the main codebase. This allows you to:
|
||||
- Create and maintain your own providers independently
|
||||
- Share providers with others without contributing to the main codebase
|
||||
- Keep provider-specific code separate from the core Llama Stack code
|
||||
|
||||
## Configuration
|
||||
|
||||
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
||||
|
||||
```yaml
|
||||
external_providers_dir: /etc/llama-stack/providers.d/
|
||||
```
|
||||
|
||||
## Directory Structure
|
||||
|
||||
The external providers directory should follow this structure:
|
||||
|
||||
```
|
||||
providers.d/
|
||||
remote/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
inline/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
```
|
||||
|
||||
Each YAML file in these directories defines a provider specification for that particular API.
|
||||
|
||||
## Provider Types
|
||||
|
||||
Llama Stack supports two types of external providers:
|
||||
|
||||
1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs)
|
||||
2. **Inline Providers**: Providers that run locally within the Llama Stack process
|
||||
|
||||
## Known External Providers
|
||||
|
||||
Here's a list of known external providers that you can use with Llama Stack:
|
||||
|
||||
| Type | Name | Description | Repository |
|
||||
|------|------|-------------|------------|
|
||||
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||
|
||||
### Remote Provider Specification
|
||||
|
||||
Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider:
|
||||
|
||||
```yaml
|
||||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages:
|
||||
- ollama
|
||||
- aiohttp
|
||||
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
|
||||
module: llama_stack_ollama_provider
|
||||
api_dependencies: []
|
||||
optional_api_dependencies: []
|
||||
```
|
||||
|
||||
#### Adapter Configuration
|
||||
|
||||
The `adapter` section defines how to load and configure the provider:
|
||||
|
||||
- `adapter_type`: A unique identifier for this adapter
|
||||
- `pip_packages`: List of Python packages required by the provider
|
||||
- `config_class`: The full path to the configuration class
|
||||
- `module`: The Python module containing the provider implementation
|
||||
|
||||
### Inline Provider Specification
|
||||
|
||||
Inline providers run locally within the Llama Stack process. Here's an example for a custom vector store provider:
|
||||
|
||||
```yaml
|
||||
module: llama_stack_vector_provider
|
||||
config_class: llama_stack_vector_provider.config.VectorStoreConfig
|
||||
pip_packages:
|
||||
- faiss-cpu
|
||||
- numpy
|
||||
api_dependencies:
|
||||
- inference
|
||||
optional_api_dependencies:
|
||||
- vector_io
|
||||
provider_data_validator: llama_stack_vector_provider.validator.VectorStoreValidator
|
||||
container_image: custom-vector-store:latest # optional
|
||||
```
|
||||
|
||||
#### Inline Provider Fields
|
||||
|
||||
- `module`: The Python module containing the provider implementation
|
||||
- `config_class`: The full path to the configuration class
|
||||
- `pip_packages`: List of Python packages required by the provider
|
||||
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
|
||||
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
|
||||
- `provider_data_validator`: Optional validator for provider data
|
||||
- `container_image`: Optional container image to use instead of pip packages
|
||||
|
||||
## Required Implementation
|
||||
|
||||
### Remote Providers
|
||||
|
||||
Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments:
|
||||
1. `config`: An instance of the provider's config class
|
||||
2. `deps`: A dictionary of API dependencies
|
||||
|
||||
This function must return an instance of the provider's adapter class that implements the required protocol for the API.
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def get_adapter_impl(
|
||||
config: OllamaImplConfig, deps: Dict[Api, Any]
|
||||
) -> OllamaInferenceAdapter:
|
||||
return OllamaInferenceAdapter(config)
|
||||
```
|
||||
|
||||
### Inline Providers
|
||||
|
||||
Inline providers must expose a `get_provider_impl()` function in their module that takes two arguments:
|
||||
1. `config`: An instance of the provider's config class
|
||||
2. `deps`: A dictionary of API dependencies
|
||||
|
||||
Example:
|
||||
```python
|
||||
async def get_provider_impl(
|
||||
config: VectorStoreConfig, deps: Dict[Api, Any]
|
||||
) -> VectorStoreImpl:
|
||||
impl = VectorStoreImpl(config, deps[Api.inference])
|
||||
await impl.initialize()
|
||||
return impl
|
||||
```
|
||||
|
||||
## Dependencies
|
||||
|
||||
The provider package must be installed on the system. For example:
|
||||
|
||||
```bash
|
||||
$ uv pip show llama-stack-ollama-provider
|
||||
Name: llama-stack-ollama-provider
|
||||
Version: 0.1.0
|
||||
Location: /path/to/venv/lib/python3.10/site-packages
|
||||
```
|
||||
|
||||
## Example: Custom Ollama Provider
|
||||
|
||||
Here's a complete example of creating and using a custom Ollama provider:
|
||||
|
||||
1. First, create the provider package:
|
||||
|
||||
```bash
|
||||
mkdir -p llama-stack-provider-ollama
|
||||
cd llama-stack-provider-ollama
|
||||
git init
|
||||
uv init
|
||||
```
|
||||
|
||||
2. Edit `pyproject.toml`:
|
||||
|
||||
```toml
|
||||
[project]
|
||||
name = "llama-stack-provider-ollama"
|
||||
version = "0.1.0"
|
||||
description = "Ollama provider for Llama Stack"
|
||||
requires-python = ">=3.10"
|
||||
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
|
||||
```
|
||||
|
||||
3. Create the provider specification:
|
||||
|
||||
```yaml
|
||||
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
|
||||
adapter:
|
||||
adapter_type: custom_ollama
|
||||
pip_packages: ["ollama", "aiohttp"]
|
||||
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||
module: llama_stack_provider_ollama
|
||||
api_dependencies: []
|
||||
optional_api_dependencies: []
|
||||
```
|
||||
|
||||
4. Install the provider:
|
||||
|
||||
```bash
|
||||
uv pip install -e .
|
||||
```
|
||||
|
||||
5. Configure Llama Stack to use external providers:
|
||||
|
||||
```yaml
|
||||
external_providers_dir: /etc/llama-stack/providers.d/
|
||||
```
|
||||
|
||||
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
||||
|
||||
## Best Practices
|
||||
|
||||
1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable.
|
||||
|
||||
2. **Version Management**: Keep your provider package versioned and compatible with the Llama Stack version you're using.
|
||||
|
||||
3. **Dependencies**: Only include the minimum required dependencies in your provider package.
|
||||
|
||||
4. **Documentation**: Include clear documentation in your provider package about:
|
||||
- Installation requirements
|
||||
- Configuration options
|
||||
- Usage examples
|
||||
- Any limitations or known issues
|
||||
|
||||
5. **Testing**: Include tests in your provider package to ensure it works correctly with Llama Stack.
|
||||
You can refer to the [integration tests
|
||||
guide](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more
|
||||
information. Execute the test for the Provider type you are developing.
|
||||
|
||||
## Troubleshooting
|
||||
|
||||
If your external provider isn't being loaded:
|
||||
|
||||
1. Check that the `external_providers_dir` path is correct and accessible.
|
||||
2. Verify that the YAML files are properly formatted.
|
||||
3. Ensure all required Python packages are installed.
|
||||
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more
|
||||
information using `LLAMA_STACK_LOGGING=all=debug`.
|
||||
5. Verify that the provider package is installed in your Python environment.
|
|
@ -11,6 +11,10 @@ Providers come in two flavors:
|
|||
|
||||
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||
|
||||
## External Providers
|
||||
|
||||
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently. See the [External Providers Guide](external) for details.
|
||||
|
||||
## Agents
|
||||
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
||||
|
||||
|
@ -50,6 +54,7 @@ The following providers (i.e., databases) are available for Vector IO:
|
|||
```{toctree}
|
||||
:maxdepth: 1
|
||||
|
||||
external
|
||||
vector_io/faiss
|
||||
vector_io/sqlite-vec
|
||||
vector_io/chromadb
|
||||
|
|
|
@ -25,15 +25,64 @@ from llama_stack.apis.models import Model
|
|||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||
|
||||
register_schema(ToolCall)
|
||||
register_schema(ToolParamDefinition)
|
||||
register_schema(ToolDefinition)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
type: Literal["greedy"] = "greedy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: Optional[float] = Field(..., gt=0.0)
|
||||
top_p: Optional[float] = 0.95
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopKSamplingStrategy(BaseModel):
|
||||
type: Literal["top_k"] = "top_k"
|
||||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SamplingParams(BaseModel):
|
||||
"""Sampling parameters.
|
||||
|
||||
:param strategy: The sampling strategy.
|
||||
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
||||
your prompt plus max_tokens cannot exceed the model's context length.
|
||||
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
||||
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
||||
The returned text will not contain the stop sequence.
|
||||
"""
|
||||
|
||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||
|
||||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class LogProbConfig(BaseModel):
|
||||
"""
|
||||
|
@ -48,18 +97,18 @@ class QuantizationType(Enum):
|
|||
"""Type of model quantization to run inference with.
|
||||
|
||||
:cvar bf16: BFloat16 typically this means _no_ quantization
|
||||
:cvar fp8: 8-bit floating point quantization
|
||||
:cvar int4: 4-bit integer quantization
|
||||
:cvar fp8_mixed: 8-bit floating point quantization with mixed precision
|
||||
:cvar int4_mixed: 4-bit integer quantization with mixed precision
|
||||
"""
|
||||
|
||||
bf16 = "bf16"
|
||||
fp8 = "fp8"
|
||||
int4 = "int4"
|
||||
fp8_mixed = "fp8_mixed"
|
||||
int4_mixed = "int4_mixed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class Fp8QuantizationConfig(BaseModel):
|
||||
type: Literal["fp8"] = "fp8"
|
||||
type: Literal["fp8_mixed"] = "fp8_mixed"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel):
|
|||
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
|
||||
"""
|
||||
|
||||
type: Literal["int4"] = "int4"
|
||||
type: Literal["int4_mixed"] = "int4_mixed"
|
||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||
|
||||
|
||||
|
|
|
@ -29,8 +29,8 @@ from rich.progress import (
|
|||
from termcolor import cprint
|
||||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
|
||||
|
||||
class Download(Subcommand):
|
||||
|
|
|
@ -63,17 +63,6 @@ class ModelDescribe(Subcommand):
|
|||
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||
]
|
||||
|
||||
if model.recommended_sampling_params is not None:
|
||||
sampling_params = model.recommended_sampling_params.model_dump()
|
||||
for k in ("max_tokens", "repetition_penalty"):
|
||||
del sampling_params[k]
|
||||
rows.append(
|
||||
(
|
||||
"Recommended sampling params",
|
||||
json.dumps(sampling_params, indent=4),
|
||||
)
|
||||
)
|
||||
|
||||
print_table(
|
||||
rows,
|
||||
headers,
|
||||
|
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
|||
|
||||
from llama_stack.cli.subcommand import Subcommand
|
||||
from llama_stack.cli.table import print_table
|
||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||
|
||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||
|
||||
|
|
|
@ -4,12 +4,12 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
from typing import Any, Dict
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
|
||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
|
||||
|
||||
|
||||
class PromptGuardModel(BaseModel):
|
||||
|
@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
|
|||
is_instruct_model: bool = False
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
||||
recommended_sampling_params: Optional[SamplingParams] = None
|
||||
|
||||
def descriptor(self) -> str:
|
||||
return self.model_id
|
||||
|
|
|
@ -312,6 +312,11 @@ a default SQLite store will be used.""",
|
|||
description="Configuration for the HTTP(S) server",
|
||||
)
|
||||
|
||||
external_providers_dir: Optional[str] = Field(
|
||||
default=None,
|
||||
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||
)
|
||||
|
||||
|
||||
class BuildConfig(BaseModel):
|
||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||
|
|
|
@ -4,12 +4,25 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import glob
|
||||
import importlib
|
||||
from typing import Dict, List
|
||||
import os
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import yaml
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
||||
from llama_stack.distribution.datatypes import StackRunConfig
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.providers.datatypes import (
|
||||
AdapterSpec,
|
||||
Api,
|
||||
InlineProviderSpec,
|
||||
ProviderSpec,
|
||||
remote_provider_spec,
|
||||
)
|
||||
|
||||
logger = get_logger(name=__name__, category="core")
|
||||
|
||||
|
||||
def stack_apis() -> List[Api]:
|
||||
|
@ -59,11 +72,116 @@ def providable_apis() -> List[Api]:
|
|||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||
|
||||
|
||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
ret = {}
|
||||
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
|
||||
adapter = AdapterSpec(**spec_data["adapter"])
|
||||
spec = remote_provider_spec(
|
||||
api=api,
|
||||
adapter=adapter,
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||
spec = InlineProviderSpec(
|
||||
api=api,
|
||||
provider_type=f"inline::{provider_name}",
|
||||
pip_packages=spec_data.get("pip_packages", []),
|
||||
module=spec_data["module"],
|
||||
config_class=spec_data["config_class"],
|
||||
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
||||
provider_data_validator=spec_data.get("provider_data_validator"),
|
||||
container_image=spec_data.get("container_image"),
|
||||
)
|
||||
return spec
|
||||
|
||||
|
||||
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||
"""Get the provider registry, optionally including external providers.
|
||||
|
||||
This function loads both built-in providers and external providers from YAML files.
|
||||
External providers are loaded from a directory structure like:
|
||||
|
||||
providers.d/
|
||||
remote/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
inline/
|
||||
inference/
|
||||
custom_ollama.yaml
|
||||
vllm.yaml
|
||||
vector_io/
|
||||
qdrant.yaml
|
||||
safety/
|
||||
llama-guard.yaml
|
||||
|
||||
Args:
|
||||
config: Optional StackRunConfig containing the external providers directory path
|
||||
|
||||
Returns:
|
||||
A dictionary mapping APIs to their available providers
|
||||
|
||||
Raises:
|
||||
FileNotFoundError: If the external providers directory doesn't exist
|
||||
ValueError: If any provider spec is invalid
|
||||
"""
|
||||
|
||||
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
|
||||
for api in providable_apis():
|
||||
name = api.name.lower()
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
logger.debug(f"Importing module {name}")
|
||||
try:
|
||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||
except ImportError as e:
|
||||
logger.warning(f"Failed to import module {name}: {e}")
|
||||
|
||||
if config and config.external_providers_dir:
|
||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||
if not os.path.exists(external_providers_dir):
|
||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||
|
||||
for api in providable_apis():
|
||||
api_name = api.name.lower()
|
||||
|
||||
# Process both remote and inline providers
|
||||
for provider_type in ["remote", "inline"]:
|
||||
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||
if not os.path.exists(api_dir):
|
||||
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||
continue
|
||||
|
||||
# Look for provider spec files in the API directory
|
||||
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||
|
||||
try:
|
||||
with open(spec_path) as f:
|
||||
spec_data = yaml.safe_load(f)
|
||||
|
||||
if provider_type == "remote":
|
||||
spec = _load_remote_provider_spec(spec_data, api)
|
||||
provider_type_key = f"remote::{provider_name}"
|
||||
else:
|
||||
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||
provider_type_key = f"inline::{provider_name}"
|
||||
|
||||
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||
if provider_type_key in ret[api]:
|
||||
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||
ret[api][provider_type_key] = spec
|
||||
except yaml.YAMLError as yaml_err:
|
||||
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||
raise yaml_err
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||
raise e
|
||||
return ret
|
||||
|
|
|
@ -351,6 +351,7 @@ async def instantiate_provider(
|
|||
if not hasattr(provider_spec, "module"):
|
||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||
|
||||
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||
module = importlib.import_module(provider_spec.module)
|
||||
args = []
|
||||
if isinstance(provider_spec, RemoteProviderSpec):
|
||||
|
|
|
@ -608,8 +608,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
|||
tool_group = await self.get_tool_group(toolgroup_id)
|
||||
if tool_group is None:
|
||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||
tools = (await self.list_tools(toolgroup_id)).data
|
||||
for tool in tools:
|
||||
tools = await self.list_tools(toolgroup_id)
|
||||
for tool in getattr(tools, "data", []):
|
||||
await self.unregister_object(tool)
|
||||
await self.unregister_object(tool_group)
|
||||
|
||||
|
|
|
@ -218,7 +218,7 @@ async def construct_stack(
|
|||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||
) -> Dict[Api, Any]:
|
||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
|
||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||
await register_resources(run_config, impls)
|
||||
return impls
|
||||
|
||||
|
|
|
@ -1,7 +1,7 @@
|
|||
# More info on playground configuration can be found here:
|
||||
# https://llama-stack.readthedocs.io/en/latest/playground
|
||||
|
||||
FROM python:3.9-slim
|
||||
FROM python:3.12-slim
|
||||
WORKDIR /app
|
||||
COPY . /app/
|
||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||
|
|
|
@ -24,6 +24,7 @@ def main():
|
|||
# Playground pages
|
||||
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
|
||||
|
||||
# Distribution pages
|
||||
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
||||
|
@ -39,6 +40,7 @@ def main():
|
|||
"Playground": [
|
||||
chat_page,
|
||||
rag_page,
|
||||
tool_page,
|
||||
application_evaluation_page,
|
||||
native_evaluation_page,
|
||||
],
|
||||
|
|
|
@ -4,6 +4,8 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||
|
||||
|
@ -102,8 +104,8 @@ def rag_chat_page():
|
|||
|
||||
# Add clear chat button to sidebar
|
||||
if st.button("Clear Chat", use_container_width=True):
|
||||
st.session_state.messages = []
|
||||
st.rerun()
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
# Chat Interface
|
||||
if "messages" not in st.session_state:
|
||||
|
@ -123,23 +125,31 @@ def rag_chat_page():
|
|||
else:
|
||||
strategy = {"type": "greedy"}
|
||||
|
||||
agent = Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
session_id = agent.create_session("rag-session")
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
llama_stack_api.client,
|
||||
model=selected_model,
|
||||
instructions=system_prompt,
|
||||
sampling_params={
|
||||
"strategy": strategy,
|
||||
},
|
||||
tools=[
|
||||
dict(
|
||||
name="builtin::rag/knowledge_search",
|
||||
args={
|
||||
"vector_db_ids": list(selected_vector_dbs),
|
||||
},
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
agent = create_agent()
|
||||
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
# Chat input
|
||||
if prompt := st.chat_input("Ask a question about your documents"):
|
||||
|
|
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
|
@ -0,0 +1,116 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import uuid
|
||||
|
||||
import streamlit as st
|
||||
from llama_stack_client import Agent
|
||||
|
||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||
|
||||
|
||||
def tool_chat_page():
|
||||
st.title("🛠 Tools")
|
||||
|
||||
client = llama_stack_api.client
|
||||
models = client.models.list()
|
||||
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
|
||||
|
||||
tool_groups = client.toolgroups.list()
|
||||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||
|
||||
def reset_agent():
|
||||
st.session_state.clear()
|
||||
st.cache_resource.clear()
|
||||
|
||||
with st.sidebar:
|
||||
st.subheader("Model")
|
||||
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
|
||||
|
||||
st.subheader("Builtin Tools")
|
||||
toolgroup_selection = st.pills(
|
||||
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
|
||||
)
|
||||
|
||||
st.subheader("MCP Servers")
|
||||
mcp_selection = st.pills(
|
||||
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
|
||||
)
|
||||
|
||||
toolgroup_selection.extend(mcp_selection)
|
||||
|
||||
active_tool_list = []
|
||||
for toolgroup_id in toolgroup_selection:
|
||||
active_tool_list.extend(
|
||||
[
|
||||
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
||||
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
||||
]
|
||||
)
|
||||
|
||||
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
||||
st.json(active_tool_list)
|
||||
|
||||
@st.cache_resource
|
||||
def create_agent():
|
||||
return Agent(
|
||||
client,
|
||||
model=model,
|
||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||
tools=toolgroup_selection,
|
||||
sampling_params={
|
||||
"strategy": {"type": "greedy"},
|
||||
},
|
||||
)
|
||||
|
||||
agent = create_agent()
|
||||
|
||||
if "agent_session_id" not in st.session_state:
|
||||
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
|
||||
|
||||
session_id = st.session_state["agent_session_id"]
|
||||
|
||||
if "messages" not in st.session_state:
|
||||
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||
|
||||
for msg in st.session_state.messages:
|
||||
with st.chat_message(msg["role"]):
|
||||
st.markdown(msg["content"])
|
||||
|
||||
if prompt := st.chat_input(placeholder=""):
|
||||
with st.chat_message("user"):
|
||||
st.markdown(prompt)
|
||||
|
||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||
|
||||
turn_response = agent.create_turn(
|
||||
session_id=session_id,
|
||||
messages=[{"role": "user", "content": prompt}],
|
||||
stream=True,
|
||||
)
|
||||
|
||||
def response_generator(turn_response):
|
||||
for response in turn_response:
|
||||
if hasattr(response.event, "payload"):
|
||||
print(response.event.payload)
|
||||
if response.event.payload.event_type == "step_progress":
|
||||
if hasattr(response.event.payload.delta, "text"):
|
||||
yield response.event.payload.delta.text
|
||||
if response.event.payload.event_type == "step_complete":
|
||||
if response.event.payload.step_details.step_type == "tool_execution":
|
||||
yield " 🛠 "
|
||||
else:
|
||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||
|
||||
with st.chat_message("assistant"):
|
||||
response = st.write_stream(response_generator(turn_response))
|
||||
|
||||
st.session_state.messages.append({"role": "assistant", "content": response})
|
||||
|
||||
|
||||
tool_chat_page()
|
|
@ -2,3 +2,4 @@ streamlit
|
|||
pandas
|
||||
llama-stack-client>=0.0.55
|
||||
streamlit-option-menu
|
||||
llama-stack>=0.1.9
|
||||
|
|
|
@ -29,6 +29,11 @@ def preserve_contexts_async_generator(
|
|||
context_var.set(initial_context_values[context_var.name])
|
||||
|
||||
item = await gen.__anext__()
|
||||
|
||||
# Update our tracked values with any changes made during this iteration
|
||||
for context_var in context_vars:
|
||||
initial_context_values[context_var.name] = context_var.get()
|
||||
|
||||
yield item
|
||||
|
||||
except StopAsyncIteration:
|
||||
|
|
164
llama_stack/models/llama/checkpoint.py
Normal file
164
llama_stack/models/llama/checkpoint.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import concurrent.futures
|
||||
import re
|
||||
from pathlib import Path
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import numpy as np
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
||||
|
||||
|
||||
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
|
||||
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
||||
if new_mp_size % old_mp_size == 0:
|
||||
# Read old MP shard and split it into smaller ones
|
||||
return [new_mp_rank * old_mp_size // new_mp_size]
|
||||
elif old_mp_size % new_mp_size == 0:
|
||||
# Merge old MP shards into a single one
|
||||
mp_factor = old_mp_size // new_mp_size
|
||||
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
|
||||
else:
|
||||
raise ValueError(
|
||||
f"Either old MP size or new MP size should be a multiple of the other: "
|
||||
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
|
||||
)
|
||||
|
||||
|
||||
def maybe_reshard_state_dict(
|
||||
ckpt_paths: List[Path],
|
||||
n_kv_heads: int,
|
||||
moe_num_experts: Optional[int] = None,
|
||||
map_location: Union[str, torch.device] = "cpu",
|
||||
mmap: bool = True,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
if str(map_location) == "cpu":
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
||||
ckpt_paths = np.array(sorted(ckpt_paths))
|
||||
|
||||
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
||||
old_mp_size = len(ckpt_paths)
|
||||
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
||||
|
||||
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
|
||||
paths = ckpt_paths[old_mp_ranks] # type: ignore
|
||||
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
||||
|
||||
if new_mp_size == old_mp_size:
|
||||
return state_dicts[0] # type: ignore
|
||||
|
||||
if moe_num_experts is not None:
|
||||
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
|
||||
|
||||
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
|
||||
return reshard_mp(
|
||||
state_dicts,
|
||||
size=max(new_mp_size // old_mp_size, 1),
|
||||
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
|
||||
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
|
||||
)
|
||||
|
||||
|
||||
_WEIGHT_ROW_KEY = {
|
||||
"feed_forward.w2",
|
||||
"feed_forward.mlp.fc2",
|
||||
"attention.wo",
|
||||
"feed_forward.mlp.fc2_weight",
|
||||
"feed_forward.w_out_shared_DF.weight",
|
||||
"attn.wo.weight",
|
||||
"mlp.c_proj.weight",
|
||||
}
|
||||
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
|
||||
|
||||
_WEIGHT_COLUMN_KEY = {
|
||||
"output",
|
||||
"feed_forward.(w1|w3)",
|
||||
"feed_forward.mlp.(fc1|fc3)",
|
||||
"feed_forward.mlp.fc1_weight",
|
||||
"attention.(wk|wq|wv|wqkv).weight",
|
||||
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
|
||||
"attn.(wk|wq|wv).weight",
|
||||
"attn.(wk|wq|wv).bias",
|
||||
"mlp.c_fc.weight",
|
||||
"mlp.c_fc.bias",
|
||||
"conv1._linear.weight",
|
||||
"tok_embeddings.weight",
|
||||
"vision_projection.weight",
|
||||
}
|
||||
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
|
||||
|
||||
|
||||
def reshard_mp(
|
||||
state_dicts: List[Dict[str, torch.Tensor]],
|
||||
size: int,
|
||||
rank: int,
|
||||
repeat_qk_qv: int = 1,
|
||||
) -> Dict[str, torch.Tensor]:
|
||||
"""
|
||||
Reshard a list of state dicts into a single state dict given a change in MP size.
|
||||
If the list has more than one state dict, we concatenate the values of the same
|
||||
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
||||
"""
|
||||
|
||||
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
|
||||
if len(tensors) > 1:
|
||||
return torch.cat(tensors, dim=dim)
|
||||
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
||||
|
||||
def process_key(key: str) -> torch.Tensor:
|
||||
if row_regex.search(key):
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
|
||||
elif column_regex.search(key):
|
||||
if "w13" in key or "fc1_weight" in key:
|
||||
dims = state_dicts[0][key].size()
|
||||
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
|
||||
return concat_or_chunk(values, dim=1).flatten(0, 1)
|
||||
elif "qkv" in key:
|
||||
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
|
||||
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
|
||||
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
|
||||
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
|
||||
elif "wk.weight" in key or "wv.weight" in key:
|
||||
# Support MP > #kv_head
|
||||
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
|
||||
elif key == "output.bias" or key == "fc.weight":
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||
elif "w_" in key:
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
|
||||
else:
|
||||
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||
else:
|
||||
return state_dicts[0][key].clone()
|
||||
|
||||
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
|
||||
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||
|
||||
column_regex = re.compile("|".join(column_keys))
|
||||
row_regex = re.compile("|".join(row_keys))
|
||||
|
||||
output: Dict[str, torch.Tensor] = {}
|
||||
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||
# Note: only processes keys in the first state dict.
|
||||
# Assumes keys are the same across all state dicts.
|
||||
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
|
||||
for future in concurrent.futures.as_completed(mappings):
|
||||
output[mappings[future]] = future.result()
|
||||
return output
|
||||
|
||||
|
||||
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
|
||||
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||
routed_regex = re.compile("|".join(routed_keys))
|
||||
keys = list(state_dict.keys())
|
||||
for key in keys:
|
||||
if routed_regex.search(key):
|
||||
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
|
||||
return state_dict
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import base64
|
||||
from enum import Enum
|
||||
from io import BytesIO
|
||||
|
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
|||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
||||
|
||||
# The goal is that these set of types are relevant for all Llama models.
|
||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||
# the llama3 series of models.
|
||||
|
@ -98,6 +89,29 @@ class StopReason(Enum):
|
|||
out_of_tokens = "out_of_tokens"
|
||||
|
||||
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = True
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
class RawMediaItem(BaseModel):
|
||||
type: Literal["image"] = "image"
|
||||
data: bytes | BytesIO
|
||||
|
@ -140,292 +154,25 @@ class RawMessage(BaseModel):
|
|||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||
|
||||
|
||||
register_schema(ToolCall)
|
||||
class GenerationResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
source: Literal["input"] | Literal["output"]
|
||||
|
||||
# index within the batch
|
||||
batch_idx: int
|
||||
# whether generation for this item is already finished. note that tokens can
|
||||
# get returned even afterwards since other items in the batch can still be generating tokens
|
||||
finished: bool
|
||||
# because a batch is parallel processed, useful decoding for one item can correspond to processing
|
||||
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
|
||||
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
|
||||
ignore_token: bool
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolParamDefinition(BaseModel):
|
||||
param_type: str
|
||||
description: Optional[str] = None
|
||||
required: Optional[bool] = True
|
||||
default: Optional[Any] = None
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class ToolDefinition(BaseModel):
|
||||
tool_name: Union[BuiltinTool, str]
|
||||
description: Optional[str] = None
|
||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||
|
||||
@field_validator("tool_name", mode="before")
|
||||
@classmethod
|
||||
def validate_field(cls, v):
|
||||
if isinstance(v, str):
|
||||
try:
|
||||
return BuiltinTool(v)
|
||||
except ValueError:
|
||||
return v
|
||||
return v
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GreedySamplingStrategy(BaseModel):
|
||||
type: Literal["greedy"] = "greedy"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopPSamplingStrategy(BaseModel):
|
||||
type: Literal["top_p"] = "top_p"
|
||||
temperature: Optional[float] = Field(..., gt=0.0)
|
||||
top_p: Optional[float] = 0.95
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class TopKSamplingStrategy(BaseModel):
|
||||
type: Literal["top_k"] = "top_k"
|
||||
top_k: int = Field(..., ge=1)
|
||||
|
||||
|
||||
SamplingStrategy = Annotated[
|
||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||
Field(discriminator="type"),
|
||||
]
|
||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class SamplingParams(BaseModel):
|
||||
"""Sampling parameters.
|
||||
|
||||
:param strategy: The sampling strategy.
|
||||
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
||||
your prompt plus max_tokens cannot exceed the model's context length.
|
||||
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
||||
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
||||
The returned text will not contain the stop sequence.
|
||||
"""
|
||||
|
||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||
|
||||
max_tokens: Optional[int] = 0
|
||||
repetition_penalty: Optional[float] = 1.0
|
||||
stop: Optional[List[str]] = None
|
||||
|
||||
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8-mixed"
|
||||
|
||||
int8 = "int8"
|
||||
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
class ModelFamily(Enum):
|
||||
llama2 = "llama2"
|
||||
llama3 = "llama3"
|
||||
llama3_1 = "llama3_1"
|
||||
llama3_2 = "llama3_2"
|
||||
llama3_3 = "llama3_3"
|
||||
llama4 = "llama4"
|
||||
safety = "safety"
|
||||
|
||||
|
||||
class CoreModelId(Enum):
|
||||
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||
|
||||
# Llama 2 family
|
||||
llama2_7b = "Llama-2-7b"
|
||||
llama2_13b = "Llama-2-13b"
|
||||
llama2_70b = "Llama-2-70b"
|
||||
llama2_7b_chat = "Llama-2-7b-chat"
|
||||
llama2_13b_chat = "Llama-2-13b-chat"
|
||||
llama2_70b_chat = "Llama-2-70b-chat"
|
||||
|
||||
# Llama 3 family
|
||||
llama3_8b = "Llama-3-8B"
|
||||
llama3_70b = "Llama-3-70B"
|
||||
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||
|
||||
# Llama 3.1 family
|
||||
llama3_1_8b = "Llama3.1-8B"
|
||||
llama3_1_70b = "Llama3.1-70B"
|
||||
llama3_1_405b = "Llama3.1-405B"
|
||||
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||
|
||||
# Llama 3.2 family
|
||||
llama3_2_1b = "Llama3.2-1B"
|
||||
llama3_2_3b = "Llama3.2-3B"
|
||||
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||
|
||||
# Llama 3.3 family
|
||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||
|
||||
# Llama 4 family
|
||||
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
|
||||
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
|
||||
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
|
||||
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
|
||||
|
||||
# Safety models
|
||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||
|
||||
|
||||
def is_multimodal(model_id) -> bool:
|
||||
if model_id in [
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def model_family(model_id) -> ModelFamily:
|
||||
if model_id in [
|
||||
CoreModelId.llama2_7b,
|
||||
CoreModelId.llama2_13b,
|
||||
CoreModelId.llama2_70b,
|
||||
CoreModelId.llama2_7b_chat,
|
||||
CoreModelId.llama2_13b_chat,
|
||||
CoreModelId.llama2_70b_chat,
|
||||
]:
|
||||
return ModelFamily.llama2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_8b,
|
||||
CoreModelId.llama3_70b,
|
||||
CoreModelId.llama3_8b_instruct,
|
||||
CoreModelId.llama3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_1_8b,
|
||||
CoreModelId.llama3_1_70b,
|
||||
CoreModelId.llama3_1_405b,
|
||||
CoreModelId.llama3_1_8b_instruct,
|
||||
CoreModelId.llama3_1_70b_instruct,
|
||||
CoreModelId.llama3_1_405b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_1
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_2_1b,
|
||||
CoreModelId.llama3_2_3b,
|
||||
CoreModelId.llama3_2_1b_instruct,
|
||||
CoreModelId.llama3_2_3b_instruct,
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_3
|
||||
elif model_id in [
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_scout_17b_16e_instruct,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
]:
|
||||
return ModelFamily.llama4
|
||||
elif model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_2_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return ModelFamily.safety
|
||||
else:
|
||||
raise ValueError(f"Unknown model family for {model_id}")
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
core_model_id: CoreModelId
|
||||
description: str
|
||||
huggingface_repo: Optional[str] = None
|
||||
recommended_sampling_params: Optional[SamplingParams] = None
|
||||
arch_args: Dict[str, Any]
|
||||
variant: str = ""
|
||||
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
pth_file_count: int
|
||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
||||
|
||||
# silence pydantic until we remove the `model_` fields
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def model_family(self) -> ModelFamily:
|
||||
return model_family(self.core_model_id)
|
||||
|
||||
# The SKU is uniquely identified by (model_id, variant) combo
|
||||
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||
if not self.variant:
|
||||
return self.core_model_id.value
|
||||
return f"{self.core_model_id.value}:{self.variant}"
|
||||
|
||||
@property
|
||||
def is_instruct_model(self) -> bool:
|
||||
return "instruct" in self.id.name
|
||||
|
||||
# Featured models are shown in the non-exhaustive model list
|
||||
@property
|
||||
def is_featured(self) -> bool:
|
||||
return self.model_family in [
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
ModelFamily.safety,
|
||||
]
|
||||
|
||||
@property
|
||||
def max_seq_length(self) -> int:
|
||||
if self.model_family == ModelFamily.llama2:
|
||||
return 4096
|
||||
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||
return 4096
|
||||
elif self.model_family == ModelFamily.llama3:
|
||||
return 8192
|
||||
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama3_2:
|
||||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||
return 8192
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama4:
|
||||
if self.core_model_id in {
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
}:
|
||||
return 262144
|
||||
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
|
||||
return 10485760
|
||||
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
||||
return 1048576
|
||||
elif self.core_model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return 131072
|
||||
else:
|
||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
||||
class QuantizationMode(str, Enum):
|
||||
none = "none"
|
||||
fp8_mixed = "fp8_mixed"
|
||||
int4_mixed = "int4_mixed"
|
||||
|
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import io
|
||||
import json
|
||||
import uuid
|
||||
|
@ -19,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
|
|||
|
||||
from PIL import Image as PIL_Image
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
|
@ -30,7 +23,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from .tokenizer import Tokenizer
|
||||
from .tool_utils import ToolUtils
|
||||
|
||||
|
|
366
llama_stack/models/llama/llama3/generation.py
Normal file
366
llama_stack/models/llama/llama3/generation.py
Normal file
|
@ -0,0 +1,366 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, LLMInput
|
||||
from .model import Transformer
|
||||
from .multimodal.model import CrossAttentionTransformer
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
class Llama3:
|
||||
@staticmethod
|
||||
def build(
|
||||
ckpt_dir: str,
|
||||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
seed: int = 1,
|
||||
device: str = "cuda",
|
||||
):
|
||||
device = torch.device(device)
|
||||
if (
|
||||
device.type == "cuda"
|
||||
and not torch.cuda.is_available()
|
||||
or device.type == "xpu"
|
||||
and not torch.xpu.is_available()
|
||||
):
|
||||
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
|
||||
|
||||
if not torch.distributed.is_initialized():
|
||||
if device.type == "cuda":
|
||||
torch.distributed.init_process_group("nccl")
|
||||
else:
|
||||
torch.distributed.init_process_group("gloo")
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
if world_size is None:
|
||||
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||
initialize_model_parallel(world_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
if device.type == "cuda":
|
||||
torch.cuda.set_device(local_rank)
|
||||
elif device.type == "xpu":
|
||||
torch.xpu.set_device(local_rank)
|
||||
|
||||
torch.manual_seed(seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
|
||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=max_seq_len,
|
||||
max_batch_size=max_batch_size,
|
||||
**params,
|
||||
)
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
|
||||
state_dict = maybe_reshard_state_dict(
|
||||
ckpt_paths,
|
||||
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||
)
|
||||
|
||||
assert model_args.vocab_size == tokenizer.n_words
|
||||
|
||||
def build_model():
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
return model
|
||||
|
||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
|
||||
torch.set_default_device(device)
|
||||
else:
|
||||
print(f"Setting default device to {device}")
|
||||
if device.type == "cuda":
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
||||
elif device.type == "xpu":
|
||||
if torch.xpu.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
|
||||
|
||||
model = build_model()
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
model.to(device)
|
||||
print("Done...")
|
||||
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
return Llama3(model, tokenizer, model_args)
|
||||
|
||||
def __init__(self, model: Transformer | CrossAttentionTransformer, tokenizer: Tokenizer, args: ModelArgs):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||
max_gen_len = self.args.max_seq_len - 1
|
||||
params = self.model.params
|
||||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input:
|
||||
for inp in model_inputs:
|
||||
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||
cprint(
|
||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||
"red",
|
||||
)
|
||||
prompt_tokens = [inp.tokens for inp in model_inputs]
|
||||
|
||||
bsz = len(model_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||
|
||||
is_vision = not isinstance(self.model, Transformer)
|
||||
if is_vision:
|
||||
images = [inp.vision.images if inp.vision is not None else [] for inp in model_inputs]
|
||||
mask = [inp.vision.mask if inp.vision is not None else [] for inp in model_inputs]
|
||||
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||
batch_images=images,
|
||||
batch_masks=mask,
|
||||
total_len=total_len,
|
||||
device=tokens.device,
|
||||
)
|
||||
|
||||
eos_reached = torch.tensor([False] * bsz)
|
||||
input_text_mask = tokens != pad_id
|
||||
|
||||
if echo:
|
||||
for i in range(max_prompt_len):
|
||||
results = []
|
||||
for j, t in enumerate(tokens[:, i]):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="input",
|
||||
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||
batch_idx=j,
|
||||
finished=False,
|
||||
ignore_token=t.item() == pad_id,
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
|
||||
prev_pos = 0
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||
text_only_inference = all(inp.vision is None for inp in model_inputs)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
cross_attention_masks,
|
||||
full_text_row_masked_out_mask,
|
||||
xattn_caches,
|
||||
text_only_inference,
|
||||
)
|
||||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if is_vision:
|
||||
# the logits space (num_classes) is designed to never contain a media_token
|
||||
# however our input token stream does contain them. we need to nuke them here
|
||||
# or else the CUDA kernels will crash with an illegal memory access
|
||||
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
||||
masks = [target.eq(t) for t in vision_tokens]
|
||||
if len(masks) > 1:
|
||||
mask = torch.logical_or(*masks)
|
||||
else:
|
||||
mask = masks[0]
|
||||
target[mask] = 0
|
||||
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=target,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
results = []
|
||||
for idx, t in enumerate(next_token):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="output",
|
||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||
batch_idx=idx,
|
||||
finished=eos_reached[idx],
|
||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
def completion(
|
||||
self,
|
||||
contents: List[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages_batch: List[List[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||
echo: bool = False,
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
model_inputs=model_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
|
@ -16,7 +16,7 @@ from typing import List, Optional
|
|||
|
||||
from termcolor import colored
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawMessage,
|
||||
StopReason,
|
||||
|
@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
|
||||
from . import template_data
|
||||
from .chat_format import ChatFormat
|
||||
from .prompt_templates import (
|
||||
|
|
|
@ -4,16 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import math
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
@ -29,6 +19,10 @@ from torch import nn
|
|||
|
||||
from .args import ModelArgs
|
||||
|
||||
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
|
||||
# dependencies. These dependencies are not part of the default dependencies
|
||||
# (requirements.txt) of the `llama-models` package.
|
||||
|
||||
|
||||
class RMSNorm(torch.nn.Module):
|
||||
def __init__(self, dim: int, eps: float = 1e-6):
|
||||
|
@ -111,9 +105,9 @@ class Attention(nn.Module):
|
|||
def __init__(self, args: ModelArgs):
|
||||
super().__init__()
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
|
|
@ -4,16 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
||||
|
||||
import logging
|
||||
import math
|
||||
from functools import partial
|
||||
|
@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
|
|||
n_heads,
|
||||
):
|
||||
super().__init__()
|
||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
qkvo_replication = 1
|
||||
if model_parallel_size > 16:
|
||||
qkvo_replication = model_parallel_size // 8
|
||||
if world_size > 16:
|
||||
qkvo_replication = world_size // 8
|
||||
|
||||
self.n_kv_heads = n_heads
|
||||
self.n_local_heads = n_heads * qkvo_replication // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size
|
||||
self.n_local_heads = n_heads * qkvo_replication // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = dim // n_heads
|
||||
|
||||
|
@ -536,16 +526,16 @@ class Attention(nn.Module):
|
|||
cache_v (torch.Tensor): Cached values for attention.
|
||||
"""
|
||||
super().__init__()
|
||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
world_size = fs_init.get_model_parallel_world_size()
|
||||
replication_factor = 1
|
||||
if model_parallel_size > 8:
|
||||
replication_factor = model_parallel_size // MP_SCALE
|
||||
if world_size > 8:
|
||||
replication_factor = world_size // MP_SCALE
|
||||
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
self.n_kv_heads *= replication_factor
|
||||
|
||||
self.n_local_heads = args.n_heads // model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
||||
self.n_local_heads = args.n_heads // world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.max_seq_len = args.max_seq_len
|
||||
|
@ -587,13 +577,11 @@ class Attention(nn.Module):
|
|||
self.n_local_kv_heads,
|
||||
self.head_dim,
|
||||
)
|
||||
device = next(self.parameters()).device
|
||||
self.register_buffer(
|
||||
"key_cache",
|
||||
torch.zeros(
|
||||
cache_shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
@ -602,7 +590,6 @@ class Attention(nn.Module):
|
|||
torch.zeros(
|
||||
cache_shape,
|
||||
dtype=dtype,
|
||||
device=device,
|
||||
),
|
||||
persistent=False,
|
||||
)
|
||||
|
@ -614,6 +601,9 @@ class Attention(nn.Module):
|
|||
freqs_cis: torch.Tensor,
|
||||
position_ids: torch.LongTensor,
|
||||
):
|
||||
self.key_cache = self.key_cache.to(x.device)
|
||||
self.value_cache = self.value_cache.to(x.device)
|
||||
|
||||
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
|
||||
|
||||
bs, slen, _ = xq.shape
|
||||
|
@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
|
|||
norm_eps: float,
|
||||
):
|
||||
super().__init__()
|
||||
self.model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
self.world_size = fs_init.get_model_parallel_world_size()
|
||||
replication_factor = 1
|
||||
if self.model_parallel_size > 8:
|
||||
replication_factor = self.model_parallel_size // MP_SCALE
|
||||
if self.world_size > 8:
|
||||
replication_factor = self.world_size // MP_SCALE
|
||||
n_kv_heads *= replication_factor
|
||||
|
||||
assert n_heads % n_kv_heads == 0
|
||||
|
@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
|
|||
# trunk LLM (i.e., group query attention) -- @dubeya
|
||||
# local heads
|
||||
assert self.n_heads % self.n_kv_heads == 0
|
||||
assert self.n_heads % self.model_parallel_size == 0
|
||||
assert self.n_kv_heads % self.model_parallel_size == 0
|
||||
self.n_local_heads = self.n_heads // self.model_parallel_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
|
||||
assert self.n_heads % self.world_size == 0
|
||||
assert self.n_kv_heads % self.world_size == 0
|
||||
self.n_local_heads = self.n_heads // self.world_size
|
||||
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||
|
||||
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
|
||||
|
@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
|
|||
self.image_res = args.vision_chunk_size
|
||||
self.max_num_chunks = args.vision_max_num_chunks
|
||||
if return_intermediate is not None:
|
||||
return_intermediate = [int(level) for level in return_intermediate.split(",")]
|
||||
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
|
||||
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
|
||||
self.patch_size = 14
|
||||
self.vision_encoder = VisionEncoder(
|
||||
|
@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
|
||||
def __init__(self, args: ModelArgs) -> None:
|
||||
super().__init__()
|
||||
self.model_parallel_size = fs_init.get_model_parallel_world_size()
|
||||
self.world_size = fs_init.get_model_parallel_world_size()
|
||||
assert args.vocab_size > 0
|
||||
self.vocab_size = args.vocab_size
|
||||
self.n_layers = args.n_layers
|
||||
self.dim = args.dim
|
||||
self.head_dim = args.dim // args.n_heads
|
||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
|
||||
assert self.vocab_size % self.model_parallel_size == 0
|
||||
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
||||
assert self.vocab_size % self.world_size == 0
|
||||
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
||||
self.pos_embeddings = None
|
||||
# final norm layer (not necessary for post-norm)
|
||||
|
@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
text_only_inference: bool = False,
|
||||
):
|
||||
assert self.cache_is_setup, "Please set up cache before calling forward"
|
||||
self.mask_cache = self.mask_cache.to(h.device)
|
||||
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||
mask = self.mask_cache.index_select(2, position_ids)
|
||||
freqs_cis = self.freqs_cis.index_select(0, position_ids)
|
||||
|
||||
|
@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
output = gather_from_tensor_model_parallel_region(output)
|
||||
return output.float()
|
||||
|
||||
def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16):
|
||||
def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
|
||||
# Set up the text kv caches
|
||||
device = next(self.parameters()).device
|
||||
ones = torch.ones(
|
||||
(self.max_seq_len, self.max_seq_len),
|
||||
dtype=torch.bool,
|
||||
|
@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
|||
|
||||
return (
|
||||
cross_attention_masks.to(device=text_device, dtype=text_dtype),
|
||||
full_text_row_masked_out_mask,
|
||||
full_text_row_masked_out_mask.to(device=text_device),
|
||||
)
|
||||
|
||||
|
||||
|
@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
|
|||
max_num_chunks=args.vision_max_num_chunks,
|
||||
)
|
||||
|
||||
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
|
||||
self.text_model.setup_cache(max_batch_size, dtype)
|
||||
def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
|
||||
self.text_model.setup_cache(max_batch_size, device, dtype)
|
||||
|
||||
def compute_vision_tokens_masks(
|
||||
self,
|
||||
batch_images: List[List[PIL_Image.Image]],
|
||||
batch_masks: List[List[List[int]]],
|
||||
total_len: int,
|
||||
device: torch.device,
|
||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||
skip_vision_encoder = False
|
||||
|
||||
|
@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
|
|||
image_res=self.params.vision_chunk_size,
|
||||
max_num_images=max_num_images,
|
||||
)
|
||||
stacked_images = stacked_images.to(device=device)
|
||||
|
||||
if skip_vision_encoder:
|
||||
vision_tokens = torch.zeros(
|
||||
|
@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
|
|||
),
|
||||
)
|
||||
else:
|
||||
vision_tokens = self.vision_model(stacked_images, aspect_ratios)
|
||||
vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
|
||||
|
||||
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
|
||||
xattn_caches = torch.stack(
|
|
@ -15,7 +15,7 @@ import textwrap
|
|||
from datetime import datetime
|
||||
from typing import Any, List, Optional
|
||||
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from llama_stack.apis.inference import (
|
||||
BuiltinTool,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
|
@ -279,6 +279,10 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
|||
{% endif -%}
|
||||
{%- endfor %}
|
||||
]
|
||||
|
||||
You can answer general questions or invoke tools when necessary.
|
||||
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||
|
||||
"""
|
||||
)
|
||||
return PromptTemplate(
|
||||
|
|
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -4,9 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
# type: ignore
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, cast
|
||||
|
@ -18,22 +15,15 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
|
|||
from torch import Tensor, nn
|
||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||
|
||||
from llama_stack.apis.inference import QuantizationType
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.providers.inline.inference.meta_reference.quantize_impls import (
|
||||
from ...datatypes import QuantizationMode
|
||||
from ...quantize_impls import (
|
||||
Fp8ScaledWeights,
|
||||
ffn_swiglu,
|
||||
load_fp8,
|
||||
quantize_fp8,
|
||||
)
|
||||
|
||||
from ...config import MetaReferenceQuantizedInferenceConfig
|
||||
from ..args import ModelArgs
|
||||
from ..model import Transformer, TransformerBlock
|
||||
|
||||
log = get_logger(__name__, category="quantization")
|
||||
from ..multimodal.model import CrossAttentionTransformer
|
||||
|
||||
|
||||
def swiglu_wrapper(
|
||||
|
@ -44,30 +34,34 @@ def swiglu_wrapper(
|
|||
return reduce_from_model_parallel_region(out)
|
||||
|
||||
|
||||
def convert_to_quantized_model(
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
quantization_mode: Optional[str] = None,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
||||
elif quantization_mode == QuantizationMode.int4_mixed:
|
||||
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
|
||||
|
||||
|
||||
def convert_to_fp8_quantized_model(
|
||||
model: Transformer,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
checkpoint_dir: str,
|
||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Transformer:
|
||||
if config.quantization.type == QuantizationType.bf16.value:
|
||||
return model
|
||||
|
||||
elif config.quantization.type != QuantizationType.fp8.value:
|
||||
raise ValueError("Only FP8 quantization is supported")
|
||||
|
||||
assert config.model is not None, "Model must be specified for quantized inference"
|
||||
llama_model = resolve_model(config.model)
|
||||
assert llama_model is not None, f"Model {config.model} not found"
|
||||
|
||||
# Move weights to GPU with quantization
|
||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
||||
log.info("Loading fp8 scales...")
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||
if os.path.isfile(fp8_scales_path):
|
||||
print("Loading fp8 scales...")
|
||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||
|
||||
for block in model.layers:
|
||||
for _, block in model.named_modules():
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
@ -81,8 +75,8 @@ def convert_to_fp8_quantized_model(
|
|||
fp8_activation_scale_ub,
|
||||
)
|
||||
else:
|
||||
log.info("Quantizing fp8 weights from bf16...")
|
||||
for block in model.layers:
|
||||
print("Quantizing fp8 weights from bf16...")
|
||||
for _, block in model.named_modules():
|
||||
if isinstance(block, TransformerBlock):
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
|
@ -92,12 +86,12 @@ def convert_to_fp8_quantized_model(
|
|||
param.weight = quantize_fp8(
|
||||
param.weight,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
output_device=device,
|
||||
)
|
||||
|
||||
for _, parameter in model.named_parameters():
|
||||
if not isinstance(parameter, Fp8ScaledWeights):
|
||||
parameter.data = parameter.to(device="cuda")
|
||||
parameter.data = parameter.to(device=device)
|
||||
return model
|
||||
|
||||
|
||||
|
@ -290,12 +284,12 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
|||
|
||||
|
||||
def convert_to_int4_quantized_model(
|
||||
model: Transformer,
|
||||
model_args: ModelArgs,
|
||||
config: MetaReferenceQuantizedInferenceConfig,
|
||||
) -> Transformer:
|
||||
model: Transformer | CrossAttentionTransformer,
|
||||
checkpoint_dir: str,
|
||||
device: Optional[torch.device] = None,
|
||||
) -> Transformer | CrossAttentionTransformer:
|
||||
"""Convert the model to int4 quantized model."""
|
||||
|
||||
model_args = model.params
|
||||
assert model_args.quantization_args is not None, "Quantization args must be specified."
|
||||
quantization_args = model_args.quantization_args
|
||||
if quantization_args.scheme is None:
|
||||
|
@ -319,5 +313,4 @@ def convert_to_int4_quantized_model(
|
|||
lora_scale = model_args.lora_args.scale
|
||||
|
||||
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
||||
return cast(Transformer, model.to(device))
|
||||
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
|
|
@ -12,8 +12,7 @@
|
|||
# the top-level of this source tree.
|
||||
|
||||
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
||||
|
||||
from ..datatypes import BuiltinTool, StopReason, ToolCall
|
||||
from .prompt_templates import (
|
||||
BuiltinToolGenerator,
|
||||
JsonCustomToolGenerator,
|
||||
|
|
|
@ -4,16 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
|
|
@ -16,7 +16,8 @@ import re
|
|||
from typing import Optional, Tuple
|
||||
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||
|
||||
logger = get_logger(name=__name__, category="inference")
|
||||
|
||||
|
|
|
@ -3,10 +3,3 @@
|
|||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
|
|
@ -4,12 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
import json
|
||||
import textwrap
|
||||
|
||||
|
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from pathlib import Path
|
||||
|
||||
|
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
|
@ -13,7 +13,7 @@ import torch
|
|||
from PIL import Image as PIL_Image
|
||||
|
||||
# TODO: either fork these or move them to the common package
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
from ..datatypes import (
|
||||
BuiltinTool,
|
||||
RawContent,
|
||||
RawMediaItem,
|
||||
|
@ -24,16 +24,10 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.tool_utils import ToolUtils
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.args import VisionArgs
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
||||
LLMInput,
|
||||
)
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.preprocess import (
|
||||
ResizeNormalizeImageTransform,
|
||||
VariableSizeImageTransform,
|
||||
)
|
||||
|
||||
from ..llama3.tool_utils import ToolUtils
|
||||
from .args import VisionArgs
|
||||
from .datatypes import LLMInput
|
||||
from .preprocess import ResizeNormalizeImageTransform, VariableSizeImageTransform
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
|
||||
|
@ -54,7 +48,7 @@ class TransformedImage:
|
|||
aspect_ratio: Tuple[int, int]
|
||||
|
||||
|
||||
def convert_rgba_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
def convert_image_to_rgb(image: PIL_Image.Image, bg: Tuple[int, int, int] = (255, 255, 255)) -> PIL_Image.Image:
|
||||
if image.mode == "RGBA":
|
||||
image.load() # for png.split()
|
||||
new_img = PIL_Image.new("RGB", image.size, bg)
|
||||
|
@ -171,7 +165,7 @@ class ChatFormat:
|
|||
|
||||
bytes_io = io.BytesIO(c.data) if isinstance(c.data, bytes) else c.data
|
||||
image = PIL_Image.open(bytes_io)
|
||||
image = convert_rgba_to_rgb(image)
|
||||
image = convert_image_to_rgb(image)
|
||||
image_tiles, ar = self.dynamic_image_transform(image, max_num_chunks=self.max_num_chunks)
|
||||
|
||||
if image_tiles.shape[0] > 1:
|
||||
|
@ -216,9 +210,12 @@ class ChatFormat:
|
|||
content = ToolUtils.encode_tool_call(t, tool_prompt_format)
|
||||
_process_content(content)
|
||||
|
||||
# Tool calls and Tool Response messages should be eom
|
||||
eom = False
|
||||
if message.role == "assistant":
|
||||
eom = message.stop_reason == StopReason.end_of_message
|
||||
eom = message.stop_reason == StopReason.end_of_message or message.tool_calls
|
||||
elif message.role == "tool":
|
||||
eom = True
|
||||
|
||||
tokens.append(self.tokenizer.special_tokens["<|eom|>" if eom else "<|eot|>"])
|
||||
return tokens, images
|
||||
|
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from typing import List, Optional, Union
|
||||
|
|
@ -10,40 +10,28 @@ import json
|
|||
import os
|
||||
import sys
|
||||
import time
|
||||
from enum import Enum
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, List, Optional
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.models.llama.llama4.chat_format import (
|
||||
ChatFormat,
|
||||
RawContent,
|
||||
RawMessage,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
|
||||
from ..common import TokenResult
|
||||
from ..checkpoint import maybe_reshard_state_dict
|
||||
from ..datatypes import GenerationResult, QuantizationMode
|
||||
from .args import ModelArgs
|
||||
from .chat_format import ChatFormat, RawContent, RawMessage
|
||||
from .datatypes import LLMInput, MaskedEmbedding, TransformerInput
|
||||
from .model import Transformer
|
||||
from .tokenizer import Tokenizer
|
||||
|
||||
torch.serialization.add_safe_globals([io.BytesIO, codecs.encode])
|
||||
|
||||
|
||||
class QuantizationMode(str, Enum):
|
||||
none = "none"
|
||||
fp8_mixed = "fp8_mixed"
|
||||
int4_mixed = "int4_mixed"
|
||||
|
||||
|
||||
class Llama4:
|
||||
@staticmethod
|
||||
def build(
|
||||
|
@ -51,7 +39,7 @@ class Llama4:
|
|||
max_seq_len: int,
|
||||
max_batch_size: int,
|
||||
world_size: Optional[int] = None,
|
||||
quantization_mode: Optional[str] = None,
|
||||
quantization_mode: Optional[QuantizationMode] = None,
|
||||
seed: int = 1,
|
||||
):
|
||||
if not torch.distributed.is_initialized():
|
||||
|
@ -72,11 +60,9 @@ class Llama4:
|
|||
|
||||
start_time = time.time()
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert world_size == len(checkpoints), (
|
||||
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {world_size}"
|
||||
)
|
||||
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
|
@ -93,10 +79,11 @@ class Llama4:
|
|||
assert model_args.vocab_size == tokenizer.n_words, f"{model_args.vocab_size=} vs. {tokenizer.n_words=} mismatch"
|
||||
print("Model args:\n", model_args.model_dump_json(indent=2))
|
||||
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
print(f"Loading checkpoint from {ckpt_dir}...")
|
||||
with open(ckpt_path, "rb") as f:
|
||||
checkpoint = torch.load(f, map_location="cpu", weights_only=True)
|
||||
state_dict = maybe_reshard_state_dict(
|
||||
ckpt_paths,
|
||||
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||
moe_num_experts=model_args.moe_args.num_experts,
|
||||
)
|
||||
print("Loaded checkpoint")
|
||||
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||
from .quantization.loader import convert_to_quantized_model
|
||||
|
@ -104,9 +91,9 @@ class Llama4:
|
|||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
model = convert_to_quantized_model(model, ckpt_dir)
|
||||
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode)
|
||||
else:
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
|
@ -115,7 +102,7 @@ class Llama4:
|
|||
|
||||
model = Transformer(model_args)
|
||||
print("Loading state dict...")
|
||||
model.load_state_dict(checkpoint, strict=False)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
print("Done...")
|
||||
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
|
||||
|
@ -130,7 +117,7 @@ class Llama4:
|
|||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
llm_input: LLMInput,
|
||||
llm_inputs: List[LLMInput],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
|
@ -138,22 +125,20 @@ class Llama4:
|
|||
echo: bool = False,
|
||||
print_model_input: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator:
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.model.args.max_seq_len:
|
||||
max_gen_len = self.model.args.max_seq_len - 1
|
||||
|
||||
params = self.model.args
|
||||
|
||||
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||
if print_model_input and get_model_parallel_rank() == 0:
|
||||
tokens_to_print = list(llm_input.tokens)
|
||||
cprint(
|
||||
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||
"red",
|
||||
)
|
||||
prompt_tokens = [llm_input.tokens]
|
||||
if print_model_input:
|
||||
cprint("Input to model:\n", "yellow")
|
||||
for inp in llm_inputs:
|
||||
cprint(self.tokenizer.decode(inp.tokens), "grey")
|
||||
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||
|
||||
bsz = 1
|
||||
bsz = len(llm_inputs)
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
|
@ -176,24 +161,33 @@ class Llama4:
|
|||
input_text_mask = tokens != pad_id
|
||||
|
||||
if echo:
|
||||
for i, t in enumerate(llm_input.tokens):
|
||||
yield TokenResult(
|
||||
token=t,
|
||||
text=self.tokenizer.decode([t]),
|
||||
logprobs=(token_logprobs[0, i : i + 1].tolist() if logprobs else None),
|
||||
)
|
||||
for i in range(max_prompt_len):
|
||||
results = []
|
||||
for j, t in enumerate(tokens[:, i]):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="input",
|
||||
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||
batch_idx=j,
|
||||
finished=False,
|
||||
ignore_token=t.item() == pad_id,
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens, device="cuda")
|
||||
|
||||
prev_pos = 0
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
image_embedding = None
|
||||
if prev_pos == 0 and llm_input.images is not None and len(llm_input.images) > 0:
|
||||
if prev_pos == 0 and any(inp.images is not None and len(inp.images) > 0 for inp in llm_inputs):
|
||||
image_mask = tokens[:, prev_pos:cur_pos] == self.tokenizer.special_tokens["<|patch|>"]
|
||||
image_mask = image_mask.unsqueeze(-1)
|
||||
h = self.model.tok_embeddings(tokens[:, prev_pos:cur_pos])
|
||||
|
||||
image_batch = [llm_input.images]
|
||||
image_batch = [inp.images if inp.images is not None else [] for inp in llm_inputs]
|
||||
image_embedding = MaskedEmbedding(
|
||||
embedding=self.model.vision_embeddings(image_batch, image_mask, h),
|
||||
mask=image_mask,
|
||||
|
@ -229,11 +223,21 @@ class Llama4:
|
|||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
yield TokenResult(
|
||||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
|
||||
)
|
||||
|
||||
results = []
|
||||
for idx, t in enumerate(next_token):
|
||||
results.append(
|
||||
GenerationResult(
|
||||
token=t.item(),
|
||||
text=self.tokenizer.decode([t.item()]),
|
||||
source="output",
|
||||
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||
batch_idx=idx,
|
||||
finished=eos_reached[idx],
|
||||
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||
)
|
||||
)
|
||||
yield results
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
|
@ -241,68 +245,47 @@ class Llama4:
|
|||
|
||||
def completion(
|
||||
self,
|
||||
content: RawContent,
|
||||
contents: List[RawContent],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator:
|
||||
llm_input = self.formatter.encode_content(content)
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||
for result in self.generate(
|
||||
llm_input=llm_input,
|
||||
llm_inputs=llm_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
if result.token in self.tokenizer.stop_tokens:
|
||||
break
|
||||
yield result
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
messages_batch: List[List[RawMessage]],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
) -> Generator:
|
||||
llm_input = self.formatter.encode_dialog_prompt(messages)
|
||||
) -> Generator[List[GenerationResult], None, None]:
|
||||
llm_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||
for result in self.generate(
|
||||
llm_input=llm_input,
|
||||
llm_inputs=llm_inputs,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
echo=echo,
|
||||
):
|
||||
if result.token in self.tokenizer.stop_tokens:
|
||||
break
|
||||
yield result
|
||||
|
||||
def chat_completion_raw(
|
||||
self,
|
||||
messages: List[RawMessage],
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
max_gen_len: Optional[int] = None,
|
||||
logprobs: bool = False,
|
||||
):
|
||||
llm_input = self.formatter.encode_dialog_prompt(messages)
|
||||
output_tokens = []
|
||||
for result in self.generate(
|
||||
llm_input=llm_input,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
max_gen_len=max_gen_len,
|
||||
logprobs=logprobs,
|
||||
):
|
||||
output_tokens.append(result.token)
|
||||
|
||||
return llm_input.tokens, output_tokens
|
||||
if all(r.finished for r in result):
|
||||
break
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
|
@ -4,16 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import math
|
||||
from typing import Any, Dict, List, Optional, Tuple
|
||||
|
||||
|
@ -184,7 +174,6 @@ class Attention(nn.Module):
|
|||
self.head_dim,
|
||||
)
|
||||
).cuda()
|
||||
|
||||
self.qk_norm = None
|
||||
if self.use_qk_norm:
|
||||
self.qk_norm = L2Norm(args.norm_eps)
|
|
@ -100,31 +100,21 @@ class Experts(nn.Module):
|
|||
|
||||
class MoE(torch.nn.Module):
|
||||
"""
|
||||
This EC implementation is modified from the original EC module.
|
||||
We refactored the token permutation and unpermutation logic and added support to tp and dp2ep sharding.
|
||||
This module supports 3 sharding methods of the experts:
|
||||
- tp: each TP rank has n_experts experts. Experts are sharded following the conventional row/column-parallel TP sharding.
|
||||
- tp2ep: each TP rank has n_experts/tp experts. Experts are not sharded.
|
||||
- dp2ep: each EP rank has n_experts/ep experts. Experts are sharded following the row/column-parallel TP sharding.
|
||||
Tensors used in this module are annotated with the suffixes that indicate the shape of the tensor.
|
||||
Several commonly used annotations include:
|
||||
- a: bsz*slen
|
||||
- E: number of experts
|
||||
- e: number of local experts per ep (n_experts/ep)
|
||||
- et: number of local experts per tp (n_experts/tp)
|
||||
- D: hidden dimension
|
||||
- d: D/tp
|
||||
- F: model dimension
|
||||
- f: F/tp (used in column/row-parallel linear)
|
||||
- G: number of tokens per expert (a * capacity_factor / E)
|
||||
- g: number of tokens per expert per TP rank (i.e., G/TP)
|
||||
- GG: G*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=False)
|
||||
- gg: g*EP (number of tokens per expert received via inter-EP a2a when ag_along_first_dim=True)
|
||||
|
||||
Examples:
|
||||
x_aD [a, D]
|
||||
routed_in_etG_D [et*G, D]
|
||||
x_eGGD: [e, GG, D]
|
||||
x_eGD: [e, G, D]
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
|
@ -207,13 +197,13 @@ class MoE(torch.nn.Module):
|
|||
routed_in_EG_D = routed_in_EG_D * router_scores.reshape(-1, 1)
|
||||
|
||||
out_aD = self.shared_expert(x_aD)
|
||||
routed_out_egg_D = self.experts(routed_in_EG_D.detach())
|
||||
routed_out_eg_D = self.experts(routed_in_EG_D.detach())
|
||||
|
||||
router_indices_EG_D = router_indices.reshape(-1, 1).expand(-1, D)
|
||||
out_aD.scatter_add_(
|
||||
dim=0,
|
||||
index=router_indices_EG_D,
|
||||
src=routed_out_egg_D.view(-1, D),
|
||||
src=routed_out_eg_D.view(-1, D),
|
||||
)
|
||||
out_aD = reduce_from_model_parallel_region(out_aD)
|
||||
return out_aD.view(-1, slen, D)
|
|
@ -4,20 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import textwrap
|
||||
from io import BytesIO
|
||||
from pathlib import Path
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||
from llama_stack.models.llama.prompt_format import (
|
||||
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||
from ..prompt_format import (
|
||||
Llama4UseCase,
|
||||
TextCompletionContent,
|
||||
UseCase,
|
||||
|
|
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama4/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
|
@ -6,20 +6,29 @@
|
|||
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional
|
||||
from typing import Callable, Optional
|
||||
|
||||
import torch
|
||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||
from torch import Tensor
|
||||
from torch import Tensor, nn
|
||||
from torch.nn import functional as F
|
||||
|
||||
from ..generation import QuantizationMode
|
||||
from ...datatypes import QuantizationMode
|
||||
from ..model import Transformer, TransformerBlock
|
||||
from ..moe import MoE
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def swiglu_wrapper_no_reduce(
|
||||
self,
|
||||
x: Tensor,
|
||||
):
|
||||
from ...quantize_impls import ffn_swiglu
|
||||
|
||||
return ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||
|
||||
|
||||
def experts_batched_swiglu_wrapper(
|
||||
self,
|
||||
x: Tensor, # (e, g, D)
|
||||
|
@ -51,24 +60,30 @@ def convert_to_quantized_model(
|
|||
|
||||
rank = get_model_parallel_rank()
|
||||
|
||||
def should_quantize_block(block: nn.Module) -> bool:
|
||||
if not isinstance(block, TransformerBlock):
|
||||
return False
|
||||
|
||||
is_moe = isinstance(block.feed_forward, MoE)
|
||||
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||
# skip quantization on first and last layers
|
||||
return is_moe and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
|
||||
|
||||
return is_moe
|
||||
|
||||
use_rich_progress = use_rich_progress and rank == 0
|
||||
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model)
|
||||
progress, log_status, update_status = logging_callbacks(use_rich_progress, rank, model, should_quantize_block)
|
||||
if quantization_mode == QuantizationMode.int4_mixed:
|
||||
int4_scales_path = os.path.join(checkpoint_dir, f"int4_scales_{rank}.pt")
|
||||
int4_zero_points_path = os.path.join(checkpoint_dir, f"int4_zero_points_{rank}.pt")
|
||||
if os.path.isfile(int4_scales_path):
|
||||
log_status(f"Rank {rank}: Loading int4 scales")
|
||||
int4_scales = torch.load(int4_scales_path, weights_only=True)
|
||||
int4_zero_points = torch.load(int4_zero_points_path, weights_only=True)
|
||||
|
||||
def apply_quantization(key, weight):
|
||||
scale = int4_scales[key]
|
||||
zero_point = int4_zero_points[key]
|
||||
return load_int4(
|
||||
weight,
|
||||
scale,
|
||||
zero_point,
|
||||
fp8_activation_scale_ub,
|
||||
output_device=torch.device("cuda"),
|
||||
)
|
||||
|
||||
|
@ -77,6 +92,7 @@ def convert_to_quantized_model(
|
|||
|
||||
def apply_quantization(_, weight):
|
||||
return quantize_int4(weight, fp8_activation_scale_ub, output_device=torch.device("cuda"))
|
||||
|
||||
else:
|
||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{rank}.pt")
|
||||
if os.path.isfile(fp8_scales_path):
|
||||
|
@ -104,33 +120,38 @@ def convert_to_quantized_model(
|
|||
progress.start()
|
||||
|
||||
for _, block in model.named_modules():
|
||||
if isinstance(block, TransformerBlock):
|
||||
# Skip quantization on first and last layers
|
||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||
continue
|
||||
if not should_quantize_block(block):
|
||||
continue
|
||||
|
||||
# Skip quantization on dense layers
|
||||
if not isinstance(block.feed_forward, MoE):
|
||||
continue
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id}")
|
||||
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id}")
|
||||
# Quantize only routed experts, not shared
|
||||
prefix = f"layers.{block.layer_id}.feed_forward"
|
||||
moe = block.feed_forward
|
||||
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
|
||||
|
||||
# Quantize only routed experts, not shared
|
||||
prefix = f"layers.{block.layer_id}.feed_forward"
|
||||
moe = block.feed_forward
|
||||
moe.experts.batched_swiglu = experts_batched_swiglu_wrapper.__get__(moe.experts)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(moe.experts, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
|
||||
setattr(
|
||||
moe.experts,
|
||||
key,
|
||||
apply_quantization(
|
||||
f"{prefix}.experts.{key}",
|
||||
param.transpose(1, 2).contiguous(),
|
||||
),
|
||||
)
|
||||
|
||||
if quantization_mode == QuantizationMode.int4_mixed:
|
||||
# Quantize shared experts
|
||||
moe.shared_expert.forward = swiglu_wrapper_no_reduce.__get__(moe.shared_expert)
|
||||
for key in ("w1", "w3", "w2"):
|
||||
param = getattr(moe.experts, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE {key}")
|
||||
setattr(
|
||||
moe.experts,
|
||||
key,
|
||||
apply_quantization(f"{prefix}.experts.{key}", param.transpose(1, 2).contiguous()),
|
||||
)
|
||||
param = getattr(moe.shared_expert, key)
|
||||
update_status(f"Rank {rank} - Layer {block.layer_id} - MoE shared expert {key}")
|
||||
param.weight = apply_quantization(f"{prefix}.shared_expert.{key}", param.weight)
|
||||
|
||||
processed_blocks += 1
|
||||
update_status(message=None, completed=processed_blocks)
|
||||
processed_blocks += 1
|
||||
update_status(message=None, completed=processed_blocks)
|
||||
|
||||
update_status(f"Rank {rank} - Moving parameters to CUDA")
|
||||
|
||||
|
@ -149,7 +170,12 @@ def convert_to_quantized_model(
|
|||
|
||||
|
||||
# fp8/int4 loading can be very slow so we add progress bars to make life slightly better
|
||||
def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
|
||||
def logging_callbacks(
|
||||
use_rich_progress: bool,
|
||||
rank: int,
|
||||
model: Transformer,
|
||||
should_quantize_block: Callable[[nn.Module], bool],
|
||||
):
|
||||
console = None
|
||||
if use_rich_progress:
|
||||
from rich.console import Console
|
||||
|
@ -162,15 +188,7 @@ def logging_callbacks(use_rich_progress: bool, rank: int, model: Transformer):
|
|||
elif rank == 0: # Only log from rank 0 for non-rich logging
|
||||
log.info(message)
|
||||
|
||||
total_blocks = sum(
|
||||
1
|
||||
for _, block in model.named_modules()
|
||||
if (
|
||||
isinstance(block, TransformerBlock)
|
||||
and not (block.layer_id == 0 or block.layer_id == (model.n_layers - 1))
|
||||
and isinstance(block.feed_forward, MoE)
|
||||
)
|
||||
)
|
||||
total_blocks = sum(1 for _, block in model.named_modules() if should_quantize_block(block))
|
||||
progress = None
|
||||
if use_rich_progress:
|
||||
from rich.progress import (
|
|
@ -4,9 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
||||
|
||||
import os
|
||||
from logging import getLogger
|
||||
from pathlib import Path
|
||||
|
@ -59,11 +56,11 @@ LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS = [
|
|||
"<|text_post_train_reserved_special_token_3|>",
|
||||
"<|text_post_train_reserved_special_token_4|>",
|
||||
"<|text_post_train_reserved_special_token_5|>",
|
||||
"<|python_start|>",
|
||||
"<|python_end|>",
|
||||
"<|text_post_train_reserved_special_token_6|>",
|
||||
"<|text_post_train_reserved_special_token_7|>",
|
||||
"<|finetune_right_pad|>",
|
||||
] + get_reserved_special_tokens(
|
||||
"text_post_train", 61, 6
|
||||
"text_post_train", 61, 8
|
||||
) # <|text_post_train_reserved_special_token_6|>, ..., <|text_post_train_reserved_special_token_66|>
|
||||
|
||||
# 200080, ..., 201133
|
||||
|
@ -85,8 +82,23 @@ LLAMA4_VISION_SPECIAL_TOKENS = [
|
|||
"vision", 1041, 7
|
||||
) # <|vision_reserved_special_token_7|>, ..., <|vision_reserved_special_token_1047|>
|
||||
|
||||
# 201134, ..., 201143
|
||||
LLAMA4_REASONING_SPECIAL_TOKENS = [
|
||||
"<|reasoning_reserved_special_token_0|>",
|
||||
"<|reasoning_reserved_special_token_1|>",
|
||||
"<|reasoning_reserved_special_token_2|>",
|
||||
"<|reasoning_reserved_special_token_3|>",
|
||||
"<|reasoning_reserved_special_token_4|>",
|
||||
"<|reasoning_reserved_special_token_5|>",
|
||||
"<|reasoning_reserved_special_token_6|>",
|
||||
"<|reasoning_reserved_special_token_7|>",
|
||||
"<|reasoning_thinking_start|>",
|
||||
"<|reasoning_thinking_end|>",
|
||||
]
|
||||
|
||||
LLAMA4_SPECIAL_TOKENS = LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS
|
||||
LLAMA4_SPECIAL_TOKENS = (
|
||||
LLAMA4_TEXT_POST_TRAIN_SPECIAL_TOKENS + LLAMA4_VISION_SPECIAL_TOKENS + LLAMA4_REASONING_SPECIAL_TOKENS
|
||||
)
|
||||
|
||||
BASIC_SPECIAL_TOKENS = [
|
||||
"<|begin_of_text|>",
|
||||
|
@ -155,6 +167,9 @@ class Tokenizer:
|
|||
self.eot_id: int = self.special_tokens["<|eot|>"]
|
||||
self.eom_id: int = self.special_tokens["<|eom|>"]
|
||||
|
||||
self.thinking_start_id: int = self.special_tokens["<|reasoning_thinking_start|>"]
|
||||
self.thinking_end_id: int = self.special_tokens["<|reasoning_thinking_end|>"]
|
||||
|
||||
self.stop_tokens = [
|
||||
self.eos_id,
|
||||
self.special_tokens["<|eom|>"],
|
||||
|
|
|
@ -4,13 +4,6 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
import math
|
||||
from typing import Any, Callable, Dict, List
|
||||
|
|
@ -28,9 +28,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer
|
||||
from llama_stack.providers.inline.inference.meta_reference.llama4.datatypes import (
|
||||
LLMInput,
|
||||
)
|
||||
|
||||
from .llama3.interface import LLama31Interface
|
||||
from .llama3.template_data import (
|
||||
|
@ -76,21 +73,22 @@ class UseCase(BaseModel):
|
|||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
|
||||
elif isinstance(dialog, TextCompletionContent):
|
||||
input_tokens, output_tokens = generator.text_completion_raw(
|
||||
dialog.content,
|
||||
temperature=0.1,
|
||||
top_p=0.95,
|
||||
max_gen_len=64,
|
||||
)
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
temperature=0.0,
|
||||
top_p=0.95,
|
||||
max_gen_len=self.max_gen_len,
|
||||
batch = [dialog]
|
||||
method = (
|
||||
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||
)
|
||||
input_tokens = []
|
||||
output_tokens = []
|
||||
for token_results in method(batch, echo=True, temperature=0.1, top_p=0.95):
|
||||
result = token_results[0]
|
||||
if result.source == "input":
|
||||
input_tokens.append(result.token)
|
||||
else:
|
||||
output_tokens.append(result.token)
|
||||
|
||||
if result.finished:
|
||||
break
|
||||
text += "##### Input Prompt Format\n"
|
||||
|
||||
# FIXME: This is added to undo the hack in chat_formatter where
|
||||
|
@ -126,27 +124,27 @@ class Llama4UseCase(UseCase):
|
|||
|
||||
text = ""
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
temperature = 0.0
|
||||
for dialog in self.dialogs:
|
||||
if isinstance(dialog, str):
|
||||
text += dialog
|
||||
text += "\n\n"
|
||||
continue
|
||||
|
||||
elif isinstance(dialog, TextCompletionContent):
|
||||
# TODO pass the raw input and do the encoding in the text completion function
|
||||
input_tokens = tokenizer.encode(dialog.content, bos=True, eos=False)
|
||||
llm_input = LLMInput(tokens=input_tokens)
|
||||
output_tokens, decoded_tokens, token_logprobs = generator.text_completion_raw(
|
||||
llm_input, temperature=temperature, max_gen_len=self.max_gen_len
|
||||
)
|
||||
|
||||
else:
|
||||
input_tokens, output_tokens = generator.chat_completion_raw(
|
||||
dialog,
|
||||
temperature=temperature,
|
||||
max_gen_len=self.max_gen_len,
|
||||
batch = [dialog]
|
||||
method = (
|
||||
generator.completion if isinstance(dialog, TextCompletionContent) else generator.chat_completion
|
||||
)
|
||||
input_tokens = []
|
||||
output_tokens = []
|
||||
for token_results in method(batch, echo=True, temperature=0.0):
|
||||
result = token_results[0]
|
||||
if result.source == "input":
|
||||
input_tokens.append(result.token)
|
||||
else:
|
||||
output_tokens.append(result.token)
|
||||
|
||||
if result.finished:
|
||||
break
|
||||
|
||||
text += "##### Input Prompt Format\n"
|
||||
text += _code_block(tokenizer.decode(input_tokens))
|
||||
|
|
|
@ -4,24 +4,15 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# top-level folder for each specific model found within the models/ directory at
|
||||
# the top-level of this source tree.
|
||||
|
||||
from dataclasses import dataclass
|
||||
from functools import lru_cache
|
||||
from typing import List, Optional
|
||||
|
||||
from .datatypes import (
|
||||
from .sku_types import (
|
||||
CheckpointQuantizationFormat,
|
||||
CoreModelId,
|
||||
Model,
|
||||
ModelFamily,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
|
||||
LLAMA2_VOCAB_SIZE = 32000
|
||||
|
@ -47,15 +38,6 @@ def all_registered_models() -> List[Model]:
|
|||
)
|
||||
|
||||
|
||||
def recommended_sampling_params() -> SamplingParams:
|
||||
return SamplingParams(
|
||||
strategy=TopPSamplingStrategy(
|
||||
temperature=1.0,
|
||||
top_p=0.9,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
def llama2_family() -> List[Model]:
|
||||
return [
|
||||
*llama2_base_models(),
|
||||
|
@ -150,7 +132,6 @@ def llama2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_7b,
|
||||
description="Llama 2 7b model",
|
||||
huggingface_repo="meta-llama/Llama-2-7b",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -169,7 +150,6 @@ def llama2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_13b,
|
||||
description="Llama 2 13b model",
|
||||
huggingface_repo="meta-llama/Llama-2-13b",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
|
@ -188,7 +168,6 @@ def llama2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_70b,
|
||||
description="Llama 2 70b model",
|
||||
huggingface_repo="meta-llama/Llama-2-70b",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -230,7 +209,6 @@ def llama3_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_70b,
|
||||
description="Llama 3 70b model",
|
||||
huggingface_repo="meta-llama/Llama-3-70B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -254,7 +232,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_8b,
|
||||
description="Llama 3.1 8b model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-8B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -273,7 +250,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_70b,
|
||||
description="Llama 3.1 70b model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-70B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -293,7 +269,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
variant="bf16-mp8",
|
||||
description="Llama 3.1 405b model (BF16 weights)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -313,7 +288,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
description="Llama 3.1 405b model (FP8 quantized)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-FP8",
|
||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -333,7 +307,6 @@ def llama3_1_base_models() -> List[Model]:
|
|||
variant="bf16-mp16",
|
||||
description="Llama 3.1 405b model (BF16 weights for mp16)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -357,7 +330,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_1b,
|
||||
description="Llama 3.2 1b model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 2048,
|
||||
"n_layers": 16,
|
||||
|
@ -376,7 +348,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_3b,
|
||||
description="Llama 3.2 3b model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 3072,
|
||||
"n_layers": 28,
|
||||
|
@ -395,7 +366,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_11b_vision,
|
||||
description="Llama 3.2 11b vision model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-11B-Vision",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -417,7 +387,6 @@ def llama3_2_base_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_90b_vision,
|
||||
description="Llama 3.2 90b vision model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-90B-Vision",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -444,7 +413,6 @@ def llama2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_7b_chat,
|
||||
description="Llama 2 7b chat model",
|
||||
huggingface_repo="meta-llama/Llama-2-7b-chat",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -463,7 +431,6 @@ def llama2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_13b_chat,
|
||||
description="Llama 2 13b chat model",
|
||||
huggingface_repo="meta-llama/Llama-2-13b-chat",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 5120,
|
||||
"n_layers": 40,
|
||||
|
@ -482,7 +449,6 @@ def llama2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama2_70b_chat,
|
||||
description="Llama 2 70b chat model",
|
||||
huggingface_repo="meta-llama/Llama-2-70b-chat",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -506,7 +472,6 @@ def llama3_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_8b_instruct,
|
||||
description="Llama 3 8b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3-8B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -525,7 +490,6 @@ def llama3_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_70b_instruct,
|
||||
description="Llama 3 70b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3-70B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -549,7 +513,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_8b_instruct,
|
||||
description="Llama 3.1 8b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-8B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -568,7 +531,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_1_70b_instruct,
|
||||
description="Llama 3.1 70b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.1-70B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -588,7 +550,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
variant="bf16-mp8",
|
||||
description="Llama 3.1 405b instruct model (BF16 weights)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -608,7 +569,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
description="Llama 3.1 405b instruct model (FP8 quantized)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct-FP8",
|
||||
quantization_format=CheckpointQuantizationFormat.fp8_mixed,
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -628,7 +588,6 @@ def llama3_1_instruct_models() -> List[Model]:
|
|||
variant="bf16-mp16",
|
||||
description="Llama 3.1 405b instruct model (BF16 weights for mp16)",
|
||||
huggingface_repo="meta-llama/Llama-3.1-405B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 16384,
|
||||
"n_layers": 126,
|
||||
|
@ -684,7 +643,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 1b INT4 quantized LoRA",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-QLORA_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_1b(),
|
||||
"quantization_args": {
|
||||
|
@ -703,7 +661,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 1b INT4 quantized SpinQuant",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct-SpinQuant_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_1b(),
|
||||
"quantization_args": {
|
||||
|
@ -718,7 +675,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 3b INT4 quantized LoRA",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-QLORA_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_3b(),
|
||||
"quantization_args": {
|
||||
|
@ -737,7 +693,6 @@ def llama3_2_quantized_models() -> List[Model]:
|
|||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
description="Llama 3.2 3b INT4 quantized SpinQuant",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct-SpinQuant_INT4_EO8",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
**arch_args_3b(),
|
||||
"quantization_args": {
|
||||
|
@ -755,7 +710,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_1b_instruct,
|
||||
description="Llama 3.2 1b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-1B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args=arch_args_1b(),
|
||||
pth_file_count=1,
|
||||
),
|
||||
|
@ -763,7 +717,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_3b_instruct,
|
||||
description="Llama 3.2 3b instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-3B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args=arch_args_3b(),
|
||||
pth_file_count=1,
|
||||
),
|
||||
|
@ -772,7 +725,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_11b_vision_instruct,
|
||||
description="Llama 3.2 11b vision instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-11B-Vision-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -794,7 +746,6 @@ def llama3_2_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_2_90b_vision_instruct,
|
||||
description="Llama 3.2 90b vision instruct model",
|
||||
huggingface_repo="meta-llama/Llama-3.2-90B-Vision-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -821,7 +772,6 @@ def llama3_3_instruct_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama3_3_70b_instruct,
|
||||
description="Llama 3.3 70b instruct",
|
||||
huggingface_repo="meta-llama/Llama-3.3-70B-Instruct",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 8192,
|
||||
"n_layers": 80,
|
||||
|
@ -846,7 +796,6 @@ def safety_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama_guard_3_11b_vision,
|
||||
description="Llama Guard v3 11b vision system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-3-11B-Vision",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 4096,
|
||||
"n_layers": 32,
|
||||
|
@ -870,7 +819,6 @@ def safety_models() -> List[Model]:
|
|||
description="Llama Guard v3 1b 'int4' quantized system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-3-1B-INT4",
|
||||
quantization_format=CheckpointQuantizationFormat.int4,
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 2048,
|
||||
"n_layers": 12,
|
||||
|
@ -888,7 +836,6 @@ def safety_models() -> List[Model]:
|
|||
core_model_id=CoreModelId.llama_guard_3_1b,
|
||||
description="Llama Guard v3 1b system safety model",
|
||||
huggingface_repo="meta-llama/Llama-Guard-3-1B",
|
||||
recommended_sampling_params=recommended_sampling_params(),
|
||||
arch_args={
|
||||
"dim": 2048,
|
||||
"n_layers": 16,
|
||||
|
|
229
llama_stack/models/llama/sku_types.py
Normal file
229
llama_stack/models/llama/sku_types.py
Normal file
|
@ -0,0 +1,229 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from enum import Enum
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, ConfigDict, Field
|
||||
|
||||
|
||||
class CheckpointQuantizationFormat(Enum):
|
||||
# default format
|
||||
bf16 = "bf16"
|
||||
|
||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
||||
fp8_mixed = "fp8-mixed"
|
||||
|
||||
int8 = "int8"
|
||||
|
||||
int4 = "int4"
|
||||
|
||||
|
||||
class ModelFamily(Enum):
|
||||
llama2 = "llama2"
|
||||
llama3 = "llama3"
|
||||
llama3_1 = "llama3_1"
|
||||
llama3_2 = "llama3_2"
|
||||
llama3_3 = "llama3_3"
|
||||
llama4 = "llama4"
|
||||
safety = "safety"
|
||||
|
||||
|
||||
class CoreModelId(Enum):
|
||||
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
||||
|
||||
# Llama 2 family
|
||||
llama2_7b = "Llama-2-7b"
|
||||
llama2_13b = "Llama-2-13b"
|
||||
llama2_70b = "Llama-2-70b"
|
||||
llama2_7b_chat = "Llama-2-7b-chat"
|
||||
llama2_13b_chat = "Llama-2-13b-chat"
|
||||
llama2_70b_chat = "Llama-2-70b-chat"
|
||||
|
||||
# Llama 3 family
|
||||
llama3_8b = "Llama-3-8B"
|
||||
llama3_70b = "Llama-3-70B"
|
||||
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
||||
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
||||
|
||||
# Llama 3.1 family
|
||||
llama3_1_8b = "Llama3.1-8B"
|
||||
llama3_1_70b = "Llama3.1-70B"
|
||||
llama3_1_405b = "Llama3.1-405B"
|
||||
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
||||
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
||||
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
||||
|
||||
# Llama 3.2 family
|
||||
llama3_2_1b = "Llama3.2-1B"
|
||||
llama3_2_3b = "Llama3.2-3B"
|
||||
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
||||
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
||||
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
||||
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
||||
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
||||
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
||||
|
||||
# Llama 3.3 family
|
||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
||||
|
||||
# Llama 4 family
|
||||
llama4_scout_17b_16e = "Llama-4-Scout-17B-16E"
|
||||
llama4_scout_17b_16e_instruct = "Llama-4-Scout-17B-16E-Instruct"
|
||||
llama4_maverick_17b_128e = "Llama-4-Maverick-17B-128E"
|
||||
llama4_maverick_17b_128e_instruct = "Llama-4-Maverick-17B-128E-Instruct"
|
||||
|
||||
# Safety models
|
||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
||||
|
||||
|
||||
def is_multimodal(model_id) -> bool:
|
||||
if model_id in [
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return True
|
||||
else:
|
||||
return False
|
||||
|
||||
|
||||
def model_family(model_id) -> ModelFamily:
|
||||
if model_id in [
|
||||
CoreModelId.llama2_7b,
|
||||
CoreModelId.llama2_13b,
|
||||
CoreModelId.llama2_70b,
|
||||
CoreModelId.llama2_7b_chat,
|
||||
CoreModelId.llama2_13b_chat,
|
||||
CoreModelId.llama2_70b_chat,
|
||||
]:
|
||||
return ModelFamily.llama2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_8b,
|
||||
CoreModelId.llama3_70b,
|
||||
CoreModelId.llama3_8b_instruct,
|
||||
CoreModelId.llama3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_1_8b,
|
||||
CoreModelId.llama3_1_70b,
|
||||
CoreModelId.llama3_1_405b,
|
||||
CoreModelId.llama3_1_8b_instruct,
|
||||
CoreModelId.llama3_1_70b_instruct,
|
||||
CoreModelId.llama3_1_405b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_1
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_2_1b,
|
||||
CoreModelId.llama3_2_3b,
|
||||
CoreModelId.llama3_2_1b_instruct,
|
||||
CoreModelId.llama3_2_3b_instruct,
|
||||
CoreModelId.llama3_2_11b_vision,
|
||||
CoreModelId.llama3_2_90b_vision,
|
||||
CoreModelId.llama3_2_11b_vision_instruct,
|
||||
CoreModelId.llama3_2_90b_vision_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_2
|
||||
elif model_id in [
|
||||
CoreModelId.llama3_3_70b_instruct,
|
||||
]:
|
||||
return ModelFamily.llama3_3
|
||||
elif model_id in [
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_scout_17b_16e_instruct,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct,
|
||||
]:
|
||||
return ModelFamily.llama4
|
||||
elif model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_2_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return ModelFamily.safety
|
||||
else:
|
||||
raise ValueError(f"Unknown model family for {model_id}")
|
||||
|
||||
|
||||
class Model(BaseModel):
|
||||
core_model_id: CoreModelId
|
||||
description: str
|
||||
huggingface_repo: Optional[str] = None
|
||||
arch_args: Dict[str, Any]
|
||||
variant: str = ""
|
||||
|
||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||
pth_file_count: int
|
||||
metadata: Dict[str, Any] = Field(default_factory=dict)
|
||||
|
||||
# silence pydantic until we remove the `model_` fields
|
||||
model_config = ConfigDict(protected_namespaces=())
|
||||
|
||||
@property
|
||||
def model_family(self) -> ModelFamily:
|
||||
return model_family(self.core_model_id)
|
||||
|
||||
# The SKU is uniquely identified by (model_id, variant) combo
|
||||
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
||||
if not self.variant:
|
||||
return self.core_model_id.value
|
||||
return f"{self.core_model_id.value}:{self.variant}"
|
||||
|
||||
@property
|
||||
def is_instruct_model(self) -> bool:
|
||||
return "instruct" in self.core_model_id.value
|
||||
|
||||
# Featured models are shown in the non-exhaustive model list
|
||||
@property
|
||||
def is_featured(self) -> bool:
|
||||
return self.model_family in [
|
||||
ModelFamily.llama3_1,
|
||||
ModelFamily.llama3_2,
|
||||
ModelFamily.llama3_3,
|
||||
ModelFamily.llama4,
|
||||
ModelFamily.safety,
|
||||
]
|
||||
|
||||
@property
|
||||
def max_seq_length(self) -> int:
|
||||
if self.model_family == ModelFamily.llama2:
|
||||
return 4096
|
||||
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
||||
return 4096
|
||||
elif self.model_family == ModelFamily.llama3:
|
||||
return 8192
|
||||
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama3_2:
|
||||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
||||
return 8192
|
||||
return 131072
|
||||
elif self.model_family == ModelFamily.llama4:
|
||||
if self.core_model_id in {
|
||||
CoreModelId.llama4_scout_17b_16e,
|
||||
CoreModelId.llama4_maverick_17b_128e,
|
||||
}:
|
||||
return 262144
|
||||
if self.core_model_id == CoreModelId.llama4_scout_17b_16e_instruct:
|
||||
return 10485760
|
||||
if self.core_model_id == CoreModelId.llama4_maverick_17b_128e_instruct:
|
||||
return 1048576
|
||||
|
||||
raise AssertionError(f"Unexpected core model id: {self.core_model_id}")
|
||||
elif self.core_model_id in [
|
||||
CoreModelId.llama_guard_3_8b,
|
||||
CoreModelId.llama_guard_3_11b_vision,
|
||||
CoreModelId.llama_guard_3_1b,
|
||||
]:
|
||||
return 131072
|
||||
else:
|
||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
|
@ -52,6 +52,7 @@ from llama_stack.apis.inference import (
|
|||
StopReason,
|
||||
SystemMessage,
|
||||
ToolDefinition,
|
||||
ToolParamDefinition,
|
||||
ToolResponse,
|
||||
ToolResponseMessage,
|
||||
UserMessage,
|
||||
|
@ -63,7 +64,6 @@ from llama_stack.log import get_logger
|
|||
from llama_stack.models.llama.datatypes import (
|
||||
BuiltinTool,
|
||||
ToolCall,
|
||||
ToolParamDefinition,
|
||||
)
|
||||
from llama_stack.providers.utils.kvstore import KVStore
|
||||
from llama_stack.providers.utils.telemetry import tracing
|
||||
|
@ -89,7 +89,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
self,
|
||||
agent_id: str,
|
||||
agent_config: AgentConfig,
|
||||
tempdir: str,
|
||||
inference_api: Inference,
|
||||
safety_api: Safety,
|
||||
tool_runtime_api: ToolRuntime,
|
||||
|
@ -99,7 +98,6 @@ class ChatAgent(ShieldRunnerMixin):
|
|||
):
|
||||
self.agent_id = agent_id
|
||||
self.agent_config = agent_config
|
||||
self.tempdir = tempdir
|
||||
self.inference_api = inference_api
|
||||
self.safety_api = safety_api
|
||||
self.vector_io_api = vector_io_api
|
||||
|
|
|
@ -7,7 +7,6 @@
|
|||
import json
|
||||
import logging
|
||||
import shutil
|
||||
import tempfile
|
||||
import uuid
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
|
@ -64,7 +63,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
self.tool_groups_api = tool_groups_api
|
||||
|
||||
self.in_memory_store = InmemoryKVStoreImpl()
|
||||
self.tempdir = tempfile.mkdtemp()
|
||||
|
||||
async def initialize(self) -> None:
|
||||
self.persistence_store = await kvstore_impl(self.config.persistence_store)
|
||||
|
@ -107,7 +105,6 @@ class MetaReferenceAgentsImpl(Agents):
|
|||
return ChatAgent(
|
||||
agent_id=agent_id,
|
||||
agent_config=agent_config,
|
||||
tempdir=self.tempdir,
|
||||
inference_api=self.inference_api,
|
||||
safety_api=self.safety_api,
|
||||
vector_io_api=self.vector_io_api,
|
||||
|
|
|
@ -4,13 +4,13 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Union
|
||||
from typing import Any, Dict
|
||||
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
|
||||
|
||||
async def get_provider_impl(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
config: MetaReferenceInferenceConfig,
|
||||
_deps: Dict[str, Any],
|
||||
):
|
||||
from .inference import MetaReferenceInferenceImpl
|
||||
|
|
|
@ -5,19 +5,10 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from pathlib import Path
|
||||
from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.distribution.utils.model_utils import model_local_dir
|
||||
|
||||
|
||||
class TokenResult(BaseModel):
|
||||
token: int
|
||||
text: str
|
||||
logprobs: Optional[List[float]] = None
|
||||
|
||||
|
||||
def model_checkpoint_dir(model_id) -> str:
|
||||
checkpoint_dir = Path(model_local_dir(model_id))
|
||||
|
||||
|
|
|
@ -21,6 +21,7 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
torch_seed: Optional[int] = None
|
||||
max_seq_len: int = 4096
|
||||
max_batch_size: int = 1
|
||||
model_parallel_size: Optional[int] = None
|
||||
|
||||
# when this is False, we assume that the distributed process group is setup by someone
|
||||
# outside of this code (e.g., when run inside `torchrun`). that is useful for clients
|
||||
|
@ -31,6 +32,8 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
# can override by specifying the directory explicitly
|
||||
checkpoint_dir: Optional[str] = None
|
||||
|
||||
quantization: Optional[QuantizationConfig] = None
|
||||
|
||||
@field_validator("model")
|
||||
@classmethod
|
||||
def validate_model(cls, model: str) -> str:
|
||||
|
@ -47,27 +50,16 @@ class MetaReferenceInferenceConfig(BaseModel):
|
|||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
quantization_type: str = "${env.QUANTIZATION_TYPE:bf16}",
|
||||
model_parallel_size: str = "${env.MODEL_PARALLEL_SIZE:0}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
return {
|
||||
"model": model,
|
||||
"max_seq_len": 4096,
|
||||
"checkpoint_dir": checkpoint_dir,
|
||||
"quantization": {
|
||||
"type": quantization_type,
|
||||
},
|
||||
"model_parallel_size": model_parallel_size,
|
||||
}
|
||||
|
||||
|
||||
class MetaReferenceQuantizedInferenceConfig(MetaReferenceInferenceConfig):
|
||||
quantization: QuantizationConfig
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(
|
||||
cls,
|
||||
model: str = "Llama3.2-3B-Instruct",
|
||||
checkpoint_dir: str = "${env.CHECKPOINT_DIR:null}",
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
config = super().sample_run_config(model, checkpoint_dir, **kwargs)
|
||||
config["quantization"] = {
|
||||
"type": "fp8",
|
||||
}
|
||||
return config
|
||||
|
|
|
@ -11,19 +11,18 @@ import torch
|
|||
from lmformatenforcer import JsonSchemaParser, TokenEnforcer, TokenEnforcerTokenizerData
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
GreedySamplingStrategy,
|
||||
JsonSchemaResponseFormat,
|
||||
ResponseFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
GreedySamplingStrategy,
|
||||
Model,
|
||||
SamplingParams,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import QuantizationMode
|
||||
from llama_stack.models.llama.llama3.generation import Llama3
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.generation import Llama4
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
|
@ -31,10 +30,8 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
|||
)
|
||||
|
||||
from .common import model_checkpoint_dir
|
||||
from .config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .config import MetaReferenceInferenceConfig
|
||||
from .inference import resolve_model
|
||||
from .llama3.generation import Llama3
|
||||
from .llama4.generation import Llama4
|
||||
|
||||
Tokenizer = Llama4Tokenizer | Llama3Tokenizer
|
||||
|
||||
|
@ -116,10 +113,11 @@ def _infer_tool_prompt_format(request: ChatCompletionRequestWithRawContent):
|
|||
return get_default_tool_prompt_format(request.model)
|
||||
|
||||
|
||||
# TODO: combine Llama3 and Llama4 generators since they are almost identical now
|
||||
class Llama4Generator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
|
@ -134,11 +132,13 @@ class Llama4Generator:
|
|||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
quantization_mode = "fp8_mixed"
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
quantization_mode = "int4_mixed"
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8_mixed":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif config.quantization.type == "int4_mixed":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
|
@ -148,7 +148,7 @@ class Llama4Generator:
|
|||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=llama_model.pth_file_count,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
|
||||
|
@ -166,8 +166,8 @@ class Llama4Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_input=self.formatter.encode_content(request.content),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_content(request.content)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -178,7 +178,8 @@ class Llama4Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
|
@ -190,8 +191,8 @@ class Llama4Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
llm_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
||||
for result in self.inner_generator.generate(
|
||||
llm_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -202,20 +203,46 @@ class Llama4Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
|
||||
class Llama3Generator:
|
||||
def __init__(
|
||||
self,
|
||||
config: MetaReferenceInferenceConfig | MetaReferenceQuantizedInferenceConfig,
|
||||
config: MetaReferenceInferenceConfig,
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
if config.quantization:
|
||||
if config.quantization.type == "fp8_mixed":
|
||||
quantization_mode = QuantizationMode.fp8_mixed
|
||||
elif config.quantization.type == "int4_mixed":
|
||||
quantization_mode = QuantizationMode.int4_mixed
|
||||
elif config.quantization.type == "bf16":
|
||||
quantization_mode = None
|
||||
else:
|
||||
raise ValueError(f"Unsupported quantization mode {config.quantization}")
|
||||
else:
|
||||
quantization_mode = None
|
||||
|
||||
self.inner_generator = Llama3.build(
|
||||
config=config,
|
||||
model_id=model_id,
|
||||
llama_model=llama_model,
|
||||
ckpt_dir=ckpt_dir,
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
world_size=config.model_parallel_size or llama_model.pth_file_count,
|
||||
quantization_mode=quantization_mode,
|
||||
)
|
||||
self.tokenizer = self.inner_generator.tokenizer
|
||||
self.args = self.inner_generator.args
|
||||
|
@ -231,8 +258,8 @@ class Llama3Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
model_input=self.formatter.encode_content(request.content),
|
||||
for result in self.inner_generator.generate(
|
||||
model_inputs=[self.formatter.encode_content(request.content)],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -243,7 +270,8 @@ class Llama3Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
||||
def chat_completion(
|
||||
self,
|
||||
|
@ -255,8 +283,8 @@ class Llama3Generator:
|
|||
max_gen_len = self.args.max_seq_len - 1
|
||||
|
||||
temperature, top_p = _infer_sampling_params(sampling_params)
|
||||
yield from self.inner_generator.generate(
|
||||
model_input=self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request)),
|
||||
for result in self.inner_generator.generate(
|
||||
model_inputs=[self.formatter.encode_dialog_prompt(request.messages, _infer_tool_prompt_format(request))],
|
||||
max_gen_len=max_gen_len,
|
||||
temperature=temperature,
|
||||
top_p=top_p,
|
||||
|
@ -267,4 +295,5 @@ class Llama3Generator:
|
|||
self.args.vocab_size,
|
||||
request.response_format,
|
||||
),
|
||||
)
|
||||
):
|
||||
yield result[0]
|
||||
|
|
|
@ -6,8 +6,11 @@
|
|||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import AsyncGenerator, List, Optional, Union
|
||||
|
||||
from termcolor import cprint
|
||||
|
||||
from llama_stack.apis.common.content_types import (
|
||||
TextDelta,
|
||||
ToolCallDelta,
|
||||
|
@ -28,23 +31,21 @@ from llama_stack.apis.inference import (
|
|||
LogProbConfig,
|
||||
Message,
|
||||
ResponseFormat,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.datatypes import (
|
||||
ModelFamily,
|
||||
SamplingParams,
|
||||
StopReason,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.apis.models import Model, ModelType
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat as Llama3ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer as Llama3Tokenizer
|
||||
from llama_stack.models.llama.llama4.chat_format import ChatFormat as Llama4ChatFormat
|
||||
from llama_stack.models.llama.llama4.tokenizer import Tokenizer as Llama4Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import ModelFamily
|
||||
from llama_stack.providers.datatypes import ModelsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.embedding_mixin import (
|
||||
SentenceTransformerEmbeddingMixin,
|
||||
|
@ -148,7 +149,7 @@ class MetaReferenceInferenceImpl(
|
|||
|
||||
if self.config.create_distributed_process_group:
|
||||
self.generator = LlamaModelParallelGenerator(
|
||||
model_parallel_size=llama_model.pth_file_count,
|
||||
model_parallel_size=self.config.model_parallel_size or llama_model.pth_file_count,
|
||||
builder_fn=builder_fn,
|
||||
builder_params=builder_params,
|
||||
formatter=(
|
||||
|
@ -338,6 +339,9 @@ class MetaReferenceInferenceImpl(
|
|||
stop_reason = None
|
||||
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, "cyan", end="")
|
||||
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if token_result.token == tokenizer.eot_id:
|
||||
|
@ -386,6 +390,9 @@ class MetaReferenceInferenceImpl(
|
|||
ipython = False
|
||||
|
||||
for token_result in self.generator.chat_completion(request):
|
||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||
cprint(token_result.text, "cyan", end="")
|
||||
|
||||
tokens.append(token_result.token)
|
||||
|
||||
if not ipython and token_result.text.startswith("<|python_tag|>"):
|
||||
|
|
|
@ -1,346 +0,0 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
|
||||
import json
|
||||
import os
|
||||
import sys
|
||||
import time
|
||||
from pathlib import Path
|
||||
from typing import Callable, Generator, Optional, Union
|
||||
|
||||
import torch
|
||||
import torch.nn.functional as F
|
||||
from fairscale.nn.model_parallel.initialize import (
|
||||
get_model_parallel_rank,
|
||||
initialize_model_parallel,
|
||||
model_parallel_is_initialized,
|
||||
)
|
||||
|
||||
from llama_stack.apis.inference import (
|
||||
Fp8QuantizationConfig,
|
||||
Int4QuantizationConfig,
|
||||
)
|
||||
from llama_stack.log import get_logger
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat, LLMInput
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
|
||||
from ..common import TokenResult, model_checkpoint_dir
|
||||
from ..config import MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig
|
||||
from .args import ModelArgs
|
||||
from .model import Transformer
|
||||
from .multimodal.model import CrossAttentionTransformer
|
||||
|
||||
log = get_logger(__name__, category="inference")
|
||||
|
||||
|
||||
class Llama3:
|
||||
@staticmethod
|
||||
def build(
|
||||
config: Union[MetaReferenceInferenceConfig, MetaReferenceQuantizedInferenceConfig],
|
||||
model_id: str,
|
||||
llama_model: Model,
|
||||
):
|
||||
"""
|
||||
Build a Llama instance by initializing and loading a model checkpoint.
|
||||
|
||||
Note:
|
||||
This method initializes the distributed process group, sets the device to CUDA,
|
||||
and loads the pre-trained model and tokenizer.
|
||||
"""
|
||||
if "DEVICE" in os.environ:
|
||||
device = os.environ.get("DEVICE")
|
||||
if device == "cuda":
|
||||
assert torch.cuda.is_available(), "PyTorch CUDA backend not available"
|
||||
if device == "xpu":
|
||||
assert torch.xpu.is_available(), "PyTorch XPU backend not available"
|
||||
else:
|
||||
if torch.cuda.is_available():
|
||||
device = "cuda"
|
||||
elif torch.xpu.is_available():
|
||||
device = "xpu"
|
||||
else:
|
||||
device = "cpu"
|
||||
log.info(f"Using {device} device")
|
||||
|
||||
llama_model_id = llama_model.core_model_id.value
|
||||
if not torch.distributed.is_initialized():
|
||||
if device == "cuda":
|
||||
torch.distributed.init_process_group("nccl")
|
||||
else:
|
||||
torch.distributed.init_process_group("gloo")
|
||||
|
||||
model_parallel_size = llama_model.pth_file_count
|
||||
|
||||
if not model_parallel_is_initialized():
|
||||
initialize_model_parallel(model_parallel_size)
|
||||
|
||||
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||
if device == "cuda":
|
||||
torch.cuda.set_device(local_rank)
|
||||
elif device == "xpu":
|
||||
torch.xpu.set_device(local_rank)
|
||||
|
||||
# seed must be the same in all processes
|
||||
if config.torch_seed is not None:
|
||||
torch.manual_seed(config.torch_seed)
|
||||
|
||||
if local_rank > 0:
|
||||
sys.stdout = open(os.devnull, "w")
|
||||
|
||||
start_time = time.time()
|
||||
if config.checkpoint_dir and config.checkpoint_dir != "null":
|
||||
ckpt_dir = config.checkpoint_dir
|
||||
else:
|
||||
resolved_model = resolve_model(model_id)
|
||||
if resolved_model is None:
|
||||
# if the model is not a native llama model, get the default checkpoint_dir based on model id
|
||||
ckpt_dir = model_checkpoint_dir(model_id)
|
||||
else:
|
||||
# if the model is a native llama model, get the default checkpoint_dir based on model core_model_id value
|
||||
ckpt_dir = model_checkpoint_dir(resolved_model.descriptor())
|
||||
|
||||
checkpoints = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||
assert len(checkpoints) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||
assert model_parallel_size == len(checkpoints), (
|
||||
f"Loading a checkpoint for MP={len(checkpoints)} but world size is {model_parallel_size}"
|
||||
)
|
||||
ckpt_path = checkpoints[get_model_parallel_rank()]
|
||||
state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
||||
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||
params = json.loads(f.read())
|
||||
|
||||
if "model" in params:
|
||||
params = params["model"]
|
||||
|
||||
model_args: ModelArgs = ModelArgs(
|
||||
max_seq_len=config.max_seq_len,
|
||||
max_batch_size=config.max_batch_size,
|
||||
**params,
|
||||
)
|
||||
|
||||
tokenizer = Tokenizer.get_instance()
|
||||
assert model_args.vocab_size == tokenizer.n_words, (
|
||||
f"model_args vocab = {model_args.vocab_size} but tokenizer vocab = {tokenizer.n_words}"
|
||||
)
|
||||
|
||||
if isinstance(config, MetaReferenceQuantizedInferenceConfig):
|
||||
if isinstance(config.quantization, Fp8QuantizationConfig):
|
||||
from .quantization.loader import convert_to_fp8_quantized_model
|
||||
|
||||
# load on CPU in bf16 so that fp8 conversion does not find an
|
||||
# unexpected (fp32, e.g.) datatype
|
||||
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
model = convert_to_fp8_quantized_model(model, config, ckpt_dir)
|
||||
elif isinstance(config.quantization, Int4QuantizationConfig):
|
||||
from .quantization.loader import convert_to_int4_quantized_model
|
||||
|
||||
model = Transformer(model_args)
|
||||
model = convert_to_int4_quantized_model(model, model_args, config)
|
||||
model.load_state_dict(state_dict, strict=True)
|
||||
|
||||
if model_args.quantization_args is not None and model_args.quantization_args.spinquant:
|
||||
# Add a wrapper for adding hadamard transform for spinquant.
|
||||
# This needs to be done after loading the state dict otherwise an error will be raised while
|
||||
# loading the state dict.
|
||||
from ..hadamard_utils import (
|
||||
add_hadamard_transform_for_spinquant,
|
||||
)
|
||||
|
||||
add_hadamard_transform_for_spinquant(model)
|
||||
else:
|
||||
raise NotImplementedError("Currently int4 and fp8 are the only supported quantization methods.")
|
||||
else:
|
||||
if device == "cuda":
|
||||
if torch.cuda.is_bf16_supported():
|
||||
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||
else:
|
||||
torch.set_default_tensor_type(torch.cuda.HalfTensor)
|
||||
else:
|
||||
torch.set_default_device(device)
|
||||
if device == "xpu" and torch.xpu.is_bf16_supported():
|
||||
torch.set_default_dtype(torch.bfloat16)
|
||||
else:
|
||||
torch.set_default_dtype(torch.half)
|
||||
if model_args.vision_chunk_size > 0:
|
||||
model = CrossAttentionTransformer(model_args)
|
||||
model.setup_cache(model_args.max_batch_size, torch.bfloat16)
|
||||
else:
|
||||
model = Transformer(model_args)
|
||||
model.load_state_dict(state_dict, strict=False)
|
||||
|
||||
model.to(device)
|
||||
|
||||
log.info(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||
return Llama3(model, tokenizer, model_args, llama_model_id)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
model: Transformer,
|
||||
tokenizer: Tokenizer,
|
||||
args: ModelArgs,
|
||||
llama_model: str,
|
||||
):
|
||||
self.args = args
|
||||
self.model = model
|
||||
self.tokenizer = tokenizer
|
||||
self.formatter = ChatFormat(tokenizer)
|
||||
self.llama_model = llama_model
|
||||
|
||||
@torch.inference_mode()
|
||||
def generate(
|
||||
self,
|
||||
model_input: LLMInput,
|
||||
max_gen_len: int,
|
||||
temperature: float = 0.6,
|
||||
top_p: float = 0.9,
|
||||
logprobs: bool = False,
|
||||
echo: bool = False,
|
||||
print_input_tokens: bool = False,
|
||||
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||
) -> Generator:
|
||||
params = self.model.params
|
||||
|
||||
if print_input_tokens:
|
||||
input_tokens = [self.formatter.vision_token if t == 128256 else t for t in model_input.tokens]
|
||||
log.info("Input to model -> " + self.tokenizer.decode(input_tokens))
|
||||
prompt_tokens = [model_input.tokens]
|
||||
|
||||
bsz = 1
|
||||
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||
|
||||
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||
|
||||
if max_prompt_len >= params.max_seq_len:
|
||||
log.error(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}")
|
||||
return
|
||||
|
||||
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||
|
||||
is_vision = isinstance(self.model, CrossAttentionTransformer)
|
||||
if is_vision:
|
||||
images = model_input.vision.images if model_input.vision is not None else []
|
||||
mask = model_input.vision.mask if model_input.vision is not None else []
|
||||
|
||||
# the method works for bsz > 1 so add a batch dimension
|
||||
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||
batch_images=[images],
|
||||
batch_masks=[mask],
|
||||
total_len=total_len,
|
||||
)
|
||||
|
||||
pad_id = self.tokenizer.pad_id
|
||||
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
||||
for k, t in enumerate(prompt_tokens):
|
||||
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
||||
if logprobs:
|
||||
token_logprobs = torch.zeros_like(tokens)
|
||||
|
||||
prev_pos = 0
|
||||
eos_reached = torch.tensor([False] * bsz)
|
||||
input_text_mask = tokens != pad_id
|
||||
if min_prompt_len == total_len:
|
||||
# TODO(ashwin): unify this branch with the one below and figure out multimodal crap
|
||||
logits = self.model.forward(tokens, prev_pos)
|
||||
token_logprobs = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=tokens,
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
|
||||
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||
for cur_pos in range(min_prompt_len, total_len):
|
||||
if is_vision:
|
||||
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||
logits = self.model.forward(
|
||||
position_ids,
|
||||
tokens,
|
||||
cross_attention_masks,
|
||||
full_text_row_masked_out_mask,
|
||||
xattn_caches,
|
||||
)
|
||||
else:
|
||||
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||
|
||||
if logits_processor is not None:
|
||||
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||
|
||||
if temperature > 0:
|
||||
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||
next_token = sample_top_p(probs, top_p)
|
||||
else:
|
||||
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||
|
||||
next_token = next_token.reshape(-1)
|
||||
# only replace token if prompt has already been generated
|
||||
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||
tokens[:, cur_pos] = next_token
|
||||
|
||||
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||
if is_vision:
|
||||
# the logits space (num_classes) is designed to never contain a media_token
|
||||
# however our input token stream does contain them. we need to nuke them here
|
||||
# or else the CUDA kernels will crash with an illegal memory access
|
||||
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
||||
masks = [target.eq(t) for t in vision_tokens]
|
||||
if len(masks) > 1:
|
||||
mask = torch.logical_or(*masks)
|
||||
else:
|
||||
mask = masks[0]
|
||||
target[mask] = 0
|
||||
|
||||
if logprobs:
|
||||
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||
input=logits.transpose(1, 2),
|
||||
target=tokens[:, prev_pos + 1 : cur_pos + 1],
|
||||
reduction="none",
|
||||
ignore_index=pad_id,
|
||||
)
|
||||
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||
yield TokenResult(
|
||||
token=next_token[0].item(),
|
||||
text=self.tokenizer.decode(next_token.tolist()),
|
||||
logprobs=(token_logprobs[:, cur_pos : cur_pos + 1][0].tolist() if logprobs else None),
|
||||
)
|
||||
|
||||
prev_pos = cur_pos
|
||||
if all(eos_reached):
|
||||
break
|
||||
|
||||
|
||||
def sample_top_p(probs, p):
|
||||
"""
|
||||
Perform top-p (nucleus) sampling on a probability distribution.
|
||||
|
||||
Args:
|
||||
probs (torch.Tensor): Probability distribution tensor.
|
||||
p (float): Probability threshold for top-p sampling.
|
||||
|
||||
Returns:
|
||||
torch.Tensor: Sampled token indices.
|
||||
|
||||
Note:
|
||||
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||
"""
|
||||
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||
mask = probs_sum - probs_sort > p
|
||||
probs_sort[mask] = 0.0
|
||||
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||
next_token = torch.gather(probs_idx, -1, next_token)
|
||||
return next_token
|
|
@ -32,13 +32,12 @@ from pydantic import BaseModel, Field
|
|||
from torch.distributed.launcher.api import LaunchConfig, elastic_launch
|
||||
from typing_extensions import Annotated
|
||||
|
||||
from llama_stack.models.llama.datatypes import GenerationResult
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
ChatCompletionRequestWithRawContent,
|
||||
CompletionRequestWithRawContent,
|
||||
)
|
||||
|
||||
from .common import TokenResult
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
||||
|
@ -75,7 +74,7 @@ class TaskRequest(BaseModel):
|
|||
|
||||
class TaskResponse(BaseModel):
|
||||
type: Literal[ProcessingMessageName.task_response] = ProcessingMessageName.task_response
|
||||
result: TokenResult
|
||||
result: GenerationResult
|
||||
|
||||
|
||||
class ExceptionResponse(BaseModel):
|
||||
|
|
|
@ -14,9 +14,10 @@ from llama_stack.apis.inference import (
|
|||
JsonSchemaResponseFormat,
|
||||
Message,
|
||||
ToolChoice,
|
||||
ToolDefinition,
|
||||
UserMessage,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool, ToolDefinition
|
||||
from llama_stack.models.llama.datatypes import BuiltinTool
|
||||
from llama_stack.providers.utils.inference.openai_compat import (
|
||||
convert_message_to_openai_dict,
|
||||
get_sampling_options,
|
||||
|
|
|
@ -46,6 +46,8 @@ from llama_stack.apis.inference import (
|
|||
TokenLogProbs,
|
||||
ToolChoice,
|
||||
ToolConfig,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.apis.models import Model
|
||||
from llama_stack.log import get_logger
|
||||
|
@ -55,8 +57,6 @@ from llama_stack.models.llama.datatypes import (
|
|||
ToolCall,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
TopPSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||
|
|
|
@ -22,8 +22,8 @@ from torchtune.models.llama3_2 import lora_llama3_2_3b
|
|||
from torchtune.modules.transforms import Transform
|
||||
|
||||
from llama_stack.apis.post_training import DatasetFormat
|
||||
from llama_stack.models.llama.datatypes import Model
|
||||
from llama_stack.models.llama.sku_list import resolve_model
|
||||
from llama_stack.models.llama.sku_types import Model
|
||||
|
||||
BuildLoraModelCallable = Callable[..., torch.nn.Module]
|
||||
BuildTokenizerCallable = Callable[..., Llama3Tokenizer]
|
||||
|
|
|
@ -23,7 +23,8 @@ from llama_stack.apis.safety import (
|
|||
)
|
||||
from llama_stack.apis.shields import Shield
|
||||
from llama_stack.distribution.datatypes import Api
|
||||
from llama_stack.models.llama.datatypes import CoreModelId, Role
|
||||
from llama_stack.models.llama.datatypes import Role
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.datatypes import ShieldsProtocolPrivate
|
||||
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||
interleaved_content_as_str,
|
||||
|
|
|
@ -24,6 +24,8 @@ META_REFERENCE_DEPS = [
|
|||
"zmq",
|
||||
"lm-format-enforcer",
|
||||
"sentence-transformers",
|
||||
"torchao==0.5.0",
|
||||
"fbgemm-gpu-genai==1.1.2",
|
||||
]
|
||||
|
||||
|
||||
|
@ -36,13 +38,6 @@ def available_providers() -> List[ProviderSpec]:
|
|||
module="llama_stack.providers.inline.inference.meta_reference",
|
||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::meta-reference-quantized",
|
||||
pip_packages=META_REFERENCE_DEPS + ["fbgemm-gpu", "torchao==0.5.0"],
|
||||
module="llama_stack.providers.inline.inference.meta_reference",
|
||||
config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceQuantizedInferenceConfig",
|
||||
),
|
||||
InlineProviderSpec(
|
||||
api=Api.inference,
|
||||
provider_type="inline::vllm",
|
||||
|
@ -222,6 +217,56 @@ def available_providers() -> List[ProviderSpec]:
|
|||
provider_data_validator="llama_stack.providers.remote.inference.groq.config.GroqProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="fireworks-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.fireworks_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.fireworks_openai_compat.config.FireworksProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="together-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.together_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.together_openai_compat.config.TogetherProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="groq-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.groq_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.groq_openai_compat.config.GroqProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="sambanova-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.sambanova_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.sambanova_openai_compat.config.SambaNovaProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="cerebras-openai-compat",
|
||||
pip_packages=["litellm"],
|
||||
module="llama_stack.providers.remote.inference.cerebras_openai_compat",
|
||||
config_class="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasCompatConfig",
|
||||
provider_data_validator="llama_stack.providers.remote.inference.cerebras_openai_compat.config.CerebrasProviderDataValidator",
|
||||
),
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.inference,
|
||||
adapter=AdapterSpec(
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
|
|
@ -28,8 +28,8 @@ from llama_stack.apis.inference import (
|
|||
ToolConfig,
|
||||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
TopKSamplingStrategy,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import TopKSamplingStrategy
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
)
|
||||
|
|
|
@ -4,7 +4,7 @@
|
|||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import CerebrasCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: CerebrasCompatConfig, _deps) -> Inference:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .cerebras import CerebrasCompatInferenceAdapter
|
||||
|
||||
adapter = CerebrasCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.cerebras_openai_compat.config import CerebrasCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..cerebras.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class CerebrasCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: CerebrasCompatConfig
|
||||
|
||||
def __init__(self, config: CerebrasCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="cerebras_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class CerebrasProviderDataValidator(BaseModel):
|
||||
cerebras_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Cerebras models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class CerebrasCompatConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Cerebras API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.cerebras.ai/v1",
|
||||
description="The URL for the Cerebras API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.CEREBRAS_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.cerebras.ai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -28,7 +28,7 @@ from llama_stack.apis.inference import (
|
|||
ToolDefinition,
|
||||
ToolPromptFormat,
|
||||
)
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ModelRegistryHelper,
|
||||
build_hf_repo_model_entry,
|
||||
|
|
|
@ -5,7 +5,7 @@
|
|||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.models.models import ModelType
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.models.llama.sku_types import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
|
@ -48,6 +48,14 @@ MODEL_ENTRIES = [
|
|||
"accounts/fireworks/models/llama-guard-3-11b-vision",
|
||||
CoreModelId.llama_guard_3_11b_vision.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama4-scout-instruct-basic",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"accounts/fireworks/models/llama4-maverick-instruct-basic",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
ProviderModelEntry(
|
||||
provider_model_id="nomic-ai/nomic-embed-text-v1.5",
|
||||
model_type=ModelType.embedding,
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import FireworksCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: FireworksCompatConfig, _deps) -> Inference:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .fireworks import FireworksCompatInferenceAdapter
|
||||
|
||||
adapter = FireworksCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class FireworksProviderDataValidator(BaseModel):
|
||||
fireworks_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Fireworks models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class FireworksCompatConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Fireworks API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.fireworks.ai/inference/v1",
|
||||
description="The URL for the Fireworks API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.FIREWORKS_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.fireworks.ai/inference/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.fireworks_openai_compat.config import FireworksCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..fireworks.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class FireworksCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: FireworksCompatConfig
|
||||
|
||||
def __init__(self, config: FireworksCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="fireworks_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
|
@ -35,4 +35,12 @@ MODEL_ENTRIES = [
|
|||
"groq/llama-3.2-3b-preview",
|
||||
CoreModelId.llama3_2_3b_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-scout-17b-16e-instruct",
|
||||
CoreModelId.llama4_scout_17b_16e_instruct.value,
|
||||
),
|
||||
build_hf_repo_model_entry(
|
||||
"groq/llama-4-maverick-17b-128e-instruct",
|
||||
CoreModelId.llama4_maverick_17b_128e_instruct.value,
|
||||
),
|
||||
]
|
||||
|
|
|
@ -0,0 +1,17 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.apis.inference import Inference
|
||||
|
||||
from .config import GroqCompatConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(config: GroqCompatConfig, _deps) -> Inference:
|
||||
# import dynamically so the import is used only when it is needed
|
||||
from .groq import GroqCompatInferenceAdapter
|
||||
|
||||
adapter = GroqCompatInferenceAdapter(config)
|
||||
return adapter
|
|
@ -0,0 +1,38 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from llama_stack.schema_utils import json_schema_type
|
||||
|
||||
|
||||
class GroqProviderDataValidator(BaseModel):
|
||||
groq_api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="API key for Groq models",
|
||||
)
|
||||
|
||||
|
||||
@json_schema_type
|
||||
class GroqCompatConfig(BaseModel):
|
||||
api_key: Optional[str] = Field(
|
||||
default=None,
|
||||
description="The Groq API key",
|
||||
)
|
||||
|
||||
openai_compat_api_base: str = Field(
|
||||
default="https://api.groq.com/openai/v1",
|
||||
description="The URL for the Groq API server",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, api_key: str = "${env.GROQ_API_KEY}", **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"openai_compat_api_base": "https://api.groq.com/openai/v1",
|
||||
"api_key": api_key,
|
||||
}
|
|
@ -0,0 +1,30 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.remote.inference.groq_openai_compat.config import GroqCompatConfig
|
||||
from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin
|
||||
|
||||
from ..groq.models import MODEL_ENTRIES
|
||||
|
||||
|
||||
class GroqCompatInferenceAdapter(LiteLLMOpenAIMixin):
|
||||
_config: GroqCompatConfig
|
||||
|
||||
def __init__(self, config: GroqCompatConfig):
|
||||
LiteLLMOpenAIMixin.__init__(
|
||||
self,
|
||||
model_entries=MODEL_ENTRIES,
|
||||
api_key_from_config=config.api_key,
|
||||
provider_data_api_key_field="groq_api_key",
|
||||
openai_compat_api_base=config.openai_compat_api_base,
|
||||
)
|
||||
self.config = config
|
||||
|
||||
async def initialize(self):
|
||||
await super().initialize()
|
||||
|
||||
async def shutdown(self):
|
||||
await super().shutdown()
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue