mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-05 18:22:41 +00:00
Merge branch 'main' into add-watsonx-inference-adapter
This commit is contained in:
commit
7eb83264ef
116 changed files with 2286 additions and 2719 deletions
71
CHANGELOG.md
71
CHANGELOG.md
|
@ -1,5 +1,76 @@
|
|||
# Changelog
|
||||
|
||||
# v0.1.8
|
||||
Published on: 2025-03-24T01:28:50Z
|
||||
|
||||
# v0.1.8 Release Notes
|
||||
|
||||
### Build and Test Agents
|
||||
* Safety: Integrated NVIDIA as a safety provider.
|
||||
* VectorDB: Added Qdrant as an inline provider.
|
||||
* Agents: Added support for multiple tool groups in agents.
|
||||
* Agents: Simplified imports for Agents in client package
|
||||
|
||||
|
||||
### Agent Evals and Model Customization
|
||||
* Introduced DocVQA and IfEval benchmarks.
|
||||
|
||||
### Deploying and Monitoring Agents
|
||||
* Introduced a Containerfile and image workflow for the Playground.
|
||||
* Implemented support for Bearer (API Key) authentication.
|
||||
* Added attribute-based access control for resources.
|
||||
* Fixes on docker deployments: use --pull always and standardized the default port to 8321
|
||||
* Deprecated: /v1/inspect/providers use /v1/providers/ instead
|
||||
|
||||
### Better Engineering
|
||||
* Consolidated scripts under the ./scripts directory.
|
||||
* Addressed mypy violations in various modules.
|
||||
* Added Dependabot scans for Python dependencies.
|
||||
* Implemented a scheduled workflow to update the changelog automatically.
|
||||
* Enforced concurrency to reduce CI loads.
|
||||
|
||||
|
||||
### New Contributors
|
||||
* @cmodi-meta made their first contribution in https://github.com/meta-llama/llama-stack/pull/1650
|
||||
* @jeffmaury made their first contribution in https://github.com/meta-llama/llama-stack/pull/1671
|
||||
* @derekhiggins made their first contribution in https://github.com/meta-llama/llama-stack/pull/1698
|
||||
* @Bobbins228 made their first contribution in https://github.com/meta-llama/llama-stack/pull/1745
|
||||
|
||||
**Full Changelog**: https://github.com/meta-llama/llama-stack/compare/v0.1.7...v0.1.8
|
||||
|
||||
---
|
||||
|
||||
# v0.1.7
|
||||
Published on: 2025-03-14T22:30:51Z
|
||||
|
||||
## 0.1.7 Release Notes
|
||||
|
||||
### Build and Test Agents
|
||||
* Inference: ImageType is now refactored to LlamaStackImageType
|
||||
* Inference: Added tests to measure TTFT
|
||||
* Inference: Bring back usage metrics
|
||||
* Agents: Added endpoint for get agent, list agents and list sessions
|
||||
* Agents: Automated conversion of type hints in client tool for lite llm format
|
||||
* Agents: Deprecated ToolResponseMessage in agent.resume API
|
||||
* Added Provider API for listing and inspecting provider info
|
||||
|
||||
### Agent Evals and Model Customization
|
||||
* Eval: Added new eval benchmarks Math 500 and BFCL v3
|
||||
* Deploy and Monitoring of Agents
|
||||
* Telemetry: Fix tracing to work across coroutines
|
||||
|
||||
### Better Engineering
|
||||
* Display code coverage for unit tests
|
||||
* Updated call sites (inference, tool calls, agents) to move to async non blocking calls
|
||||
* Unit tests also run on Python 3.11, 3.12, and 3.13
|
||||
* Added ollama inference to Integration tests CI
|
||||
* Improved documentation across examples, testing, CLI, updated providers table )
|
||||
|
||||
|
||||
|
||||
|
||||
---
|
||||
|
||||
# v0.1.6
|
||||
Published on: 2025-03-08T04:35:08Z
|
||||
|
||||
|
|
|
@ -81,7 +81,9 @@ Note that you can create a dotenv file `.env` that includes necessary environmen
|
|||
LLAMA_STACK_BASE_URL=http://localhost:8321
|
||||
LLAMA_STACK_CLIENT_LOG=debug
|
||||
LLAMA_STACK_PORT=8321
|
||||
LLAMA_STACK_CONFIG=
|
||||
LLAMA_STACK_CONFIG=<provider-name>
|
||||
TAVILY_SEARCH_API_KEY=
|
||||
BRAVE_SEARCH_API_KEY=
|
||||
```
|
||||
|
||||
And then use this dotenv file when running client SDK tests via the following:
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/bedrock/build.yaml
|
|
@ -1,15 +0,0 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: distribution-bedrock
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run.yaml:/root/llamastack-run-bedrock.yaml
|
||||
ports:
|
||||
- "8321:8321"
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-bedrock.yaml"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/bedrock/run.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/cerebras/build.yaml
|
|
@ -1,16 +0,0 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: llamastack/distribution-cerebras
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run.yaml:/root/llamastack-run-cerebras.yaml
|
||||
ports:
|
||||
- "8321:8321"
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-cerebras.yaml"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/cerebras/run.yaml
|
|
@ -1,50 +0,0 @@
|
|||
services:
|
||||
text-generation-inference:
|
||||
image: registry.dell.huggingface.co/enterprise-dell-inference-meta-llama-meta-llama-3.1-8b-instruct
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- $HOME/.cache/huggingface:/data
|
||||
ports:
|
||||
- "5009:5009"
|
||||
devices:
|
||||
- nvidia.com/gpu=all
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=0,1,2,3,4
|
||||
- NUM_SHARD=4
|
||||
- MAX_BATCH_PREFILL_TOKENS=32768
|
||||
- MAX_INPUT_TOKENS=8000
|
||||
- MAX_TOTAL_TOKENS=8192
|
||||
command: []
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
# that's the closest analogue to --gpus; provide
|
||||
# an integer amount of devices or 'all'
|
||||
count: all
|
||||
# Devices are reserved using a list of capabilities, making
|
||||
# capabilities the only required field. A device MUST
|
||||
# satisfy all the requested capabilities for a successful
|
||||
# reservation.
|
||||
capabilities: [gpu]
|
||||
runtime: nvidia
|
||||
llamastack:
|
||||
depends_on:
|
||||
text-generation-inference:
|
||||
condition: service_healthy
|
||||
image: llamastack/distribution-tgi
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
# Link to TGI run.yaml file
|
||||
- ./run.yaml:/root/my-run.yaml
|
||||
ports:
|
||||
- "8321:8321"
|
||||
# Hack: wait for TGI server to start before starting docker
|
||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
|
@ -1,44 +0,0 @@
|
|||
version: '2'
|
||||
image_name: local
|
||||
container_image: null
|
||||
conda_env: local
|
||||
apis:
|
||||
- shields
|
||||
- agents
|
||||
- models
|
||||
- memory
|
||||
- memory_banks
|
||||
- inference
|
||||
- safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: tgi0
|
||||
provider_type: remote::tgi
|
||||
config:
|
||||
url: http://127.0.0.1:80
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::faiss
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
telemetry:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
|
@ -433,6 +433,7 @@
|
|||
"zmq"
|
||||
],
|
||||
"nvidia": [
|
||||
"aiohttp",
|
||||
"aiosqlite",
|
||||
"blobfile",
|
||||
"chardet",
|
||||
|
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/fireworks/build.yaml
|
|
@ -1,14 +0,0 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: llamastack/distribution-fireworks
|
||||
ports:
|
||||
- "8321:8321"
|
||||
environment:
|
||||
- FIREWORKS_API_KEY=${FIREWORKS_API_KEY}
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --template fireworks"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/fireworks/run.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/meta-reference-gpu/build.yaml
|
|
@ -1,34 +0,0 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: llamastack/distribution-meta-reference-gpu
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run.yaml:/root/my-run.yaml
|
||||
ports:
|
||||
- "8321:8321"
|
||||
devices:
|
||||
- nvidia.com/gpu=all
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
command: []
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
# that's the closest analogue to --gpus; provide
|
||||
# an integer amount of devices or 'all'
|
||||
count: 1
|
||||
# Devices are reserved using a list of capabilities, making
|
||||
# capabilities the only required field. A device MUST
|
||||
# satisfy all the requested capabilities for a successful
|
||||
# reservation.
|
||||
capabilities: [gpu]
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
||||
runtime: nvidia
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/meta-reference-gpu/run-with-safety.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/meta-reference-gpu/run.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/meta-reference-quantized-gpu/build.yaml
|
|
@ -1,35 +0,0 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: llamastack/distribution-meta-reference-quantized-gpu
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run.yaml:/root/my-run.yaml
|
||||
ports:
|
||||
- "8321:8321"
|
||||
devices:
|
||||
- nvidia.com/gpu=all
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
command: []
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
# that's the closest analogue to --gpus; provide
|
||||
# an integer amount of devices or 'all'
|
||||
count: 1
|
||||
# Devices are reserved using a list of capabilities, making
|
||||
# capabilities the only required field. A device MUST
|
||||
# satisfy all the requested capabilities for a successful
|
||||
# reservation.
|
||||
capabilities: [gpu]
|
||||
runtime: nvidia
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
|
@ -1,58 +0,0 @@
|
|||
version: '2'
|
||||
image_name: local
|
||||
container_image: null
|
||||
conda_env: local
|
||||
apis:
|
||||
- shields
|
||||
- agents
|
||||
- models
|
||||
- memory
|
||||
- memory_banks
|
||||
- inference
|
||||
- safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference-quantized
|
||||
config:
|
||||
model: Llama3.2-3B-Instruct:int4-qlora-eo8
|
||||
quantization:
|
||||
type: int4
|
||||
torch_seed: null
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 1
|
||||
- provider_id: meta1
|
||||
provider_type: inline::meta-reference-quantized
|
||||
config:
|
||||
# not a quantized model !
|
||||
model: Llama-Guard-3-1B
|
||||
quantization: null
|
||||
torch_seed: null
|
||||
max_seq_len: 2048
|
||||
max_batch_size: 1
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
- provider_id: meta1
|
||||
provider_type: inline::prompt-guard
|
||||
config:
|
||||
model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
agents:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/kvstore.db
|
||||
telemetry:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/ollama/build.yaml
|
|
@ -1,71 +0,0 @@
|
|||
services:
|
||||
ollama:
|
||||
image: ollama/ollama:latest
|
||||
network_mode: ${NETWORK_MODE:-bridge}
|
||||
volumes:
|
||||
- ~/.ollama:/root/.ollama
|
||||
ports:
|
||||
- "11434:11434"
|
||||
environment:
|
||||
OLLAMA_DEBUG: 1
|
||||
command: []
|
||||
deploy:
|
||||
resources:
|
||||
limits:
|
||||
memory: 8G # Set maximum memory
|
||||
reservations:
|
||||
memory: 8G # Set minimum memory reservation
|
||||
# healthcheck:
|
||||
# # ugh, no CURL in ollama image
|
||||
# test: ["CMD", "curl", "-f", "http://ollama:11434"]
|
||||
# interval: 10s
|
||||
# timeout: 5s
|
||||
# retries: 5
|
||||
|
||||
ollama-init:
|
||||
image: ollama/ollama:latest
|
||||
depends_on:
|
||||
- ollama
|
||||
# condition: service_healthy
|
||||
network_mode: ${NETWORK_MODE:-bridge}
|
||||
environment:
|
||||
- OLLAMA_HOST=ollama
|
||||
- INFERENCE_MODEL=${INFERENCE_MODEL}
|
||||
- SAFETY_MODEL=${SAFETY_MODEL:-}
|
||||
volumes:
|
||||
- ~/.ollama:/root/.ollama
|
||||
- ./pull-models.sh:/pull-models.sh
|
||||
entrypoint: ["/pull-models.sh"]
|
||||
|
||||
llamastack:
|
||||
depends_on:
|
||||
ollama:
|
||||
condition: service_started
|
||||
ollama-init:
|
||||
condition: service_started
|
||||
image: ${LLAMA_STACK_IMAGE:-llamastack/distribution-ollama}
|
||||
network_mode: ${NETWORK_MODE:-bridge}
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
# Link to ollama run.yaml file
|
||||
- ~/local/llama-stack/:/app/llama-stack-source
|
||||
- ./run${SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
|
||||
ports:
|
||||
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
|
||||
environment:
|
||||
- INFERENCE_MODEL=${INFERENCE_MODEL}
|
||||
- SAFETY_MODEL=${SAFETY_MODEL:-}
|
||||
- OLLAMA_URL=http://ollama:11434
|
||||
entrypoint: >
|
||||
python -m llama_stack.distribution.server.server /root/my-run.yaml \
|
||||
--port ${LLAMA_STACK_PORT:-8321}
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 10s
|
||||
max_attempts: 3
|
||||
window: 60s
|
||||
volumes:
|
||||
ollama:
|
||||
ollama-init:
|
||||
llamastack:
|
|
@ -1,18 +0,0 @@
|
|||
#!/bin/sh
|
||||
|
||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
echo "Preloading (${INFERENCE_MODEL}, ${SAFETY_MODEL})..."
|
||||
for model in ${INFERENCE_MODEL} ${SAFETY_MODEL}; do
|
||||
echo "Preloading $model..."
|
||||
if ! ollama run "$model"; then
|
||||
echo "Failed to pull and run $model"
|
||||
exit 1
|
||||
fi
|
||||
done
|
||||
|
||||
echo "All models pulled successfully"
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/ollama/run-with-safety.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/ollama/run.yaml
|
Binary file not shown.
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/nvidia/build.yaml
|
|
@ -1,19 +0,0 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: distribution-nvidia:dev
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run.yaml:/root/llamastack-run-nvidia.yaml
|
||||
ports:
|
||||
- "8321:8321"
|
||||
environment:
|
||||
- INFERENCE_MODEL=${INFERENCE_MODEL:-Llama3.1-8B-Instruct}
|
||||
- NVIDIA_API_KEY=${NVIDIA_API_KEY:-}
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml-config /root/llamastack-run-nvidia.yaml"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/nvidia/run.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/remote-vllm/build.yaml
|
|
@ -1,99 +0,0 @@
|
|||
services:
|
||||
vllm-inference:
|
||||
image: vllm/vllm-openai:latest
|
||||
volumes:
|
||||
- $HOME/.cache/huggingface:/root/.cache/huggingface
|
||||
network_mode: ${NETWORK_MODE:-bridged}
|
||||
ports:
|
||||
- "${VLLM_INFERENCE_PORT:-5100}:${VLLM_INFERENCE_PORT:-5100}"
|
||||
devices:
|
||||
- nvidia.com/gpu=all
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=${VLLM_INFERENCE_GPU:-0}
|
||||
- HUGGING_FACE_HUB_TOKEN=$HF_TOKEN
|
||||
command: >
|
||||
--gpu-memory-utilization 0.75
|
||||
--model ${VLLM_INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
|
||||
--enforce-eager
|
||||
--max-model-len 8192
|
||||
--max-num-seqs 16
|
||||
--port ${VLLM_INFERENCE_PORT:-5100}
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:${VLLM_INFERENCE_PORT:-5100}/v1/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
capabilities: [gpu]
|
||||
runtime: nvidia
|
||||
|
||||
# A little trick:
|
||||
# if VLLM_SAFETY_MODEL is set, we will create a service for the safety model
|
||||
# otherwise, the entry will end in a hyphen which gets ignored by docker compose
|
||||
vllm-${VLLM_SAFETY_MODEL:+safety}:
|
||||
image: vllm/vllm-openai:latest
|
||||
volumes:
|
||||
- $HOME/.cache/huggingface:/root/.cache/huggingface
|
||||
network_mode: ${NETWORK_MODE:-bridged}
|
||||
ports:
|
||||
- "${VLLM_SAFETY_PORT:-5101}:${VLLM_SAFETY_PORT:-5101}"
|
||||
devices:
|
||||
- nvidia.com/gpu=all
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=${VLLM_SAFETY_GPU:-1}
|
||||
- HUGGING_FACE_HUB_TOKEN=$HF_TOKEN
|
||||
command: >
|
||||
--gpu-memory-utilization 0.75
|
||||
--model ${VLLM_SAFETY_MODEL}
|
||||
--enforce-eager
|
||||
--max-model-len 8192
|
||||
--max-num-seqs 16
|
||||
--port ${VLLM_SAFETY_PORT:-5101}
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://localhost:${VLLM_SAFETY_PORT:-5101}/v1/health"]
|
||||
interval: 30s
|
||||
timeout: 10s
|
||||
retries: 5
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
capabilities: [gpu]
|
||||
runtime: nvidia
|
||||
llamastack:
|
||||
depends_on:
|
||||
- vllm-inference:
|
||||
condition: service_healthy
|
||||
- vllm-${VLLM_SAFETY_MODEL:+safety}:
|
||||
condition: service_healthy
|
||||
image: llamastack/distribution-remote-vllm:test-0.0.52rc3
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run${VLLM_SAFETY_MODEL:+-with-safety}.yaml:/root/llamastack-run-remote-vllm.yaml
|
||||
network_mode: ${NETWORK_MODE:-bridged}
|
||||
environment:
|
||||
- VLLM_URL=http://vllm-inference:${VLLM_INFERENCE_PORT:-5100}/v1
|
||||
- VLLM_SAFETY_URL=http://vllm-safety:${VLLM_SAFETY_PORT:-5101}/v1
|
||||
- INFERENCE_MODEL=${INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
|
||||
- MAX_TOKENS=${MAX_TOKENS:-4096}
|
||||
- SQLITE_STORE_DIR=${SQLITE_STORE_DIR:-$HOME/.llama/distributions/remote-vllm}
|
||||
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
||||
ports:
|
||||
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
|
||||
# Hack: wait for vLLM server to start before starting docker
|
||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-remote-vllm.yaml --port 8321"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
||||
volumes:
|
||||
vllm-inference:
|
||||
vllm-safety:
|
||||
llamastack:
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/remote-vllm/run-with-safety.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/remote-vllm/run.yaml
|
|
@ -1,9 +0,0 @@
|
|||
name: runpod
|
||||
distribution_spec:
|
||||
description: Use Runpod for running LLM inference
|
||||
providers:
|
||||
inference: remote::runpod
|
||||
memory: meta-reference
|
||||
safety: meta-reference
|
||||
agents: meta-reference
|
||||
telemetry: meta-reference
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/sambanova/build.yaml
|
|
@ -1,16 +0,0 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: llamastack/distribution-sambanova
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run.yaml:/root/llamastack-run-sambanova.yaml
|
||||
ports:
|
||||
- "5000:5000"
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/llamastack-run-sambanova.yaml"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/sambanova/run.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/tgi/build.yaml
|
|
@ -1,103 +0,0 @@
|
|||
services:
|
||||
tgi-inference:
|
||||
image: ghcr.io/huggingface/text-generation-inference:latest
|
||||
volumes:
|
||||
- $HOME/.cache/huggingface:/data
|
||||
network_mode: ${NETWORK_MODE:-bridged}
|
||||
ports:
|
||||
- "${TGI_INFERENCE_PORT:-8080}:${TGI_INFERENCE_PORT:-8080}"
|
||||
devices:
|
||||
- nvidia.com/gpu=all
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=${TGI_INFERENCE_GPU:-0}
|
||||
- HF_TOKEN=$HF_TOKEN
|
||||
- HF_HOME=/data
|
||||
- HF_DATASETS_CACHE=/data
|
||||
- HF_MODULES_CACHE=/data
|
||||
- HF_HUB_CACHE=/data
|
||||
command: >
|
||||
--dtype bfloat16
|
||||
--usage-stats off
|
||||
--sharded false
|
||||
--model-id ${TGI_INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
|
||||
--port ${TGI_INFERENCE_PORT:-8080}
|
||||
--cuda-memory-fraction 0.75
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://tgi-inference:${TGI_INFERENCE_PORT:-8080}/health"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
capabilities: [gpu]
|
||||
runtime: nvidia
|
||||
|
||||
tgi-${TGI_SAFETY_MODEL:+safety}:
|
||||
image: ghcr.io/huggingface/text-generation-inference:latest
|
||||
volumes:
|
||||
- $HOME/.cache/huggingface:/data
|
||||
network_mode: ${NETWORK_MODE:-bridged}
|
||||
ports:
|
||||
- "${TGI_SAFETY_PORT:-8081}:${TGI_SAFETY_PORT:-8081}"
|
||||
devices:
|
||||
- nvidia.com/gpu=all
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=${TGI_SAFETY_GPU:-1}
|
||||
- HF_TOKEN=$HF_TOKEN
|
||||
- HF_HOME=/data
|
||||
- HF_DATASETS_CACHE=/data
|
||||
- HF_MODULES_CACHE=/data
|
||||
- HF_HUB_CACHE=/data
|
||||
command: >
|
||||
--dtype bfloat16
|
||||
--usage-stats off
|
||||
--sharded false
|
||||
--model-id ${TGI_SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
||||
--port ${TGI_SAFETY_PORT:-8081}
|
||||
--cuda-memory-fraction 0.75
|
||||
healthcheck:
|
||||
test: ["CMD", "curl", "-f", "http://tgi-safety:${TGI_SAFETY_PORT:-8081}/health"]
|
||||
interval: 5s
|
||||
timeout: 5s
|
||||
retries: 30
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
capabilities: [gpu]
|
||||
runtime: nvidia
|
||||
|
||||
llamastack:
|
||||
depends_on:
|
||||
tgi-inference:
|
||||
condition: service_healthy
|
||||
tgi-${TGI_SAFETY_MODEL:+safety}:
|
||||
condition: service_healthy
|
||||
image: llamastack/distribution-tgi:test-0.0.52rc3
|
||||
network_mode: ${NETWORK_MODE:-bridged}
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run${TGI_SAFETY_MODEL:+-with-safety}.yaml:/root/my-run.yaml
|
||||
ports:
|
||||
- "${LLAMA_STACK_PORT:-8321}:${LLAMA_STACK_PORT:-8321}"
|
||||
# Hack: wait for TGI server to start before starting docker
|
||||
entrypoint: bash -c "sleep 60; python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
||||
environment:
|
||||
- TGI_URL=http://tgi-inference:${TGI_INFERENCE_PORT:-8080}
|
||||
- SAFETY_TGI_URL=http://tgi-safety:${TGI_SAFETY_PORT:-8081}
|
||||
- INFERENCE_MODEL=${INFERENCE_MODEL:-meta-llama/Llama-3.2-3B-Instruct}
|
||||
- SAFETY_MODEL=${SAFETY_MODEL:-meta-llama/Llama-Guard-3-1B}
|
||||
|
||||
volumes:
|
||||
tgi-inference:
|
||||
tgi-safety:
|
||||
llamastack:
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/tgi/run-with-safety.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/tgi/run.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/together/build.yaml
|
|
@ -1,14 +0,0 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: llamastack/distribution-together
|
||||
ports:
|
||||
- "8321:8321"
|
||||
environment:
|
||||
- TOGETHER_API_KEY=${TOGETHER_API_KEY}
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --template together"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/together/run.yaml
|
|
@ -1 +0,0 @@
|
|||
../../llama_stack/templates/inline-vllm/build.yaml
|
|
@ -1,35 +0,0 @@
|
|||
services:
|
||||
llamastack:
|
||||
image: llamastack/distribution-inline-vllm
|
||||
network_mode: "host"
|
||||
volumes:
|
||||
- ~/.llama:/root/.llama
|
||||
- ./run.yaml:/root/my-run.yaml
|
||||
ports:
|
||||
- "8321:8321"
|
||||
devices:
|
||||
- nvidia.com/gpu=all
|
||||
environment:
|
||||
- CUDA_VISIBLE_DEVICES=0
|
||||
command: []
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- driver: nvidia
|
||||
# that's the closest analogue to --gpus; provide
|
||||
# an integer amount of devices or 'all'
|
||||
count: 1
|
||||
# Devices are reserved using a list of capabilities, making
|
||||
# capabilities the only required field. A device MUST
|
||||
# satisfy all the requested capabilities for a successful
|
||||
# reservation.
|
||||
capabilities: [gpu]
|
||||
runtime: nvidia
|
||||
entrypoint: bash -c "python -m llama_stack.distribution.server.server --yaml_config /root/my-run.yaml"
|
||||
deploy:
|
||||
restart_policy:
|
||||
condition: on-failure
|
||||
delay: 3s
|
||||
max_attempts: 5
|
||||
window: 60s
|
|
@ -1,66 +0,0 @@
|
|||
version: '2'
|
||||
image_name: local
|
||||
container_image: null
|
||||
conda_env: local
|
||||
apis:
|
||||
- shields
|
||||
- agents
|
||||
- models
|
||||
- memory
|
||||
- memory_banks
|
||||
- inference
|
||||
- safety
|
||||
providers:
|
||||
inference:
|
||||
- provider_id: vllm-inference
|
||||
provider_type: inline::vllm
|
||||
config:
|
||||
model: Llama3.2-3B-Instruct
|
||||
tensor_parallel_size: 1
|
||||
gpu_memory_utilization: 0.4
|
||||
enforce_eager: true
|
||||
max_tokens: 4096
|
||||
- provider_id: vllm-inference-safety
|
||||
provider_type: inline::vllm
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
tensor_parallel_size: 1
|
||||
gpu_memory_utilization: 0.2
|
||||
enforce_eager: true
|
||||
max_tokens: 4096
|
||||
safety:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::llama-guard
|
||||
config:
|
||||
model: Llama-Guard-3-1B
|
||||
excluded_categories: []
|
||||
# Uncomment to use prompt guard
|
||||
# - provider_id: meta1
|
||||
# provider_type: inline::prompt-guard
|
||||
# config:
|
||||
# model: Prompt-Guard-86M
|
||||
memory:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
||||
# Uncomment to use pgvector
|
||||
# - provider_id: pgvector
|
||||
# provider_type: remote::pgvector
|
||||
# config:
|
||||
# host: 127.0.0.1
|
||||
# port: 5432
|
||||
# db: postgres
|
||||
# user: postgres
|
||||
# password: mysecretpassword
|
||||
agents:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
persistence_store:
|
||||
namespace: null
|
||||
type: sqlite
|
||||
db_path: ~/.llama/runtime/agents_store.db
|
||||
telemetry:
|
||||
- provider_id: meta0
|
||||
provider_type: inline::meta-reference
|
||||
config: {}
|
95
docs/_static/llama-stack-spec.html
vendored
95
docs/_static/llama-stack-spec.html
vendored
|
@ -818,14 +818,7 @@
|
|||
"delete": {
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "OK",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/FileResponse"
|
||||
}
|
||||
}
|
||||
}
|
||||
"description": "OK"
|
||||
},
|
||||
"400": {
|
||||
"$ref": "#/components/responses/BadRequest400"
|
||||
|
@ -6140,46 +6133,6 @@
|
|||
"title": "FileUploadResponse",
|
||||
"description": "Response after initiating a file upload session."
|
||||
},
|
||||
"FileResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bucket": {
|
||||
"type": "string",
|
||||
"description": "Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)"
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)"
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": "MIME type of the file"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Upload URL for the file contents"
|
||||
},
|
||||
"bytes": {
|
||||
"type": "integer",
|
||||
"description": "Size of the file in bytes"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "integer",
|
||||
"description": "Timestamp of when the file was created"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"bucket",
|
||||
"key",
|
||||
"mime_type",
|
||||
"url",
|
||||
"bytes",
|
||||
"created_at"
|
||||
],
|
||||
"title": "FileResponse",
|
||||
"description": "Response representing a file entry."
|
||||
},
|
||||
"EmbeddingsRequest": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -6933,6 +6886,46 @@
|
|||
"title": "URIDataSource",
|
||||
"description": "A dataset that can be obtained from a URI."
|
||||
},
|
||||
"FileResponse": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"bucket": {
|
||||
"type": "string",
|
||||
"description": "Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)"
|
||||
},
|
||||
"key": {
|
||||
"type": "string",
|
||||
"description": "Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)"
|
||||
},
|
||||
"mime_type": {
|
||||
"type": "string",
|
||||
"description": "MIME type of the file"
|
||||
},
|
||||
"url": {
|
||||
"type": "string",
|
||||
"description": "Upload URL for the file contents"
|
||||
},
|
||||
"bytes": {
|
||||
"type": "integer",
|
||||
"description": "Size of the file in bytes"
|
||||
},
|
||||
"created_at": {
|
||||
"type": "integer",
|
||||
"description": "Timestamp of when the file was created"
|
||||
}
|
||||
},
|
||||
"additionalProperties": false,
|
||||
"required": [
|
||||
"bucket",
|
||||
"key",
|
||||
"mime_type",
|
||||
"url",
|
||||
"bytes",
|
||||
"created_at"
|
||||
],
|
||||
"title": "FileResponse",
|
||||
"description": "Response representing a file entry."
|
||||
},
|
||||
"Model": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
@ -7671,7 +7664,8 @@
|
|||
"completed",
|
||||
"in_progress",
|
||||
"failed",
|
||||
"scheduled"
|
||||
"scheduled",
|
||||
"cancelled"
|
||||
],
|
||||
"title": "JobStatus"
|
||||
},
|
||||
|
@ -8135,7 +8129,8 @@
|
|||
"completed",
|
||||
"in_progress",
|
||||
"failed",
|
||||
"scheduled"
|
||||
"scheduled",
|
||||
"cancelled"
|
||||
],
|
||||
"title": "JobStatus"
|
||||
}
|
||||
|
|
72
docs/_static/llama-stack-spec.yaml
vendored
72
docs/_static/llama-stack-spec.yaml
vendored
|
@ -557,10 +557,6 @@ paths:
|
|||
responses:
|
||||
'200':
|
||||
description: OK
|
||||
content:
|
||||
application/json:
|
||||
schema:
|
||||
$ref: '#/components/schemas/FileResponse'
|
||||
'400':
|
||||
$ref: '#/components/responses/BadRequest400'
|
||||
'429':
|
||||
|
@ -4286,39 +4282,6 @@ components:
|
|||
title: FileUploadResponse
|
||||
description: >-
|
||||
Response after initiating a file upload session.
|
||||
FileResponse:
|
||||
type: object
|
||||
properties:
|
||||
bucket:
|
||||
type: string
|
||||
description: >-
|
||||
Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
||||
key:
|
||||
type: string
|
||||
description: >-
|
||||
Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
mime_type:
|
||||
type: string
|
||||
description: MIME type of the file
|
||||
url:
|
||||
type: string
|
||||
description: Upload URL for the file contents
|
||||
bytes:
|
||||
type: integer
|
||||
description: Size of the file in bytes
|
||||
created_at:
|
||||
type: integer
|
||||
description: Timestamp of when the file was created
|
||||
additionalProperties: false
|
||||
required:
|
||||
- bucket
|
||||
- key
|
||||
- mime_type
|
||||
- url
|
||||
- bytes
|
||||
- created_at
|
||||
title: FileResponse
|
||||
description: Response representing a file entry.
|
||||
EmbeddingsRequest:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -4830,6 +4793,39 @@ components:
|
|||
title: URIDataSource
|
||||
description: >-
|
||||
A dataset that can be obtained from a URI.
|
||||
FileResponse:
|
||||
type: object
|
||||
properties:
|
||||
bucket:
|
||||
type: string
|
||||
description: >-
|
||||
Bucket under which the file is stored (valid chars: a-zA-Z0-9_-)
|
||||
key:
|
||||
type: string
|
||||
description: >-
|
||||
Key under which the file is stored (valid chars: a-zA-Z0-9_-/.)
|
||||
mime_type:
|
||||
type: string
|
||||
description: MIME type of the file
|
||||
url:
|
||||
type: string
|
||||
description: Upload URL for the file contents
|
||||
bytes:
|
||||
type: integer
|
||||
description: Size of the file in bytes
|
||||
created_at:
|
||||
type: integer
|
||||
description: Timestamp of when the file was created
|
||||
additionalProperties: false
|
||||
required:
|
||||
- bucket
|
||||
- key
|
||||
- mime_type
|
||||
- url
|
||||
- bytes
|
||||
- created_at
|
||||
title: FileResponse
|
||||
description: Response representing a file entry.
|
||||
Model:
|
||||
type: object
|
||||
properties:
|
||||
|
@ -5306,6 +5302,7 @@ components:
|
|||
- in_progress
|
||||
- failed
|
||||
- scheduled
|
||||
- cancelled
|
||||
title: JobStatus
|
||||
scheduled_at:
|
||||
type: string
|
||||
|
@ -5583,6 +5580,7 @@ components:
|
|||
- in_progress
|
||||
- failed
|
||||
- scheduled
|
||||
- cancelled
|
||||
title: JobStatus
|
||||
additionalProperties: false
|
||||
required:
|
||||
|
|
File diff suppressed because one or more lines are too long
|
@ -963,16 +963,19 @@
|
|||
"\n",
|
||||
"client.benchmarks.register(\n",
|
||||
" benchmark_id=\"meta-reference::mmmu\",\n",
|
||||
" # Note: we can use any value as `dataset_id` because we'll be using the `evaluate_rows` API which accepts the \n",
|
||||
" # `input_rows` argument and does not fetch data from the dataset.\n",
|
||||
" dataset_id=f\"mmmu-{subset}-{split}\",\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||
" # Note: for the same reason as above, we can use any value as `scoring_functions`.\n",
|
||||
" scoring_functions=[],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows_alpha(\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
" benchmark_id=\"meta-reference::mmmu\",\n",
|
||||
" input_rows=eval_rows,\n",
|
||||
" # Note: Here we define the actual scoring functions.\n",
|
||||
" scoring_functions=[\"basic::regex_parser_multiple_choice_answer\"],\n",
|
||||
" benchmark_config={\n",
|
||||
" \"type\": \"benchmark\",\n",
|
||||
" \"eval_candidate\": {\n",
|
||||
" \"type\": \"model\",\n",
|
||||
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
||||
|
@ -1139,12 +1142,11 @@
|
|||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
")\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows_alpha(\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||
" input_rows=eval_rows.data,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
" benchmark_config={\n",
|
||||
" \"type\": \"benchmark\",\n",
|
||||
" \"eval_candidate\": {\n",
|
||||
" \"type\": \"model\",\n",
|
||||
" \"model\": \"meta-llama/Llama-3.2-90B-Vision-Instruct\",\n",
|
||||
|
@ -1288,12 +1290,11 @@
|
|||
" \"enable_session_persistence\": False,\n",
|
||||
"}\n",
|
||||
"\n",
|
||||
"response = client.eval.evaluate_rows_alpha(\n",
|
||||
"response = client.eval.evaluate_rows(\n",
|
||||
" benchmark_id=\"meta-reference::simpleqa\",\n",
|
||||
" input_rows=eval_rows.data,\n",
|
||||
" scoring_functions=[\"llm-as-judge::405b-simpleqa\"],\n",
|
||||
" benchmark_config={\n",
|
||||
" \"type\": \"benchmark\",\n",
|
||||
" \"eval_candidate\": {\n",
|
||||
" \"type\": \"agent\",\n",
|
||||
" \"config\": agent_config,\n",
|
||||
|
|
|
@ -21,7 +21,7 @@ from llama_stack.distribution.stack import LlamaStack # noqa: E402
|
|||
|
||||
from .pyopenapi.options import Options # noqa: E402
|
||||
from .pyopenapi.specification import Info, Server # noqa: E402
|
||||
from .pyopenapi.utility import Specification, validate_api_method_return_types # noqa: E402
|
||||
from .pyopenapi.utility import Specification, validate_api # noqa: E402
|
||||
|
||||
|
||||
def str_presenter(dumper, data):
|
||||
|
@ -40,8 +40,7 @@ def main(output_dir: str):
|
|||
raise ValueError(f"Directory {output_dir} does not exist")
|
||||
|
||||
# Validate API protocols before generating spec
|
||||
print("Validating API method return types...")
|
||||
return_type_errors = validate_api_method_return_types()
|
||||
return_type_errors = validate_api()
|
||||
if return_type_errors:
|
||||
print("\nAPI Method Return Type Validation Errors:\n")
|
||||
for error in return_type_errors:
|
||||
|
|
|
@ -7,10 +7,9 @@
|
|||
import json
|
||||
import typing
|
||||
import inspect
|
||||
import os
|
||||
from pathlib import Path
|
||||
from typing import TextIO
|
||||
from typing import Any, Dict, List, Optional, Protocol, Type, Union, get_type_hints, get_origin, get_args
|
||||
from typing import Any, List, Optional, Union, get_type_hints, get_origin, get_args
|
||||
|
||||
from llama_stack.strong_typing.schema import object_to_json, StrictJsonType
|
||||
from llama_stack.distribution.resolver import api_protocol_map
|
||||
|
@ -125,29 +124,59 @@ def is_optional_type(type_: Any) -> bool:
|
|||
return origin is Optional or (origin is Union and type(None) in args)
|
||||
|
||||
|
||||
def validate_api_method_return_types() -> List[str]:
|
||||
"""Validate that all API methods have proper return types."""
|
||||
def _validate_api_method_return_type(method) -> str | None:
|
||||
hints = get_type_hints(method)
|
||||
|
||||
if 'return' not in hints:
|
||||
return "has no return type annotation"
|
||||
|
||||
return_type = hints['return']
|
||||
if is_optional_type(return_type):
|
||||
return "returns Optional type"
|
||||
|
||||
|
||||
def _validate_api_delete_method_returns_none(method) -> str | None:
|
||||
hints = get_type_hints(method)
|
||||
|
||||
if 'return' not in hints:
|
||||
return "has no return type annotation"
|
||||
|
||||
return_type = hints['return']
|
||||
if return_type is not None and return_type is not type(None):
|
||||
return "does not return None"
|
||||
|
||||
|
||||
_VALIDATORS = {
|
||||
"GET": [
|
||||
_validate_api_method_return_type,
|
||||
],
|
||||
"DELETE": [
|
||||
_validate_api_delete_method_returns_none,
|
||||
],
|
||||
}
|
||||
|
||||
|
||||
def _get_methods_by_type(protocol, method_type: str):
|
||||
members = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
return {
|
||||
method_name: method
|
||||
for method_name, method in members
|
||||
if (webmethod := getattr(method, '__webmethod__', None))
|
||||
if webmethod and webmethod.method == method_type
|
||||
}
|
||||
|
||||
|
||||
def validate_api() -> List[str]:
|
||||
"""Validate the API protocols."""
|
||||
errors = []
|
||||
protocols = api_protocol_map()
|
||||
|
||||
for protocol_name, protocol in protocols.items():
|
||||
methods = inspect.getmembers(protocol, predicate=inspect.isfunction)
|
||||
|
||||
for method_name, method in methods:
|
||||
if not hasattr(method, '__webmethod__'):
|
||||
continue
|
||||
|
||||
# Only check GET methods
|
||||
if method.__webmethod__.method != "GET":
|
||||
continue
|
||||
|
||||
hints = get_type_hints(method)
|
||||
|
||||
if 'return' not in hints:
|
||||
errors.append(f"Method {protocol_name}.{method_name} has no return type annotation")
|
||||
else:
|
||||
return_type = hints['return']
|
||||
if is_optional_type(return_type):
|
||||
errors.append(f"Method {protocol_name}.{method_name} returns Optional type")
|
||||
for target, validators in _VALIDATORS.items():
|
||||
for protocol_name, protocol in protocols.items():
|
||||
for validator in validators:
|
||||
for method_name, method in _get_methods_by_type(protocol, target).items():
|
||||
err = validator(method)
|
||||
if err:
|
||||
errors.append(f"Method {protocol_name}.{method_name} {err}")
|
||||
|
||||
return errors
|
||||
|
|
|
@ -9,6 +9,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
|||
| datasetio | `inline::localfs` |
|
||||
| eval | `inline::meta-reference` |
|
||||
| inference | `remote::nvidia` |
|
||||
| post_training | `remote::nvidia` |
|
||||
| safety | `remote::nvidia` |
|
||||
| scoring | `inline::basic` |
|
||||
| telemetry | `inline::meta-reference` |
|
||||
|
@ -21,6 +22,12 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
|||
The following environment variables can be configured:
|
||||
|
||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||
|
|
|
@ -15,6 +15,7 @@ class JobStatus(Enum):
|
|||
in_progress = "in_progress"
|
||||
failed = "failed"
|
||||
scheduled = "scheduled"
|
||||
cancelled = "cancelled"
|
||||
|
||||
|
||||
@json_schema_type
|
||||
|
|
|
@ -34,6 +34,7 @@ class Api(Enum):
|
|||
scoring_functions = "scoring_functions"
|
||||
benchmarks = "benchmarks"
|
||||
tool_groups = "tool_groups"
|
||||
files = "files"
|
||||
|
||||
# built-in API
|
||||
inspect = "inspect"
|
||||
|
|
|
@ -164,7 +164,7 @@ class Files(Protocol):
|
|||
self,
|
||||
bucket: str,
|
||||
key: str,
|
||||
) -> FileResponse:
|
||||
) -> None:
|
||||
"""
|
||||
Delete a file identified by a bucket and key.
|
||||
|
||||
|
|
|
@ -43,7 +43,7 @@ class StackRun(Subcommand):
|
|||
self.parser.add_argument(
|
||||
"--image-name",
|
||||
type=str,
|
||||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
||||
default=None,
|
||||
help="Name of the image to run. Defaults to the current conda environment",
|
||||
)
|
||||
self.parser.add_argument(
|
||||
|
|
|
@ -12,6 +12,7 @@ from llama_stack.apis.benchmarks import Benchmarks
|
|||
from llama_stack.apis.datasetio import DatasetIO
|
||||
from llama_stack.apis.datasets import Datasets
|
||||
from llama_stack.apis.eval import Eval
|
||||
from llama_stack.apis.files import Files
|
||||
from llama_stack.apis.inference import Inference
|
||||
from llama_stack.apis.inspect import Inspect
|
||||
from llama_stack.apis.models import Models
|
||||
|
@ -79,6 +80,7 @@ def api_protocol_map() -> Dict[Api, Any]:
|
|||
Api.post_training: PostTraining,
|
||||
Api.tool_groups: ToolGroups,
|
||||
Api.tool_runtime: ToolRuntime,
|
||||
Api.files: Files,
|
||||
}
|
||||
|
||||
|
||||
|
|
|
@ -13,6 +13,7 @@ LLAMA_CHECKPOINT_DIR=${LLAMA_CHECKPOINT_DIR:-}
|
|||
LLAMA_STACK_DIR=${LLAMA_STACK_DIR:-}
|
||||
TEST_PYPI_VERSION=${TEST_PYPI_VERSION:-}
|
||||
PYPI_VERSION=${PYPI_VERSION:-}
|
||||
VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||
|
||||
set -euo pipefail
|
||||
|
||||
|
@ -69,22 +70,25 @@ while [[ $# -gt 0 ]]; do
|
|||
;;
|
||||
esac
|
||||
done
|
||||
|
||||
PYTHON_BINARY="python"
|
||||
case "$env_type" in
|
||||
"venv")
|
||||
# Activate virtual environment
|
||||
if [ ! -d "$env_path_or_name" ]; then
|
||||
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
if [ -n "$VIRTUAL_ENV" && "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
|
||||
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
|
||||
else
|
||||
# Activate virtual environment
|
||||
if [ ! -d "$env_path_or_name" ]; then
|
||||
echo -e "${RED}Error: Virtual environment not found at $env_path_or_name${NC}" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
if [ ! -f "$env_path_or_name/bin/activate" ]; then
|
||||
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
|
||||
exit 1
|
||||
fi
|
||||
if [ ! -f "$env_path_or_name/bin/activate" ]; then
|
||||
echo -e "${RED}Error: Virtual environment activate binary not found at $env_path_or_name/bin/activate" >&2
|
||||
exit 1
|
||||
fi
|
||||
|
||||
source "$env_path_or_name/bin/activate"
|
||||
source "$env_path_or_name/bin/activate"
|
||||
fi
|
||||
;;
|
||||
"conda")
|
||||
if ! is_command_available conda; then
|
||||
|
|
|
@ -18,15 +18,19 @@ def preserve_contexts_async_generator(
|
|||
This is needed because we start a new asyncio event loop for each streaming request,
|
||||
and we need to preserve the context across the event loop boundary.
|
||||
"""
|
||||
# Capture initial context values
|
||||
initial_context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
|
||||
async def wrapper() -> AsyncGenerator[T, None]:
|
||||
while True:
|
||||
try:
|
||||
item = await gen.__anext__()
|
||||
context_values = {context_var.name: context_var.get() for context_var in context_vars}
|
||||
yield item
|
||||
# Restore context values before any await
|
||||
for context_var in context_vars:
|
||||
_ = context_var.set(context_values[context_var.name])
|
||||
context_var.set(initial_context_values[context_var.name])
|
||||
|
||||
item = await gen.__anext__()
|
||||
yield item
|
||||
|
||||
except StopAsyncIteration:
|
||||
break
|
||||
|
||||
|
|
|
@ -28,6 +28,11 @@ class TelemetryConfig(BaseModel):
|
|||
default="http://localhost:4318/v1/metrics",
|
||||
description="The OpenTelemetry collector endpoint URL for metrics",
|
||||
)
|
||||
service_name: str = Field(
|
||||
# service name is always the same, use zero-width space to avoid clutter
|
||||
default="",
|
||||
description="The service name to use for telemetry",
|
||||
)
|
||||
sinks: List[TelemetrySink] = Field(
|
||||
default=[TelemetrySink.CONSOLE, TelemetrySink.SQLITE],
|
||||
description="List of telemetry sinks to enable (possible values: otel, sqlite, console)",
|
||||
|
@ -47,6 +52,7 @@ class TelemetryConfig(BaseModel):
|
|||
@classmethod
|
||||
def sample_run_config(cls, __distro_dir__: str, db_name: str = "trace_store.db") -> Dict[str, Any]:
|
||||
return {
|
||||
"service_name": "${env.OTEL_SERVICE_NAME:}",
|
||||
"sinks": "${env.TELEMETRY_SINKS:console,sqlite}",
|
||||
"sqlite_db_path": "${env.SQLITE_DB_PATH:" + __distro_dir__ + "/" + db_name + "}",
|
||||
}
|
||||
|
|
|
@ -67,8 +67,7 @@ class TelemetryAdapter(TelemetryDatasetMixin, Telemetry):
|
|||
|
||||
resource = Resource.create(
|
||||
{
|
||||
# service name is always the same, use zero-width space to avoid clutter
|
||||
ResourceAttributes.SERVICE_NAME: "",
|
||||
ResourceAttributes.SERVICE_NAME: self.config.service_name,
|
||||
}
|
||||
)
|
||||
|
||||
|
|
11
llama_stack/providers/registry/files.py
Normal file
11
llama_stack/providers/registry/files.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from llama_stack.providers.datatypes import ProviderSpec
|
||||
|
||||
|
||||
def available_providers() -> list[ProviderSpec]:
|
||||
return []
|
|
@ -6,7 +6,7 @@
|
|||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
||||
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||
|
||||
|
||||
def available_providers() -> List[ProviderSpec]:
|
||||
|
@ -22,4 +22,13 @@ def available_providers() -> List[ProviderSpec]:
|
|||
Api.datasets,
|
||||
],
|
||||
),
|
||||
remote_provider_spec(
|
||||
api=Api.post_training,
|
||||
adapter=AdapterSpec(
|
||||
adapter_type="nvidia",
|
||||
pip_packages=["requests", "aiohttp"],
|
||||
module="llama_stack.providers.remote.post_training.nvidia",
|
||||
config_class="llama_stack.providers.remote.post_training.nvidia.NvidiaPostTrainingConfig",
|
||||
),
|
||||
),
|
||||
]
|
||||
|
|
|
@ -55,7 +55,7 @@ from .openai_utils import (
|
|||
convert_openai_completion_choice,
|
||||
convert_openai_completion_stream,
|
||||
)
|
||||
from .utils import _is_nvidia_hosted, check_health
|
||||
from .utils import _is_nvidia_hosted
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
@ -134,7 +134,9 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
if content_has_media(content):
|
||||
raise NotImplementedError("Media is not supported")
|
||||
|
||||
await check_health(self._config) # this raises errors
|
||||
# ToDo: check health of NeMo endpoints and enable this
|
||||
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
request = convert_completion_request(
|
||||
|
@ -236,7 +238,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
|||
if tool_prompt_format:
|
||||
warnings.warn("tool_prompt_format is not supported by NVIDIA NIM, ignoring", stacklevel=2)
|
||||
|
||||
await check_health(self._config) # this raises errors
|
||||
# await check_health(self._config) # this raises errors
|
||||
|
||||
provider_model_id = self.get_provider_model_id(model_id)
|
||||
request = await convert_chat_completion_request(
|
||||
|
|
5
llama_stack/providers/remote/post_training/__init__.py
Normal file
5
llama_stack/providers/remote/post_training/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
138
llama_stack/providers/remote/post_training/nvidia/README.md
Normal file
138
llama_stack/providers/remote/post_training/nvidia/README.md
Normal file
|
@ -0,0 +1,138 @@
|
|||
# NVIDIA Post-Training Provider for LlamaStack
|
||||
|
||||
This provider enables fine-tuning of LLMs using NVIDIA's NeMo Customizer service.
|
||||
|
||||
## Features
|
||||
|
||||
- Supervised fine-tuning of Llama models
|
||||
- LoRA fine-tuning support
|
||||
- Job management and status tracking
|
||||
|
||||
## Getting Started
|
||||
|
||||
### Prerequisites
|
||||
|
||||
- LlamaStack with NVIDIA configuration
|
||||
- Access to Hosted NVIDIA NeMo Customizer service
|
||||
- Dataset registered in the Hosted NVIDIA NeMo Customizer service
|
||||
- Base model downloaded and available in the Hosted NVIDIA NeMo Customizer service
|
||||
|
||||
### Setup
|
||||
|
||||
Build the NVIDIA environment:
|
||||
|
||||
```bash
|
||||
llama stack build --template nvidia --image-type conda
|
||||
```
|
||||
|
||||
### Basic Usage using the LlamaStack Python Client
|
||||
|
||||
### Create Customization Job
|
||||
|
||||
#### Initialize the client
|
||||
|
||||
```python
|
||||
import os
|
||||
|
||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
|
||||
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
||||
|
||||
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||
|
||||
client = LlamaStackAsLibraryClient("nvidia")
|
||||
client.initialize()
|
||||
```
|
||||
|
||||
#### Configure fine-tuning parameters
|
||||
|
||||
```python
|
||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
||||
TrainingConfig,
|
||||
TrainingConfigDataConfig,
|
||||
TrainingConfigOptimizerConfig,
|
||||
)
|
||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
||||
```
|
||||
|
||||
#### Set up LoRA configuration
|
||||
|
||||
```python
|
||||
algorithm_config = LoraFinetuningConfig(type="LoRA", adapter_dim=16)
|
||||
```
|
||||
|
||||
#### Configure training data
|
||||
|
||||
```python
|
||||
data_config = TrainingConfigDataConfig(
|
||||
dataset_id="your-dataset-id", # Use client.datasets.list() to see available datasets
|
||||
batch_size=16,
|
||||
)
|
||||
```
|
||||
|
||||
#### Configure optimizer
|
||||
|
||||
```python
|
||||
optimizer_config = TrainingConfigOptimizerConfig(
|
||||
lr=0.0001,
|
||||
)
|
||||
```
|
||||
|
||||
#### Set up training configuration
|
||||
|
||||
```python
|
||||
training_config = TrainingConfig(
|
||||
n_epochs=2,
|
||||
data_config=data_config,
|
||||
optimizer_config=optimizer_config,
|
||||
)
|
||||
```
|
||||
|
||||
#### Start fine-tuning job
|
||||
|
||||
```python
|
||||
training_job = client.post_training.supervised_fine_tune(
|
||||
job_uuid="unique-job-id",
|
||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||
checkpoint_dir="",
|
||||
algorithm_config=algorithm_config,
|
||||
training_config=training_config,
|
||||
logger_config={},
|
||||
hyperparam_search_config={},
|
||||
)
|
||||
```
|
||||
|
||||
### List all jobs
|
||||
|
||||
```python
|
||||
jobs = client.post_training.job.list()
|
||||
```
|
||||
|
||||
### Check job status
|
||||
|
||||
```python
|
||||
job_status = client.post_training.job.status(job_uuid="your-job-id")
|
||||
```
|
||||
|
||||
### Cancel a job
|
||||
|
||||
```python
|
||||
client.post_training.job.cancel(job_uuid="your-job-id")
|
||||
```
|
||||
|
||||
### Inference with the fine-tuned model
|
||||
|
||||
```python
|
||||
response = client.inference.completion(
|
||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||
stream=False,
|
||||
model_id="test-example-model@v1",
|
||||
sampling_params={
|
||||
"max_tokens": 50,
|
||||
},
|
||||
)
|
||||
print(response.content)
|
||||
```
|
|
@ -0,0 +1,23 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
|
||||
async def get_adapter_impl(
|
||||
config: NvidiaPostTrainingConfig,
|
||||
_deps,
|
||||
):
|
||||
from .post_training import NvidiaPostTrainingAdapter
|
||||
|
||||
if not isinstance(config, NvidiaPostTrainingConfig):
|
||||
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||
|
||||
impl = NvidiaPostTrainingAdapter(config)
|
||||
return impl
|
||||
|
||||
|
||||
__all__ = ["get_adapter_impl", "NvidiaPostTrainingAdapter"]
|
113
llama_stack/providers/remote/post_training/nvidia/config.py
Normal file
113
llama_stack/providers/remote/post_training/nvidia/config.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import os
|
||||
from typing import Any, Dict, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
# TODO: add default values for all fields
|
||||
|
||||
|
||||
class NvidiaPostTrainingConfig(BaseModel):
|
||||
"""Configuration for NVIDIA Post Training implementation."""
|
||||
|
||||
api_key: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_API_KEY"),
|
||||
description="The NVIDIA API key.",
|
||||
)
|
||||
|
||||
dataset_namespace: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_DATASET_NAMESPACE", "default"),
|
||||
description="The NVIDIA dataset namespace.",
|
||||
)
|
||||
|
||||
project_id: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_PROJECT_ID", "test-example-model@v1"),
|
||||
description="The NVIDIA project ID.",
|
||||
)
|
||||
|
||||
# ToDO: validate this, add default value
|
||||
customizer_url: Optional[str] = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_CUSTOMIZER_URL"),
|
||||
description="Base URL for the NeMo Customizer API",
|
||||
)
|
||||
|
||||
timeout: int = Field(
|
||||
default=300,
|
||||
description="Timeout for the NVIDIA Post Training API",
|
||||
)
|
||||
|
||||
max_retries: int = Field(
|
||||
default=3,
|
||||
description="Maximum number of retries for the NVIDIA Post Training API",
|
||||
)
|
||||
|
||||
# ToDo: validate this
|
||||
output_model_dir: str = Field(
|
||||
default_factory=lambda: os.getenv("NVIDIA_OUTPUT_MODEL_DIR", "test-example-model@v1"),
|
||||
description="Directory to save the output model",
|
||||
)
|
||||
|
||||
@classmethod
|
||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||
return {
|
||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||
"dataset_namespace": "${env.NVIDIA_DATASET_NAMESPACE:default}",
|
||||
"project_id": "${env.NVIDIA_PROJECT_ID:test-project}",
|
||||
"customizer_url": "${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}",
|
||||
}
|
||||
|
||||
|
||||
class SFTLoRADefaultConfig(BaseModel):
|
||||
"""NVIDIA-specific training configuration with default values."""
|
||||
|
||||
# ToDo: split into SFT and LoRA configs??
|
||||
|
||||
# General training parameters
|
||||
n_epochs: int = 50
|
||||
|
||||
# NeMo customizer specific parameters
|
||||
log_every_n_steps: Optional[int] = None
|
||||
val_check_interval: float = 0.25
|
||||
sequence_packing_enabled: bool = False
|
||||
weight_decay: float = 0.01
|
||||
lr: float = 0.0001
|
||||
|
||||
# SFT specific parameters
|
||||
hidden_dropout: Optional[float] = None
|
||||
attention_dropout: Optional[float] = None
|
||||
ffn_dropout: Optional[float] = None
|
||||
|
||||
# LoRA default parameters
|
||||
lora_adapter_dim: int = 8
|
||||
lora_adapter_dropout: Optional[float] = None
|
||||
lora_alpha: int = 16
|
||||
|
||||
# Data config
|
||||
batch_size: int = 8
|
||||
|
||||
@classmethod
|
||||
def sample_config(cls) -> Dict[str, Any]:
|
||||
"""Return a sample configuration for NVIDIA training."""
|
||||
return {
|
||||
"n_epochs": 50,
|
||||
"log_every_n_steps": 10,
|
||||
"val_check_interval": 0.25,
|
||||
"sequence_packing_enabled": False,
|
||||
"weight_decay": 0.01,
|
||||
"hidden_dropout": 0.1,
|
||||
"attention_dropout": 0.1,
|
||||
"lora_adapter_dim": 8,
|
||||
"lora_alpha": 16,
|
||||
"data_config": {
|
||||
"dataset_id": "default",
|
||||
"batch_size": 8,
|
||||
},
|
||||
"optimizer_config": {
|
||||
"lr": 0.0001,
|
||||
},
|
||||
}
|
24
llama_stack/providers/remote/post_training/nvidia/models.py
Normal file
24
llama_stack/providers/remote/post_training/nvidia/models.py
Normal file
|
@ -0,0 +1,24 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
from typing import List
|
||||
|
||||
from llama_stack.models.llama.datatypes import CoreModelId
|
||||
from llama_stack.providers.utils.inference.model_registry import (
|
||||
ProviderModelEntry,
|
||||
build_hf_repo_model_entry,
|
||||
)
|
||||
|
||||
_MODEL_ENTRIES = [
|
||||
build_hf_repo_model_entry(
|
||||
"meta/llama-3.1-8b-instruct",
|
||||
CoreModelId.llama3_1_8b_instruct.value,
|
||||
)
|
||||
]
|
||||
|
||||
|
||||
def get_model_entries() -> List[ProviderModelEntry]:
|
||||
return _MODEL_ENTRIES
|
|
@ -0,0 +1,439 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import Any, Dict, List, Literal, Optional
|
||||
|
||||
import aiohttp
|
||||
from pydantic import BaseModel, ConfigDict
|
||||
|
||||
from llama_stack.apis.post_training import (
|
||||
AlgorithmConfig,
|
||||
DPOAlignmentConfig,
|
||||
JobStatus,
|
||||
PostTrainingJob,
|
||||
PostTrainingJobArtifactsResponse,
|
||||
PostTrainingJobStatusResponse,
|
||||
TrainingConfig,
|
||||
)
|
||||
from llama_stack.providers.remote.post_training.nvidia.config import NvidiaPostTrainingConfig
|
||||
from llama_stack.providers.remote.post_training.nvidia.utils import warn_unsupported_params
|
||||
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||
|
||||
from .models import _MODEL_ENTRIES
|
||||
|
||||
# Map API status to JobStatus enum
|
||||
STATUS_MAPPING = {
|
||||
"running": "in_progress",
|
||||
"completed": "completed",
|
||||
"failed": "failed",
|
||||
"cancelled": "cancelled",
|
||||
"pending": "scheduled",
|
||||
}
|
||||
|
||||
|
||||
class NvidiaPostTrainingJob(PostTrainingJob):
|
||||
"""Parse the response from the Customizer API.
|
||||
Inherits job_uuid from PostTrainingJob.
|
||||
Adds status, created_at, updated_at parameters.
|
||||
Passes through all other parameters from data field in the response.
|
||||
"""
|
||||
|
||||
model_config = ConfigDict(extra="allow")
|
||||
status: JobStatus
|
||||
created_at: datetime
|
||||
updated_at: datetime
|
||||
|
||||
|
||||
class ListNvidiaPostTrainingJobs(BaseModel):
|
||||
data: List[NvidiaPostTrainingJob]
|
||||
|
||||
|
||||
class NvidiaPostTrainingJobStatusResponse(PostTrainingJobStatusResponse):
|
||||
model_config = ConfigDict(extra="allow")
|
||||
|
||||
|
||||
class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||
def __init__(self, config: NvidiaPostTrainingConfig):
|
||||
self.config = config
|
||||
self.headers = {}
|
||||
if config.api_key:
|
||||
self.headers["Authorization"] = f"Bearer {config.api_key}"
|
||||
|
||||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||
# TODO: filter by available models based on /config endpoint
|
||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||
self.customizer_url = config.customizer_url
|
||||
|
||||
if not self.customizer_url:
|
||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||
self.customizer_url = "http://nemo.test"
|
||||
|
||||
async def _make_request(
|
||||
self,
|
||||
method: str,
|
||||
path: str,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
json: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> Dict[str, Any]:
|
||||
"""Helper method to make HTTP requests to the Customizer API."""
|
||||
url = f"{self.customizer_url}{path}"
|
||||
request_headers = self.headers.copy()
|
||||
|
||||
if headers:
|
||||
request_headers.update(headers)
|
||||
|
||||
# Add content-type header for JSON requests
|
||||
if json and "Content-Type" not in request_headers:
|
||||
request_headers["Content-Type"] = "application/json"
|
||||
|
||||
for _ in range(self.config.max_retries):
|
||||
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||
if response.status >= 400:
|
||||
error_data = await response.json()
|
||||
raise Exception(f"API request failed: {error_data}")
|
||||
return await response.json()
|
||||
|
||||
async def get_training_jobs(
|
||||
self,
|
||||
page: Optional[int] = 1,
|
||||
page_size: Optional[int] = 10,
|
||||
sort: Optional[Literal["created_at", "-created_at"]] = "created_at",
|
||||
) -> ListNvidiaPostTrainingJobs:
|
||||
"""Get all customization jobs.
|
||||
Updated the base class return type from ListPostTrainingJobsResponse to ListNvidiaPostTrainingJobs.
|
||||
|
||||
Returns a ListNvidiaPostTrainingJobs object with the following fields:
|
||||
- data: List[NvidiaPostTrainingJob] - List of NvidiaPostTrainingJob objects
|
||||
|
||||
ToDo: Support for schema input for filtering.
|
||||
"""
|
||||
params = {"page": page, "page_size": page_size, "sort": sort}
|
||||
|
||||
response = await self._make_request("GET", "/v1/customization/jobs", params=params)
|
||||
|
||||
jobs = []
|
||||
for job in response.get("data", []):
|
||||
job_id = job.pop("id")
|
||||
job_status = job.pop("status", "unknown").lower()
|
||||
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
||||
|
||||
# Convert string timestamps to datetime objects
|
||||
created_at = (
|
||||
datetime.fromisoformat(job.pop("created_at"))
|
||||
if "created_at" in job
|
||||
else datetime.now(tz=datetime.timezone.utc)
|
||||
)
|
||||
updated_at = (
|
||||
datetime.fromisoformat(job.pop("updated_at"))
|
||||
if "updated_at" in job
|
||||
else datetime.now(tz=datetime.timezone.utc)
|
||||
)
|
||||
|
||||
# Create NvidiaPostTrainingJob instance
|
||||
jobs.append(
|
||||
NvidiaPostTrainingJob(
|
||||
job_uuid=job_id,
|
||||
status=JobStatus(mapped_status),
|
||||
created_at=created_at,
|
||||
updated_at=updated_at,
|
||||
**job,
|
||||
)
|
||||
)
|
||||
|
||||
return ListNvidiaPostTrainingJobs(data=jobs)
|
||||
|
||||
async def get_training_job_status(self, job_uuid: str) -> NvidiaPostTrainingJobStatusResponse:
|
||||
"""Get the status of a customization job.
|
||||
Updated the base class return type from PostTrainingJobResponse to NvidiaPostTrainingJob.
|
||||
|
||||
Returns a NvidiaPostTrainingJob object with the following fields:
|
||||
- job_uuid: str - Unique identifier for the job
|
||||
- status: JobStatus - Current status of the job (in_progress, completed, failed, cancelled, scheduled)
|
||||
- created_at: datetime - The time when the job was created
|
||||
- updated_at: datetime - The last time the job status was updated
|
||||
|
||||
Additional fields that may be included:
|
||||
- steps_completed: Optional[int] - Number of training steps completed
|
||||
- epochs_completed: Optional[int] - Number of epochs completed
|
||||
- percentage_done: Optional[float] - Percentage of training completed (0-100)
|
||||
- best_epoch: Optional[int] - The epoch with the best performance
|
||||
- train_loss: Optional[float] - Training loss of the best checkpoint
|
||||
- val_loss: Optional[float] - Validation loss of the best checkpoint
|
||||
- metrics: Optional[Dict] - Additional training metrics
|
||||
- status_logs: Optional[List] - Detailed logs of status changes
|
||||
"""
|
||||
response = await self._make_request(
|
||||
"GET",
|
||||
f"/v1/customization/jobs/{job_uuid}/status",
|
||||
params={"job_id": job_uuid},
|
||||
)
|
||||
|
||||
api_status = response.pop("status").lower()
|
||||
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
||||
|
||||
return NvidiaPostTrainingJobStatusResponse(
|
||||
status=JobStatus(mapped_status),
|
||||
job_uuid=job_uuid,
|
||||
started_at=datetime.fromisoformat(response.pop("created_at")),
|
||||
updated_at=datetime.fromisoformat(response.pop("updated_at")),
|
||||
**response,
|
||||
)
|
||||
|
||||
async def cancel_training_job(self, job_uuid: str) -> None:
|
||||
await self._make_request(
|
||||
method="POST", path=f"/v1/customization/jobs/{job_uuid}/cancel", params={"job_id": job_uuid}
|
||||
)
|
||||
|
||||
async def get_training_job_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||
|
||||
async def get_post_training_artifacts(self, job_uuid: str) -> PostTrainingJobArtifactsResponse:
|
||||
raise NotImplementedError("Job artifacts are not implemented yet")
|
||||
|
||||
async def supervised_fine_tune(
|
||||
self,
|
||||
job_uuid: str,
|
||||
training_config: Dict[str, Any],
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
model: str,
|
||||
checkpoint_dir: Optional[str],
|
||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||
extra_json: Optional[Dict[str, Any]] = None,
|
||||
params: Optional[Dict[str, Any]] = None,
|
||||
headers: Optional[Dict[str, Any]] = None,
|
||||
**kwargs,
|
||||
) -> NvidiaPostTrainingJob:
|
||||
"""
|
||||
Fine-tunes a model on a dataset.
|
||||
Currently only supports Lora finetuning for standlone docker container.
|
||||
Assumptions:
|
||||
- nemo microservice is running and endpoint is set in config.customizer_url
|
||||
- dataset is registered separately in nemo datastore
|
||||
- model checkpoint is downloaded as per nemo customizer requirements
|
||||
|
||||
Parameters:
|
||||
training_config: TrainingConfig - Configuration for training
|
||||
model: str - Model identifier
|
||||
algorithm_config: Optional[AlgorithmConfig] - Algorithm-specific configuration
|
||||
checkpoint_dir: Optional[str] - Directory containing model checkpoints, ignored atm
|
||||
job_uuid: str - Unique identifier for the job, ignored atm
|
||||
hyperparam_search_config: Dict[str, Any] - Configuration for hyperparameter search, ignored atm
|
||||
logger_config: Dict[str, Any] - Configuration for logging, ignored atm
|
||||
|
||||
Environment Variables:
|
||||
- NVIDIA_API_KEY: str - API key for the NVIDIA API
|
||||
Default: None
|
||||
- NVIDIA_DATASET_NAMESPACE: str - Namespace of the dataset
|
||||
Default: "default"
|
||||
- NVIDIA_CUSTOMIZER_URL: str - URL of the NeMo Customizer API
|
||||
Default: "http://nemo.test"
|
||||
- NVIDIA_PROJECT_ID: str - ID of the project
|
||||
Default: "test-project"
|
||||
- NVIDIA_OUTPUT_MODEL_DIR: str - Directory to save the output model
|
||||
Default: "test-example-model@v1"
|
||||
|
||||
Supported models:
|
||||
- meta/llama-3.1-8b-instruct
|
||||
|
||||
Supported algorithm configs:
|
||||
- LoRA, SFT
|
||||
|
||||
Supported Parameters:
|
||||
- TrainingConfig:
|
||||
- n_epochs: int - Number of epochs to train
|
||||
Default: 50
|
||||
- data_config: DataConfig - Configuration for the dataset
|
||||
- optimizer_config: OptimizerConfig - Configuration for the optimizer
|
||||
- dtype: str - Data type for training
|
||||
not supported (users are informed via warnings)
|
||||
- efficiency_config: EfficiencyConfig - Configuration for efficiency
|
||||
not supported
|
||||
- max_steps_per_epoch: int - Maximum number of steps per epoch
|
||||
Default: 1000
|
||||
## NeMo customizer specific parameters
|
||||
- log_every_n_steps: int - Log every n steps
|
||||
Default: None
|
||||
- val_check_interval: float - Validation check interval
|
||||
Default: 0.25
|
||||
- sequence_packing_enabled: bool - Sequence packing enabled
|
||||
Default: False
|
||||
## NeMo customizer specific SFT parameters
|
||||
- hidden_dropout: float - Hidden dropout
|
||||
Default: None (0.0-1.0)
|
||||
- attention_dropout: float - Attention dropout
|
||||
Default: None (0.0-1.0)
|
||||
- ffn_dropout: float - FFN dropout
|
||||
Default: None (0.0-1.0)
|
||||
|
||||
- DataConfig:
|
||||
- dataset_id: str - Dataset ID
|
||||
- batch_size: int - Batch size
|
||||
Default: 8
|
||||
|
||||
- OptimizerConfig:
|
||||
- lr: float - Learning rate
|
||||
Default: 0.0001
|
||||
## NeMo customizer specific parameter
|
||||
- weight_decay: float - Weight decay
|
||||
Default: 0.01
|
||||
|
||||
- LoRA config:
|
||||
## NeMo customizer specific LoRA parameters
|
||||
- adapter_dim: int - Adapter dimension
|
||||
Default: 8 (supports powers of 2)
|
||||
- adapter_dropout: float - Adapter dropout
|
||||
Default: None (0.0-1.0)
|
||||
- alpha: int - Scaling factor for the LoRA update
|
||||
Default: 16
|
||||
Note:
|
||||
- checkpoint_dir, hyperparam_search_config, logger_config are not supported (users are informed via warnings)
|
||||
- Some parameters from TrainingConfig, DataConfig, OptimizerConfig are not supported (users are informed via warnings)
|
||||
|
||||
User is informed about unsupported parameters via warnings.
|
||||
"""
|
||||
# Map model to nvidia model name
|
||||
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
|
||||
nvidia_model = self.get_provider_model_id(model)
|
||||
|
||||
# Check for unsupported method parameters
|
||||
unsupported_method_params = []
|
||||
if checkpoint_dir:
|
||||
unsupported_method_params.append(f"checkpoint_dir={checkpoint_dir}")
|
||||
if hyperparam_search_config:
|
||||
unsupported_method_params.append("hyperparam_search_config")
|
||||
if logger_config:
|
||||
unsupported_method_params.append("logger_config")
|
||||
|
||||
if unsupported_method_params:
|
||||
warnings.warn(
|
||||
f"Parameters: {', '.join(unsupported_method_params)} are not supported and will be ignored",
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Define all supported parameters
|
||||
supported_params = {
|
||||
"training_config": {
|
||||
"n_epochs",
|
||||
"data_config",
|
||||
"optimizer_config",
|
||||
"log_every_n_steps",
|
||||
"val_check_interval",
|
||||
"sequence_packing_enabled",
|
||||
"hidden_dropout",
|
||||
"attention_dropout",
|
||||
"ffn_dropout",
|
||||
},
|
||||
"data_config": {"dataset_id", "batch_size"},
|
||||
"optimizer_config": {"lr", "weight_decay"},
|
||||
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
|
||||
}
|
||||
|
||||
# Validate all parameters at once
|
||||
warn_unsupported_params(training_config, supported_params["training_config"], "TrainingConfig")
|
||||
warn_unsupported_params(training_config["data_config"], supported_params["data_config"], "DataConfig")
|
||||
warn_unsupported_params(
|
||||
training_config["optimizer_config"], supported_params["optimizer_config"], "OptimizerConfig"
|
||||
)
|
||||
|
||||
output_model = self.config.output_model_dir
|
||||
|
||||
# Prepare base job configuration
|
||||
job_config = {
|
||||
"config": nvidia_model,
|
||||
"dataset": {
|
||||
"name": training_config["data_config"]["dataset_id"],
|
||||
"namespace": self.config.dataset_namespace,
|
||||
},
|
||||
"hyperparameters": {
|
||||
"training_type": "sft",
|
||||
"finetuning_type": "lora",
|
||||
**{
|
||||
k: v
|
||||
for k, v in {
|
||||
"epochs": training_config.get("n_epochs"),
|
||||
"batch_size": training_config["data_config"].get("batch_size"),
|
||||
"learning_rate": training_config["optimizer_config"].get("lr"),
|
||||
"weight_decay": training_config["optimizer_config"].get("weight_decay"),
|
||||
"log_every_n_steps": training_config.get("log_every_n_steps"),
|
||||
"val_check_interval": training_config.get("val_check_interval"),
|
||||
"sequence_packing_enabled": training_config.get("sequence_packing_enabled"),
|
||||
}.items()
|
||||
if v is not None
|
||||
},
|
||||
},
|
||||
"project": self.config.project_id,
|
||||
# TODO: ignored ownership, add it later
|
||||
# "ownership": {"created_by": self.config.user_id, "access_policies": self.config.access_policies},
|
||||
"output_model": output_model,
|
||||
}
|
||||
|
||||
# Handle SFT-specific optional parameters
|
||||
job_config["hyperparameters"]["sft"] = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"ffn_dropout": training_config.get("ffn_dropout"),
|
||||
"hidden_dropout": training_config.get("hidden_dropout"),
|
||||
"attention_dropout": training_config.get("attention_dropout"),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
|
||||
# Remove the sft dictionary if it's empty
|
||||
if not job_config["hyperparameters"]["sft"]:
|
||||
job_config["hyperparameters"].pop("sft")
|
||||
|
||||
# Handle LoRA-specific configuration
|
||||
if algorithm_config:
|
||||
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
|
||||
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
||||
job_config["hyperparameters"]["lora"] = {
|
||||
k: v
|
||||
for k, v in {
|
||||
"adapter_dim": algorithm_config.get("adapter_dim"),
|
||||
"alpha": algorithm_config.get("alpha"),
|
||||
"adapter_dropout": algorithm_config.get("adapter_dropout"),
|
||||
}.items()
|
||||
if v is not None
|
||||
}
|
||||
else:
|
||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||
|
||||
# Create the customization job
|
||||
response = await self._make_request(
|
||||
method="POST",
|
||||
path="/v1/customization/jobs",
|
||||
headers={"Accept": "application/json"},
|
||||
json=job_config,
|
||||
)
|
||||
|
||||
job_uuid = response["id"]
|
||||
response.pop("status")
|
||||
created_at = datetime.fromisoformat(response.pop("created_at"))
|
||||
updated_at = datetime.fromisoformat(response.pop("updated_at"))
|
||||
|
||||
return NvidiaPostTrainingJob(
|
||||
job_uuid=job_uuid, status=JobStatus.in_progress, created_at=created_at, updated_at=updated_at, **response
|
||||
)
|
||||
|
||||
async def preference_optimize(
|
||||
self,
|
||||
job_uuid: str,
|
||||
finetuned_model: str,
|
||||
algorithm_config: DPOAlignmentConfig,
|
||||
training_config: TrainingConfig,
|
||||
hyperparam_search_config: Dict[str, Any],
|
||||
logger_config: Dict[str, Any],
|
||||
) -> PostTrainingJob:
|
||||
"""Optimize a model based on preference data."""
|
||||
raise NotImplementedError("Preference optimization is not implemented yet")
|
||||
|
||||
async def get_training_job_container_logs(self, job_uuid: str) -> PostTrainingJobStatusResponse:
|
||||
raise NotImplementedError("Job logs are not implemented yet")
|
63
llama_stack/providers/remote/post_training/nvidia/utils.py
Normal file
63
llama_stack/providers/remote/post_training/nvidia/utils.py
Normal file
|
@ -0,0 +1,63 @@
|
|||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||
# All rights reserved.
|
||||
#
|
||||
# This source code is licensed under the terms described in the LICENSE file in
|
||||
# the root directory of this source tree.
|
||||
|
||||
import logging
|
||||
import warnings
|
||||
from typing import Any, Dict, Set, Tuple
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from llama_stack.apis.post_training import TrainingConfig
|
||||
from llama_stack.providers.remote.post_training.nvidia.config import SFTLoRADefaultConfig
|
||||
|
||||
from .config import NvidiaPostTrainingConfig
|
||||
|
||||
logger = logging.getLogger(__name__)
|
||||
|
||||
|
||||
def warn_unsupported_params(config_dict: Any, supported_keys: Set[str], config_name: str) -> None:
|
||||
keys = set(config_dict.__annotations__.keys()) if isinstance(config_dict, BaseModel) else config_dict.keys()
|
||||
unsupported_params = [k for k in keys if k not in supported_keys]
|
||||
if unsupported_params:
|
||||
warnings.warn(
|
||||
f"Parameters: {unsupported_params} in `{config_name}` not supported and will be ignored.", stacklevel=2
|
||||
)
|
||||
|
||||
|
||||
def validate_training_params(
|
||||
training_config: Dict[str, Any], supported_keys: Set[str], config_name: str = "TrainingConfig"
|
||||
) -> None:
|
||||
"""
|
||||
Validates training parameters against supported keys.
|
||||
|
||||
Args:
|
||||
training_config: Dictionary containing training configuration parameters
|
||||
supported_keys: Set of supported parameter keys
|
||||
config_name: Name of the configuration for warning messages
|
||||
"""
|
||||
sft_lora_fields = set(SFTLoRADefaultConfig.__annotations__.keys())
|
||||
training_config_fields = set(TrainingConfig.__annotations__.keys())
|
||||
|
||||
# Check for not supported parameters:
|
||||
# - not in either of configs
|
||||
# - in TrainingConfig but not in SFTLoRADefaultConfig
|
||||
unsupported_params = []
|
||||
for key in training_config:
|
||||
if isinstance(key, str) and key not in (supported_keys.union(sft_lora_fields)):
|
||||
if key in (not sft_lora_fields or training_config_fields):
|
||||
unsupported_params.append(key)
|
||||
|
||||
if unsupported_params:
|
||||
warnings.warn(
|
||||
f"Parameters: {unsupported_params} in `{config_name}` are not supported and will be ignored.", stacklevel=2
|
||||
)
|
||||
|
||||
|
||||
# ToDo: implement post health checks for customizer are enabled
|
||||
async def _get_health(url: str) -> Tuple[bool, bool]: ...
|
||||
|
||||
|
||||
async def check_health(config: NvidiaPostTrainingConfig) -> None: ...
|
|
@ -39,6 +39,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/bedrock/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -79,6 +79,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/cerebras/trace_store.db}
|
||||
tool_runtime:
|
||||
|
|
|
@ -42,6 +42,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ci-tests/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -45,6 +45,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -41,6 +41,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dell/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -71,6 +71,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/dev/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -50,6 +50,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -45,6 +45,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/fireworks/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -45,6 +45,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/groq/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -50,6 +50,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -45,6 +45,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-endpoint/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -50,6 +50,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -45,6 +45,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/hf-serverless/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -52,6 +52,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -46,6 +46,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-gpu/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -48,6 +48,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/meta-reference-quantized-gpu/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -14,6 +14,8 @@ distribution_spec:
|
|||
- inline::meta-reference
|
||||
eval:
|
||||
- inline::meta-reference
|
||||
post_training:
|
||||
- remote::nvidia
|
||||
datasetio:
|
||||
- inline::localfs
|
||||
scoring:
|
||||
|
|
|
@ -21,6 +21,7 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"agents": ["inline::meta-reference"],
|
||||
"telemetry": ["inline::meta-reference"],
|
||||
"eval": ["inline::meta-reference"],
|
||||
"post_training": ["remote::nvidia"],
|
||||
"datasetio": ["inline::localfs"],
|
||||
"scoring": ["inline::basic"],
|
||||
"tool_runtime": ["inline::rag-runtime"],
|
||||
|
@ -89,6 +90,31 @@ def get_distribution_template() -> DistributionTemplate:
|
|||
"",
|
||||
"NVIDIA API Key",
|
||||
),
|
||||
## Nemo Customizer related variables
|
||||
"NVIDIA_USER_ID": (
|
||||
"llama-stack-user",
|
||||
"NVIDIA User ID",
|
||||
),
|
||||
"NVIDIA_DATASET_NAMESPACE": (
|
||||
"default",
|
||||
"NVIDIA Dataset Namespace",
|
||||
),
|
||||
"NVIDIA_ACCESS_POLICIES": (
|
||||
"{}",
|
||||
"NVIDIA Access Policies",
|
||||
),
|
||||
"NVIDIA_PROJECT_ID": (
|
||||
"test-project",
|
||||
"NVIDIA Project ID",
|
||||
),
|
||||
"NVIDIA_CUSTOMIZER_URL": (
|
||||
"https://customizer.api.nvidia.com",
|
||||
"NVIDIA Customizer URL",
|
||||
),
|
||||
"NVIDIA_OUTPUT_MODEL_DIR": (
|
||||
"test-example-model@v1",
|
||||
"NVIDIA Output Model Directory",
|
||||
),
|
||||
"GUARDRAILS_SERVICE_URL": (
|
||||
"http://0.0.0.0:7331",
|
||||
"URL for the NeMo Guardrails Service",
|
||||
|
|
|
@ -5,6 +5,7 @@ apis:
|
|||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
|
@ -48,6 +49,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||
eval:
|
||||
|
@ -58,6 +60,14 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
||||
post_training:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
api_key: ${env.NVIDIA_API_KEY:}
|
||||
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:default}
|
||||
project_id: ${env.NVIDIA_PROJECT_ID:test-project}
|
||||
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}
|
||||
datasetio:
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
|
|
|
@ -5,6 +5,7 @@ apis:
|
|||
- datasetio
|
||||
- eval
|
||||
- inference
|
||||
- post_training
|
||||
- safety
|
||||
- scoring
|
||||
- telemetry
|
||||
|
@ -43,6 +44,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||
eval:
|
||||
|
@ -53,6 +55,14 @@ providers:
|
|||
type: sqlite
|
||||
namespace: null
|
||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
||||
post_training:
|
||||
- provider_id: nvidia
|
||||
provider_type: remote::nvidia
|
||||
config:
|
||||
api_key: ${env.NVIDIA_API_KEY:}
|
||||
dataset_namespace: ${env.NVIDIA_DATASET_NAMESPACE:default}
|
||||
project_id: ${env.NVIDIA_PROJECT_ID:test-project}
|
||||
customizer_url: ${env.NVIDIA_CUSTOMIZER_URL:http://nemo.test}
|
||||
datasetio:
|
||||
- provider_id: localfs
|
||||
provider_type: inline::localfs
|
||||
|
|
|
@ -43,6 +43,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -41,6 +41,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -68,6 +68,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/open-benchmark/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -50,6 +50,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -45,6 +45,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/passthrough/trace_store.db}
|
||||
eval:
|
||||
|
|
|
@ -88,6 +88,7 @@ providers:
|
|||
- provider_id: meta-reference
|
||||
provider_type: inline::meta-reference
|
||||
config:
|
||||
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/remote-vllm/trace_store.db}
|
||||
tool_runtime:
|
||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue