Merge branch 'meta-llama:main' into main

This commit is contained in:
Chacksu 2024-11-20 16:12:38 -05:00 committed by GitHub
commit edfd92d81f
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
339 changed files with 16825 additions and 5865 deletions

View file

@ -4,7 +4,8 @@ In short, provide a summary of what this PR does and why. Usually, the relevant
- [ ] Addresses issue (#issue) - [ ] Addresses issue (#issue)
## Feature/Issue validation/testing/test plan
## Test Plan
Please describe: Please describe:
- tests you ran to verify your changes with result summaries. - tests you ran to verify your changes with result summaries.

View file

@ -57,3 +57,17 @@ repos:
# hooks: # hooks:
# - id: markdown-link-check # - id: markdown-link-check
# args: ['--quiet'] # args: ['--quiet']
# - repo: local
# hooks:
# - id: distro-codegen
# name: Distribution Template Codegen
# additional_dependencies:
# - rich
# - pydantic
# entry: python -m llama_stack.scripts.distro_codegen
# language: python
# pass_filenames: false
# require_serial: true
# files: ^llama_stack/templates/.*$
# stages: [manual]

35
CHANGELOG.md Normal file
View file

@ -0,0 +1,35 @@
# Changelog
## 0.0.53
### Added
- Resource-oriented design for models, shields, memory banks, datasets and eval tasks
- Persistence for registered objects with distribution
- Ability to persist memory banks created for FAISS
- PostgreSQL KVStore implementation
- Environment variable placeholder support in run.yaml files
- Comprehensive Zero-to-Hero notebooks and quickstart guides
- Support for quantized models in Ollama
- Vision models support for Together, Fireworks, Meta-Reference, and Ollama, and vLLM
- Bedrock distribution with safety shields support
- Evals API with task registration and scoring functions
- MMLU and SimpleQA benchmark scoring functions
- Huggingface dataset provider integration for benchmarks
- Support for custom dataset registration from local paths
- Benchmark evaluation CLI tools with visualization tables
- RAG evaluation scoring functions and metrics
- Local persistence for datasets and eval tasks
### Changed
- Split safety into distinct providers (llama-guard, prompt-guard, code-scanner)
- Changed provider naming convention (`impls``inline`, `adapters``remote`)
- Updated API signatures for dataset and eval task registration
- Restructured folder organization for providers
- Enhanced Docker build configuration
- Added version prefixing for REST API routes
- Enhanced evaluation task registration workflow
- Improved benchmark evaluation output formatting
- Restructured evals folder organization for better modularity
### Removed
- `llama stack configure` command

View file

@ -12,6 +12,11 @@ We actively welcome your pull requests.
5. Make sure your code lints. 5. Make sure your code lints.
6. If you haven't already, complete the Contributor License Agreement ("CLA"). 6. If you haven't already, complete the Contributor License Agreement ("CLA").
### Updating Provider Configurations
If you have made changes to a provider's configuration in any form (introducing a new config key, or changing models, etc.), you should run `python llama_stack/scripts/distro_codegen.py` to re-generate various YAML files as well as the documentation. You should not change `docs/source/.../distributions/` files manually as they are auto-generated.
### Building the Documentation ### Building the Documentation
If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme. If you are making changes to the documentation at [https://llama-stack.readthedocs.io/en/latest/](https://llama-stack.readthedocs.io/en/latest/), you can use the following command to build the documentation and preview your changes. You will need [Sphinx](https://www.sphinx-doc.org/en/master/) and the readthedocs theme.
@ -22,9 +27,23 @@ pip install -r requirements.txt
pip install sphinx-autobuild pip install sphinx-autobuild
# This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation. # This will start a local server (usually at http://127.0.0.1:8000) that automatically rebuilds and refreshes when you make changes to the documentation.
make html
sphinx-autobuild source build/html sphinx-autobuild source build/html
``` ```
## Pre-commit Hooks
We use [pre-commit](https://pre-commit.com/) to run linting and formatting checks on your code. You can install the pre-commit hooks by running:
```bash
$ cd llama-stack
$ conda activate <your-environment>
$ pip install pre-commit
$ pre-commit install
```
After that, pre-commit hooks will run automatically before each commit.
## Contributor License Agreement ("CLA") ## Contributor License Agreement ("CLA")
In order to accept your pull request, we need you to submit a CLA. You only need In order to accept your pull request, we need you to submit a CLA. You only need
to do this once to work on any of Meta's open source projects. to do this once to work on any of Meta's open source projects.

View file

@ -1,4 +1,5 @@
include requirements.txt include requirements.txt
include distributions/dependencies.json
include llama_stack/distribution/*.sh include llama_stack/distribution/*.sh
include llama_stack/cli/scripts/*.sh include llama_stack/cli/scripts/*.sh
include llama_stack/templates/*/build.yaml include llama_stack/templates/*/*.yaml

View file

@ -101,6 +101,7 @@ Please checkout our [Documentations](https://llama-stack.readthedocs.io/en/lates
* [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) * [Getting Started](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html)
* Quick guide to start a Llama Stack server. * Quick guide to start a Llama Stack server.
* [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs * [Jupyter notebook](./docs/getting_started.ipynb) to walk-through how to use simple text and vision inference llama_stack_client APIs
* The complete Llama Stack lesson [Colab notebook](https://colab.research.google.com/drive/1dtVmxotBsI4cGZQNsJRYPrLiDeT0Wnwt) of the new [Llama 3.2 course on Deeplearning.ai](https://learn.deeplearning.ai/courses/introducing-multimodal-llama-3-2/lesson/8/llama-stack).
* [Contributing](CONTRIBUTING.md) * [Contributing](CONTRIBUTING.md)
* [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/api_providers/new_api_provider.html) to walk-through how to add a new API provider. * [Adding a new API Provider](https://llama-stack.readthedocs.io/en/latest/api_providers/new_api_provider.html) to walk-through how to add a new API provider.
@ -111,7 +112,7 @@ Please checkout our [Documentations](https://llama-stack.readthedocs.io/en/lates
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/) | Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift) | Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client) | Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | | Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications. Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.

View file

@ -1,5 +1,4 @@
version: '2' version: '2'
built_at: '2024-11-01T17:40:45.325529'
image_name: local image_name: local
name: bedrock name: bedrock
docker_image: null docker_image: null
@ -23,7 +22,7 @@ providers:
region_name: <AWS_REGION> region_name: <AWS_REGION>
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: {} config: {}
safety: safety:
- provider_id: bedrock0 - provider_id: bedrock0
@ -35,12 +34,12 @@ providers:
region_name: <AWS_REGION> region_name: <AWS_REGION>
agents: agents:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: config:
persistence_store: persistence_store:
type: sqlite type: sqlite
db_path: ~/.llama/runtime/kvstore.db db_path: ~/.llama/runtime/kvstore.db
telemetry: telemetry:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: {} config: {}

View file

@ -1,5 +1,4 @@
version: '2' version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local image_name: local
docker_image: null docker_image: null
conda_env: local conda_env: local
@ -19,22 +18,21 @@ providers:
url: http://127.0.0.1:80 url: http://127.0.0.1:80
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::faiss
config: {} config: {}
agents: agents:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: config:
persistence_store: persistence_store:
namespace: null namespace: null
@ -42,5 +40,5 @@ providers:
db_path: ~/.llama/runtime/kvstore.db db_path: ~/.llama/runtime/kvstore.db
telemetry: telemetry:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: {} config: {}

View file

@ -0,0 +1,171 @@
{
"together": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"together",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"remote-vllm": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"openai",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"fireworks": [
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"fireworks-ai",
"httpx",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"tgi": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"huggingface_hub",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"meta-reference-gpu": [
"accelerate",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"fairscale",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"lm-format-enforcer",
"matplotlib",
"nltk",
"numpy",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"torch",
"torchvision",
"tqdm",
"transformers",
"uvicorn",
"zmq",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
],
"ollama": [
"aiohttp",
"aiosqlite",
"blobfile",
"chardet",
"chromadb-client",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"matplotlib",
"nltk",
"numpy",
"ollama",
"pandas",
"pillow",
"psycopg2-binary",
"pypdf",
"redis",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"uvicorn",
"sentence-transformers --no-deps",
"torch --index-url https://download.pytorch.org/whl/cpu"
]
}

View file

@ -1,51 +0,0 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: fireworks0
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference
# api_key: <ENTER_YOUR_API_KEY>
safety:
- provider_id: meta0
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: meta-reference
config: {}
# Uncomment to use weaviate memory provider
# - provider_id: weaviate0
# provider_type: remote::weaviate
# config: {}
agents:
- provider_id: meta0
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta0
provider_type: meta-reference
config: {}

View file

@ -0,0 +1 @@
../../llama_stack/templates/fireworks/run.yaml

View file

@ -0,0 +1 @@
../../llama_stack/templates/inline-vllm/build.yaml

View file

@ -0,0 +1,35 @@
services:
llamastack:
image: llamastack/distribution-inline-vllm
network_mode: "host"
volumes:
- ~/.llama:/root/.llama
- ./run.yaml:/root/my-run.yaml
ports:
- "5000:5000"
devices:
- nvidia.com/gpu=all
environment:
- CUDA_VISIBLE_DEVICES=0
command: []
deploy:
resources:
reservations:
devices:
- driver: nvidia
# that's the closest analogue to --gpus; provide
# an integer amount of devices or 'all'
count: 1
# Devices are reserved using a list of capabilities, making
# capabilities the only required field. A device MUST
# satisfy all the requested capabilities for a successful
# reservation.
capabilities: [gpu]
runtime: nvidia
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
deploy:
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s

View file

@ -0,0 +1,66 @@
version: '2'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: vllm-inference
provider_type: inline::vllm
config:
model: Llama3.2-3B-Instruct
tensor_parallel_size: 1
gpu_memory_utilization: 0.4
enforce_eager: true
max_tokens: 4096
- provider_id: vllm-inference-safety
provider_type: inline::vllm
config:
model: Llama-Guard-3-1B
tensor_parallel_size: 1
gpu_memory_utilization: 0.2
enforce_eager: true
max_tokens: 4096
safety:
- provider_id: meta0
provider_type: inline::llama-guard
config:
model: Llama-Guard-3-1B
excluded_categories: []
# Uncomment to use prompt guard
# - provider_id: meta1
# provider_type: inline::prompt-guard
# config:
# model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: inline::meta-reference
config: {}
# Uncomment to use pgvector
# - provider_id: pgvector
# provider_type: remote::pgvector
# config:
# host: 127.0.0.1
# port: 5432
# db: postgres
# user: postgres
# password: mysecretpassword
agents:
- provider_id: meta0
provider_type: inline::meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/agents_store.db
telemetry:
- provider_id: meta0
provider_type: inline::meta-reference
config: {}

View file

@ -25,11 +25,10 @@ services:
# satisfy all the requested capabilities for a successful # satisfy all the requested capabilities for a successful
# reservation. # reservation.
capabilities: [gpu] capabilities: [gpu]
runtime: nvidia
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
deploy:
restart_policy: restart_policy:
condition: on-failure condition: on-failure
delay: 3s delay: 3s
max_attempts: 5 max_attempts: 5
window: 60s window: 60s
runtime: nvidia
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"

View file

@ -0,0 +1 @@
../../llama_stack/templates/meta-reference-gpu/run-with-safety.yaml

View file

@ -1,66 +0,0 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: meta-reference-inference
provider_type: meta-reference
config:
model: Llama3.2-3B-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
- provider_id: meta-reference-safety
provider_type: meta-reference
config:
model: Llama-Guard-3-1B
quantization: null
torch_seed: null
max_seq_len: 2048
max_batch_size: 1
safety:
- provider_id: meta0
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
# Uncomment to use prompt guard
# prompt_guard_shield:
# model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: meta-reference
config: {}
# Uncomment to use pgvector
# - provider_id: pgvector
# provider_type: remote::pgvector
# config:
# host: 127.0.0.1
# port: 5432
# db: postgres
# user: postgres
# password: mysecretpassword
agents:
- provider_id: meta0
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/agents_store.db
telemetry:
- provider_id: meta0
provider_type: meta-reference
config: {}

View file

@ -0,0 +1 @@
../../llama_stack/templates/meta-reference-gpu/run.yaml

View file

@ -1,5 +1,4 @@
version: '2' version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local image_name: local
docker_image: null docker_image: null
conda_env: local conda_env: local
@ -14,7 +13,7 @@ apis:
providers: providers:
inference: inference:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference-quantized provider_type: inline::meta-reference-quantized
config: config:
model: Llama3.2-3B-Instruct:int4-qlora-eo8 model: Llama3.2-3B-Instruct:int4-qlora-eo8
quantization: quantization:
@ -22,24 +21,32 @@ providers:
torch_seed: null torch_seed: null
max_seq_len: 2048 max_seq_len: 2048
max_batch_size: 1 max_batch_size: 1
- provider_id: meta1
provider_type: inline::meta-reference-quantized
config:
# not a quantized model !
model: Llama-Guard-3-1B
quantization: null
torch_seed: null
max_seq_len: 2048
max_batch_size: 1
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: model: Llama-Guard-3-1B
model: Llama-Guard-3-1B excluded_categories: []
excluded_categories: [] - provider_id: meta1
disable_input_check: false provider_type: inline::prompt-guard
disable_output_check: false config:
prompt_guard_shield: model: Prompt-Guard-86M
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: {} config: {}
agents: agents:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: config:
persistence_store: persistence_store:
namespace: null namespace: null
@ -47,5 +54,5 @@ providers:
db_path: ~/.llama/runtime/kvstore.db db_path: ~/.llama/runtime/kvstore.db
telemetry: telemetry:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: {} config: {}

View file

@ -0,0 +1 @@
../../llama_stack/templates/ollama/build.yaml

View file

@ -1,5 +1,4 @@
version: '2' version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local image_name: local
docker_image: null docker_image: null
conda_env: local conda_env: local
@ -13,28 +12,22 @@ apis:
- safety - safety
providers: providers:
inference: inference:
- provider_id: ollama0 - provider_id: ollama
provider_type: remote::ollama provider_type: remote::ollama
config: config:
url: http://127.0.0.1:14343 url: ${env.OLLAMA_URL:http://127.0.0.1:11434}
safety: safety:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::llama-guard
config: config:
llama_guard_shield: excluded_categories: []
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
memory: memory:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: {} config: {}
agents: agents:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: config:
persistence_store: persistence_store:
namespace: null namespace: null
@ -42,5 +35,12 @@ providers:
db_path: ~/.llama/runtime/kvstore.db db_path: ~/.llama/runtime/kvstore.db
telemetry: telemetry:
- provider_id: meta0 - provider_id: meta0
provider_type: meta-reference provider_type: inline::meta-reference
config: {} config: {}
models:
- model_id: ${env.INFERENCE_MODEL:Llama3.2-3B-Instruct}
provider_id: ollama
- model_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}
provider_id: ollama
shields:
- shield_id: ${env.SAFETY_MODEL:Llama-Guard-3-1B}

View file

@ -0,0 +1,71 @@
services:
ollama:
image: ollama/ollama:latest
network_mode: ${NETWORK_MODE:-bridge}
volumes:
- ~/.ollama:/root/.ollama
ports:
- "11434:11434"
environment:
OLLAMA_DEBUG: 1
command: []
deploy:
resources:
limits:
memory: 8G # Set maximum memory
reservations:
memory: 8G # Set minimum memory reservation
# healthcheck:
# # ugh, no CURL in ollama image
# test: ["CMD", "curl", "-f", "http://ollama:11434"]
# interval: 10s
# timeout: 5s
# retries: 5
ollama-init:
image: ollama/ollama:latest
depends_on:
- ollama
# condition: service_healthy
network_mode: ${NETWORK_MODE:-bridge}
environment:
- OLLAMA_HOST=ollama
- INFERENCE_MODEL=${INFERENCE_MODEL}
- SAFETY_MODEL=${SAFETY_MODEL:-}
volumes:
- ~/.ollama:/root/.ollama
- ./pull-models.sh:/pull-models.sh
entrypoint: ["/pull-models.sh"]
llamastack:
depends_on:
ollama:
condition: service_started
ollama-init:
condition: service_started
image: ${LLAMA_STACK_IMAGE:-llamastack/distribution-ollama}
network_mode: ${NETWORK_MODE:-bridge}
volumes:
- ~/.llama:/root/.llama
# Link to ollama run.yaml file
- ~/local/llama-stack/:/app/llama-stack-source
- ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
environment:
- INFERENCE_MODEL=${INFERENCE_MODEL}
- SAFETY_MODEL=${SAFETY_MODEL:-}
- OLLAMA_URL=http://ollama:11434
entrypoint: >
python -m llama_stack.distribution.server.server /root/my-run.yaml \
--port ${LLAMA_STACK_PORT:-5001}
deploy:
restart_policy:
condition: on-failure
delay: 10s
max_attempts: 3
window: 60s
volumes:
ollama:
ollama-init:
llamastack:

View file

@ -1,30 +0,0 @@
services:
ollama:
image: ollama/ollama:latest
network_mode: "host"
volumes:
- ollama:/root/.ollama # this solution synchronizes with the docker volume and loads the model rocket fast
ports:
- "11434:11434"
command: []
llamastack:
depends_on:
- ollama
image: llamastack/distribution-ollama
network_mode: "host"
volumes:
- ~/.llama:/root/.llama
# Link to ollama run.yaml file
- ./run.yaml:/root/my-run.yaml
ports:
- "5000:5000"
# Hack: wait for ollama server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
deploy:
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s
volumes:
ollama:

View file

@ -1,46 +0,0 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: ollama0
provider_type: remote::ollama
config:
url: http://127.0.0.1:14343
safety:
- provider_id: meta0
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: meta-reference
config: {}
agents:
- provider_id: meta0
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta0
provider_type: meta-reference
config: {}

View file

@ -0,0 +1,18 @@
#!/bin/sh
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
echo "Preloading (${INFERENCE_MODEL}, ${SAFETY_MODEL})..."
for model in ${INFERENCE_MODEL} ${SAFETY_MODEL}; do
echo "Preloading $model..."
if ! ollama run "$model"; then
echo "Failed to pull and run $model"
exit 1
fi
done
echo "All models pulled successfully"

View file

@ -0,0 +1 @@
../../llama_stack/templates/ollama/run-with-safety.yaml

View file

@ -0,0 +1 @@
../../llama_stack/templates/ollama/run.yaml

View file

@ -0,0 +1 @@
../../llama_stack/templates/remote-vllm/build.yaml

View file

@ -0,0 +1,100 @@
services:
vllm-inference:
image: vllm/vllm-openai:latest
volumes:
- $HOME/.cache/huggingface:/root/.cache/huggingface
network_mode: ${NETWORK_MODE:-bridged}
ports:
- "${VLLM_INFERENCE_PORT:-5100}:${VLLM_INFERENCE_PORT:-5100}"
devices:
- nvidia.com/gpu=all
environment:
- CUDA_VISIBLE_DEVICES=${VLLM_INFERENCE_GPU:-0}
- HUGGING_FACE_HUB_TOKEN=$HF_TOKEN
command: >
--gpu-memory-utilization 0.75
--model ${VLLM_INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
--enforce-eager
--max-model-len 8192
--max-num-seqs 16
--port ${VLLM_INFERENCE_PORT:-5100}
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:${VLLM_INFERENCE_PORT:-5100}/v1/health"]
interval: 30s
timeout: 10s
retries: 5
deploy:
resources:
reservations:
devices:
- driver: nvidia
capabilities: [gpu]
runtime: nvidia
# A little trick:
# if VLLM_SAFETY_MODEL is set, we will create a service for the safety model
# otherwise, the entry will end in a hyphen which gets ignored by docker compose
vllm-${VLLM_SAFETY_MODEL:+safety}:
image: vllm/vllm-openai:latest
volumes:
- $HOME/.cache/huggingface:/root/.cache/huggingface
network_mode: ${NETWORK_MODE:-bridged}
ports:
- "${VLLM_SAFETY_PORT:-5101}:${VLLM_SAFETY_PORT:-5101}"
devices:
- nvidia.com/gpu=all
environment:
- CUDA_VISIBLE_DEVICES=${VLLM_SAFETY_GPU:-1}
- HUGGING_FACE_HUB_TOKEN=$HF_TOKEN
command: >
--gpu-memory-utilization 0.75
--model ${VLLM_SAFETY_MODEL}
--enforce-eager
--max-model-len 8192
--max-num-seqs 16
--port ${VLLM_SAFETY_PORT:-5101}
healthcheck:
test: ["CMD", "curl", "-f", "http://localhost:${VLLM_SAFETY_PORT:-5101}/v1/health"]
interval: 30s
timeout: 10s
retries: 5
deploy:
resources:
reservations:
devices:
- driver: nvidia
capabilities: [gpu]
runtime: nvidia
llamastack:
depends_on:
- vllm-inference:
condition: service_healthy
- vllm-${VLLM_SAFETY_MODEL:+safety}:
condition: service_healthy
# image: llamastack/distribution-remote-vllm
image: llamastack/distribution-remote-vllm:test-0.0.52rc3
volumes:
- ~/.llama:/root/.llama
- ./run${VLLM_SAFETY_MODEL:+-with-safety}.yaml:/root/llamastack-run-remote-vllm.yaml
network_mode: ${NETWORK_MODE:-bridged}
environment:
- VLLM_URL=http://vllm-inference:${VLLM_INFERENCE_PORT:-5100}/v1
- VLLM_SAFETY_URL=http://vllm-safety:${VLLM_SAFETY_PORT:-5101}/v1
- INFERENCE_MODEL=${INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
- MAX_TOKENS=${MAX_TOKENS:-4096}
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
ports:
- "${LLAMASTACK_PORT:-5001}:${LLAMASTACK_PORT:-5001}"
# Hack: wait for vLLM server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 5001"
deploy:
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s
volumes:
vllm-inference:
vllm-safety:
llamastack:

View file

@ -0,0 +1 @@
../../llama_stack/templates/remote-vllm/run-with-safety.yaml

View file

@ -0,0 +1 @@
../../llama_stack/templates/remote-vllm/run.yaml

View file

@ -0,0 +1,103 @@
services:
tgi-inference:
image: ghcr.io/huggingface/text-generation-inference:latest
volumes:
- $HOME/.cache/huggingface:/data
network_mode: ${NETWORK_MODE:-bridged}
ports:
- "${TGI_INFERENCE_PORT:-8080}:${TGI_INFERENCE_PORT:-8080}"
devices:
- nvidia.com/gpu=all
environment:
- CUDA_VISIBLE_DEVICES=${TGI_INFERENCE_GPU:-0}
- HF_TOKEN=$HF_TOKEN
- HF_HOME=/data
- HF_DATASETS_CACHE=/data
- HF_MODULES_CACHE=/data
- HF_HUB_CACHE=/data
command: >
--dtype bfloat16
--usage-stats off
--sharded false
--model-id ${TGI_INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
--port ${TGI_INFERENCE_PORT:-8080}
--cuda-memory-fraction 0.75
healthcheck:
test: ["CMD", "curl", "-f", "http://tgi-inference:${TGI_INFERENCE_PORT:-8080}/health"]
interval: 5s
timeout: 5s
retries: 30
deploy:
resources:
reservations:
devices:
- driver: nvidia
capabilities: [gpu]
runtime: nvidia
tgi-${TGI_SAFETY_MODEL:+safety}:
image: ghcr.io/huggingface/text-generation-inference:latest
volumes:
- $HOME/.cache/huggingface:/data
network_mode: ${NETWORK_MODE:-bridged}
ports:
- "${TGI_SAFETY_PORT:-8081}:${TGI_SAFETY_PORT:-8081}"
devices:
- nvidia.com/gpu=all
environment:
- CUDA_VISIBLE_DEVICES=${TGI_SAFETY_GPU:-1}
- HF_TOKEN=$HF_TOKEN
- HF_HOME=/data
- HF_DATASETS_CACHE=/data
- HF_MODULES_CACHE=/data
- HF_HUB_CACHE=/data
command: >
--dtype bfloat16
--usage-stats off
--sharded false
--model-id ${TGI_SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
--port ${TGI_SAFETY_PORT:-8081}
--cuda-memory-fraction 0.75
healthcheck:
test: ["CMD", "curl", "-f", "http://tgi-safety:${TGI_SAFETY_PORT:-8081}/health"]
interval: 5s
timeout: 5s
retries: 30
deploy:
resources:
reservations:
devices:
- driver: nvidia
capabilities: [gpu]
runtime: nvidia
llamastack:
depends_on:
tgi-inference:
condition: service_healthy
tgi-${TGI_SAFETY_MODEL:+safety}:
condition: service_healthy
image: llamastack/distribution-tgi:test-0.0.52rc3
network_mode: ${NETWORK_MODE:-bridged}
volumes:
- ~/.llama:/root/.llama
- ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
ports:
- "${LLAMA_STACK_PORT:-5001}:${LLAMA_STACK_PORT:-5001}"
# Hack: wait for TGI server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s
environment:
- TGI_URL=http://tgi-inference:${TGI_INFERENCE_PORT:-8080}
- SAFETY_TGI_URL=http://tgi-safety:${TGI_SAFETY_PORT:-8081}
- INFERENCE_MODEL=${INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
volumes:
tgi-inference:
tgi-safety:
llamastack:

View file

@ -1,33 +0,0 @@
services:
text-generation-inference:
image: ghcr.io/huggingface/text-generation-inference:latest
network_mode: "host"
volumes:
- $HOME/.cache/huggingface:/data
ports:
- "5009:5009"
command: ["--dtype", "bfloat16", "--usage-stats", "on", "--sharded", "false", "--model-id", "meta-llama/Llama-3.1-8B-Instruct", "--port", "5009", "--cuda-memory-fraction", "0.3"]
runtime: nvidia
healthcheck:
test: ["CMD", "curl", "-f", "http://text-generation-inference:5009/health"]
interval: 5s
timeout: 5s
retries: 30
llamastack:
depends_on:
text-generation-inference:
condition: service_healthy
image: llamastack/llamastack-tgi
network_mode: "host"
volumes:
- ~/.llama:/root/.llama
# Link to run.yaml file
- ./run.yaml:/root/my-run.yaml
ports:
- "5000:5000"
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s

View file

@ -1,46 +0,0 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: <ENTER_YOUR_TGI_HOSTED_ENDPOINT>
safety:
- provider_id: meta0
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: meta-reference
config: {}
agents:
- provider_id: meta0
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta0
provider_type: meta-reference
config: {}

View file

@ -1,55 +0,0 @@
services:
text-generation-inference:
image: ghcr.io/huggingface/text-generation-inference:latest
network_mode: "host"
volumes:
- $HOME/.cache/huggingface:/data
ports:
- "5009:5009"
devices:
- nvidia.com/gpu=all
environment:
- CUDA_VISIBLE_DEVICES=0
- HF_HOME=/data
- HF_DATASETS_CACHE=/data
- HF_MODULES_CACHE=/data
- HF_HUB_CACHE=/data
command: ["--dtype", "bfloat16", "--usage-stats", "on", "--sharded", "false", "--model-id", "meta-llama/Llama-3.1-8B-Instruct", "--port", "5009", "--cuda-memory-fraction", "0.3"]
deploy:
resources:
reservations:
devices:
- driver: nvidia
# that's the closest analogue to --gpus; provide
# an integer amount of devices or 'all'
count: 1
# Devices are reserved using a list of capabilities, making
# capabilities the only required field. A device MUST
# satisfy all the requested capabilities for a successful
# reservation.
capabilities: [gpu]
runtime: nvidia
healthcheck:
test: ["CMD", "curl", "-f", "http://text-generation-inference:5009/health"]
interval: 5s
timeout: 5s
retries: 30
llamastack:
depends_on:
text-generation-inference:
condition: service_healthy
image: llamastack/distribution-tgi
network_mode: "host"
volumes:
- ~/.llama:/root/.llama
# Link to TGI run.yaml file
- ./run.yaml:/root/my-run.yaml
ports:
- "5000:5000"
# Hack: wait for TGI server to start before starting docker
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
restart_policy:
condition: on-failure
delay: 3s
max_attempts: 5
window: 60s

View file

@ -1,46 +0,0 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009
safety:
- provider_id: meta0
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: meta-reference
config: {}
agents:
- provider_id: meta0
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta0
provider_type: meta-reference
config: {}

View file

@ -0,0 +1 @@
../../llama_stack/templates/tgi/run-with-safety.yaml

1
distributions/tgi/run.yaml Symbolic link
View file

@ -0,0 +1 @@
../../llama_stack/templates/tgi/run.yaml

View file

@ -1,47 +0,0 @@
version: '2'
built_at: '2024-10-08T17:40:45.325529'
image_name: local
docker_image: null
conda_env: local
apis:
- shields
- agents
- models
- memory
- memory_banks
- inference
- safety
providers:
inference:
- provider_id: together0
provider_type: remote::together
config:
url: https://api.together.xyz/v1
# api_key: <ENTER_YOUR_API_KEY>
safety:
- provider_id: meta0
provider_type: meta-reference
config:
llama_guard_shield:
model: Llama-Guard-3-1B
excluded_categories: []
disable_input_check: false
disable_output_check: false
prompt_guard_shield:
model: Prompt-Guard-86M
memory:
- provider_id: meta0
provider_type: remote::weaviate
config: {}
agents:
- provider_id: meta0
provider_type: meta-reference
config:
persistence_store:
namespace: null
type: sqlite
db_path: ~/.llama/runtime/kvstore.db
telemetry:
- provider_id: meta0
provider_type: meta-reference
config: {}

View file

@ -0,0 +1 @@
../../llama_stack/templates/together/run.yaml

View file

@ -1 +0,0 @@
../../llama_stack/templates/vllm/build.yaml

View file

@ -0,0 +1,796 @@
{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
" let's explore how to have a conversation about images using the Memory API! This section will show you how to:\n",
"1. Load and prepare images for the API\n",
"2. Send image-based queries\n",
"3. Create an interactive chat loop with images\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"import base64\n",
"import mimetypes\n",
"from pathlib import Path\n",
"from typing import Optional, Union\n",
"\n",
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.types import UserMessage\n",
"from llama_stack_client.lib.inference.event_logger import EventLogger\n",
"from termcolor import cprint\n",
"\n",
"# Helper function to convert image to data URL\n",
"def image_to_data_url(file_path: Union[str, Path]) -> str:\n",
" \"\"\"Convert an image file to a data URL format.\n",
"\n",
" Args:\n",
" file_path: Path to the image file\n",
"\n",
" Returns:\n",
" str: Data URL containing the encoded image\n",
" \"\"\"\n",
" file_path = Path(file_path)\n",
" if not file_path.exists():\n",
" raise FileNotFoundError(f\"Image not found: {file_path}\")\n",
"\n",
" mime_type, _ = mimetypes.guess_type(str(file_path))\n",
" if mime_type is None:\n",
" raise ValueError(\"Could not determine MIME type of the image\")\n",
"\n",
" with open(file_path, \"rb\") as image_file:\n",
" encoded_string = base64.b64encode(image_file.read()).decode(\"utf-8\")\n",
"\n",
" return f\"data:{mime_type};base64,{encoded_string}\""
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 2. Create an Interactive Image Chat\n",
"\n",
"Let's create a function that enables back-and-forth conversation about an image:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from IPython.display import Image, display\n",
"import ipywidgets as widgets\n",
"\n",
"# Display the image we'll be chatting about\n",
"image_path = \"your_image.jpg\" # Replace with your image path\n",
"display(Image(filename=image_path))\n",
"\n",
"# Initialize the client\n",
"client = LlamaStackClient(\n",
" base_url=f\"http://localhost:8000\", # Adjust host/port as needed\n",
")\n",
"\n",
"# Create chat interface\n",
"output = widgets.Output()\n",
"text_input = widgets.Text(\n",
" value='',\n",
" placeholder='Type your question about the image...',\n",
" description='Ask:',\n",
" disabled=False\n",
")\n",
"\n",
"# Display interface\n",
"display(text_input, output)\n",
"\n",
"# Handle chat interaction\n",
"async def on_submit(change):\n",
" with output:\n",
" question = text_input.value\n",
" if question.lower() == 'exit':\n",
" print(\"Chat ended.\")\n",
" return\n",
"\n",
" message = UserMessage(\n",
" role=\"user\",\n",
" content=[\n",
" {\"image\": {\"uri\": image_to_data_url(image_path)}},\n",
" question,\n",
" ],\n",
" )\n",
"\n",
" print(f\"\\nUser> {question}\")\n",
" response = client.inference.chat_completion(\n",
" messages=[message],\n",
" model=\"Llama3.2-11B-Vision-Instruct\",\n",
" stream=True,\n",
" )\n",
"\n",
" print(\"Assistant> \", end='')\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
" text_input.value = '' # Clear input after sending\n",
"\n",
"text_input.on_submit(lambda x: asyncio.create_task(on_submit(x)))"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Tool Calling"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"In this section, we'll explore how to enhance your applications with tool calling capabilities. We'll cover:\n",
"1. Setting up and using the Brave Search API\n",
"2. Creating custom tools\n",
"3. Configuring tool prompts and safety settings"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import asyncio\n",
"import os\n",
"from typing import Dict, List, Optional\n",
"from dotenv import load_dotenv\n",
"\n",
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.lib.agents.agent import Agent\n",
"from llama_stack_client.lib.agents.event_logger import EventLogger\n",
"from llama_stack_client.types.agent_create_params import (\n",
" AgentConfig,\n",
" AgentConfigToolSearchToolDefinition,\n",
")\n",
"\n",
"# Load environment variables\n",
"load_dotenv()\n",
"\n",
"# Helper function to create an agent with tools\n",
"async def create_tool_agent(\n",
" client: LlamaStackClient,\n",
" tools: List[Dict],\n",
" instructions: str = \"You are a helpful assistant\",\n",
" model: str = \"Llama3.1-8B-Instruct\",\n",
") -> Agent:\n",
" \"\"\"Create an agent with specified tools.\"\"\"\n",
" agent_config = AgentConfig(\n",
" model=model,\n",
" instructions=instructions,\n",
" sampling_params={\n",
" \"strategy\": \"greedy\",\n",
" \"temperature\": 1.0,\n",
" \"top_p\": 0.9,\n",
" },\n",
" tools=tools,\n",
" tool_choice=\"auto\",\n",
" tool_prompt_format=\"json\",\n",
" input_shields=[\"Llama-Guard-3-1B\"],\n",
" output_shields=[\"Llama-Guard-3-1B\"],\n",
" enable_session_persistence=True,\n",
" )\n",
"\n",
" return Agent(client, agent_config)"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"First, create a `.env` file in your notebook directory with your Brave Search API key:\n",
"\n",
"```\n",
"BRAVE_SEARCH_API_KEY=your_key_here\n",
"```\n"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"async def create_search_agent(client: LlamaStackClient) -> Agent:\n",
" \"\"\"Create an agent with Brave Search capability.\"\"\"\n",
" search_tool = AgentConfigToolSearchToolDefinition(\n",
" type=\"brave_search\",\n",
" engine=\"brave\",\n",
" api_key=os.getenv(\"BRAVE_SEARCH_API_KEY\"),\n",
" )\n",
"\n",
" return await create_tool_agent(\n",
" client=client,\n",
" tools=[search_tool],\n",
" instructions=\"\"\"\n",
" You are a research assistant that can search the web.\n",
" Always cite your sources with URLs when providing information.\n",
" Format your responses as:\n",
"\n",
" FINDINGS:\n",
" [Your summary here]\n",
"\n",
" SOURCES:\n",
" - [Source title](URL)\n",
" \"\"\"\n",
" )\n",
"\n",
"# Example usage\n",
"async def search_example():\n",
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
" agent = await create_search_agent(client)\n",
"\n",
" # Create a session\n",
" session_id = agent.create_session(\"search-session\")\n",
"\n",
" # Example queries\n",
" queries = [\n",
" \"What are the latest developments in quantum computing?\",\n",
" \"Who won the most recent Super Bowl?\",\n",
" ]\n",
"\n",
" for query in queries:\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
"# Run the example (in Jupyter, use asyncio.run())\n",
"await search_example()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 3. Custom Tool Creation\n",
"\n",
"Let's create a custom weather tool:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"from typing import TypedDict, Optional\n",
"from datetime import datetime\n",
"\n",
"# Define tool types\n",
"class WeatherInput(TypedDict):\n",
" location: str\n",
" date: Optional[str]\n",
"\n",
"class WeatherOutput(TypedDict):\n",
" temperature: float\n",
" conditions: str\n",
" humidity: float\n",
"\n",
"class WeatherTool:\n",
" \"\"\"Example custom tool for weather information.\"\"\"\n",
"\n",
" def __init__(self, api_key: Optional[str] = None):\n",
" self.api_key = api_key\n",
"\n",
" async def get_weather(self, location: str, date: Optional[str] = None) -> WeatherOutput:\n",
" \"\"\"Simulate getting weather data (replace with actual API call).\"\"\"\n",
" # Mock implementation\n",
" return {\n",
" \"temperature\": 72.5,\n",
" \"conditions\": \"partly cloudy\",\n",
" \"humidity\": 65.0\n",
" }\n",
"\n",
" async def __call__(self, input_data: WeatherInput) -> WeatherOutput:\n",
" \"\"\"Make the tool callable with structured input.\"\"\"\n",
" return await self.get_weather(\n",
" location=input_data[\"location\"],\n",
" date=input_data.get(\"date\")\n",
" )\n",
"\n",
"async def create_weather_agent(client: LlamaStackClient) -> Agent:\n",
" \"\"\"Create an agent with weather tool capability.\"\"\"\n",
" weather_tool = {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_weather\",\n",
" \"description\": \"Get weather information for a location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"City or location name\"\n",
" },\n",
" \"date\": {\n",
" \"type\": \"string\",\n",
" \"description\": \"Optional date (YYYY-MM-DD)\",\n",
" \"format\": \"date\"\n",
" }\n",
" },\n",
" \"required\": [\"location\"]\n",
" }\n",
" },\n",
" \"implementation\": WeatherTool()\n",
" }\n",
"\n",
" return await create_tool_agent(\n",
" client=client,\n",
" tools=[weather_tool],\n",
" instructions=\"\"\"\n",
" You are a weather assistant that can provide weather information.\n",
" Always specify the location clearly in your responses.\n",
" Include both temperature and conditions in your summaries.\n",
" \"\"\"\n",
" )\n",
"\n",
"# Example usage\n",
"async def weather_example():\n",
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
" agent = await create_weather_agent(client)\n",
"\n",
" session_id = agent.create_session(\"weather-session\")\n",
"\n",
" queries = [\n",
" \"What's the weather like in San Francisco?\",\n",
" \"Tell me the weather in Tokyo tomorrow\",\n",
" ]\n",
"\n",
" for query in queries:\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
"\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
"\n",
"# Run the example\n",
"await weather_example()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Multi-Tool Agent"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"async def create_multi_tool_agent(client: LlamaStackClient) -> Agent:\n",
" \"\"\"Create an agent with multiple tools.\"\"\"\n",
" tools = [\n",
" # Brave Search tool\n",
" AgentConfigToolSearchToolDefinition(\n",
" type=\"brave_search\",\n",
" engine=\"brave\",\n",
" api_key=os.getenv(\"BRAVE_SEARCH_API_KEY\"),\n",
" ),\n",
" # Weather tool\n",
" {\n",
" \"type\": \"function\",\n",
" \"function\": {\n",
" \"name\": \"get_weather\",\n",
" \"description\": \"Get weather information for a location\",\n",
" \"parameters\": {\n",
" \"type\": \"object\",\n",
" \"properties\": {\n",
" \"location\": {\"type\": \"string\"},\n",
" \"date\": {\"type\": \"string\", \"format\": \"date\"}\n",
" },\n",
" \"required\": [\"location\"]\n",
" }\n",
" },\n",
" \"implementation\": WeatherTool()\n",
" }\n",
" ]\n",
"\n",
" return await create_tool_agent(\n",
" client=client,\n",
" tools=tools,\n",
" instructions=\"\"\"\n",
" You are an assistant that can search the web and check weather information.\n",
" Use the appropriate tool based on the user's question.\n",
" For weather queries, always specify location and conditions.\n",
" For web searches, always cite your sources.\n",
" \"\"\"\n",
" )\n",
"\n",
"# Interactive example with multi-tool agent\n",
"async def interactive_multi_tool():\n",
" client = LlamaStackClient(base_url=\"http://localhost:8000\")\n",
" agent = await create_multi_tool_agent(client)\n",
" session_id = agent.create_session(\"interactive-session\")\n",
"\n",
" print(\"🤖 Multi-tool Agent Ready! (type 'exit' to quit)\")\n",
" print(\"Example questions:\")\n",
" print(\"- What's the weather in Paris and what events are happening there?\")\n",
" print(\"- Tell me about recent space discoveries and the weather on Mars\")\n",
"\n",
" while True:\n",
" query = input(\"\\nYour question: \")\n",
" if query.lower() == 'exit':\n",
" break\n",
"\n",
" print(\"\\nThinking...\")\n",
" try:\n",
" response = agent.create_turn(\n",
" messages=[{\"role\": \"user\", \"content\": query}],\n",
" session_id=session_id,\n",
" )\n",
"\n",
" async for log in EventLogger().log(response):\n",
" log.print()\n",
" except Exception as e:\n",
" print(f\"Error: {e}\")\n",
"\n",
"# Run interactive example\n",
"await interactive_multi_tool()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## Memory "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"Getting Started with Memory API Tutorial 🚀\n",
"Welcome! This interactive tutorial will guide you through using the Memory API, a powerful tool for document storage and retrieval. Whether you're new to vector databases or an experienced developer, this notebook will help you understand the basics and get up and running quickly.\n",
"What you'll learn:\n",
"\n",
"How to set up and configure the Memory API client\n",
"Creating and managing memory banks (vector stores)\n",
"Different ways to insert documents into the system\n",
"How to perform intelligent queries on your documents\n",
"\n",
"Prerequisites:\n",
"\n",
"Basic Python knowledge\n",
"A running instance of the Memory API server (we'll use localhost in this tutorial)\n",
"\n",
"Let's start by installing the required packages:"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Install the client library and a helper package for colored output\n",
"!pip install llama-stack-client termcolor\n",
"\n",
"# 💡 Note: If you're running this in a new environment, you might need to restart\n",
"# your kernel after installation"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"1. Initial Setup\n",
"First, we'll import the necessary libraries and set up some helper functions. Let's break down what each import does:\n",
"\n",
"llama_stack_client: Our main interface to the Memory API\n",
"base64: Helps us encode files for transmission\n",
"mimetypes: Determines file types automatically\n",
"termcolor: Makes our output prettier with colors\n",
"\n",
"❓ Question: Why do we need to convert files to data URLs?\n",
"Answer: Data URLs allow us to embed file contents directly in our requests, making it easier to transmit files to the API without needing separate file uploads."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"import base64\n",
"import json\n",
"import mimetypes\n",
"import os\n",
"from pathlib import Path\n",
"\n",
"from llama_stack_client import LlamaStackClient\n",
"from llama_stack_client.types.memory_insert_params import Document\n",
"from termcolor import cprint\n",
"\n",
"# Helper function to convert files to data URLs\n",
"def data_url_from_file(file_path: str) -> str:\n",
" \"\"\"Convert a file to a data URL for API transmission\n",
"\n",
" Args:\n",
" file_path (str): Path to the file to convert\n",
"\n",
" Returns:\n",
" str: Data URL containing the file's contents\n",
"\n",
" Example:\n",
" >>> url = data_url_from_file('example.txt')\n",
" >>> print(url[:30]) # Preview the start of the URL\n",
" 'data:text/plain;base64,SGVsbG8='\n",
" \"\"\"\n",
" if not os.path.exists(file_path):\n",
" raise FileNotFoundError(f\"File not found: {file_path}\")\n",
"\n",
" with open(file_path, \"rb\") as file:\n",
" file_content = file.read()\n",
"\n",
" base64_content = base64.b64encode(file_content).decode(\"utf-8\")\n",
" mime_type, _ = mimetypes.guess_type(file_path)\n",
"\n",
" data_url = f\"data:{mime_type};base64,{base64_content}\"\n",
" return data_url"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"2. Initialize Client and Create Memory Bank\n",
"Now we'll set up our connection to the Memory API and create our first memory bank. A memory bank is like a specialized database that stores document embeddings for semantic search.\n",
"❓ Key Concepts:\n",
"\n",
"embedding_model: The model used to convert text into vector representations\n",
"chunk_size: How large each piece of text should be when splitting documents\n",
"overlap_size: How much overlap between chunks (helps maintain context)\n",
"\n",
"✨ Pro Tip: Choose your chunk size based on your use case. Smaller chunks (256-512 tokens) are better for precise retrieval, while larger chunks (1024+ tokens) maintain more context."
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Configure connection parameters\n",
"HOST = \"localhost\" # Replace with your host if using a remote server\n",
"PORT = 8000 # Replace with your port if different\n",
"\n",
"# Initialize client\n",
"client = LlamaStackClient(\n",
" base_url=f\"http://{HOST}:{PORT}\",\n",
")\n",
"\n",
"# Let's see what providers are available\n",
"# Providers determine where and how your data is stored\n",
"providers = client.providers.list()\n",
"print(\"Available providers:\")\n",
"print(json.dumps(providers, indent=2))\n",
"\n",
"# Create a memory bank with optimized settings for general use\n",
"client.memory_banks.register(\n",
" memory_bank={\n",
" \"identifier\": \"tutorial_bank\", # A unique name for your memory bank\n",
" \"embedding_model\": \"all-MiniLM-L6-v2\", # A lightweight but effective model\n",
" \"chunk_size_in_tokens\": 512, # Good balance between precision and context\n",
" \"overlap_size_in_tokens\": 64, # Helps maintain context between chunks\n",
" \"provider_id\": providers[\"memory\"][0].provider_id, # Use the first available provider\n",
" }\n",
")\n",
"\n",
"# Let's verify our memory bank was created\n",
"memory_banks = client.memory_banks.list()\n",
"print(\"\\nRegistered memory banks:\")\n",
"print(json.dumps(memory_banks, indent=2))\n",
"\n",
"# 🎯 Exercise: Try creating another memory bank with different settings!\n",
"# What happens if you try to create a bank with the same identifier?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"3. Insert Documents\n",
"The Memory API supports multiple ways to add documents. We'll demonstrate two common approaches:\n",
"\n",
"Loading documents from URLs\n",
"Loading documents from local files\n",
"\n",
"❓ Important Concepts:\n",
"\n",
"Each document needs a unique document_id\n",
"Metadata helps organize and filter documents later\n",
"The API automatically processes and chunks documents"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Example URLs to documentation\n",
"# 💡 Replace these with your own URLs or use the examples\n",
"urls = [\n",
" \"memory_optimizations.rst\",\n",
" \"chat.rst\",\n",
" \"llama3.rst\",\n",
"]\n",
"\n",
"# Create documents from URLs\n",
"# We add metadata to help organize our documents\n",
"url_documents = [\n",
" Document(\n",
" document_id=f\"url-doc-{i}\", # Unique ID for each document\n",
" content=f\"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}\",\n",
" mime_type=\"text/plain\",\n",
" metadata={\"source\": \"url\", \"filename\": url}, # Metadata helps with organization\n",
" )\n",
" for i, url in enumerate(urls)\n",
"]\n",
"\n",
"# Example with local files\n",
"# 💡 Replace these with your actual files\n",
"local_files = [\"example.txt\", \"readme.md\"]\n",
"file_documents = [\n",
" Document(\n",
" document_id=f\"file-doc-{i}\",\n",
" content=data_url_from_file(path),\n",
" metadata={\"source\": \"local\", \"filename\": path},\n",
" )\n",
" for i, path in enumerate(local_files)\n",
" if os.path.exists(path)\n",
"]\n",
"\n",
"# Combine all documents\n",
"all_documents = url_documents + file_documents\n",
"\n",
"# Insert documents into memory bank\n",
"response = client.memory.insert(\n",
" bank_id=\"tutorial_bank\",\n",
" documents=all_documents,\n",
")\n",
"\n",
"print(\"Documents inserted successfully!\")\n",
"\n",
"# 🎯 Exercise: Try adding your own documents!\n",
"# - What happens if you try to insert a document with an existing ID?\n",
"# - What other metadata might be useful to add?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"4. Query the Memory Bank\n",
"Now for the exciting part - querying our documents! The Memory API uses semantic search to find relevant content based on meaning, not just keywords.\n",
"❓ Understanding Scores:\n",
"\n",
"Scores range from 0 to 1, with 1 being the most relevant\n",
"Generally, scores above 0.7 indicate strong relevance\n",
"Consider your use case when deciding on score thresholds"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"def print_query_results(query: str):\n",
" \"\"\"Helper function to print query results in a readable format\n",
"\n",
" Args:\n",
" query (str): The search query to execute\n",
" \"\"\"\n",
" print(f\"\\nQuery: {query}\")\n",
" print(\"-\" * 50)\n",
"\n",
" response = client.memory.query(\n",
" bank_id=\"tutorial_bank\",\n",
" query=[query], # The API accepts multiple queries at once!\n",
" )\n",
"\n",
" for i, (chunk, score) in enumerate(zip(response.chunks, response.scores)):\n",
" print(f\"\\nResult {i+1} (Score: {score:.3f})\")\n",
" print(\"=\" * 40)\n",
" print(chunk)\n",
" print(\"=\" * 40)\n",
"\n",
"# Let's try some example queries\n",
"queries = [\n",
" \"How do I use LoRA?\", # Technical question\n",
" \"Tell me about memory optimizations\", # General topic\n",
" \"What are the key features of Llama 3?\" # Product-specific\n",
"]\n",
"\n",
"for query in queries:\n",
" print_query_results(query)\n",
"\n",
"# 🎯 Exercises:\n",
"# 1. Try writing your own queries! What works well? What doesn't?\n",
"# 2. How do different phrasings of the same question affect results?\n",
"# 3. What happens if you query for content that isn't in your documents?"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"5. Advanced Usage: Query with Metadata Filtering\n",
"One powerful feature is the ability to filter results based on metadata. This helps when you want to search within specific subsets of your documents.\n",
"❓ Use Cases for Metadata Filtering:\n",
"\n",
"Search within specific document types\n",
"Filter by date ranges\n",
"Limit results to certain authors or sources"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": [
"# Query with metadata filter\n",
"response = client.memory.query(\n",
" bank_id=\"tutorial_bank\",\n",
" query=[\"Tell me about optimization\"],\n",
" metadata_filter={\"source\": \"url\"} # Only search in URL documents\n",
")\n",
"\n",
"print(\"\\nFiltered Query Results:\")\n",
"print(\"-\" * 50)\n",
"for chunk, score in zip(response.chunks, response.scores):\n",
" print(f\"Score: {score:.3f}\")\n",
" print(f\"Chunk:\\n{chunk}\\n\")\n",
"\n",
"# 🎯 Advanced Exercises:\n",
"# 1. Try combining multiple metadata filters\n",
"# 2. Compare results with and without filters\n",
"# 3. What happens with non-existent metadata fields?"
]
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"name": "python",
"version": "3.12.5"
}
},
"nbformat": 4,
"nbformat_minor": 2
}

BIN
docs/_static/safety_system.webp vendored Normal file

Binary file not shown.

After

Width:  |  Height:  |  Size: 31 KiB

View file

@ -31,60 +31,10 @@ from .strong_typing.schema import json_schema_type
schema_utils.json_schema_type = json_schema_type schema_utils.json_schema_type = json_schema_type
from llama_models.llama3.api.datatypes import * # noqa: F403 # this line needs to be here to ensure json_schema_type has been altered before
from llama_stack.apis.agents import * # noqa: F403 # the imports use the annotation
from llama_stack.apis.datasets import * # noqa: F403 from llama_stack.apis.version import LLAMA_STACK_API_VERSION # noqa: E402
from llama_stack.apis.datasetio import * # noqa: F403 from llama_stack.distribution.stack import LlamaStack # noqa: E402
from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.eval import * # noqa: F403
from llama_stack.apis.inference import * # noqa: F403
from llama_stack.apis.batch_inference import * # noqa: F403
from llama_stack.apis.memory import * # noqa: F403
from llama_stack.apis.telemetry import * # noqa: F403
from llama_stack.apis.post_training import * # noqa: F403
from llama_stack.apis.synthetic_data_generation import * # noqa: F403
from llama_stack.apis.safety import * # noqa: F403
from llama_stack.apis.models import * # noqa: F403
from llama_stack.apis.memory_banks import * # noqa: F403
from llama_stack.apis.shields import * # noqa: F403
from llama_stack.apis.inspect import * # noqa: F403
class LlamaStack(
MemoryBanks,
Inference,
BatchInference,
Agents,
Safety,
SyntheticDataGeneration,
Datasets,
Telemetry,
PostTraining,
Memory,
Eval,
Scoring,
ScoringFunctions,
DatasetIO,
Models,
Shields,
Inspect,
):
pass
# TODO: this should be fixed in the generator itself so it reads appropriate annotations
STREAMING_ENDPOINTS = [
"/agents/turn/create",
"/inference/chat_completion",
]
def patch_sse_stream_responses(spec: Specification):
for path, path_item in spec.document.paths.items():
if path in STREAMING_ENDPOINTS:
content = path_item.post.responses["200"].content.pop("application/json")
path_item.post.responses["200"].content["text/event-stream"] = content
def main(output_dir: str): def main(output_dir: str):
@ -103,7 +53,7 @@ def main(output_dir: str):
server=Server(url="http://any-hosted-llama-stack.com"), server=Server(url="http://any-hosted-llama-stack.com"),
info=Info( info=Info(
title="[DRAFT] Llama Stack Specification", title="[DRAFT] Llama Stack Specification",
version="0.0.1", version=LLAMA_STACK_API_VERSION,
description="""This is the specification of the llama stack that provides description="""This is the specification of the llama stack that provides
a set of endpoints and their corresponding interfaces that are tailored to a set of endpoints and their corresponding interfaces that are tailored to
best leverage Llama Models. The specification is still in draft and subject to change. best leverage Llama Models. The specification is still in draft and subject to change.
@ -113,8 +63,6 @@ def main(output_dir: str):
), ),
) )
patch_sse_stream_responses(spec)
with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp: with open(output_dir / "llama-stack-spec.yaml", "w", encoding="utf-8") as fp:
yaml.dump(spec.get_json(), fp, allow_unicode=True) yaml.dump(spec.get_json(), fp, allow_unicode=True)

View file

@ -4,6 +4,7 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import collections
import hashlib import hashlib
import ipaddress import ipaddress
import typing import typing
@ -176,9 +177,20 @@ class ContentBuilder:
) -> Dict[str, MediaType]: ) -> Dict[str, MediaType]:
"Creates the content subtree for a request or response." "Creates the content subtree for a request or response."
def has_iterator_type(t):
if typing.get_origin(t) is typing.Union:
return any(has_iterator_type(a) for a in typing.get_args(t))
else:
# TODO: needs a proper fix where we let all types correctly flow upwards
# and then test against AsyncIterator
return "StreamChunk" in str(t)
if is_generic_list(payload_type): if is_generic_list(payload_type):
media_type = "application/jsonl" media_type = "application/jsonl"
item_type = unwrap_generic_list(payload_type) item_type = unwrap_generic_list(payload_type)
elif has_iterator_type(payload_type):
item_type = payload_type
media_type = "text/event-stream"
else: else:
media_type = "application/json" media_type = "application/json"
item_type = payload_type item_type = payload_type
@ -190,7 +202,9 @@ class ContentBuilder:
) -> MediaType: ) -> MediaType:
schema = self.schema_builder.classdef_to_ref(item_type) schema = self.schema_builder.classdef_to_ref(item_type)
if self.schema_transformer: if self.schema_transformer:
schema_transformer: Callable[[SchemaOrRef], SchemaOrRef] = self.schema_transformer # type: ignore schema_transformer: Callable[[SchemaOrRef], SchemaOrRef] = (
self.schema_transformer
)
schema = schema_transformer(schema) schema = schema_transformer(schema)
if not examples: if not examples:
@ -618,6 +632,7 @@ class Generator:
raise NotImplementedError(f"unknown HTTP method: {op.http_method}") raise NotImplementedError(f"unknown HTTP method: {op.http_method}")
route = op.get_route() route = op.get_route()
print(f"route: {route}")
if route in paths: if route in paths:
paths[route].update(pathItem) paths[route].update(pathItem)
else: else:
@ -671,6 +686,8 @@ class Generator:
for extra_tag_group in extra_tag_groups.values(): for extra_tag_group in extra_tag_groups.values():
tags.extend(extra_tag_group) tags.extend(extra_tag_group)
tags = sorted(tags, key=lambda t: t.name)
tag_groups = [] tag_groups = []
if operation_tags: if operation_tags:
tag_groups.append( tag_groups.append(

View file

@ -12,6 +12,8 @@ import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union from typing import Any, Callable, Dict, Iterable, Iterator, List, Optional, Tuple, Union
from llama_stack.apis.version import LLAMA_STACK_API_VERSION
from termcolor import colored from termcolor import colored
from ..strong_typing.inspection import ( from ..strong_typing.inspection import (
@ -111,9 +113,12 @@ class EndpointOperation:
def get_route(self) -> str: def get_route(self) -> str:
if self.route is not None: if self.route is not None:
return self.route assert (
"_" not in self.route
), f"route should not contain underscores: {self.route}"
return "/".join(["", LLAMA_STACK_API_VERSION, self.route.lstrip("/")])
route_parts = ["", self.name] route_parts = ["", LLAMA_STACK_API_VERSION, self.name]
for param_name, _ in self.path_params: for param_name, _ in self.path_params:
route_parts.append("{" + param_name + "}") route_parts.append("{" + param_name + "}")
return "/".join(route_parts) return "/".join(route_parts)

View file

@ -358,6 +358,7 @@ def unwrap_union_types(typ: object) -> Tuple[object, ...]:
:returns: The inner types `T1`, `T2`, etc. :returns: The inner types `T1`, `T2`, etc.
""" """
typ = unwrap_annotated_type(typ)
return _unwrap_union_types(typ) return _unwrap_union_types(typ)

File diff suppressed because it is too large Load diff

File diff suppressed because it is too large Load diff

View file

@ -6,8 +6,8 @@ This guide contains references to walk you through adding a new API provider.
1. First, decide which API your provider falls into (e.g. Inference, Safety, Agents, Memory). 1. First, decide which API your provider falls into (e.g. Inference, Safety, Agents, Memory).
2. Decide whether your provider is a remote provider, or inline implmentation. A remote provider is a provider that makes a remote request to an service. An inline provider is a provider where implementation is executed locally. Checkout the examples, and follow the structure to add your own API provider. Please find the following code pointers: 2. Decide whether your provider is a remote provider, or inline implmentation. A remote provider is a provider that makes a remote request to an service. An inline provider is a provider where implementation is executed locally. Checkout the examples, and follow the structure to add your own API provider. Please find the following code pointers:
- [Inference Remote Adapter](https://github.com/meta-llama/llama-stack/tree/docs/llama_stack/providers/remote/inference) - [Remote Adapters](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/remote)
- [Inference Inline Provider](https://github.com/meta-llama/llama-stack/tree/docs/llama_stack/providers/inline/meta_reference/inference) - [Inline Providers](https://github.com/meta-llama/llama-stack/tree/main/llama_stack/providers/inline)
3. [Build a Llama Stack distribution](https://llama-stack.readthedocs.io/en/latest/distribution_dev/building_distro.html) with your API provider. 3. [Build a Llama Stack distribution](https://llama-stack.readthedocs.io/en/latest/distribution_dev/building_distro.html) with your API provider.
4. Test your code! 4. Test your code!

View file

@ -35,14 +35,14 @@ the provider types (implementations) you want to use for these APIs.
Tip: use <TAB> to see options for the providers. Tip: use <TAB> to see options for the providers.
> Enter provider for API inference: meta-reference > Enter provider for API inference: inline::meta-reference
> Enter provider for API safety: meta-reference > Enter provider for API safety: inline::llama-guard
> Enter provider for API agents: meta-reference > Enter provider for API agents: inline::meta-reference
> Enter provider for API memory: meta-reference > Enter provider for API memory: inline::faiss
> Enter provider for API datasetio: meta-reference > Enter provider for API datasetio: inline::meta-reference
> Enter provider for API scoring: meta-reference > Enter provider for API scoring: inline::meta-reference
> Enter provider for API eval: meta-reference > Enter provider for API eval: inline::meta-reference
> Enter provider for API telemetry: meta-reference > Enter provider for API telemetry: inline::meta-reference
> (Optional) Enter a short description for your Llama Stack: > (Optional) Enter a short description for your Llama Stack:
@ -203,8 +203,8 @@ distribution_spec:
description: Like local, but use ollama for running LLM inference description: Like local, but use ollama for running LLM inference
providers: providers:
inference: remote::ollama inference: remote::ollama
memory: meta-reference memory: inline::faiss
safety: meta-reference safety: inline::llama-guard
agents: meta-reference agents: meta-reference
telemetry: meta-reference telemetry: meta-reference
image_type: conda image_type: conda

View file

@ -1,64 +0,0 @@
# Fireworks Distribution
The `llamastack/distribution-fireworks` distribution consists of the following provider configurations.
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|----------------- |--------------- |---------------- |-------------------------------------------------- |---------------- |---------------- |
| **Provider(s)** | remote::fireworks | meta-reference | meta-reference | meta-reference | meta-reference |
### Step 0. Prerequisite
- Make sure you have access to a fireworks API Key. You can get one by visiting [fireworks.ai](https://fireworks.ai/)
### Step 1. Start the Distribution (Single Node CPU)
#### (Option 1) Start Distribution Via Docker
> [!NOTE]
> This assumes you have an hosted endpoint at Fireworks with API Key.
```
$ cd distributions/fireworks && docker compose up
```
Make sure in you `run.yaml` file, you inference provider is pointing to the correct Fireworks URL server endpoint. E.g.
```
inference:
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference
api_key: <optional api key>
```
#### (Option 2) Start Distribution Via Conda
```bash
llama stack build --template fireworks --image-type conda
# -- modify run.yaml to a valid Fireworks server endpoint
llama stack run ./run.yaml
```
### (Optional) Model Serving
Use `llama-stack-client models list` to check the available models served by Fireworks.
```
$ llama-stack-client models list
+------------------------------+------------------------------+---------------+------------+
| identifier | llama_model | provider_id | metadata |
+==============================+==============================+===============+============+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-1B-Instruct | Llama3.2-1B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
```

View file

@ -1,15 +1,42 @@
# Remote-Hosted Distribution # Remote-Hosted Distribution
Remote Hosted distributions are distributions connecting to remote hosted services through Llama Stack server. Inference is done through remote providers. These are useful if you have an API key for a remote inference provider like Fireworks, Together, etc. Remote-Hosted distributions are available endpoints serving Llama Stack API that you can directly connect to.
| **Distribution** | **Llama Stack Docker** | Start This Distribution | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | | Distribution | Endpoint | Inference | Agents | Memory | Safety | Telemetry |
|:----------------: |:------------------------------------------: |:-----------------------: |:------------------: |:------------------: |:------------------: |:------------------: |:------------------: | |-------------|----------|-----------|---------|---------|---------|------------|
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/together.html) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference | | Together | [https://llama-stack.together.ai](https://llama-stack.together.ai) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/remote_hosted_distro/fireworks.html) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference | | Fireworks | [https://llamastack-preview.fireworks.ai](https://llamastack-preview.fireworks.ai) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference |
```{toctree} ## Connecting to Remote-Hosted Distributions
:maxdepth: 1
fireworks You can use `llama-stack-client` to interact with these endpoints. For example, to list the available models served by the Fireworks endpoint:
together
```bash
$ pip install llama-stack-client
$ llama-stack-client configure --endpoint https://llamastack-preview.fireworks.ai
$ llama-stack-client models list
``` ```
You will see outputs:
```
$ llama-stack-client models list
+------------------------------+------------------------------+---------------+------------+
| identifier | llama_model | provider_id | metadata |
+==============================+==============================+===============+============+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-1B-Instruct | Llama3.2-1B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
```
Checkout the [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python/blob/main/docs/cli_reference.md) repo for more details on how to use the `llama-stack-client` CLI. Checkout [llama-stack-app](https://github.com/meta-llama/llama-stack-apps/tree/main) for examples applications built on top of Llama Stack.

View file

@ -1,62 +0,0 @@
# Together Distribution
### Connect to a Llama Stack Together Endpoint
- You may connect to a hosted endpoint `https://llama-stack.together.ai`, serving a Llama Stack distribution
The `llamastack/distribution-together` distribution consists of the following provider configurations.
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** |
|----------------- |--------------- |---------------- |-------------------------------------------------- |---------------- |---------------- |
| **Provider(s)** | remote::together | meta-reference | meta-reference, remote::weaviate | meta-reference | meta-reference |
### Docker: Start the Distribution (Single Node CPU)
> [!NOTE]
> This assumes you have an hosted endpoint at Together with API Key.
```
$ cd distributions/together && docker compose up
```
Make sure in your `run.yaml` file, your inference provider is pointing to the correct Together URL server endpoint. E.g.
```
inference:
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: <optional api key>
```
### Conda llama stack run (Single Node CPU)
```bash
llama stack build --template together --image-type conda
# -- modify run.yaml to a valid Together server endpoint
llama stack run ./run.yaml
```
### (Optional) Update Model Serving Configuration
Use `llama-stack-client models list` to check the available models served by together.
```
$ llama-stack-client models list
+------------------------------+------------------------------+---------------+------------+
| identifier | llama_model | provider_id | metadata |
+==============================+==============================+===============+============+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
```

View file

@ -0,0 +1,68 @@
# Fireworks Distribution
The `llamastack/distribution-fireworks` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `remote::fireworks` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `FIREWORKS_API_KEY`: Fireworks.AI API Key (default: ``)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct (fireworks/llama-v3p1-8b-instruct)`
- `meta-llama/Llama-3.1-70B-Instruct (fireworks/llama-v3p1-70b-instruct)`
- `meta-llama/Llama-3.1-405B-Instruct-FP8 (fireworks/llama-v3p1-405b-instruct)`
- `meta-llama/Llama-3.2-1B-Instruct (fireworks/llama-v3p2-1b-instruct)`
- `meta-llama/Llama-3.2-3B-Instruct (fireworks/llama-v3p2-3b-instruct)`
- `meta-llama/Llama-3.2-11B-Vision-Instruct (fireworks/llama-v3p2-11b-vision-instruct)`
- `meta-llama/Llama-3.2-90B-Vision-Instruct (fireworks/llama-v3p2-90b-vision-instruct)`
- `meta-llama/Llama-Guard-3-8B (fireworks/llama-guard-3-8b)`
- `meta-llama/Llama-Guard-3-11B-Vision (fireworks/llama-guard-3-11b-vision)`
### Prerequisite: API Keys
Make sure you have access to a Fireworks API Key. You can get one by visiting [fireworks.ai](https://fireworks.ai/).
## Running Llama Stack with Fireworks
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-fireworks \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env FIREWORKS_API_KEY=$FIREWORKS_API_KEY
```
### Via Conda
```bash
llama stack build --template fireworks --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--env FIREWORKS_API_KEY=$FIREWORKS_API_KEY
```

View file

@ -8,6 +8,10 @@ We offer deployable distributions where you can host your own Llama Stack server
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference | | Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | meta-reference-quantized | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | meta-reference | meta-reference | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html) | remote::ollama | meta-reference | remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference | | TGI | [llamastack/distribution-tgi](https://hub.docker.com/repository/docker/llamastack/distribution-tgi/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html) | remote::tgi | meta-reference | meta-reference; remote::pgvector; remote::chromadb | meta-reference | meta-reference |
| Together | [llamastack/distribution-together](https://hub.docker.com/repository/docker/llamastack/distribution-together/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/together.html) | remote::together | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Fireworks | [llamastack/distribution-fireworks](https://hub.docker.com/repository/docker/llamastack/distribution-fireworks/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/fireworks.html) | remote::fireworks | meta-reference | remote::weaviate | meta-reference | meta-reference |
| Bedrock | [llamastack/distribution-bedrock](https://hub.docker.com/repository/docker/llamastack/distribution-bedrock/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/bedrock.html) | remote::bedrock | meta-reference | remote::weaviate | meta-reference | meta-reference |
```{toctree} ```{toctree}
:maxdepth: 1 :maxdepth: 1
@ -17,4 +21,8 @@ meta-reference-quantized-gpu
ollama ollama
tgi tgi
dell-tgi dell-tgi
together
fireworks
remote-vllm
bedrock
``` ```

View file

@ -1,15 +1,32 @@
# Meta Reference Distribution # Meta Reference Distribution
The `llamastack/distribution-meta-reference-gpu` distribution consists of the following provider configurations. The `llamastack/distribution-meta-reference-gpu` distribution consists of the following provider configurations:
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `inline::meta-reference` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|----------------- |--------------- |---------------- |-------------------------------------------------- |---------------- |---------------- |
| **Provider(s)** | meta-reference | meta-reference | meta-reference, remote::pgvector, remote::chroma | meta-reference | meta-reference | ### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
- `SAFETY_CHECKPOINT_DIR`: Directory containing the Llama-Guard model checkpoint (default: `null`)
### Step 0. Prerequisite - Downloading Models ## Prerequisite: Downloading Models
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/cli_reference/download_models.html) here to download the models.
Please make sure you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
``` ```
$ ls ~/.llama/checkpoints $ ls ~/.llama/checkpoints
@ -17,55 +34,56 @@ Llama3.1-8B Llama3.2-11B-Vision-Instruct Llama3.2-1B-Instruct Llama3
Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-Guard-3-1B Prompt-Guard-86M
``` ```
### Step 1. Start the Distribution ## Running the Distribution
#### (Option 1) Start with Docker You can do this via Conda (build code) or Docker which has a pre-built image.
```
$ cd distributions/meta-reference-gpu && docker compose up ### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-meta-reference-gpu \
/root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
``` ```
> [!NOTE] If you are using Llama Stack Safety / Shield APIs, use:
> This assumes you have access to GPU to start a local server with access to your GPU.
```bash
> [!NOTE] docker run \
> `~/.llama` should be the path containing downloaded weights of Llama models. -it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run-with-safety.yaml:/root/my-run.yaml \
This will download and start running a pre-built docker container. Alternatively, you may use the following commands: llamastack/distribution-meta-reference-gpu \
/root/my-run.yaml \
``` --port $LLAMA_STACK_PORT \
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama -v ./run.yaml:/root/my-run.yaml --gpus=all distribution-meta-reference-gpu --yaml_config /root/my-run.yaml --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
``` ```
#### (Option 2) Start with Conda ### Via Conda
1. Install the `llama` CLI. See [CLI Reference](https://llama-stack.readthedocs.io/en/latest/cli_reference/index.html) Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available.
2. Build the `meta-reference-gpu` distribution ```bash
llama stack build --template meta-reference-gpu --image-type conda
``` llama stack run ./run.yaml \
$ llama stack build --template meta-reference-gpu --image-type conda --port 5001 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
``` ```
3. Start running distribution If you are using Llama Stack Safety / Shield APIs, use:
```
$ cd distributions/meta-reference-gpu
$ llama stack run ./run.yaml
```
### (Optional) Serving a new model ```bash
You may change the `config.model` in `run.yaml` to update the model currently being served by the distribution. Make sure you have the model checkpoint downloaded in your `~/.llama`. llama stack run ./run-with-safety.yaml \
--port 5001 \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
``` ```
inference:
- provider_id: meta0
provider_type: meta-reference
config:
model: Llama3.2-11B-Vision-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
```
Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.

View file

@ -2,111 +2,120 @@
The `llamastack/distribution-ollama` distribution consists of the following provider configurations. The `llamastack/distribution-ollama` distribution consists of the following provider configurations.
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | | API | Provider(s) |
|----------------- |---------------- |---------------- |---------------------------------- |---------------- |---------------- | |-----|-------------|
| **Provider(s)** | remote::ollama | meta-reference | remote::pgvector, remote::chroma | remote::ollama | meta-reference | | agents | `inline::meta-reference` |
| inference | `remote::ollama` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
### Docker: Start a Distribution (Single Node GPU) You should use this distribution if you have a regular desktop machine without very powerful GPUs. Of course, if you have powerful GPUs, you can still continue using this distribution since Ollama supports GPU acceleration.### Environment Variables
> [!NOTE] The following environment variables can be configured:
> This assumes you have access to GPU to start a Ollama server with access to your GPU.
``` - `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
$ cd distributions/ollama/gpu - `OLLAMA_URL`: URL of the Ollama server (default: `http://127.0.0.1:11434`)
$ ls - `INFERENCE_MODEL`: Inference model loaded into the Ollama server (default: `meta-llama/Llama-3.2-3B-Instruct`)
compose.yaml run.yaml - `SAFETY_MODEL`: Safety model loaded into the Ollama server (default: `meta-llama/Llama-Guard-3-1B`)
$ docker compose up
## Setting up Ollama server
Please check the [Ollama Documentation](https://github.com/ollama/ollama) on how to install and run Ollama. After installing Ollama, you need to run `ollama serve` to start the server.
In order to load models, you can run:
```bash
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
# ollama names this model differently, and we must use the ollama name when loading the model
export OLLAMA_INFERENCE_MODEL="llama3.2:3b-instruct-fp16"
ollama run $OLLAMA_INFERENCE_MODEL --keepalive 60m
``` ```
You will see outputs similar to following --- If you are using Llama Stack Safety / Shield APIs, you will also need to pull and run the safety model.
```
[ollama] | [GIN] 2024/10/18 - 21:19:41 | 200 | 226.841µs | ::1 | GET "/api/ps" ```bash
[ollama] | [GIN] 2024/10/18 - 21:19:42 | 200 | 60.908µs | ::1 | GET "/api/ps" export SAFETY_MODEL="meta-llama/Llama-Guard-3-1B"
INFO: Started server process [1]
INFO: Waiting for application startup. # ollama names this model differently, and we must use the ollama name when loading the model
INFO: Application startup complete. export OLLAMA_SAFETY_MODEL="llama-guard3:1b"
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit) ollama run $OLLAMA_SAFETY_MODEL --keepalive 60m
[llamastack] | Resolved 12 providers
[llamastack] | inner-inference => ollama0
[llamastack] | models => __routing_table__
[llamastack] | inference => __autorouted__
``` ```
To kill the server ## Running Llama Stack
```
docker compose down Now you are ready to run Llama Stack with Ollama as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
export LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-ollama \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://host.docker.internal:11434
``` ```
### Docker: Start the Distribution (Single Node CPU) If you are using Llama Stack Safety / Shield APIs, use:
> [!NOTE] ```bash
> This will start an ollama server with CPU only, please see [Ollama Documentations](https://github.com/ollama/ollama) for serving models on CPU only. docker run \
-it \
``` -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
$ cd distributions/ollama/cpu -v ~/.llama:/root/.llama \
$ ls -v ./run-with-safety.yaml:/root/my-run.yaml \
compose.yaml run.yaml llamastack/distribution-ollama \
$ docker compose up --yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env SAFETY_MODEL=$SAFETY_MODEL \
--env OLLAMA_URL=http://host.docker.internal:11434
``` ```
### Conda: ollama run + llama stack run ### Via Conda
If you wish to separately spin up a Ollama server, and connect with Llama Stack, you may use the following commands. Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available.
#### Start Ollama server. ```bash
- Please check the [Ollama Documentations](https://github.com/ollama/ollama) for more details. export LLAMA_STACK_PORT=5001
**Via Docker**
```
docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
```
**Via CLI**
```
ollama run <model_id>
```
#### Start Llama Stack server pointing to Ollama server
**Via Conda**
```
llama stack build --template ollama --image-type conda llama stack build --template ollama --image-type conda
llama stack run ./gpu/run.yaml llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env OLLAMA_URL=http://localhost:11434
``` ```
**Via Docker** If you are using Llama Stack Safety / Shield APIs, use:
```
docker run --network host -it -p 5000:5000 -v ~/.llama:/root/.llama -v ./gpu/run.yaml:/root/llamastack-run-ollama.yaml --gpus=all llamastack/distribution-ollama --yaml_config /root/llamastack-run-ollama.yaml ```bash
llama stack run ./run-with-safety.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env SAFETY_MODEL=$SAFETY_MODEL \
--env OLLAMA_URL=http://localhost:11434
``` ```
Make sure in your `run.yaml` file, your inference provider is pointing to the correct Ollama endpoint. E.g.
```
inference:
- provider_id: ollama0
provider_type: remote::ollama
config:
url: http://127.0.0.1:14343
```
### (Optional) Update Model Serving Configuration ### (Optional) Update Model Serving Configuration
#### Downloading model via Ollama
You can use ollama for managing model downloads.
```
ollama pull llama3.1:8b-instruct-fp16
ollama pull llama3.1:70b-instruct-fp16
```
> [!NOTE] > [!NOTE]
> Please check the [OLLAMA_SUPPORTED_MODELS](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers.remote/inference/ollama/ollama.py) for the supported Ollama models. > Please check the [OLLAMA_SUPPORTED_MODELS](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers.remote/inference/ollama/ollama.py) for the supported Ollama models.
To serve a new model with `ollama` To serve a new model with `ollama`
``` ```bash
ollama run <model_name> ollama run <model_name>
``` ```
@ -119,7 +128,7 @@ llama3.1:8b-instruct-fp16 4aacac419454 17 GB 100% GPU 4 minutes fro
``` ```
To verify that the model served by ollama is correctly connected to Llama Stack server To verify that the model served by ollama is correctly connected to Llama Stack server
``` ```bash
$ llama-stack-client models list $ llama-stack-client models list
+----------------------+----------------------+---------------+-----------------------------------------------+ +----------------------+----------------------+---------------+-----------------------------------------------+
| identifier | llama_model | provider_id | metadata | | identifier | llama_model | provider_id | metadata |

View file

@ -0,0 +1,144 @@
# Remote vLLM Distribution
The `llamastack/distribution-remote-vllm` distribution consists of the following provider configurations:
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `remote::vllm` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `INFERENCE_MODEL`: Inference model loaded into the vLLM server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `VLLM_URL`: URL of the vLLM server with the main inference model (default: `http://host.docker.internal:5100}/v1`)
- `MAX_TOKENS`: Maximum number of tokens for generation (default: `4096`)
- `SAFETY_VLLM_URL`: URL of the vLLM server with the safety model (default: `http://host.docker.internal:5101/v1`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
## Setting up vLLM server
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0
docker run \
--runtime nvidia \
--gpus $CUDA_VISIBLE_DEVICES \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
-p $INFERENCE_PORT:$INFERENCE_PORT \
--ipc=host \
vllm/vllm-openai:latest \
--gpu-memory-utilization 0.7 \
--model $INFERENCE_MODEL \
--port $INFERENCE_PORT
```
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```bash
export SAFETY_PORT=8081
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
export CUDA_VISIBLE_DEVICES=1
docker run \
--runtime nvidia \
--gpus $CUDA_VISIBLE_DEVICES \
-v ~/.cache/huggingface:/root/.cache/huggingface \
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
-p $SAFETY_PORT:$SAFETY_PORT \
--ipc=host \
vllm/vllm-openai:latest \
--gpu-memory-utilization 0.7 \
--model $SAFETY_MODEL \
--port $SAFETY_PORT
```
## Running Llama Stack
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-remote-vllm \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
export SAFETY_PORT=8081
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-remote-vllm \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env VLLM_URL=http://host.docker.internal:$INFERENCE_PORT/v1 \
--env SAFETY_MODEL=$SAFETY_MODEL \
--env SAFETY_VLLM_URL=http://host.docker.internal:$SAFETY_PORT/v1
```
### Via Conda
Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available.
```bash
export INFERENCE_PORT=8000
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export LLAMA_STACK_PORT=5001
cd distributions/remote-vllm
llama stack build --template remote-vllm --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env VLLM_URL=http://localhost:$INFERENCE_PORT/v1
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
export SAFETY_PORT=8081
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
llama stack run ./run-with-safety.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env VLLM_URL=http://localhost:$INFERENCE_PORT/v1 \
--env SAFETY_MODEL=$SAFETY_MODEL \
--env SAFETY_VLLM_URL=http://localhost:$SAFETY_PORT/v1
```

View file

@ -2,111 +2,125 @@
The `llamastack/distribution-tgi` distribution consists of the following provider configurations. The `llamastack/distribution-tgi` distribution consists of the following provider configurations.
| API | Provider(s) |
| **API** | **Inference** | **Agents** | **Memory** | **Safety** | **Telemetry** | |-----|-------------|
|----------------- |--------------- |---------------- |-------------------------------------------------- |---------------- |---------------- | | agents | `inline::meta-reference` |
| **Provider(s)** | remote::tgi | meta-reference | meta-reference, remote::pgvector, remote::chroma | meta-reference | meta-reference | | inference | `remote::tgi` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
### Docker: Start the Distribution (Single Node GPU) You can use this distribution if you have GPUs and want to run an independent TGI server container for running inference.
> [!NOTE] ### Environment Variables
> This assumes you have access to GPU to start a TGI server with access to your GPU.
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `INFERENCE_MODEL`: Inference model loaded into the TGI server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `TGI_URL`: URL of the TGI server with the main inference model (default: `http://127.0.0.1:8080}/v1`)
- `TGI_SAFETY_URL`: URL of the TGI server with the safety model (default: `http://127.0.0.1:8081/v1`)
- `SAFETY_MODEL`: Name of the safety (Llama-Guard) model to use (default: `meta-llama/Llama-Guard-3-1B`)
``` ## Setting up TGI server
$ cd distributions/tgi/gpu && docker compose up
Please check the [TGI Getting Started Guide](https://github.com/huggingface/text-generation-inference?tab=readme-ov-file#get-started) to get a TGI endpoint. Here is a sample script to start a TGI server locally via Docker:
```bash
export INFERENCE_PORT=8080
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
export CUDA_VISIBLE_DEVICES=0
docker run --rm -it \
-v $HOME/.cache/huggingface:/data \
-p $INFERENCE_PORT:$INFERENCE_PORT \
--gpus $CUDA_VISIBLE_DEVICES \
ghcr.io/huggingface/text-generation-inference:2.3.1 \
--dtype bfloat16 \
--usage-stats off \
--sharded false \
--cuda-memory-fraction 0.7 \
--model-id $INFERENCE_MODEL \
--port $INFERENCE_PORT
``` ```
The script will first start up TGI server, then start up Llama Stack distribution server hooking up to the remote TGI provider for inference. You should be able to see the following outputs -- If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a TGI with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
```
[text-generation-inference] | 2024-10-15T18:56:33.810397Z INFO text_generation_router::server: router/src/server.rs:1813: Using config Some(Llama) ```bash
[text-generation-inference] | 2024-10-15T18:56:33.810448Z WARN text_generation_router::server: router/src/server.rs:1960: Invalid hostname, defaulting to 0.0.0.0 export SAFETY_PORT=8081
[text-generation-inference] | 2024-10-15T18:56:33.864143Z INFO text_generation_router::server: router/src/server.rs:2353: Connected export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
INFO: Started server process [1] export CUDA_VISIBLE_DEVICES=1
INFO: Waiting for application startup.
INFO: Application startup complete. docker run --rm -it \
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit) -v $HOME/.cache/huggingface:/data \
-p $SAFETY_PORT:$SAFETY_PORT \
--gpus $CUDA_VISIBLE_DEVICES \
ghcr.io/huggingface/text-generation-inference:2.3.1 \
--dtype bfloat16 \
--usage-stats off \
--sharded false \
--model-id $SAFETY_MODEL \
--port $SAFETY_PORT
``` ```
To kill the server ## Running Llama Stack
```
docker compose down Now you are ready to run Llama Stack with TGI as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-tgi \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT
``` ```
### Docker: Start the Distribution (Single Node CPU) If you are using Llama Stack Safety / Shield APIs, use:
> [!NOTE] ```bash
> This assumes you have an hosted endpoint compatible with TGI server. docker run \
-it \
``` -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
$ cd distributions/tgi/cpu && docker compose up -v ./run-with-safety.yaml:/root/my-run.yaml \
llamastack/distribution-tgi \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=$INFERENCE_MODEL \
--env TGI_URL=http://host.docker.internal:$INFERENCE_PORT \
--env SAFETY_MODEL=$SAFETY_MODEL \
--env TGI_SAFETY_URL=http://host.docker.internal:$SAFETY_PORT
``` ```
Replace <ENTER_YOUR_TGI_HOSTED_ENDPOINT> in `run.yaml` file with your TGI endpoint. ### Via Conda
```
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: <ENTER_YOUR_TGI_HOSTED_ENDPOINT>
```
### Conda: TGI server + llama stack run Make sure you have done `pip install llama-stack` and have the Llama Stack CLI available.
If you wish to separately spin up a TGI server, and connect with Llama Stack, you may use the following commands.
#### Start TGI server locally
- Please check the [TGI Getting Started Guide](https://github.com/huggingface/text-generation-inference?tab=readme-ov-file#get-started) to get a TGI endpoint.
```
docker run --rm -it -v $HOME/.cache/huggingface:/data -p 5009:5009 --gpus all ghcr.io/huggingface/text-generation-inference:latest --dtype bfloat16 --usage-stats on --sharded false --model-id meta-llama/Llama-3.1-8B-Instruct --port 5009
```
#### Start Llama Stack server pointing to TGI server
**Via Conda**
```bash ```bash
llama stack build --template tgi --image-type conda llama stack build --template tgi --image-type conda
# -- start a TGI server endpoint llama stack run ./run.yaml
llama stack run ./gpu/run.yaml --port 5001
--env INFERENCE_MODEL=$INFERENCE_MODEL
--env TGI_URL=http://127.0.0.1:$INFERENCE_PORT
``` ```
**Via Docker** If you are using Llama Stack Safety / Shield APIs, use:
```
docker run --network host -it -p 5000:5000 -v ./run.yaml:/root/my-run.yaml --gpus=all llamastack/distribution-tgi --yaml_config /root/my-run.yaml
```
Make sure in you `run.yaml` file, you inference provider is pointing to the correct TGI server endpoint. E.g. ```bash
``` llama stack run ./run-with-safety.yaml
inference: --port 5001
- provider_id: tgi0 --env INFERENCE_MODEL=$INFERENCE_MODEL
provider_type: remote::tgi --env TGI_URL=http://127.0.0.1:$INFERENCE_PORT
config: --env SAFETY_MODEL=$SAFETY_MODEL
url: http://127.0.0.1:5009 --env TGI_SAFETY_URL=http://127.0.0.1:$SAFETY_PORT
```
### (Optional) Update Model Serving Configuration
To serve a new model with `tgi`, change the docker command flag `--model-id <model-to-serve>`.
This can be done by edit the `command` args in `compose.yaml`. E.g. Replace "Llama-3.2-1B-Instruct" with the model you want to serve.
```
command: ["--dtype", "bfloat16", "--usage-stats", "on", "--sharded", "false", "--model-id", "meta-llama/Llama-3.2-1B-Instruct", "--port", "5009", "--cuda-memory-fraction", "0.3"]
```
or by changing the docker run command's `--model-id` flag
```
docker run --rm -it -v $HOME/.cache/huggingface:/data -p 5009:5009 --gpus all ghcr.io/huggingface/text-generation-inference:latest --dtype bfloat16 --usage-stats on --sharded false --model-id meta-llama/Llama-3.2-1B-Instruct --port 5009
```
In `run.yaml`, make sure you point the correct server endpoint to the TGI server endpoint serving your model.
```
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009
``` ```

View file

@ -0,0 +1,67 @@
# Fireworks Distribution
The `llamastack/distribution-together` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| inference | `remote::together` |
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
| safety | `inline::llama-guard` |
| telemetry | `inline::meta-reference` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `TOGETHER_API_KEY`: Together.AI API Key (default: ``)
### Models
The following models are available by default:
- `meta-llama/Llama-3.1-8B-Instruct`
- `meta-llama/Llama-3.1-70B-Instruct`
- `meta-llama/Llama-3.1-405B-Instruct-FP8`
- `meta-llama/Llama-3.2-3B-Instruct`
- `meta-llama/Llama-3.2-11B-Vision-Instruct`
- `meta-llama/Llama-3.2-90B-Vision-Instruct`
- `meta-llama/Llama-Guard-3-8B`
- `meta-llama/Llama-Guard-3-11B-Vision`
### Prerequisite: API Keys
Make sure you have access to a Together API Key. You can get one by visiting [together.xyz](https://together.xyz/).
## Running Llama Stack with Together
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-together \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env TOGETHER_API_KEY=$TOGETHER_API_KEY
```
### Via Conda
```bash
llama stack build --template together --image-type conda
llama stack run ./run.yaml \
--port 5001 \
--env TOGETHER_API_KEY=$TOGETHER_API_KEY
```

View file

@ -53,9 +53,9 @@ Please see our pages in detail for the types of distributions we offer:
3. [On-device Distribution](./distributions/ondevice_distro/index.md): If you want to run Llama Stack inference on your iOS / Android device. 3. [On-device Distribution](./distributions/ondevice_distro/index.md): If you want to run Llama Stack inference on your iOS / Android device.
### Quick Start Commands ### Table of Contents
Once you have decided on the inference provider and distribution to use, use the following quick start commands to get started. Once you have decided on the inference provider and distribution to use, use the following guides to get started.
##### 1.0 Prerequisite ##### 1.0 Prerequisite
@ -80,6 +80,11 @@ Llama3.1-8B-Instruct Llama3.2-1B Llama3.2-3B-Instruct Llama-
::: :::
:::{tab-item} vLLM
##### System Requirements
Access to Single-Node GPU to start a vLLM server.
:::
:::{tab-item} tgi :::{tab-item} tgi
##### System Requirements ##### System Requirements
Access to Single-Node GPU to start a TGI server. Access to Single-Node GPU to start a TGI server.
@ -104,365 +109,33 @@ Access to Single-Node CPU with Fireworks hosted endpoint via API_KEY from [firew
##### 1.1. Start the distribution ##### 1.1. Start the distribution
**(Option 1) Via Docker**
::::{tab-set} ::::{tab-set}
:::{tab-item} meta-reference-gpu :::{tab-item} meta-reference-gpu
``` - [Start Meta Reference GPU Distribution](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/meta-reference-gpu.html)
$ cd llama-stack/distributions/meta-reference-gpu && docker compose up :::
```
This will download and start running a pre-built Docker container. Alternatively, you may use the following commands: :::{tab-item} vLLM
- [Start vLLM Distribution](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/remote-vllm.html)
```
docker run -it -p 5000:5000 -v ~/.llama:/root/.llama -v ./run.yaml:/root/my-run.yaml --gpus=all distribution-meta-reference-gpu --yaml_config /root/my-run.yaml
```
::: :::
:::{tab-item} tgi :::{tab-item} tgi
``` - [Start TGI Distribution](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/tgi.html)
$ cd llama-stack/distributions/tgi/gpu && docker compose up
```
The script will first start up TGI server, then start up Llama Stack distribution server hooking up to the remote TGI provider for inference. You should see the following outputs --
```
[text-generation-inference] | 2024-10-15T18:56:33.810397Z INFO text_generation_router::server: router/src/server.rs:1813: Using config Some(Llama)
[text-generation-inference] | 2024-10-15T18:56:33.810448Z WARN text_generation_router::server: router/src/server.rs:1960: Invalid hostname, defaulting to 0.0.0.0
[text-generation-inference] | 2024-10-15T18:56:33.864143Z INFO text_generation_router::server: router/src/server.rs:2353: Connected
INFO: Started server process [1]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
```
To kill the server
```
docker compose down
```
:::
:::{tab-item} ollama
```
$ cd llama-stack/distributions/ollama/cpu && docker compose up
```
You will see outputs similar to following ---
```
[ollama] | [GIN] 2024/10/18 - 21:19:41 | 200 | 226.841µs | ::1 | GET "/api/ps"
[ollama] | [GIN] 2024/10/18 - 21:19:42 | 200 | 60.908µs | ::1 | GET "/api/ps"
INFO: Started server process [1]
INFO: Waiting for application startup.
INFO: Application startup complete.
INFO: Uvicorn running on http://[::]:5000 (Press CTRL+C to quit)
[llamastack] | Resolved 12 providers
[llamastack] | inner-inference => ollama0
[llamastack] | models => __routing_table__
[llamastack] | inference => __autorouted__
```
To kill the server
```
docker compose down
```
:::
:::{tab-item} fireworks
```
$ cd llama-stack/distributions/fireworks && docker compose up
```
Make sure your `run.yaml` file has the inference provider pointing to the correct Fireworks URL server endpoint. E.g.
```
inference:
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference
api_key: <optional api key>
```
:::
:::{tab-item} together
```
$ cd distributions/together && docker compose up
```
Make sure your `run.yaml` file has the inference provider pointing to the correct Together URL server endpoint. E.g.
```
inference:
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: <optional api key>
```
:::
::::
**(Option 2) Via Conda**
::::{tab-set}
:::{tab-item} meta-reference-gpu
1. Install the `llama` CLI. See [CLI Reference](https://llama-stack.readthedocs.io/en/latest/cli_reference/index.html)
2. Build the `meta-reference-gpu` distribution
```
$ llama stack build --template meta-reference-gpu --image-type conda
```
3. Start running distribution
```
$ cd llama-stack/distributions/meta-reference-gpu
$ llama stack run ./run.yaml
```
:::
:::{tab-item} tgi
1. Install the `llama` CLI. See [CLI Reference](https://llama-stack.readthedocs.io/en/latest/cli_reference/index.html)
2. Build the `tgi` distribution
```bash
llama stack build --template tgi --image-type conda
```
3. Start a TGI server endpoint
4. Make sure in your `run.yaml` file, your `conda_env` is pointing to the conda environment and inference provider is pointing to the correct TGI server endpoint. E.g.
```
conda_env: llamastack-tgi
...
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009
```
5. Start Llama Stack server
```bash
llama stack run ./gpu/run.yaml
```
::: :::
:::{tab-item} ollama :::{tab-item} ollama
- [Start Ollama Distribution](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/ollama.html)
If you wish to separately spin up a Ollama server, and connect with Llama Stack, you may use the following commands.
#### Start Ollama server.
- Please check the [Ollama Documentations](https://github.com/ollama/ollama) for more details.
**Via Docker**
```
docker run -d -v ollama:/root/.ollama -p 11434:11434 --name ollama ollama/ollama
```
**Via CLI**
```
ollama run <model_id>
```
#### Start Llama Stack server pointing to Ollama server
Make sure your `run.yaml` file has the inference provider pointing to the correct Ollama endpoint. E.g.
```
conda_env: llamastack-ollama
...
inference:
- provider_id: ollama0
provider_type: remote::ollama
config:
url: http://127.0.0.1:11434
```
```
llama stack build --template ollama --image-type conda
llama stack run ./gpu/run.yaml
```
:::
:::{tab-item} fireworks
```bash
llama stack build --template fireworks --image-type conda
# -- modify run.yaml to a valid Fireworks server endpoint
llama stack run ./run.yaml
```
Make sure your `run.yaml` file has the inference provider pointing to the correct Fireworks URL server endpoint. E.g.
```
conda_env: llamastack-fireworks
...
inference:
- provider_id: fireworks
provider_type: remote::fireworks
config:
url: https://api.fireworks.ai/inference
api_key: <optional api key>
```
::: :::
:::{tab-item} together :::{tab-item} together
- [Start Together Distribution](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/together.html)
```bash
llama stack build --template together --image-type conda
# -- modify run.yaml to a valid Together server endpoint
llama stack run ./run.yaml
```
Make sure your `run.yaml` file has the inference provider pointing to the correct Together URL server endpoint. E.g.
```
conda_env: llamastack-together
...
inference:
- provider_id: together
provider_type: remote::together
config:
url: https://api.together.xyz/v1
api_key: <optional api key>
```
:::
::::
##### 1.2 (Optional) Update Model Serving Configuration
::::{tab-set}
:::{tab-item} meta-reference-gpu
You may change the `config.model` in `run.yaml` to update the model currently being served by the distribution. Make sure you have the model checkpoint downloaded in your `~/.llama`.
```
inference:
- provider_id: meta0
provider_type: meta-reference
config:
model: Llama3.2-11B-Vision-Instruct
quantization: null
torch_seed: null
max_seq_len: 4096
max_batch_size: 1
```
Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
:::
:::{tab-item} tgi
To serve a new model with `tgi`, change the docker command flag `--model-id <model-to-serve>`.
This can be done by edit the `command` args in `compose.yaml`. E.g. Replace "Llama-3.2-1B-Instruct" with the model you want to serve.
```
command: ["--dtype", "bfloat16", "--usage-stats", "on", "--sharded", "false", "--model-id", "meta-llama/Llama-3.2-1B-Instruct", "--port", "5009", "--cuda-memory-fraction", "0.3"]
```
or by changing the docker run command's `--model-id` flag
```
docker run --rm -it -v $HOME/.cache/huggingface:/data -p 5009:5009 --gpus all ghcr.io/huggingface/text-generation-inference:latest --dtype bfloat16 --usage-stats on --sharded false --model-id meta-llama/Llama-3.2-1B-Instruct --port 5009
```
Make sure your `run.yaml` file has the inference provider pointing to the TGI server endpoint serving your model.
```
inference:
- provider_id: tgi0
provider_type: remote::tgi
config:
url: http://127.0.0.1:5009
```
```
Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
:::
:::{tab-item} ollama
You can use ollama for managing model downloads.
```
ollama pull llama3.1:8b-instruct-fp16
ollama pull llama3.1:70b-instruct-fp16
```
> Please check the [OLLAMA_SUPPORTED_MODELS](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers.remote/inference/ollama/ollama.py) for the supported Ollama models.
To serve a new model with `ollama`
```
ollama run <model_name>
```
To make sure that the model is being served correctly, run `ollama ps` to get a list of models being served by ollama.
```
$ ollama ps
NAME ID SIZE PROCESSOR UNTIL
llama3.1:8b-instruct-fp16 4aacac419454 17 GB 100% GPU 4 minutes from now
```
To verify that the model served by ollama is correctly connected to Llama Stack server
```
$ llama-stack-client models list
+----------------------+----------------------+---------------+-----------------------------------------------+
| identifier | llama_model | provider_id | metadata |
+======================+======================+===============+===============================================+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | ollama0 | {'ollama_model': 'llama3.1:8b-instruct-fp16'} |
+----------------------+----------------------+---------------+-----------------------------------------------+
```
:::
:::{tab-item} together
Use `llama-stack-client models list` to check the available models served by together.
```
$ llama-stack-client models list
+------------------------------+------------------------------+---------------+------------+
| identifier | llama_model | provider_id | metadata |
+==============================+==============================+===============+============+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | together0 | {} |
+------------------------------+------------------------------+---------------+------------+
```
::: :::
:::{tab-item} fireworks :::{tab-item} fireworks
Use `llama-stack-client models list` to check the available models served by Fireworks. - [Start Fireworks Distribution](https://llama-stack.readthedocs.io/en/latest/getting_started/distributions/self_hosted_distro/fireworks.html)
```
$ llama-stack-client models list
+------------------------------+------------------------------+---------------+------------+
| identifier | llama_model | provider_id | metadata |
+==============================+==============================+===============+============+
| Llama3.1-8B-Instruct | Llama3.1-8B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-70B-Instruct | Llama3.1-70B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.1-405B-Instruct | Llama3.1-405B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-1B-Instruct | Llama3.2-1B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-3B-Instruct | Llama3.2-3B-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-11B-Vision-Instruct | Llama3.2-11B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
| Llama3.2-90B-Vision-Instruct | Llama3.2-90B-Vision-Instruct | fireworks0 | {} |
+------------------------------+------------------------------+---------------+------------+
```
::: :::
:::: ::::
##### Troubleshooting ##### Troubleshooting
- If you encounter any issues, search through our [GitHub Issues](https://github.com/meta-llama/llama-stack/issues), or file an new issue. - If you encounter any issues, search through our [GitHub Issues](https://github.com/meta-llama/llama-stack/issues), or file an new issue.
- Use `--port <PORT>` flag to use a different port number. For docker run, update the `-p <PORT>:<PORT>` flag. - Use `--port <PORT>` flag to use a different port number. For docker run, update the `-p <PORT>:<PORT>` flag.
@ -474,10 +147,10 @@ $ llama-stack-client models list
Once the server is set up, we can test it with a client to verify it's working correctly. The following command will send a chat completion request to the server's `/inference/chat_completion` API: Once the server is set up, we can test it with a client to verify it's working correctly. The following command will send a chat completion request to the server's `/inference/chat_completion` API:
```bash ```bash
$ curl http://localhost:5000/inference/chat_completion \ $ curl http://localhost:5000/alpha/inference/chat-completion \
-H "Content-Type: application/json" \ -H "Content-Type: application/json" \
-d '{ -d '{
"model": "Llama3.1-8B-Instruct", "model_id": "meta-llama/Llama-3.1-8B-Instruct",
"messages": [ "messages": [
{"role": "system", "content": "You are a helpful assistant."}, {"role": "system", "content": "You are a helpful assistant."},
{"role": "user", "content": "Write me a 2 sentence poem about the moon"} {"role": "user", "content": "Write me a 2 sentence poem about the moon"}

View file

@ -74,7 +74,7 @@ A Distribution is where APIs and Providers are assembled together to provide a c
| Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/) | Python | [llama-stack-client-python](https://github.com/meta-llama/llama-stack-client-python) | [![PyPI version](https://img.shields.io/pypi/v/llama_stack_client.svg)](https://pypi.org/project/llama_stack_client/)
| Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift) | Swift | [llama-stack-client-swift](https://github.com/meta-llama/llama-stack-client-swift) | [![Swift Package Index](https://img.shields.io/endpoint?url=https%3A%2F%2Fswiftpackageindex.com%2Fapi%2Fpackages%2Fmeta-llama%2Fllama-stack-client-swift%2Fbadge%3Ftype%3Dswift-versions)](https://swiftpackageindex.com/meta-llama/llama-stack-client-swift)
| Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client) | Node | [llama-stack-client-node](https://github.com/meta-llama/llama-stack-client-node) | [![NPM version](https://img.shields.io/npm/v/llama-stack-client.svg)](https://npmjs.org/package/llama-stack-client)
| Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | | Kotlin | [llama-stack-client-kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) | [![Maven version](https://img.shields.io/maven-central/v/com.llama.llamastack/llama-stack-client-kotlin)](https://central.sonatype.com/artifact/com.llama.llamastack/llama-stack-client-kotlin)
Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications. Check out our client SDKs for connecting to Llama Stack server in your preferred language, you can choose from [python](https://github.com/meta-llama/llama-stack-client-python), [node](https://github.com/meta-llama/llama-stack-client-node), [swift](https://github.com/meta-llama/llama-stack-client-swift), and [kotlin](https://github.com/meta-llama/llama-stack-client-kotlin) programming languages to quickly build your applications.

View file

@ -54,6 +54,7 @@ class ToolDefinitionCommon(BaseModel):
class SearchEngineType(Enum): class SearchEngineType(Enum):
bing = "bing" bing = "bing"
brave = "brave" brave = "brave"
tavily = "tavily"
@json_schema_type @json_schema_type
@ -271,7 +272,7 @@ class Session(BaseModel):
turns: List[Turn] turns: List[Turn]
started_at: datetime started_at: datetime
memory_bank: Optional[MemoryBankDef] = None memory_bank: Optional[MemoryBank] = None
class AgentConfigCommon(BaseModel): class AgentConfigCommon(BaseModel):

View file

@ -49,7 +49,7 @@ class BatchChatCompletionResponse(BaseModel):
@runtime_checkable @runtime_checkable
class BatchInference(Protocol): class BatchInference(Protocol):
@webmethod(route="/batch_inference/completion") @webmethod(route="/batch-inference/completion")
async def batch_completion( async def batch_completion(
self, self,
model: str, model: str,
@ -58,7 +58,7 @@ class BatchInference(Protocol):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> BatchCompletionResponse: ... ) -> BatchCompletionResponse: ...
@webmethod(route="/batch_inference/chat_completion") @webmethod(route="/batch-inference/chat-completion")
async def batch_chat_completion( async def batch_chat_completion(
self, self,
model: str, model: str,

View file

@ -21,7 +21,7 @@ class PaginatedRowsResult(BaseModel):
class DatasetStore(Protocol): class DatasetStore(Protocol):
def get_dataset(self, identifier: str) -> DatasetDefWithProvider: ... def get_dataset(self, dataset_id: str) -> Dataset: ...
@runtime_checkable @runtime_checkable
@ -29,7 +29,7 @@ class DatasetIO(Protocol):
# keeping for aligning with inference/safety, but this is not used # keeping for aligning with inference/safety, but this is not used
dataset_store: DatasetStore dataset_store: DatasetStore
@webmethod(route="/datasetio/get_rows_paginated", method="GET") @webmethod(route="/datasetio/get-rows-paginated", method="GET")
async def get_rows_paginated( async def get_rows_paginated(
self, self,
dataset_id: str, dataset_id: str,

View file

@ -13,16 +13,11 @@ from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type class CommonDatasetFields(BaseModel):
class DatasetDef(BaseModel): dataset_schema: Dict[str, ParamType]
identifier: str = Field(
description="A unique name for the dataset",
)
dataset_schema: Dict[str, ParamType] = Field(
description="The schema definition for this dataset",
)
url: URL url: URL
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
@ -31,25 +26,41 @@ class DatasetDef(BaseModel):
@json_schema_type @json_schema_type
class DatasetDefWithProvider(DatasetDef): class Dataset(CommonDatasetFields, Resource):
type: Literal["dataset"] = "dataset" type: Literal[ResourceType.dataset.value] = ResourceType.dataset.value
provider_id: str = Field(
description="ID of the provider which serves this dataset", @property
) def dataset_id(self) -> str:
return self.identifier
@property
def provider_dataset_id(self) -> str:
return self.provider_resource_id
class DatasetInput(CommonDatasetFields, BaseModel):
dataset_id: str
provider_id: Optional[str] = None
provider_dataset_id: Optional[str] = None
class Datasets(Protocol): class Datasets(Protocol):
@webmethod(route="/datasets/register", method="POST") @webmethod(route="/datasets/register", method="POST")
async def register_dataset( async def register_dataset(
self, self,
dataset_def: DatasetDefWithProvider, dataset_id: str,
dataset_schema: Dict[str, ParamType],
url: URL,
provider_dataset_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ... ) -> None: ...
@webmethod(route="/datasets/get", method="GET") @webmethod(route="/datasets/get", method="GET")
async def get_dataset( async def get_dataset(
self, self,
dataset_identifier: str, dataset_id: str,
) -> Optional[DatasetDefWithProvider]: ... ) -> Optional[Dataset]: ...
@webmethod(route="/datasets/list", method="GET") @webmethod(route="/datasets/list", method="GET")
async def list_datasets(self) -> List[DatasetDefWithProvider]: ... async def list_datasets(self) -> List[Dataset]: ...

View file

@ -14,6 +14,7 @@ from llama_stack.apis.scoring_functions import * # noqa: F403
from llama_stack.apis.agents import AgentConfig from llama_stack.apis.agents import AgentConfig
from llama_stack.apis.common.job_types import Job, JobStatus from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.scoring import * # noqa: F403 from llama_stack.apis.scoring import * # noqa: F403
from llama_stack.apis.eval_tasks import * # noqa: F403
@json_schema_type @json_schema_type
@ -35,36 +36,65 @@ EvalCandidate = Annotated[
] ]
@json_schema_type
class BenchmarkEvalTaskConfig(BaseModel):
type: Literal["benchmark"] = "benchmark"
eval_candidate: EvalCandidate
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
@json_schema_type
class AppEvalTaskConfig(BaseModel):
type: Literal["app"] = "app"
eval_candidate: EvalCandidate
scoring_params: Dict[str, ScoringFnParams] = Field(
description="Map between scoring function id and parameters for each scoring function you want to run",
default_factory=dict,
)
num_examples: Optional[int] = Field(
description="Number of examples to evaluate (useful for testing), if not provided, all examples in the dataset will be evaluated",
default=None,
)
# we could optinally add any specific dataset config here
EvalTaskConfig = Annotated[
Union[BenchmarkEvalTaskConfig, AppEvalTaskConfig], Field(discriminator="type")
]
@json_schema_type @json_schema_type
class EvaluateResponse(BaseModel): class EvaluateResponse(BaseModel):
generations: List[Dict[str, Any]] generations: List[Dict[str, Any]]
# each key in the dict is a scoring function name # each key in the dict is a scoring function name
scores: Dict[str, ScoringResult] scores: Dict[str, ScoringResult]
class Eval(Protocol): class Eval(Protocol):
@webmethod(route="/eval/evaluate_batch", method="POST") @webmethod(route="/eval/run-eval", method="POST")
async def evaluate_batch( async def run_eval(
self, self,
dataset_id: str, task_id: str,
candidate: EvalCandidate, task_config: EvalTaskConfig,
scoring_functions: List[str],
) -> Job: ... ) -> Job: ...
@webmethod(route="/eval/evaluate", method="POST") @webmethod(route="/eval/evaluate-rows", method="POST")
async def evaluate( async def evaluate_rows(
self, self,
task_id: str,
input_rows: List[Dict[str, Any]], input_rows: List[Dict[str, Any]],
candidate: EvalCandidate,
scoring_functions: List[str], scoring_functions: List[str],
task_config: EvalTaskConfig,
) -> EvaluateResponse: ... ) -> EvaluateResponse: ...
@webmethod(route="/eval/job/status", method="GET") @webmethod(route="/eval/job/status", method="GET")
async def job_status(self, job_id: str) -> Optional[JobStatus]: ... async def job_status(self, task_id: str, job_id: str) -> Optional[JobStatus]: ...
@webmethod(route="/eval/job/cancel", method="POST") @webmethod(route="/eval/job/cancel", method="POST")
async def job_cancel(self, job_id: str) -> None: ... async def job_cancel(self, task_id: str, job_id: str) -> None: ...
@webmethod(route="/eval/job/result", method="GET") @webmethod(route="/eval/job/result", method="GET")
async def job_result(self, job_id: str) -> EvaluateResponse: ... async def job_result(self, task_id: str, job_id: str) -> EvaluateResponse: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .eval_tasks import * # noqa: F401 F403

View file

@ -0,0 +1,60 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field
from llama_stack.apis.resource import Resource, ResourceType
class CommonEvalTaskFields(BaseModel):
dataset_id: str
scoring_functions: List[str]
metadata: Dict[str, Any] = Field(
default_factory=dict,
description="Metadata for this evaluation task",
)
@json_schema_type
class EvalTask(CommonEvalTaskFields, Resource):
type: Literal[ResourceType.eval_task.value] = ResourceType.eval_task.value
@property
def eval_task_id(self) -> str:
return self.identifier
@property
def provider_eval_task_id(self) -> str:
return self.provider_resource_id
class EvalTaskInput(CommonEvalTaskFields, BaseModel):
eval_task_id: str
provider_id: Optional[str] = None
provider_eval_task_id: Optional[str] = None
@runtime_checkable
class EvalTasks(Protocol):
@webmethod(route="/eval-tasks/list", method="GET")
async def list_eval_tasks(self) -> List[EvalTask]: ...
@webmethod(route="/eval-tasks/get", method="GET")
async def get_eval_task(self, name: str) -> Optional[EvalTask]: ...
@webmethod(route="/eval-tasks/register", method="POST")
async def register_eval_task(
self,
eval_task_id: str,
dataset_id: str,
scoring_functions: List[str],
provider_eval_task_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> None: ...

View file

@ -216,7 +216,7 @@ class EmbeddingsResponse(BaseModel):
class ModelStore(Protocol): class ModelStore(Protocol):
def get_model(self, identifier: str) -> ModelDef: ... def get_model(self, identifier: str) -> Model: ...
@runtime_checkable @runtime_checkable
@ -226,7 +226,7 @@ class Inference(Protocol):
@webmethod(route="/inference/completion") @webmethod(route="/inference/completion")
async def completion( async def completion(
self, self,
model: str, model_id: str,
content: InterleavedTextMedia, content: InterleavedTextMedia,
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
response_format: Optional[ResponseFormat] = None, response_format: Optional[ResponseFormat] = None,
@ -234,10 +234,10 @@ class Inference(Protocol):
logprobs: Optional[LogProbConfig] = None, logprobs: Optional[LogProbConfig] = None,
) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ... ) -> Union[CompletionResponse, AsyncIterator[CompletionResponseStreamChunk]]: ...
@webmethod(route="/inference/chat_completion") @webmethod(route="/inference/chat-completion")
async def chat_completion( async def chat_completion(
self, self,
model: str, model_id: str,
messages: List[Message], messages: List[Message],
sampling_params: Optional[SamplingParams] = SamplingParams(), sampling_params: Optional[SamplingParams] = SamplingParams(),
# zero-shot tool definitions as input to the model # zero-shot tool definitions as input to the model
@ -254,6 +254,6 @@ class Inference(Protocol):
@webmethod(route="/inference/embeddings") @webmethod(route="/inference/embeddings")
async def embeddings( async def embeddings(
self, self,
model: str, model_id: str,
contents: List[InterleavedTextMedia], contents: List[InterleavedTextMedia],
) -> EmbeddingsResponse: ... ) -> EmbeddingsResponse: ...

View file

@ -75,14 +75,22 @@ class MemoryClient(Memory):
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):
banks_client = MemoryBanksClient(f"http://{host}:{port}") banks_client = MemoryBanksClient(f"http://{host}:{port}")
bank = VectorMemoryBankDef( bank = VectorMemoryBank(
identifier="test_bank", identifier="test_bank",
provider_id="", provider_id="",
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
) )
await banks_client.register_memory_bank(bank) await banks_client.register_memory_bank(
bank.identifier,
VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512,
overlap_size_in_tokens=64,
),
provider_resource_id=bank.identifier,
)
retrieved_bank = await banks_client.get_memory_bank(bank.identifier) retrieved_bank = await banks_client.get_memory_bank(bank.identifier)
assert retrieved_bank is not None assert retrieved_bank is not None

View file

@ -39,7 +39,7 @@ class QueryDocumentsResponse(BaseModel):
class MemoryBankStore(Protocol): class MemoryBankStore(Protocol):
def get_memory_bank(self, bank_id: str) -> Optional[MemoryBankDef]: ... def get_memory_bank(self, bank_id: str) -> Optional[MemoryBank]: ...
@runtime_checkable @runtime_checkable

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json
from typing import Any, Dict, List, Optional from typing import Any, Dict, List, Optional
@ -26,13 +25,13 @@ def deserialize_memory_bank_def(
raise ValueError("Memory bank type not specified") raise ValueError("Memory bank type not specified")
type = j["type"] type = j["type"]
if type == MemoryBankType.vector.value: if type == MemoryBankType.vector.value:
return VectorMemoryBankDef(**j) return VectorMemoryBank(**j)
elif type == MemoryBankType.keyvalue.value: elif type == MemoryBankType.keyvalue.value:
return KeyValueMemoryBankDef(**j) return KeyValueMemoryBank(**j)
elif type == MemoryBankType.keyword.value: elif type == MemoryBankType.keyword.value:
return KeywordMemoryBankDef(**j) return KeywordMemoryBank(**j)
elif type == MemoryBankType.graph.value: elif type == MemoryBankType.graph.value:
return GraphMemoryBankDef(**j) return GraphMemoryBank(**j)
else: else:
raise ValueError(f"Unknown memory bank type: {type}") raise ValueError(f"Unknown memory bank type: {type}")
@ -47,7 +46,7 @@ class MemoryBanksClient(MemoryBanks):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: async def list_memory_banks(self) -> List[MemoryBank]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/memory_banks/list", f"{self.base_url}/memory_banks/list",
@ -57,13 +56,20 @@ class MemoryBanksClient(MemoryBanks):
return [deserialize_memory_bank_def(x) for x in response.json()] return [deserialize_memory_bank_def(x) for x in response.json()]
async def register_memory_bank( async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider self,
memory_bank_id: str,
params: BankParams,
provider_resource_id: Optional[str] = None,
provider_id: Optional[str] = None,
) -> None: ) -> None:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/memory_banks/register", f"{self.base_url}/memory_banks/register",
json={ json={
"memory_bank": json.loads(memory_bank.json()), "memory_bank_id": memory_bank_id,
"provider_resource_id": provider_resource_id,
"provider_id": provider_id,
"params": params.dict(),
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
@ -71,13 +77,13 @@ class MemoryBanksClient(MemoryBanks):
async def get_memory_bank( async def get_memory_bank(
self, self,
identifier: str, memory_bank_id: str,
) -> Optional[MemoryBankDefWithProvider]: ) -> Optional[MemoryBank]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/memory_banks/get", f"{self.base_url}/memory_banks/get",
params={ params={
"identifier": identifier, "memory_bank_id": memory_bank_id,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
@ -94,12 +100,12 @@ async def run_main(host: str, port: int, stream: bool):
# register memory bank for the first time # register memory bank for the first time
response = await client.register_memory_bank( response = await client.register_memory_bank(
VectorMemoryBankDef( memory_bank_id="test_bank2",
identifier="test_bank2", params=VectorMemoryBankParams(
embedding_model="all-MiniLM-L6-v2", embedding_model="all-MiniLM-L6-v2",
chunk_size_in_tokens=512, chunk_size_in_tokens=512,
overlap_size_in_tokens=64, overlap_size_in_tokens=64,
) ),
) )
cprint(f"register_memory_bank response={response}", "blue") cprint(f"register_memory_bank response={response}", "blue")

View file

@ -5,11 +5,21 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import List, Literal, Optional, Protocol, runtime_checkable, Union from typing import (
Annotated,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type @json_schema_type
@ -20,59 +30,120 @@ class MemoryBankType(Enum):
graph = "graph" graph = "graph"
class CommonDef(BaseModel): # define params for each type of memory bank, this leads to a tagged union
identifier: str # accepted as input from the API or from the config.
# Hack: move this out later
provider_id: str = ""
@json_schema_type @json_schema_type
class VectorMemoryBankDef(CommonDef): class VectorMemoryBankParams(BaseModel):
type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str embedding_model: str
chunk_size_in_tokens: int chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None overlap_size_in_tokens: Optional[int] = None
@json_schema_type @json_schema_type
class KeyValueMemoryBankDef(CommonDef): class KeyValueMemoryBankParams(BaseModel):
type: Literal[MemoryBankType.keyvalue.value] = MemoryBankType.keyvalue.value memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
MemoryBankType.keyvalue.value
)
@json_schema_type @json_schema_type
class KeywordMemoryBankDef(CommonDef): class KeywordMemoryBankParams(BaseModel):
type: Literal[MemoryBankType.keyword.value] = MemoryBankType.keyword.value memory_bank_type: Literal[MemoryBankType.keyword.value] = (
MemoryBankType.keyword.value
)
@json_schema_type @json_schema_type
class GraphMemoryBankDef(CommonDef): class GraphMemoryBankParams(BaseModel):
type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBankDef = Annotated[ BankParams = Annotated[
Union[ Union[
VectorMemoryBankDef, VectorMemoryBankParams,
KeyValueMemoryBankDef, KeyValueMemoryBankParams,
KeywordMemoryBankDef, KeywordMemoryBankParams,
GraphMemoryBankDef, GraphMemoryBankParams,
], ],
Field(discriminator="type"), Field(discriminator="memory_bank_type"),
] ]
MemoryBankDefWithProvider = MemoryBankDef
# Some common functionality for memory banks.
class MemoryBankResourceMixin(Resource):
type: Literal[ResourceType.memory_bank.value] = ResourceType.memory_bank.value
@property
def memory_bank_id(self) -> str:
return self.identifier
@property
def provider_memory_bank_id(self) -> str:
return self.provider_resource_id
@json_schema_type
class VectorMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.vector.value] = MemoryBankType.vector.value
embedding_model: str
chunk_size_in_tokens: int
overlap_size_in_tokens: Optional[int] = None
@json_schema_type
class KeyValueMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.keyvalue.value] = (
MemoryBankType.keyvalue.value
)
# TODO: KeyValue and Keyword are so similar in name, oof. Get a better naming convention.
@json_schema_type
class KeywordMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.keyword.value] = (
MemoryBankType.keyword.value
)
@json_schema_type
class GraphMemoryBank(MemoryBankResourceMixin):
memory_bank_type: Literal[MemoryBankType.graph.value] = MemoryBankType.graph.value
MemoryBank = Annotated[
Union[
VectorMemoryBank,
KeyValueMemoryBank,
KeywordMemoryBank,
GraphMemoryBank,
],
Field(discriminator="memory_bank_type"),
]
class MemoryBankInput(BaseModel):
memory_bank_id: str
params: BankParams
provider_memory_bank_id: Optional[str] = None
@runtime_checkable @runtime_checkable
class MemoryBanks(Protocol): class MemoryBanks(Protocol):
@webmethod(route="/memory_banks/list", method="GET") @webmethod(route="/memory-banks/list", method="GET")
async def list_memory_banks(self) -> List[MemoryBankDefWithProvider]: ... async def list_memory_banks(self) -> List[MemoryBank]: ...
@webmethod(route="/memory_banks/get", method="GET") @webmethod(route="/memory-banks/get", method="GET")
async def get_memory_bank( async def get_memory_bank(self, memory_bank_id: str) -> Optional[MemoryBank]: ...
self, identifier: str
) -> Optional[MemoryBankDefWithProvider]: ...
@webmethod(route="/memory_banks/register", method="POST") @webmethod(route="/memory-banks/register", method="POST")
async def register_memory_bank( async def register_memory_bank(
self, memory_bank: MemoryBankDefWithProvider self,
) -> None: ... memory_bank_id: str,
params: BankParams,
provider_id: Optional[str] = None,
provider_memory_bank_id: Optional[str] = None,
) -> MemoryBank: ...
@webmethod(route="/memory-banks/unregister", method="POST")
async def unregister_memory_bank(self, memory_bank_id: str) -> None: ...

View file

@ -26,16 +26,16 @@ class ModelsClient(Models):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_models(self) -> List[ModelDefWithProvider]: async def list_models(self) -> List[Model]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/models/list", f"{self.base_url}/models/list",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return [ModelDefWithProvider(**x) for x in response.json()] return [Model(**x) for x in response.json()]
async def register_model(self, model: ModelDefWithProvider) -> None: async def register_model(self, model: Model) -> None:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/models/register", f"{self.base_url}/models/register",
@ -46,7 +46,7 @@ class ModelsClient(Models):
) )
response.raise_for_status() response.raise_for_status()
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: async def get_model(self, identifier: str) -> Optional[Model]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/models/get", f"{self.base_url}/models/get",
@ -59,7 +59,16 @@ class ModelsClient(Models):
j = response.json() j = response.json()
if j is None: if j is None:
return None return None
return ModelDefWithProvider(**j) return Model(**j)
async def unregister_model(self, model_id: str) -> None:
async with httpx.AsyncClient() as client:
response = await client.delete(
f"{self.base_url}/models/delete",
params={"model_id": model_id},
headers={"Content-Type": "application/json"},
)
response.raise_for_status()
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):

View file

@ -7,16 +7,12 @@
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, ConfigDict, Field
from llama_stack.apis.resource import Resource, ResourceType
class ModelDef(BaseModel): class CommonModelFields(BaseModel):
identifier: str = Field(
description="A unique name for the model type",
)
llama_model: str = Field(
description="Pointer to the underlying core Llama family model. Each model served by Llama Stack must have a core Llama model.",
)
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this model", description="Any additional metadata for this model",
@ -24,20 +20,44 @@ class ModelDef(BaseModel):
@json_schema_type @json_schema_type
class ModelDefWithProvider(ModelDef): class Model(CommonModelFields, Resource):
type: Literal["model"] = "model" type: Literal[ResourceType.model.value] = ResourceType.model.value
provider_id: str = Field(
description="The provider ID for this model", @property
) def model_id(self) -> str:
return self.identifier
@property
def provider_model_id(self) -> str:
return self.provider_resource_id
model_config = ConfigDict(protected_namespaces=())
class ModelInput(CommonModelFields):
model_id: str
provider_id: Optional[str] = None
provider_model_id: Optional[str] = None
model_config = ConfigDict(protected_namespaces=())
@runtime_checkable @runtime_checkable
class Models(Protocol): class Models(Protocol):
@webmethod(route="/models/list", method="GET") @webmethod(route="/models/list", method="GET")
async def list_models(self) -> List[ModelDefWithProvider]: ... async def list_models(self) -> List[Model]: ...
@webmethod(route="/models/get", method="GET") @webmethod(route="/models/get", method="GET")
async def get_model(self, identifier: str) -> Optional[ModelDefWithProvider]: ... async def get_model(self, identifier: str) -> Optional[Model]: ...
@webmethod(route="/models/register", method="POST") @webmethod(route="/models/register", method="POST")
async def register_model(self, model: ModelDefWithProvider) -> None: ... async def register_model(
self,
model_id: str,
provider_model_id: Optional[str] = None,
provider_id: Optional[str] = None,
metadata: Optional[Dict[str, Any]] = None,
) -> Model: ...
@webmethod(route="/models/unregister", method="POST")
async def unregister_model(self, model_id: str) -> None: ...

View file

@ -176,7 +176,7 @@ class PostTrainingJobArtifactsResponse(BaseModel):
class PostTraining(Protocol): class PostTraining(Protocol):
@webmethod(route="/post_training/supervised_fine_tune") @webmethod(route="/post-training/supervised-fine-tune")
def supervised_fine_tune( def supervised_fine_tune(
self, self,
job_uuid: str, job_uuid: str,
@ -193,7 +193,7 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post_training/preference_optimize") @webmethod(route="/post-training/preference-optimize")
def preference_optimize( def preference_optimize(
self, self,
job_uuid: str, job_uuid: str,
@ -208,22 +208,22 @@ class PostTraining(Protocol):
logger_config: Dict[str, Any], logger_config: Dict[str, Any],
) -> PostTrainingJob: ... ) -> PostTrainingJob: ...
@webmethod(route="/post_training/jobs") @webmethod(route="/post-training/jobs")
def get_training_jobs(self) -> List[PostTrainingJob]: ... def get_training_jobs(self) -> List[PostTrainingJob]: ...
# sends SSE stream of logs # sends SSE stream of logs
@webmethod(route="/post_training/job/logs") @webmethod(route="/post-training/job/logs")
def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ... def get_training_job_logstream(self, job_uuid: str) -> PostTrainingJobLogStream: ...
@webmethod(route="/post_training/job/status") @webmethod(route="/post-training/job/status")
def get_training_job_status( def get_training_job_status(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobStatusResponse: ... ) -> PostTrainingJobStatusResponse: ...
@webmethod(route="/post_training/job/cancel") @webmethod(route="/post-training/job/cancel")
def cancel_training_job(self, job_uuid: str) -> None: ... def cancel_training_job(self, job_uuid: str) -> None: ...
@webmethod(route="/post_training/job/artifacts") @webmethod(route="/post-training/job/artifacts")
def get_training_job_artifacts( def get_training_job_artifacts(
self, job_uuid: str self, job_uuid: str
) -> PostTrainingJobArtifactsResponse: ... ) -> PostTrainingJobArtifactsResponse: ...

View file

@ -0,0 +1,39 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from enum import Enum
from llama_models.schema_utils import json_schema_type
from pydantic import BaseModel, Field
@json_schema_type
class ResourceType(Enum):
model = "model"
shield = "shield"
memory_bank = "memory_bank"
dataset = "dataset"
scoring_function = "scoring_function"
eval_task = "eval_task"
class Resource(BaseModel):
"""Base class for all Llama Stack resources"""
identifier: str = Field(
description="Unique identifier for this resource in llama stack"
)
provider_resource_id: str = Field(
description="Unique identifier for this resource in the provider",
default=None,
)
provider_id: str = Field(description="ID of the provider that owns this resource")
type: ResourceType = Field(
description="Type of resource (e.g. 'model', 'shield', 'memory_bank', etc.)"
)

View file

@ -27,7 +27,7 @@ async def get_client_impl(config: RemoteProviderConfig, _deps: Any) -> Safety:
def encodable_dict(d: BaseModel): def encodable_dict(d: BaseModel):
return json.loads(d.json()) return json.loads(d.model_dump_json())
class SafetyClient(Safety): class SafetyClient(Safety):
@ -41,13 +41,13 @@ class SafetyClient(Safety):
pass pass
async def run_shield( async def run_shield(
self, shield_type: str, messages: List[Message] self, shield_id: str, messages: List[Message]
) -> RunShieldResponse: ) -> RunShieldResponse:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/safety/run_shield", f"{self.base_url}/safety/run_shield",
json=dict( json=dict(
shield_type=shield_type, shield_id=shield_id,
messages=[encodable_dict(m) for m in messages], messages=[encodable_dict(m) for m in messages],
), ),
headers={ headers={
@ -80,7 +80,7 @@ async def run_main(host: str, port: int, image_path: str = None):
) )
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
response = await client.run_shield( response = await client.run_shield(
shield_type="llama_guard", shield_id="Llama-Guard-3-1B",
messages=[message], messages=[message],
) )
print(response) print(response)
@ -91,7 +91,7 @@ async def run_main(host: str, port: int, image_path: str = None):
]: ]:
cprint(f"User>{message.content}", "green") cprint(f"User>{message.content}", "green")
response = await client.run_shield( response = await client.run_shield(
shield_type="llama_guard", shield_id="llama_guard",
messages=[message], messages=[message],
) )
print(response) print(response)

View file

@ -39,14 +39,17 @@ class RunShieldResponse(BaseModel):
class ShieldStore(Protocol): class ShieldStore(Protocol):
async def get_shield(self, identifier: str) -> ShieldDef: ... async def get_shield(self, identifier: str) -> Shield: ...
@runtime_checkable @runtime_checkable
class Safety(Protocol): class Safety(Protocol):
shield_store: ShieldStore shield_store: ShieldStore
@webmethod(route="/safety/run_shield") @webmethod(route="/safety/run-shield")
async def run_shield( async def run_shield(
self, identifier: str, messages: List[Message], params: Dict[str, Any] = None self,
shield_id: str,
messages: List[Message],
params: Dict[str, Any] = None,
) -> RunShieldResponse: ... ) -> RunShieldResponse: ...

View file

@ -37,22 +37,24 @@ class ScoreResponse(BaseModel):
class ScoringFunctionStore(Protocol): class ScoringFunctionStore(Protocol):
def get_scoring_function(self, name: str) -> ScoringFnDefWithProvider: ... def get_scoring_function(self, scoring_fn_id: str) -> ScoringFn: ...
@runtime_checkable @runtime_checkable
class Scoring(Protocol): class Scoring(Protocol):
scoring_function_store: ScoringFunctionStore scoring_function_store: ScoringFunctionStore
@webmethod(route="/scoring/score_batch") @webmethod(route="/scoring/score-batch")
async def score_batch( async def score_batch(
self, self,
dataset_id: str, dataset_id: str,
scoring_functions: List[str], scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
save_results_dataset: bool = False, save_results_dataset: bool = False,
) -> ScoreBatchResponse: ... ) -> ScoreBatchResponse: ...
@webmethod(route="/scoring/score") @webmethod(route="/scoring/score")
async def score( async def score(
self, input_rows: List[Dict[str, Any]], scoring_functions: List[str] self,
input_rows: List[Dict[str, Any]],
scoring_functions: Dict[str, Optional[ScoringFnParams]] = None,
) -> ScoreResponse: ... ) -> ScoreResponse: ...

View file

@ -4,72 +4,119 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from enum import Enum
from typing import (
Any,
Dict,
List,
Literal,
Optional,
Protocol,
runtime_checkable,
Union,
)
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel, Field
from typing_extensions import Annotated
from llama_stack.apis.common.type_system import ParamType from llama_stack.apis.common.type_system import ParamType
from llama_stack.apis.resource import Resource, ResourceType
@json_schema_type
class Parameter(BaseModel):
name: str
type: ParamType
description: Optional[str] = None
# Perhaps more structure can be imposed on these functions. Maybe they could be associated # Perhaps more structure can be imposed on these functions. Maybe they could be associated
# with standard metrics so they can be rolled up? # with standard metrics so they can be rolled up?
@json_schema_type
class ScoringFnParamsType(Enum):
llm_as_judge = "llm_as_judge"
regex_parser = "regex_parser"
class LLMAsJudgeContext(BaseModel): @json_schema_type
class LLMAsJudgeScoringFnParams(BaseModel):
type: Literal[ScoringFnParamsType.llm_as_judge.value] = (
ScoringFnParamsType.llm_as_judge.value
)
judge_model: str judge_model: str
prompt_template: Optional[str] = None prompt_template: Optional[str] = None
judge_score_regex: Optional[List[str]] = Field( judge_score_regexes: Optional[List[str]] = Field(
description="Regex to extract the score from the judge response", description="Regexes to extract the answer from generated response",
default=None, default_factory=list,
) )
@json_schema_type @json_schema_type
class ScoringFnDef(BaseModel): class RegexParserScoringFnParams(BaseModel):
identifier: str type: Literal[ScoringFnParamsType.regex_parser.value] = (
ScoringFnParamsType.regex_parser.value
)
parsing_regexes: Optional[List[str]] = Field(
description="Regex to extract the answer from generated response",
default_factory=list,
)
ScoringFnParams = Annotated[
Union[
LLMAsJudgeScoringFnParams,
RegexParserScoringFnParams,
],
Field(discriminator="type"),
]
class CommonScoringFnFields(BaseModel):
description: Optional[str] = None description: Optional[str] = None
metadata: Dict[str, Any] = Field( metadata: Dict[str, Any] = Field(
default_factory=dict, default_factory=dict,
description="Any additional metadata for this definition", description="Any additional metadata for this definition",
) )
parameters: List[Parameter] = Field(
description="List of parameters for the deterministic function",
default_factory=list,
)
return_type: ParamType = Field( return_type: ParamType = Field(
description="The return type of the deterministic function", description="The return type of the deterministic function",
) )
context: Optional[LLMAsJudgeContext] = None params: Optional[ScoringFnParams] = Field(
# We can optionally add information here to support packaging of code, etc. description="The parameters for the scoring function for benchmark eval, these can be overridden for app eval",
default=None,
)
@json_schema_type @json_schema_type
class ScoringFnDefWithProvider(ScoringFnDef): class ScoringFn(CommonScoringFnFields, Resource):
type: Literal["scoring_fn"] = "scoring_fn" type: Literal[ResourceType.scoring_function.value] = (
provider_id: str = Field( ResourceType.scoring_function.value
description="ID of the provider which serves this dataset",
) )
@property
def scoring_fn_id(self) -> str:
return self.identifier
@property
def provider_scoring_fn_id(self) -> str:
return self.provider_resource_id
class ScoringFnInput(CommonScoringFnFields, BaseModel):
scoring_fn_id: str
provider_id: Optional[str] = None
provider_scoring_fn_id: Optional[str] = None
@runtime_checkable @runtime_checkable
class ScoringFunctions(Protocol): class ScoringFunctions(Protocol):
@webmethod(route="/scoring_functions/list", method="GET") @webmethod(route="/scoring-functions/list", method="GET")
async def list_scoring_functions(self) -> List[ScoringFnDefWithProvider]: ... async def list_scoring_functions(self) -> List[ScoringFn]: ...
@webmethod(route="/scoring_functions/get", method="GET") @webmethod(route="/scoring-functions/get", method="GET")
async def get_scoring_function( async def get_scoring_function(self, scoring_fn_id: str) -> Optional[ScoringFn]: ...
self, name: str
) -> Optional[ScoringFnDefWithProvider]: ...
@webmethod(route="/scoring_functions/register", method="POST") @webmethod(route="/scoring-functions/register", method="POST")
async def register_scoring_function( async def register_scoring_function(
self, function_def: ScoringFnDefWithProvider self,
scoring_fn_id: str,
description: str,
return_type: ParamType,
provider_scoring_fn_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[ScoringFnParams] = None,
) -> None: ... ) -> None: ...

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import asyncio import asyncio
import json
from typing import List, Optional from typing import List, Optional
@ -26,32 +25,41 @@ class ShieldsClient(Shields):
async def shutdown(self) -> None: async def shutdown(self) -> None:
pass pass
async def list_shields(self) -> List[ShieldDefWithProvider]: async def list_shields(self) -> List[Shield]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/shields/list", f"{self.base_url}/shields/list",
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
return [ShieldDefWithProvider(**x) for x in response.json()] return [Shield(**x) for x in response.json()]
async def register_shield(self, shield: ShieldDefWithProvider) -> None: async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str],
provider_id: Optional[str],
params: Optional[Dict[str, Any]],
) -> None:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.post( response = await client.post(
f"{self.base_url}/shields/register", f"{self.base_url}/shields/register",
json={ json={
"shield": json.loads(shield.json()), "shield_id": shield_id,
"provider_shield_id": provider_shield_id,
"provider_id": provider_id,
"params": params,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
response.raise_for_status() response.raise_for_status()
async def get_shield(self, shield_type: str) -> Optional[ShieldDefWithProvider]: async def get_shield(self, shield_id: str) -> Optional[Shield]:
async with httpx.AsyncClient() as client: async with httpx.AsyncClient() as client:
response = await client.get( response = await client.get(
f"{self.base_url}/shields/get", f"{self.base_url}/shields/get",
params={ params={
"shield_type": shield_type, "shield_id": shield_id,
}, },
headers={"Content-Type": "application/json"}, headers={"Content-Type": "application/json"},
) )
@ -61,7 +69,7 @@ class ShieldsClient(Shields):
if j is None: if j is None:
return None return None
return ShieldDefWithProvider(**j) return Shield(**j)
async def run_main(host: str, port: int, stream: bool): async def run_main(host: str, port: int, stream: bool):

View file

@ -4,49 +4,52 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum
from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable from typing import Any, Dict, List, Literal, Optional, Protocol, runtime_checkable
from llama_models.schema_utils import json_schema_type, webmethod from llama_models.schema_utils import json_schema_type, webmethod
from pydantic import BaseModel, Field from pydantic import BaseModel
from llama_stack.apis.resource import Resource, ResourceType
class CommonShieldFields(BaseModel):
params: Optional[Dict[str, Any]] = None
@json_schema_type @json_schema_type
class ShieldType(Enum): class Shield(CommonShieldFields, Resource):
generic_content_shield = "generic_content_shield" """A safety shield resource that can be used to check content"""
llama_guard = "llama_guard"
code_scanner = "code_scanner" type: Literal[ResourceType.shield.value] = ResourceType.shield.value
prompt_guard = "prompt_guard"
@property
def shield_id(self) -> str:
return self.identifier
@property
def provider_shield_id(self) -> str:
return self.provider_resource_id
class ShieldDef(BaseModel): class ShieldInput(CommonShieldFields):
identifier: str = Field( shield_id: str
description="A unique identifier for the shield type", provider_id: Optional[str] = None
) provider_shield_id: Optional[str] = None
shield_type: str = Field(
description="The type of shield this is; the value is one of the ShieldType enum"
)
params: Dict[str, Any] = Field(
default_factory=dict,
description="Any additional parameters needed for this shield",
)
@json_schema_type
class ShieldDefWithProvider(ShieldDef):
type: Literal["shield"] = "shield"
provider_id: str = Field(
description="The provider ID for this shield type",
)
@runtime_checkable @runtime_checkable
class Shields(Protocol): class Shields(Protocol):
@webmethod(route="/shields/list", method="GET") @webmethod(route="/shields/list", method="GET")
async def list_shields(self) -> List[ShieldDefWithProvider]: ... async def list_shields(self) -> List[Shield]: ...
@webmethod(route="/shields/get", method="GET") @webmethod(route="/shields/get", method="GET")
async def get_shield(self, identifier: str) -> Optional[ShieldDefWithProvider]: ... async def get_shield(self, identifier: str) -> Optional[Shield]: ...
@webmethod(route="/shields/register", method="POST") @webmethod(route="/shields/register", method="POST")
async def register_shield(self, shield: ShieldDefWithProvider) -> None: ... async def register_shield(
self,
shield_id: str,
provider_shield_id: Optional[str] = None,
provider_id: Optional[str] = None,
params: Optional[Dict[str, Any]] = None,
) -> Shield: ...

View file

@ -44,7 +44,7 @@ class SyntheticDataGenerationResponse(BaseModel):
class SyntheticDataGeneration(Protocol): class SyntheticDataGeneration(Protocol):
@webmethod(route="/synthetic_data_generation/generate") @webmethod(route="/synthetic-data-generation/generate")
def synthetic_data_generate( def synthetic_data_generate(
self, self,
dialogs: List[Message], dialogs: List[Message],

View file

@ -125,8 +125,8 @@ Event = Annotated[
@runtime_checkable @runtime_checkable
class Telemetry(Protocol): class Telemetry(Protocol):
@webmethod(route="/telemetry/log_event") @webmethod(route="/telemetry/log-event")
async def log_event(self, event: Event) -> None: ... async def log_event(self, event: Event) -> None: ...
@webmethod(route="/telemetry/get_trace", method="GET") @webmethod(route="/telemetry/get-trace", method="GET")
async def get_trace(self, trace_id: str) -> Trace: ... async def get_trace(self, trace_id: str) -> Trace: ...

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
LLAMA_STACK_API_VERSION = "alpha"

View file

@ -9,15 +9,27 @@ import asyncio
import json import json
import os import os
import shutil import shutil
import time from dataclasses import dataclass
from datetime import datetime from datetime import datetime
from functools import partial from functools import partial
from pathlib import Path from pathlib import Path
from typing import Dict, List from typing import Dict, List, Optional
import httpx import httpx
from pydantic import BaseModel
from llama_models.datatypes import Model
from llama_models.sku_list import LlamaDownloadInfo
from pydantic import BaseModel, ConfigDict
from rich.console import Console
from rich.progress import (
BarColumn,
DownloadColumn,
Progress,
TextColumn,
TimeRemainingColumn,
TransferSpeedColumn,
)
from termcolor import cprint from termcolor import cprint
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
@ -61,6 +73,13 @@ def setup_download_parser(parser: argparse.ArgumentParser) -> None:
required=False, required=False,
help="For source=meta, URL obtained from llama.meta.com after accepting license terms", help="For source=meta, URL obtained from llama.meta.com after accepting license terms",
) )
parser.add_argument(
"--max-parallel",
type=int,
required=False,
default=3,
help="Maximum number of concurrent downloads",
)
parser.add_argument( parser.add_argument(
"--ignore-patterns", "--ignore-patterns",
type=str, type=str,
@ -80,6 +99,245 @@ safetensors files to avoid downloading duplicate weights.
parser.set_defaults(func=partial(run_download_cmd, parser=parser)) parser.set_defaults(func=partial(run_download_cmd, parser=parser))
@dataclass
class DownloadTask:
url: str
output_file: str
total_size: int = 0
downloaded_size: int = 0
task_id: Optional[int] = None
retries: int = 0
max_retries: int = 3
class DownloadError(Exception):
pass
class CustomTransferSpeedColumn(TransferSpeedColumn):
def render(self, task):
if task.finished:
return "-"
return super().render(task)
class ParallelDownloader:
def __init__(
self,
max_concurrent_downloads: int = 3,
buffer_size: int = 1024 * 1024,
timeout: int = 30,
):
self.max_concurrent_downloads = max_concurrent_downloads
self.buffer_size = buffer_size
self.timeout = timeout
self.console = Console()
self.progress = Progress(
TextColumn("[bold blue]{task.description}"),
BarColumn(bar_width=40),
"[progress.percentage]{task.percentage:>3.1f}%",
DownloadColumn(),
CustomTransferSpeedColumn(),
TimeRemainingColumn(),
console=self.console,
expand=True,
)
self.client_options = {
"timeout": httpx.Timeout(timeout),
"follow_redirects": True,
}
async def retry_with_exponential_backoff(
self, task: DownloadTask, func, *args, **kwargs
):
last_exception = None
for attempt in range(task.max_retries):
try:
return await func(*args, **kwargs)
except Exception as e:
last_exception = e
if attempt < task.max_retries - 1:
wait_time = min(30, 2**attempt) # Cap at 30 seconds
self.console.print(
f"[yellow]Attempt {attempt + 1}/{task.max_retries} failed, "
f"retrying in {wait_time} seconds: {str(e)}[/yellow]"
)
await asyncio.sleep(wait_time)
continue
raise last_exception
async def get_file_info(
self, client: httpx.AsyncClient, task: DownloadTask
) -> None:
async def _get_info():
response = await client.head(
task.url, headers={"Accept-Encoding": "identity"}, **self.client_options
)
response.raise_for_status()
return response
try:
response = await self.retry_with_exponential_backoff(task, _get_info)
task.url = str(response.url)
task.total_size = int(response.headers.get("Content-Length", 0))
if task.total_size == 0:
raise DownloadError(
f"Unable to determine file size for {task.output_file}. "
"The server might not support range requests."
)
# Update the progress bar's total size once we know it
if task.task_id is not None:
self.progress.update(task.task_id, total=task.total_size)
except httpx.HTTPError as e:
self.console.print(f"[red]Error getting file info: {str(e)}[/red]")
raise
def verify_file_integrity(self, task: DownloadTask) -> bool:
if not os.path.exists(task.output_file):
return False
return os.path.getsize(task.output_file) == task.total_size
async def download_chunk(
self, client: httpx.AsyncClient, task: DownloadTask, start: int, end: int
) -> None:
async def _download_chunk():
headers = {"Range": f"bytes={start}-{end}"}
async with client.stream(
"GET", task.url, headers=headers, **self.client_options
) as response:
response.raise_for_status()
with open(task.output_file, "ab") as file:
file.seek(start)
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
task.downloaded_size += len(chunk)
self.progress.update(
task.task_id,
completed=task.downloaded_size,
)
try:
await self.retry_with_exponential_backoff(task, _download_chunk)
except Exception as e:
raise DownloadError(
f"Failed to download chunk {start}-{end} after "
f"{task.max_retries} attempts: {str(e)}"
) from e
async def prepare_download(self, task: DownloadTask) -> None:
output_dir = os.path.dirname(task.output_file)
os.makedirs(output_dir, exist_ok=True)
if os.path.exists(task.output_file):
task.downloaded_size = os.path.getsize(task.output_file)
async def download_file(self, task: DownloadTask) -> None:
try:
async with httpx.AsyncClient(**self.client_options) as client:
await self.get_file_info(client, task)
# Check if file is already downloaded
if os.path.exists(task.output_file):
if self.verify_file_integrity(task):
self.console.print(
f"[green]Already downloaded {task.output_file}[/green]"
)
self.progress.update(task.task_id, completed=task.total_size)
return
await self.prepare_download(task)
try:
# Split the remaining download into chunks
chunk_size = 27_000_000_000 # Cloudfront max chunk size
chunks = []
current_pos = task.downloaded_size
while current_pos < task.total_size:
chunk_end = min(
current_pos + chunk_size - 1, task.total_size - 1
)
chunks.append((current_pos, chunk_end))
current_pos = chunk_end + 1
# Download chunks in sequence
for chunk_start, chunk_end in chunks:
await self.download_chunk(client, task, chunk_start, chunk_end)
except Exception as e:
raise DownloadError(f"Download failed: {str(e)}") from e
except Exception as e:
self.progress.update(
task.task_id, description=f"[red]Failed: {task.output_file}[/red]"
)
raise DownloadError(
f"Download failed for {task.output_file}: {str(e)}"
) from e
def has_disk_space(self, tasks: List[DownloadTask]) -> bool:
try:
total_remaining_size = sum(
task.total_size - task.downloaded_size for task in tasks
)
dir_path = os.path.dirname(os.path.abspath(tasks[0].output_file))
free_space = shutil.disk_usage(dir_path).free
# Add 10% buffer for safety
required_space = int(total_remaining_size * 1.1)
if free_space < required_space:
self.console.print(
f"[red]Not enough disk space. Required: {required_space // (1024 * 1024)} MB, "
f"Available: {free_space // (1024 * 1024)} MB[/red]"
)
return False
return True
except Exception as e:
raise DownloadError(f"Failed to check disk space: {str(e)}") from e
async def download_all(self, tasks: List[DownloadTask]) -> None:
if not tasks:
raise ValueError("No download tasks provided")
if not self.has_disk_space(tasks):
raise DownloadError("Insufficient disk space for downloads")
failed_tasks = []
with self.progress:
for task in tasks:
desc = f"Downloading {Path(task.output_file).name}"
task.task_id = self.progress.add_task(
desc, total=task.total_size, completed=task.downloaded_size
)
semaphore = asyncio.Semaphore(self.max_concurrent_downloads)
async def download_with_semaphore(task: DownloadTask):
async with semaphore:
try:
await self.download_file(task)
except Exception as e:
failed_tasks.append((task, str(e)))
await asyncio.gather(*(download_with_semaphore(task) for task in tasks))
if failed_tasks:
self.console.print("\n[red]Some downloads failed:[/red]")
for task, error in failed_tasks:
self.console.print(
f"[red]- {Path(task.output_file).name}: {error}[/red]"
)
raise DownloadError(f"{len(failed_tasks)} downloads failed")
def _hf_download( def _hf_download(
model: "Model", model: "Model",
hf_token: str, hf_token: str,
@ -120,69 +378,50 @@ def _hf_download(
print(f"\nSuccessfully downloaded model to {true_output_dir}") print(f"\nSuccessfully downloaded model to {true_output_dir}")
def _meta_download(model: "Model", meta_url: str, info: "LlamaDownloadInfo"): def _meta_download(
model: "Model",
model_id: str,
meta_url: str,
info: "LlamaDownloadInfo",
max_concurrent_downloads: int,
):
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
output_dir = Path(model_local_dir(model.descriptor())) output_dir = Path(model_local_dir(model.descriptor()))
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
# I believe we can use some concurrency here if needed but not sure it is worth it # Create download tasks for each file
tasks = []
for f in info.files: for f in info.files:
output_file = str(output_dir / f) output_file = str(output_dir / f)
url = meta_url.replace("*", f"{info.folder}/{f}") url = meta_url.replace("*", f"{info.folder}/{f}")
total_size = info.pth_size if "consolidated" in f else 0 total_size = info.pth_size if "consolidated" in f else 0
cprint(f"Downloading `{f}`...", "white") tasks.append(
downloader = ResumableDownloader(url, output_file, total_size) DownloadTask(
asyncio.run(downloader.download()) url=url, output_file=output_file, total_size=total_size, max_retries=3
print(f"\nSuccessfully downloaded model to {output_dir}")
cprint(f"\nMD5 Checksums are at: {output_dir / 'checklist.chk'}", "white")
def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_models.sku_list import llama_meta_net_info, resolve_model
from .model.safety_models import prompt_guard_download_info, prompt_guard_model_sku
if args.manifest_file:
_download_from_manifest(args.manifest_file)
return
if args.model_id is None:
parser.error("Please provide a model id")
return
# Check if model_id is a comma-separated list
model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
prompt_guard = prompt_guard_model_sku()
for model_id in model_ids:
if model_id == prompt_guard.model_id:
model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {model_id} not found")
continue
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url or input(
f"Please provide the signed URL for model {model_id} you received via email after visiting https://www.llama.com/llama-downloads/ (e.g., https://llama3-1.llamameta.net/*?Policy...): "
) )
assert "llamameta.net" in meta_url )
_meta_download(model, meta_url, info)
# Initialize and run parallel downloader
downloader = ParallelDownloader(max_concurrent_downloads=max_concurrent_downloads)
asyncio.run(downloader.download_all(tasks))
cprint(f"\nSuccessfully downloaded model to {output_dir}", "green")
cprint(
f"\nView MD5 checksum files at: {output_dir / 'checklist.chk'}",
"white",
)
cprint(
f"\n[Optionally] To run MD5 checksums, use the following command: llama model verify-download --model-id {model_id}",
"yellow",
)
class ModelEntry(BaseModel): class ModelEntry(BaseModel):
model_id: str model_id: str
files: Dict[str, str] files: Dict[str, str]
class Config: model_config = ConfigDict(protected_namespaces=())
protected_namespaces = ()
class Manifest(BaseModel): class Manifest(BaseModel):
@ -190,7 +429,7 @@ class Manifest(BaseModel):
expires_on: datetime expires_on: datetime
def _download_from_manifest(manifest_file: str): def _download_from_manifest(manifest_file: str, max_concurrent_downloads: int):
from llama_stack.distribution.utils.model_utils import model_local_dir from llama_stack.distribution.utils.model_utils import model_local_dir
with open(manifest_file, "r") as f: with open(manifest_file, "r") as f:
@ -200,143 +439,88 @@ def _download_from_manifest(manifest_file: str):
if datetime.now() > manifest.expires_on: if datetime.now() > manifest.expires_on:
raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}") raise ValueError(f"Manifest URLs have expired on {manifest.expires_on}")
console = Console()
for entry in manifest.models: for entry in manifest.models:
print(f"Downloading model {entry.model_id}...") console.print(f"[blue]Downloading model {entry.model_id}...[/blue]")
output_dir = Path(model_local_dir(entry.model_id)) output_dir = Path(model_local_dir(entry.model_id))
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
if any(output_dir.iterdir()): if any(output_dir.iterdir()):
cprint(f"Output directory {output_dir} is not empty.", "red") console.print(
f"[yellow]Output directory {output_dir} is not empty.[/yellow]"
)
while True: while True:
resp = input( resp = input(
"Do you want to (C)ontinue download or (R)estart completely? (continue/restart): " "Do you want to (C)ontinue download or (R)estart completely? (continue/restart): "
) )
if resp.lower() == "restart" or resp.lower() == "r": if resp.lower() in ["restart", "r"]:
shutil.rmtree(output_dir) shutil.rmtree(output_dir)
os.makedirs(output_dir, exist_ok=True) os.makedirs(output_dir, exist_ok=True)
break break
elif resp.lower() == "continue" or resp.lower() == "c": elif resp.lower() in ["continue", "c"]:
print("Continuing download...") console.print("[blue]Continuing download...[/blue]")
break break
else: else:
cprint("Invalid response. Please try again.", "red") console.print("[red]Invalid response. Please try again.[/red]")
for fname, url in entry.files.items(): # Create download tasks for all files in the manifest
output_file = str(output_dir / fname) tasks = [
downloader = ResumableDownloader(url, output_file) DownloadTask(url=url, output_file=str(output_dir / fname), max_retries=3)
asyncio.run(downloader.download()) for fname, url in entry.files.items()
]
# Initialize and run parallel downloader
downloader = ParallelDownloader(
max_concurrent_downloads=max_concurrent_downloads
)
asyncio.run(downloader.download_all(tasks))
class ResumableDownloader: def run_download_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
def __init__( """Main download command handler"""
self, try:
url: str, if args.manifest_file:
output_file: str, _download_from_manifest(args.manifest_file, args.max_parallel)
total_size: int = 0,
buffer_size: int = 32 * 1024,
):
self.url = url
self.output_file = output_file
self.buffer_size = buffer_size
self.total_size = total_size
self.downloaded_size = 0
self.start_size = 0
self.start_time = 0
async def get_file_info(self, client: httpx.AsyncClient) -> None:
if self.total_size > 0:
return return
# Force disable compression when trying to retrieve file size if args.model_id is None:
response = await client.head( parser.error("Please provide a model id")
self.url, follow_redirects=True, headers={"Accept-Encoding": "identity"} return
)
response.raise_for_status()
self.url = str(response.url) # Update URL in case of redirects
self.total_size = int(response.headers.get("Content-Length", 0))
if self.total_size == 0:
raise ValueError(
"Unable to determine file size. The server might not support range requests."
)
async def download(self) -> None: # Handle comma-separated model IDs
self.start_time = time.time() model_ids = [model_id.strip() for model_id in args.model_id.split(",")]
async with httpx.AsyncClient(follow_redirects=True) as client:
await self.get_file_info(client)
if os.path.exists(self.output_file): from llama_models.sku_list import llama_meta_net_info, resolve_model
self.downloaded_size = os.path.getsize(self.output_file)
self.start_size = self.downloaded_size
if self.downloaded_size >= self.total_size:
print(f"Already downloaded `{self.output_file}`, skipping...")
return
additional_size = self.total_size - self.downloaded_size from .model.safety_models import (
if not self.has_disk_space(additional_size): prompt_guard_download_info,
M = 1024 * 1024 # noqa prompt_guard_model_sku,
print(
f"Not enough disk space to download `{self.output_file}`. "
f"Required: {(additional_size // M):.2f} MB"
)
raise ValueError(
f"Not enough disk space to download `{self.output_file}`"
)
while True:
if self.downloaded_size >= self.total_size:
break
# Cloudfront has a max-size limit
max_chunk_size = 27_000_000_000
request_size = min(
self.total_size - self.downloaded_size, max_chunk_size
)
headers = {
"Range": f"bytes={self.downloaded_size}-{self.downloaded_size + request_size}"
}
print(f"Downloading `{self.output_file}`....{headers}")
try:
async with client.stream(
"GET", self.url, headers=headers
) as response:
response.raise_for_status()
with open(self.output_file, "ab") as file:
async for chunk in response.aiter_bytes(self.buffer_size):
file.write(chunk)
self.downloaded_size += len(chunk)
self.print_progress()
except httpx.HTTPError as e:
print(f"\nDownload interrupted: {e}")
print("You can resume the download by running the script again.")
except Exception as e:
print(f"\nAn error occurred: {e}")
print(f"\nFinished downloading `{self.output_file}`....")
def print_progress(self) -> None:
percent = (self.downloaded_size / self.total_size) * 100
bar_length = 50
filled_length = int(bar_length * self.downloaded_size // self.total_size)
bar = "" * filled_length + "-" * (bar_length - filled_length)
elapsed_time = time.time() - self.start_time
M = 1024 * 1024 # noqa
speed = (
(self.downloaded_size - self.start_size) / (elapsed_time * M)
if elapsed_time > 0
else 0
)
print(
f"\rProgress: |{bar}| {percent:.2f}% "
f"({self.downloaded_size // M}/{self.total_size // M} MB) "
f"Speed: {speed:.2f} MiB/s",
end="",
flush=True,
) )
def has_disk_space(self, file_size: int) -> bool: prompt_guard = prompt_guard_model_sku()
dir_path = os.path.dirname(os.path.abspath(self.output_file)) for model_id in model_ids:
free_space = shutil.disk_usage(dir_path).free if model_id == prompt_guard.model_id:
return free_space > file_size model = prompt_guard
info = prompt_guard_download_info()
else:
model = resolve_model(model_id)
if model is None:
parser.error(f"Model {model_id} not found")
continue
info = llama_meta_net_info(model)
if args.source == "huggingface":
_hf_download(model, args.hf_token, args.ignore_patterns, parser)
else:
meta_url = args.meta_url or input(
f"Please provide the signed URL for model {model_id} you received via email "
f"after visiting https://www.llama.com/llama-downloads/ "
f"(e.g., https://llama3-1.llamameta.net/*?Policy...): "
)
if "llamameta.net" not in meta_url:
parser.error("Invalid Meta URL provided")
_meta_download(model, model_id, meta_url, info, args.max_parallel)
except Exception as e:
parser.error(f"Download failed: {str(e)}")

View file

@ -9,6 +9,7 @@ import argparse
from .download import Download from .download import Download
from .model import ModelParser from .model import ModelParser
from .stack import StackParser from .stack import StackParser
from .verify_download import VerifyDownload
class LlamaCLIParser: class LlamaCLIParser:
@ -27,9 +28,10 @@ class LlamaCLIParser:
subparsers = self.parser.add_subparsers(title="subcommands") subparsers = self.parser.add_subparsers(title="subcommands")
# Add sub-commands # Add sub-commands
Download.create(subparsers)
ModelParser.create(subparsers) ModelParser.create(subparsers)
StackParser.create(subparsers) StackParser.create(subparsers)
Download.create(subparsers)
VerifyDownload.create(subparsers)
def parse_args(self) -> argparse.Namespace: def parse_args(self) -> argparse.Namespace:
return self.parser.parse_args() return self.parser.parse_args()

View file

@ -10,6 +10,7 @@ from llama_stack.cli.model.describe import ModelDescribe
from llama_stack.cli.model.download import ModelDownload from llama_stack.cli.model.download import ModelDownload
from llama_stack.cli.model.list import ModelList from llama_stack.cli.model.list import ModelList
from llama_stack.cli.model.prompt_format import ModelPromptFormat from llama_stack.cli.model.prompt_format import ModelPromptFormat
from llama_stack.cli.model.verify_download import ModelVerifyDownload
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
@ -32,3 +33,4 @@ class ModelParser(Subcommand):
ModelList.create(subparsers) ModelList.create(subparsers)
ModelPromptFormat.create(subparsers) ModelPromptFormat.create(subparsers)
ModelDescribe.create(subparsers) ModelDescribe.create(subparsers)
ModelVerifyDownload.create(subparsers)

View file

@ -0,0 +1,24 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
from llama_stack.cli.subcommand import Subcommand
class ModelVerifyDownload(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama model verify-download",
description="Verify the downloaded checkpoints' checksums",
formatter_class=argparse.RawTextHelpFormatter,
)
from llama_stack.cli.verify_download import setup_verify_download_parser
setup_verify_download_parser(self.parser)

View file

@ -8,10 +8,14 @@ import argparse
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
from llama_stack.distribution.datatypes import * # noqa: F403 from llama_stack.distribution.datatypes import * # noqa: F403
import importlib
import os import os
import shutil
from functools import lru_cache from functools import lru_cache
from pathlib import Path from pathlib import Path
import pkg_resources
from llama_stack.distribution.distribution import get_provider_registry from llama_stack.distribution.distribution import get_provider_registry
from llama_stack.distribution.utils.dynamic import instantiate_class_type from llama_stack.distribution.utils.dynamic import instantiate_class_type
@ -99,7 +103,9 @@ class StackBuild(Subcommand):
self.parser.error( self.parser.error(
f"Please specify a image-type (docker | conda) for {args.template}" f"Please specify a image-type (docker | conda) for {args.template}"
) )
self._run_stack_build_command_from_build_config(build_config) self._run_stack_build_command_from_build_config(
build_config, template_name=args.template
)
return return
self.parser.error( self.parser.error(
@ -193,7 +199,6 @@ class StackBuild(Subcommand):
apis = list(build_config.distribution_spec.providers.keys()) apis = list(build_config.distribution_spec.providers.keys())
run_config = StackRunConfig( run_config = StackRunConfig(
built_at=datetime.now(),
docker_image=( docker_image=(
build_config.name build_config.name
if build_config.image_type == ImageType.docker.value if build_config.image_type == ImageType.docker.value
@ -217,15 +222,23 @@ class StackBuild(Subcommand):
provider_types = [provider_types] provider_types = [provider_types]
for i, provider_type in enumerate(provider_types): for i, provider_type in enumerate(provider_types):
p_spec = Provider( pid = provider_type.split("::")[-1]
provider_id=f"{provider_type}-{i}",
provider_type=provider_type,
config={},
)
config_type = instantiate_class_type( config_type = instantiate_class_type(
provider_registry[Api(api)][provider_type].config_class provider_registry[Api(api)][provider_type].config_class
) )
p_spec.config = config_type() if hasattr(config_type, "sample_run_config"):
config = config_type.sample_run_config(
__distro_dir__=f"distributions/{build_config.name}"
)
else:
config = {}
p_spec = Provider(
provider_id=f"{pid}-{i}" if len(provider_types) > 1 else pid,
provider_type=provider_type,
config=config,
)
run_config.providers[api].append(p_spec) run_config.providers[api].append(p_spec)
os.makedirs(build_dir, exist_ok=True) os.makedirs(build_dir, exist_ok=True)
@ -241,12 +254,13 @@ class StackBuild(Subcommand):
) )
def _run_stack_build_command_from_build_config( def _run_stack_build_command_from_build_config(
self, build_config: BuildConfig self, build_config: BuildConfig, template_name: Optional[str] = None
) -> None: ) -> None:
import json import json
import os import os
import yaml import yaml
from termcolor import cprint
from llama_stack.distribution.build import build_image from llama_stack.distribution.build import build_image
from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR from llama_stack.distribution.utils.config_dirs import DISTRIBS_BASE_DIR
@ -264,7 +278,29 @@ class StackBuild(Subcommand):
if return_code != 0: if return_code != 0:
return return
self._generate_run_config(build_config, build_dir) if template_name:
# copy run.yaml from template to build_dir instead of generating it again
template_path = pkg_resources.resource_filename(
"llama_stack", f"templates/{template_name}/run.yaml"
)
os.makedirs(build_dir, exist_ok=True)
run_config_file = build_dir / f"{build_config.name}-run.yaml"
shutil.copy(template_path, run_config_file)
module_name = f"llama_stack.templates.{template_name}"
module = importlib.import_module(module_name)
distribution_template = module.get_distribution_template()
cprint("Build Successful! Next steps: ", color="green")
env_vars = ", ".join(distribution_template.run_config_env_vars.keys())
cprint(
f" 1. Set the environment variables: {env_vars}",
color="green",
)
cprint(
f" 2. `llama stack run {run_config_file}`",
color="green",
)
else:
self._generate_run_config(build_config, build_dir)
def _run_template_list_cmd(self, args: argparse.Namespace) -> None: def _run_template_list_cmd(self, args: argparse.Namespace) -> None:
import json import json

View file

@ -40,7 +40,7 @@ class StackConfigure(Subcommand):
self.parser.error( self.parser.error(
""" """
DEPRECATED! llama stack configure has been deprecated. DEPRECATED! llama stack configure has been deprecated.
Please use llama stack run --config <path/to/run.yaml> instead. Please use llama stack run <path/to/run.yaml> instead.
Please see example run.yaml in /distributions folder. Please see example run.yaml in /distributions folder.
""" """
) )

View file

@ -5,9 +5,12 @@
# the root directory of this source tree. # the root directory of this source tree.
import argparse import argparse
from pathlib import Path
from llama_stack.cli.subcommand import Subcommand from llama_stack.cli.subcommand import Subcommand
REPO_ROOT = Path(__file__).parent.parent.parent.parent
class StackRun(Subcommand): class StackRun(Subcommand):
def __init__(self, subparsers: argparse._SubParsersAction): def __init__(self, subparsers: argparse._SubParsersAction):
@ -39,16 +42,24 @@ class StackRun(Subcommand):
help="Disable IPv6 support", help="Disable IPv6 support",
default=False, default=False,
) )
self.parser.add_argument(
"--env",
action="append",
help="Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times.",
default=[],
metavar="KEY=VALUE",
)
def _run_stack_run_cmd(self, args: argparse.Namespace) -> None: def _run_stack_run_cmd(self, args: argparse.Namespace) -> None:
from pathlib import Path
import pkg_resources import pkg_resources
import yaml import yaml
from llama_stack.distribution.build import ImageType from llama_stack.distribution.build import ImageType
from llama_stack.distribution.configure import parse_and_maybe_upgrade_config from llama_stack.distribution.configure import parse_and_maybe_upgrade_config
from llama_stack.distribution.utils.config_dirs import BUILDS_BASE_DIR from llama_stack.distribution.utils.config_dirs import (
BUILDS_BASE_DIR,
DISTRIBS_BASE_DIR,
)
from llama_stack.distribution.utils.exec import run_with_pty from llama_stack.distribution.utils.exec import run_with_pty
if not args.config: if not args.config:
@ -56,24 +67,41 @@ class StackRun(Subcommand):
return return
config_file = Path(args.config) config_file = Path(args.config)
if not config_file.exists() and not args.config.endswith(".yaml"): has_yaml_suffix = args.config.endswith(".yaml")
if not config_file.exists() and not has_yaml_suffix:
# check if this is a template
config_file = (
Path(REPO_ROOT) / "llama_stack" / "templates" / args.config / "run.yaml"
)
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to conda dir # check if it's a build config saved to conda dir
config_file = Path( config_file = Path(
BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml" BUILDS_BASE_DIR / ImageType.conda.value / f"{args.config}-run.yaml"
) )
if not config_file.exists() and not args.config.endswith(".yaml"): if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to docker dir # check if it's a build config saved to docker dir
config_file = Path( config_file = Path(
BUILDS_BASE_DIR / ImageType.docker.value / f"{args.config}-run.yaml" BUILDS_BASE_DIR / ImageType.docker.value / f"{args.config}-run.yaml"
) )
if not config_file.exists() and not has_yaml_suffix:
# check if it's a build config saved to ~/.llama dir
config_file = Path(
DISTRIBS_BASE_DIR
/ f"llamastack-{args.config}"
/ f"{args.config}-run.yaml"
)
if not config_file.exists(): if not config_file.exists():
self.parser.error( self.parser.error(
f"File {str(config_file)} does not exist. Please run `llama stack build` to generate (and optionally edit) a run.yaml file" f"File {str(config_file)} does not exist. Please run `llama stack build` to generate (and optionally edit) a run.yaml file"
) )
return return
print(f"Using config file: {config_file}")
config_dict = yaml.safe_load(config_file.read_text()) config_dict = yaml.safe_load(config_file.read_text())
config = parse_and_maybe_upgrade_config(config_dict) config = parse_and_maybe_upgrade_config(config_dict)
@ -97,4 +125,16 @@ class StackRun(Subcommand):
if args.disable_ipv6: if args.disable_ipv6:
run_args.append("--disable-ipv6") run_args.append("--disable-ipv6")
for env_var in args.env:
if "=" not in env_var:
self.parser.error(
f"Environment variable '{env_var}' must be in KEY=VALUE format"
)
return
key, value = env_var.split("=", 1) # split on first = only
if not key:
self.parser.error(f"Environment variable '{env_var}' has empty key")
return
run_args.extend(["--env", f"{key}={value}"])
run_with_pty(run_args) run_with_pty(run_args)

View file

@ -25,11 +25,11 @@ def up_to_date_config():
providers: providers:
inference: inference:
- provider_id: provider1 - provider_id: provider1
provider_type: meta-reference provider_type: inline::meta-reference
config: {{}} config: {{}}
safety: safety:
- provider_id: provider1 - provider_id: provider1
provider_type: meta-reference provider_type: inline::meta-reference
config: config:
llama_guard_shield: llama_guard_shield:
model: Llama-Guard-3-1B model: Llama-Guard-3-1B
@ -39,7 +39,7 @@ def up_to_date_config():
enable_prompt_guard: false enable_prompt_guard: false
memory: memory:
- provider_id: provider1 - provider_id: provider1
provider_type: meta-reference provider_type: inline::meta-reference
config: {{}} config: {{}}
""".format( """.format(
version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat() version=LLAMA_STACK_RUN_CONFIG_VERSION, built_at=datetime.now().isoformat()
@ -61,13 +61,13 @@ def old_config():
host: localhost host: localhost
port: 11434 port: 11434
routing_key: Llama3.2-1B-Instruct routing_key: Llama3.2-1B-Instruct
- provider_type: meta-reference - provider_type: inline::meta-reference
config: config:
model: Llama3.1-8B-Instruct model: Llama3.1-8B-Instruct
routing_key: Llama3.1-8B-Instruct routing_key: Llama3.1-8B-Instruct
safety: safety:
- routing_key: ["shield1", "shield2"] - routing_key: ["shield1", "shield2"]
provider_type: meta-reference provider_type: inline::meta-reference
config: config:
llama_guard_shield: llama_guard_shield:
model: Llama-Guard-3-1B model: Llama-Guard-3-1B
@ -77,7 +77,7 @@ def old_config():
enable_prompt_guard: false enable_prompt_guard: false
memory: memory:
- routing_key: vector - routing_key: vector
provider_type: meta-reference provider_type: inline::meta-reference
config: {{}} config: {{}}
api_providers: api_providers:
telemetry: telemetry:

View file

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import argparse
import hashlib
from dataclasses import dataclass
from functools import partial
from pathlib import Path
from typing import Dict, List, Optional
from rich.console import Console
from rich.progress import Progress, SpinnerColumn, TextColumn
from llama_stack.cli.subcommand import Subcommand
@dataclass
class VerificationResult:
filename: str
expected_hash: str
actual_hash: Optional[str]
exists: bool
matches: bool
class VerifyDownload(Subcommand):
"""Llama cli for verifying downloaded model files"""
def __init__(self, subparsers: argparse._SubParsersAction):
super().__init__()
self.parser = subparsers.add_parser(
"verify-download",
prog="llama verify-download",
description="Verify integrity of downloaded model files",
formatter_class=argparse.RawTextHelpFormatter,
)
setup_verify_download_parser(self.parser)
def setup_verify_download_parser(parser: argparse.ArgumentParser) -> None:
parser.add_argument(
"--model-id",
required=True,
help="Model ID to verify",
)
parser.set_defaults(func=partial(run_verify_cmd, parser=parser))
def calculate_md5(filepath: Path, chunk_size: int = 8192) -> str:
md5_hash = hashlib.md5()
with open(filepath, "rb") as f:
for chunk in iter(lambda: f.read(chunk_size), b""):
md5_hash.update(chunk)
return md5_hash.hexdigest()
def load_checksums(checklist_path: Path) -> Dict[str, str]:
checksums = {}
with open(checklist_path, "r") as f:
for line in f:
if line.strip():
md5sum, filepath = line.strip().split(" ", 1)
# Remove leading './' if present
filepath = filepath.lstrip("./")
checksums[filepath] = md5sum
return checksums
def verify_files(
model_dir: Path, checksums: Dict[str, str], console: Console
) -> List[VerificationResult]:
results = []
with Progress(
SpinnerColumn(),
TextColumn("[progress.description]{task.description}"),
console=console,
) as progress:
for filepath, expected_hash in checksums.items():
full_path = model_dir / filepath
task_id = progress.add_task(f"Verifying {filepath}...", total=None)
exists = full_path.exists()
actual_hash = None
matches = False
if exists:
actual_hash = calculate_md5(full_path)
matches = actual_hash == expected_hash
results.append(
VerificationResult(
filename=filepath,
expected_hash=expected_hash,
actual_hash=actual_hash,
exists=exists,
matches=matches,
)
)
progress.remove_task(task_id)
return results
def run_verify_cmd(args: argparse.Namespace, parser: argparse.ArgumentParser):
from llama_stack.distribution.utils.model_utils import model_local_dir
console = Console()
model_dir = Path(model_local_dir(args.model_id))
checklist_path = model_dir / "checklist.chk"
if not model_dir.exists():
parser.error(f"Model directory not found: {model_dir}")
if not checklist_path.exists():
parser.error(f"Checklist file not found: {checklist_path}")
checksums = load_checksums(checklist_path)
results = verify_files(model_dir, checksums, console)
# Print results
console.print("\nVerification Results:")
all_good = True
for result in results:
if not result.exists:
console.print(f"[red]❌ {result.filename}: File not found[/red]")
all_good = False
elif not result.matches:
console.print(
f"[red]❌ {result.filename}: Hash mismatch[/red]\n"
f" Expected: {result.expected_hash}\n"
f" Got: {result.actual_hash}"
)
all_good = False
else:
console.print(f"[green]✓ {result.filename}: Verified[/green]")
if all_good:
console.print("\n[green]All files verified successfully![/green]")

View file

@ -5,7 +5,7 @@
# the root directory of this source tree. # the root directory of this source tree.
from enum import Enum from enum import Enum
from typing import List, Optional from typing import List
import pkg_resources import pkg_resources
from pydantic import BaseModel from pydantic import BaseModel
@ -38,28 +38,19 @@ class ImageType(Enum):
conda = "conda" conda = "conda"
class Dependencies(BaseModel):
pip_packages: List[str]
docker_image: Optional[str] = None
class ApiInput(BaseModel): class ApiInput(BaseModel):
api: Api api: Api
provider: str provider: str
def build_image(build_config: BuildConfig, build_file_path: Path): def get_provider_dependencies(
package_deps = Dependencies( config_providers: Dict[str, List[Provider]]
docker_image=build_config.distribution_spec.docker_image or "python:3.10-slim", ) -> tuple[list[str], list[str]]:
pip_packages=SERVER_DEPENDENCIES, """Get normal and special dependencies from provider configuration."""
)
# extend package dependencies based on providers spec
all_providers = get_provider_registry() all_providers = get_provider_registry()
for ( deps = []
api_str,
provider_or_providers, for api_str, provider_or_providers in config_providers.items():
) in build_config.distribution_spec.providers.items():
providers_for_api = all_providers[Api(api_str)] providers_for_api = all_providers[Api(api_str)]
providers = ( providers = (
@ -69,25 +60,50 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
) )
for provider in providers: for provider in providers:
if provider not in providers_for_api: # Providers from BuildConfig and RunConfig are subtly different  not great
provider_type = (
provider if isinstance(provider, str) else provider.provider_type
)
if provider_type not in providers_for_api:
raise ValueError( raise ValueError(
f"Provider `{provider}` is not available for API `{api_str}`" f"Provider `{provider}` is not available for API `{api_str}`"
) )
provider_spec = providers_for_api[provider] provider_spec = providers_for_api[provider_type]
package_deps.pip_packages.extend(provider_spec.pip_packages) deps.extend(provider_spec.pip_packages)
if provider_spec.docker_image: if provider_spec.docker_image:
raise ValueError("A stack's dependencies cannot have a docker image") raise ValueError("A stack's dependencies cannot have a docker image")
normal_deps = []
special_deps = [] special_deps = []
deps = [] for package in deps:
for package in package_deps.pip_packages:
if "--no-deps" in package or "--index-url" in package: if "--no-deps" in package or "--index-url" in package:
special_deps.append(package) special_deps.append(package)
else: else:
deps.append(package) normal_deps.append(package)
deps = list(set(deps))
special_deps = list(set(special_deps)) return list(set(normal_deps)), list(set(special_deps))
def print_pip_install_help(providers: Dict[str, List[Provider]]):
normal_deps, special_deps = get_provider_dependencies(providers)
print(
f"Please install needed dependencies using the following commands:\n\n\tpip install {' '.join(normal_deps)}"
)
for special_dep in special_deps:
print(f"\tpip install {special_dep}")
print()
def build_image(build_config: BuildConfig, build_file_path: Path):
docker_image = build_config.distribution_spec.docker_image or "python:3.10-slim"
normal_deps, special_deps = get_provider_dependencies(
build_config.distribution_spec.providers
)
normal_deps += SERVER_DEPENDENCIES
if build_config.image_type == ImageType.docker.value: if build_config.image_type == ImageType.docker.value:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
@ -96,10 +112,10 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
args = [ args = [
script, script,
build_config.name, build_config.name,
package_deps.docker_image, docker_image,
str(build_file_path), str(build_file_path),
str(BUILDS_BASE_DIR / ImageType.docker.value), str(BUILDS_BASE_DIR / ImageType.docker.value),
" ".join(deps), " ".join(normal_deps),
] ]
else: else:
script = pkg_resources.resource_filename( script = pkg_resources.resource_filename(
@ -109,7 +125,7 @@ def build_image(build_config: BuildConfig, build_file_path: Path):
script, script,
build_config.name, build_config.name,
str(build_file_path), str(build_file_path),
" ".join(deps), " ".join(normal_deps),
] ]
if special_deps: if special_deps:

Some files were not shown because too many files have changed in this diff Show more