mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-21 03:59:42 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
7cdd2a0410
264 changed files with 229042 additions and 8445 deletions
|
@ -320,7 +320,7 @@ jobs:
|
||||||
- name: "PR - Update comment"
|
- name: "PR - Update comment"
|
||||||
id: pr_update_comment
|
id: pr_update_comment
|
||||||
if: github.event_name == 'pull_request_target'
|
if: github.event_name == 'pull_request_target'
|
||||||
uses: thollander/actions-comment-pull-request@65f9e5c9a1f2cd378bd74b2e057c9736982a8e74 # v3.0.1
|
uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b # v3.0.1
|
||||||
with:
|
with:
|
||||||
filePath: test-summary.md
|
filePath: test-summary.md
|
||||||
|
|
||||||
|
|
36
.github/workflows/integration-tests.yml
vendored
36
.github/workflows/integration-tests.yml
vendored
|
@ -34,22 +34,20 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
|
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
- name: Install Ollama
|
- name: Install and start Ollama
|
||||||
run: |
|
run: |
|
||||||
|
# the ollama installer also starts the ollama service
|
||||||
curl -fsSL https://ollama.com/install.sh | sh
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
|
||||||
- name: Pull Ollama image
|
- name: Pull Ollama image
|
||||||
run: |
|
run: |
|
||||||
|
# TODO: cache the model. OLLAMA_MODELS defaults to ~ollama/.ollama/models.
|
||||||
ollama pull llama3.2:3b-instruct-fp16
|
ollama pull llama3.2:3b-instruct-fp16
|
||||||
|
|
||||||
- name: Start Ollama in background
|
|
||||||
run: |
|
|
||||||
nohup ollama run llama3.2:3b-instruct-fp16 > ollama.log 2>&1 &
|
|
||||||
|
|
||||||
- name: Set Up Environment and Install Dependencies
|
- name: Set Up Environment and Install Dependencies
|
||||||
run: |
|
run: |
|
||||||
uv sync --extra dev --extra test
|
uv sync --extra dev --extra test
|
||||||
|
@ -61,21 +59,6 @@ jobs:
|
||||||
uv pip install -e .
|
uv pip install -e .
|
||||||
llama stack build --template ollama --image-type venv
|
llama stack build --template ollama --image-type venv
|
||||||
|
|
||||||
- name: Wait for Ollama to start
|
|
||||||
run: |
|
|
||||||
echo "Waiting for Ollama..."
|
|
||||||
for i in {1..30}; do
|
|
||||||
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
|
||||||
echo "Ollama is running!"
|
|
||||||
exit 0
|
|
||||||
fi
|
|
||||||
sleep 1
|
|
||||||
done
|
|
||||||
echo "Ollama failed to start"
|
|
||||||
ollama ps
|
|
||||||
ollama.log
|
|
||||||
exit 1
|
|
||||||
|
|
||||||
- name: Start Llama Stack server in background
|
- name: Start Llama Stack server in background
|
||||||
if: matrix.client-type == 'http'
|
if: matrix.client-type == 'http'
|
||||||
env:
|
env:
|
||||||
|
@ -99,6 +82,17 @@ jobs:
|
||||||
cat server.log
|
cat server.log
|
||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
|
- name: Verify Ollama status is OK
|
||||||
|
if: matrix.client-type == 'http'
|
||||||
|
run: |
|
||||||
|
echo "Verifying Ollama status..."
|
||||||
|
ollama_status=$(curl -s -L http://127.0.0.1:8321/v1/providers/ollama|jq --raw-output .health.status)
|
||||||
|
echo "Ollama status: $ollama_status"
|
||||||
|
if [ "$ollama_status" != "OK" ]; then
|
||||||
|
echo "Ollama health check failed"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
||||||
- name: Run Integration Tests
|
- name: Run Integration Tests
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
|
9
.github/workflows/pre-commit.yml
vendored
9
.github/workflows/pre-commit.yml
vendored
|
@ -31,3 +31,12 @@ jobs:
|
||||||
- name: Verify if there are any diff files after pre-commit
|
- name: Verify if there are any diff files after pre-commit
|
||||||
run: |
|
run: |
|
||||||
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
|
git diff --exit-code || (echo "There are uncommitted changes, run pre-commit locally and commit again" && exit 1)
|
||||||
|
|
||||||
|
- name: Verify if there are any new files after pre-commit
|
||||||
|
run: |
|
||||||
|
unstaged_files=$(git ls-files --others --exclude-standard)
|
||||||
|
if [ -n "$unstaged_files" ]; then
|
||||||
|
echo "There are uncommitted new files, run pre-commit locally and commit again"
|
||||||
|
echo "$unstaged_files"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
28
.github/workflows/providers-build.yml
vendored
28
.github/workflows/providers-build.yml
vendored
|
@ -56,7 +56,7 @@ jobs:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
|
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
@ -81,3 +81,29 @@ jobs:
|
||||||
run: |
|
run: |
|
||||||
source test/bin/activate
|
source test/bin/activate
|
||||||
uv pip list
|
uv pip list
|
||||||
|
|
||||||
|
build-single-provider:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@v5
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
- name: Build a single provider
|
||||||
|
run: |
|
||||||
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama
|
||||||
|
|
93
.github/workflows/test-external-providers.yml
vendored
Normal file
93
.github/workflows/test-external-providers.yml
vendored
Normal file
|
@ -0,0 +1,93 @@
|
||||||
|
name: Test External Providers
|
||||||
|
|
||||||
|
on:
|
||||||
|
push:
|
||||||
|
branches: [ main ]
|
||||||
|
pull_request:
|
||||||
|
branches: [ main ]
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
test-external-providers:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@v4
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@v5
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install Ollama
|
||||||
|
run: |
|
||||||
|
curl -fsSL https://ollama.com/install.sh | sh
|
||||||
|
|
||||||
|
- name: Pull Ollama image
|
||||||
|
run: |
|
||||||
|
ollama pull llama3.2:3b-instruct-fp16
|
||||||
|
|
||||||
|
- name: Start Ollama in background
|
||||||
|
run: |
|
||||||
|
nohup ollama run llama3.2:3b-instruct-fp16 --keepalive=30m > ollama.log 2>&1 &
|
||||||
|
|
||||||
|
- name: Set Up Environment and Install Dependencies
|
||||||
|
run: |
|
||||||
|
uv sync --extra dev --extra test
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
- name: Install Ollama custom provider
|
||||||
|
run: |
|
||||||
|
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
||||||
|
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
||||||
|
uv pip install tests/external-provider/llama-stack-provider-ollama
|
||||||
|
|
||||||
|
- name: Create provider configuration
|
||||||
|
run: |
|
||||||
|
mkdir -p /tmp/providers.d/remote/inference
|
||||||
|
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
||||||
|
|
||||||
|
- name: Wait for Ollama to start
|
||||||
|
run: |
|
||||||
|
echo "Waiting for Ollama..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if curl -s http://localhost:11434 | grep -q "Ollama is running"; then
|
||||||
|
echo "Ollama is running!"
|
||||||
|
exit 0
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "Ollama failed to start"
|
||||||
|
ollama ps
|
||||||
|
ollama.log
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
- name: Start Llama Stack server in background
|
||||||
|
env:
|
||||||
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
|
run: |
|
||||||
|
source .venv/bin/activate
|
||||||
|
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
|
||||||
|
|
||||||
|
- name: Wait for Llama Stack server to be ready
|
||||||
|
run: |
|
||||||
|
echo "Waiting for Llama Stack server..."
|
||||||
|
for i in {1..30}; do
|
||||||
|
if curl -s http://localhost:8321/v1/health | grep -q "OK"; then
|
||||||
|
echo "Llama Stack server is up!"
|
||||||
|
if grep -q "remote::custom_ollama from /tmp/providers.d/remote/inference/custom_ollama.yaml" server.log; then
|
||||||
|
echo "Llama Stack server is using custom Ollama provider"
|
||||||
|
exit 0
|
||||||
|
else
|
||||||
|
echo "Llama Stack server is not using custom Ollama provider"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
sleep 1
|
||||||
|
done
|
||||||
|
echo "Llama Stack server failed to start"
|
||||||
|
cat server.log
|
||||||
|
exit 1
|
||||||
|
|
||||||
|
- name: run inference tests
|
||||||
|
run: |
|
||||||
|
uv run pytest -v tests/integration/inference/test_text_inference.py --stack-config="http://localhost:8321" --text-model="meta-llama/Llama-3.2-3B-Instruct" --embedding-model=all-MiniLM-L6-v2
|
2
.github/workflows/unit-tests.yml
vendored
2
.github/workflows/unit-tests.yml
vendored
|
@ -38,7 +38,7 @@ jobs:
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
- uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
|
- uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
enable-cache: false
|
enable-cache: false
|
||||||
|
|
2
.github/workflows/update-readthedocs.yml
vendored
2
.github/workflows/update-readthedocs.yml
vendored
|
@ -41,7 +41,7 @@ jobs:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
|
|
||||||
- name: Install the latest version of uv
|
- name: Install the latest version of uv
|
||||||
uses: astral-sh/setup-uv@22695119d769bdb6f7032ad67b9bca0ef8c4a174 # v5.4.0
|
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
|
|
||||||
- name: Sync with uv
|
- name: Sync with uv
|
||||||
run: uv sync --extra docs
|
run: uv sync --extra docs
|
||||||
|
|
37
CHANGELOG.md
37
CHANGELOG.md
|
@ -1,5 +1,42 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
# v0.2.1
|
||||||
|
Published on: 2025-04-05T23:13:00Z
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.0
|
||||||
|
Published on: 2025-04-05T19:04:29Z
|
||||||
|
|
||||||
|
## Llama 4 Support
|
||||||
|
|
||||||
|
Checkout more at https://www.llama.com
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.1.9
|
||||||
|
Published on: 2025-03-29T00:52:23Z
|
||||||
|
|
||||||
|
### Build and Test Agents
|
||||||
|
* Agents: Entire document context with attachments
|
||||||
|
* RAG: Documentation with sqlite-vec faiss comparison
|
||||||
|
* Getting started: Fixes to getting started notebook.
|
||||||
|
|
||||||
|
### Agent Evals and Model Customization
|
||||||
|
* (**New**) Post-training: Add nemo customizer
|
||||||
|
|
||||||
|
### Better Engineering
|
||||||
|
* Moved sqlite-vec to non-blocking calls
|
||||||
|
* Don't return a payload on file delete
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# v0.1.8
|
# v0.1.8
|
||||||
Published on: 2025-03-24T01:28:50Z
|
Published on: 2025-03-24T01:28:50Z
|
||||||
|
|
||||||
|
|
|
@ -1,8 +1,10 @@
|
||||||
include pyproject.toml
|
include pyproject.toml
|
||||||
include llama_stack/templates/dependencies.json
|
include llama_stack/templates/dependencies.json
|
||||||
include llama_stack/models/llama/llama3/tokenizer.model
|
include llama_stack/models/llama/llama3/tokenizer.model
|
||||||
|
include llama_stack/models/llama/llama4/tokenizer.model
|
||||||
include llama_stack/distribution/*.sh
|
include llama_stack/distribution/*.sh
|
||||||
include llama_stack/cli/scripts/*.sh
|
include llama_stack/cli/scripts/*.sh
|
||||||
include llama_stack/templates/*/*.yaml
|
include llama_stack/templates/*/*.yaml
|
||||||
include llama_stack/providers/tests/test_cases/inference/*.json
|
include llama_stack/providers/tests/test_cases/inference/*.json
|
||||||
include llama_stack/models/llama/*/*.md
|
include llama_stack/models/llama/*/*.md
|
||||||
|
include llama_stack/tests/integration/*.jpg
|
||||||
|
|
66
README.md
66
README.md
|
@ -3,12 +3,76 @@
|
||||||
[](https://pypi.org/project/llama_stack/)
|
[](https://pypi.org/project/llama_stack/)
|
||||||
[](https://pypi.org/project/llama-stack/)
|
[](https://pypi.org/project/llama-stack/)
|
||||||
[](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
|
[](https://github.com/meta-llama/llama-stack/blob/main/LICENSE)
|
||||||
[](https://discord.gg/llama-stack)
|
[](https://discord.gg/llama-stack)
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
|
||||||
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
|
||||||
|
|
||||||
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb)
|
||||||
|
|
||||||
|
### ✨🎉 Llama 4 Support 🎉✨
|
||||||
|
We released [Version 0.2.0](https://github.com/meta-llama/llama-stack/releases/tag/v0.2.0) with support for the Llama 4 herd of models released by Meta.
|
||||||
|
|
||||||
|
<details>
|
||||||
|
|
||||||
|
<summary>👋 Click here to see how to run Llama 4 models on Llama Stack </summary>
|
||||||
|
|
||||||
|
\
|
||||||
|
*Note you need 8xH100 GPU-host to run these models*
|
||||||
|
|
||||||
|
```bash
|
||||||
|
pip install -U llama_stack
|
||||||
|
|
||||||
|
MODEL="Llama-4-Scout-17B-16E-Instruct"
|
||||||
|
# get meta url from llama.com
|
||||||
|
llama model download --source meta --model-id $MODEL --meta-url <META_URL>
|
||||||
|
|
||||||
|
# start a llama stack server
|
||||||
|
INFERENCE_MODEL=meta-llama/$MODEL llama stack build --run --template meta-reference-gpu
|
||||||
|
|
||||||
|
# install client to interact with the server
|
||||||
|
pip install llama-stack-client
|
||||||
|
```
|
||||||
|
### CLI
|
||||||
|
```bash
|
||||||
|
# Run a chat completion
|
||||||
|
llama-stack-client --endpoint http://localhost:8321 \
|
||||||
|
inference chat-completion \
|
||||||
|
--model-id meta-llama/$MODEL \
|
||||||
|
--message "write a haiku for meta's llama 4 models"
|
||||||
|
|
||||||
|
ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(content="Whispers in code born\nLlama's gentle, wise heartbeat\nFuture's soft unfold", role='assistant', stop_reason='end_of_turn', tool_calls=[]),
|
||||||
|
logprobs=None,
|
||||||
|
metrics=[Metric(metric='prompt_tokens', value=21.0, unit=None), Metric(metric='completion_tokens', value=28.0, unit=None), Metric(metric='total_tokens', value=49.0, unit=None)]
|
||||||
|
)
|
||||||
|
```
|
||||||
|
### Python SDK
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
client = LlamaStackClient(base_url=f"http://localhost:8321")
|
||||||
|
|
||||||
|
model_id = "meta-llama/Llama-4-Scout-17B-16E-Instruct"
|
||||||
|
prompt = "Write a haiku about coding"
|
||||||
|
|
||||||
|
print(f"User> {prompt}")
|
||||||
|
response = client.inference.chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": prompt},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(f"Assistant> {response.completion_message.content}")
|
||||||
|
```
|
||||||
|
As more providers start supporting Llama 4, you can use them in Llama Stack as well. We are adding to the list. Stay tuned!
|
||||||
|
|
||||||
|
|
||||||
|
</details>
|
||||||
|
|
||||||
|
|
||||||
|
### Overview
|
||||||
|
|
||||||
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides
|
Llama Stack standardizes the core building blocks that simplify AI application development. It codifies best practices across the Llama ecosystem. More specifically, it provides
|
||||||
|
|
||||||
- **Unified API layer** for Inference, RAG, Agents, Tools, Safety, Evals, and Telemetry.
|
- **Unified API layer** for Inference, RAG, Agents, Tools, Safety, Evals, and Telemetry.
|
||||||
|
|
11
docs/_static/css/my_theme.css
vendored
11
docs/_static/css/my_theme.css
vendored
|
@ -16,3 +16,14 @@
|
||||||
.hide-title h1 {
|
.hide-title h1 {
|
||||||
display: none;
|
display: none;
|
||||||
}
|
}
|
||||||
|
|
||||||
|
h2, h3, h4 {
|
||||||
|
font-weight: normal;
|
||||||
|
}
|
||||||
|
html[data-theme="dark"] .rst-content div[class^="highlight"] {
|
||||||
|
background-color: #0b0b0b;
|
||||||
|
}
|
||||||
|
pre {
|
||||||
|
white-space: pre-wrap !important;
|
||||||
|
word-break: break-all;
|
||||||
|
}
|
||||||
|
|
9
docs/_static/js/detect_theme.js
vendored
Normal file
9
docs/_static/js/detect_theme.js
vendored
Normal file
|
@ -0,0 +1,9 @@
|
||||||
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
|
const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
|
||||||
|
const htmlElement = document.documentElement;
|
||||||
|
if (prefersDark) {
|
||||||
|
htmlElement.setAttribute("data-theme", "dark");
|
||||||
|
} else {
|
||||||
|
htmlElement.setAttribute("data-theme", "light");
|
||||||
|
}
|
||||||
|
});
|
1496
docs/_static/llama-stack-spec.html
vendored
1496
docs/_static/llama-stack-spec.html
vendored
File diff suppressed because it is too large
Load diff
1092
docs/_static/llama-stack-spec.yaml
vendored
1092
docs/_static/llama-stack-spec.yaml
vendored
File diff suppressed because it is too large
Load diff
File diff suppressed because it is too large
Load diff
876
docs/getting_started_llama4.ipynb
Normal file
876
docs/getting_started_llama4.ipynb
Normal file
File diff suppressed because one or more lines are too long
|
@ -51,6 +51,7 @@ def main(output_dir: str):
|
||||||
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now
|
"Converting the spec to YAML (openapi.yaml) and HTML (openapi.html) at " + now
|
||||||
)
|
)
|
||||||
print("")
|
print("")
|
||||||
|
|
||||||
spec = Specification(
|
spec = Specification(
|
||||||
LlamaStack,
|
LlamaStack,
|
||||||
Options(
|
Options(
|
||||||
|
|
|
@ -519,7 +519,7 @@ class Generator:
|
||||||
)
|
)
|
||||||
|
|
||||||
def _build_extra_tag_groups(
|
def _build_extra_tag_groups(
|
||||||
self, extra_types: Dict[str, List[type]]
|
self, extra_types: Dict[str, Dict[str, type]]
|
||||||
) -> Dict[str, List[Tag]]:
|
) -> Dict[str, List[Tag]]:
|
||||||
"""
|
"""
|
||||||
Creates a dictionary of tag group captions as keys, and tag lists as values.
|
Creates a dictionary of tag group captions as keys, and tag lists as values.
|
||||||
|
@ -532,9 +532,8 @@ class Generator:
|
||||||
for category_name, category_items in extra_types.items():
|
for category_name, category_items in extra_types.items():
|
||||||
tag_list: List[Tag] = []
|
tag_list: List[Tag] = []
|
||||||
|
|
||||||
for extra_type in category_items:
|
for name, extra_type in category_items.items():
|
||||||
name = python_type_to_name(extra_type)
|
schema = self.schema_builder.classdef_to_schema(extra_type)
|
||||||
schema = self.schema_builder.classdef_to_named_schema(name, extra_type)
|
|
||||||
tag_list.append(self._build_type_tag(name, schema))
|
tag_list.append(self._build_type_tag(name, schema))
|
||||||
|
|
||||||
if tag_list:
|
if tag_list:
|
||||||
|
@ -863,7 +862,7 @@ class Generator:
|
||||||
for caption, extra_tag_group in extra_tag_groups.items():
|
for caption, extra_tag_group in extra_tag_groups.items():
|
||||||
tag_groups.append(
|
tag_groups.append(
|
||||||
TagGroup(
|
TagGroup(
|
||||||
name=self.options.map(caption),
|
name=caption,
|
||||||
tags=sorted(tag.name for tag in extra_tag_group),
|
tags=sorted(tag.name for tag in extra_tag_group),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
|
@ -2,6 +2,14 @@
|
||||||
|
|
||||||
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
Here's a collection of comprehensive guides, examples, and resources for building AI applications with Llama Stack. For the complete documentation, visit our [ReadTheDocs page](https://llama-stack.readthedocs.io/en/latest/index.html).
|
||||||
|
|
||||||
|
## Render locally
|
||||||
|
```bash
|
||||||
|
pip install -r requirements.txt
|
||||||
|
cd docs
|
||||||
|
python -m sphinx_autobuild source _build
|
||||||
|
```
|
||||||
|
You can open up the docs in your browser at http://localhost:8000
|
||||||
|
|
||||||
## Content
|
## Content
|
||||||
|
|
||||||
Try out Llama Stack's capabilities through our detailed Jupyter notebooks:
|
Try out Llama Stack's capabilities through our detailed Jupyter notebooks:
|
||||||
|
|
|
@ -3,10 +3,12 @@ myst-parser
|
||||||
linkify
|
linkify
|
||||||
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
-e git+https://github.com/pytorch/pytorch_sphinx_theme.git#egg=pytorch_sphinx_theme
|
||||||
sphinx-rtd-theme>=1.0.0
|
sphinx-rtd-theme>=1.0.0
|
||||||
sphinx-pdj-theme
|
sphinx_autobuild
|
||||||
sphinx-copybutton
|
sphinx-copybutton
|
||||||
sphinx-tabs
|
|
||||||
sphinx-design
|
sphinx-design
|
||||||
|
sphinx-pdj-theme
|
||||||
|
sphinx_rtd_dark_mode
|
||||||
|
sphinx-tabs
|
||||||
sphinxcontrib-openapi
|
sphinxcontrib-openapi
|
||||||
sphinxcontrib-redoc
|
sphinxcontrib-redoc
|
||||||
sphinxcontrib-mermaid
|
sphinxcontrib-mermaid
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
# Llama Stack Agent Framework
|
# Agents
|
||||||
|
|
||||||
The Llama Stack agent framework is built on a modular architecture that allows for flexible and powerful AI applications. This document explains the key components and how they work together.
|
An Agent in Llama Stack is a powerful abstraction that allows you to build complex AI applications.
|
||||||
|
|
||||||
|
The Llama Stack agent framework is built on a modular architecture that allows for flexible and powerful AI
|
||||||
|
applications. This document explains the key components and how they work together.
|
||||||
|
|
||||||
## Core Concepts
|
## Core Concepts
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
## Agent Execution Loop
|
## Agent Execution Loop
|
||||||
|
|
||||||
Agents are the heart of complex AI applications. They combine inference, memory, safety, and tool usage into coherent workflows. At its core, an agent follows a sophisticated execution loop that enables multi-step reasoning, tool usage, and safety checks.
|
Agents are the heart of Llama Stack applications. They combine inference, memory, safety, and tool usage into coherent
|
||||||
|
workflows. At its core, an agent follows a sophisticated execution loop that enables multi-step reasoning, tool usage,
|
||||||
|
and safety checks.
|
||||||
|
|
||||||
|
### Steps in the Agent Workflow
|
||||||
|
|
||||||
Each agent turn follows these key steps:
|
Each agent turn follows these key steps:
|
||||||
|
|
||||||
|
@ -64,7 +68,10 @@ sequenceDiagram
|
||||||
S->>U: 5. Final Response
|
S->>U: 5. Final Response
|
||||||
```
|
```
|
||||||
|
|
||||||
Each step in this process can be monitored and controlled through configurations. Here's an example that demonstrates monitoring the agent's execution:
|
Each step in this process can be monitored and controlled through configurations.
|
||||||
|
|
||||||
|
### Agent Execution Loop Example
|
||||||
|
Here's an example that demonstrates monitoring the agent's execution:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
|
from llama_stack_client import LlamaStackClient, Agent, AgentEventLogger
|
||||||
|
|
|
@ -8,9 +8,9 @@ The best way to get started is to look at this notebook which walks through the
|
||||||
|
|
||||||
Here are some key topics that will help you build effective agents:
|
Here are some key topics that will help you build effective agents:
|
||||||
|
|
||||||
|
- **[RAG (Retrieval-Augmented Generation)](rag)**: Learn how to enhance your agents with external knowledge through retrieval mechanisms.
|
||||||
- **[Agent](agent)**: Understand the components and design patterns of the Llama Stack agent framework.
|
- **[Agent](agent)**: Understand the components and design patterns of the Llama Stack agent framework.
|
||||||
- **[Agent Execution Loop](agent_execution_loop)**: Understand how agents process information, make decisions, and execute actions in a continuous loop.
|
- **[Agent Execution Loop](agent_execution_loop)**: Understand how agents process information, make decisions, and execute actions in a continuous loop.
|
||||||
- **[RAG (Retrieval-Augmented Generation)](rag)**: Learn how to enhance your agents with external knowledge through retrieval mechanisms.
|
|
||||||
- **[Tools](tools)**: Extend your agents' capabilities by integrating with external tools and APIs.
|
- **[Tools](tools)**: Extend your agents' capabilities by integrating with external tools and APIs.
|
||||||
- **[Evals](evals)**: Evaluate your agents' effectiveness and identify areas for improvement.
|
- **[Evals](evals)**: Evaluate your agents' effectiveness and identify areas for improvement.
|
||||||
- **[Telemetry](telemetry)**: Monitor and analyze your agents' performance and behavior.
|
- **[Telemetry](telemetry)**: Monitor and analyze your agents' performance and behavior.
|
||||||
|
@ -20,12 +20,11 @@ Here are some key topics that will help you build effective agents:
|
||||||
:hidden:
|
:hidden:
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
|
rag
|
||||||
agent
|
agent
|
||||||
agent_execution_loop
|
agent_execution_loop
|
||||||
rag
|
|
||||||
tools
|
tools
|
||||||
telemetry
|
|
||||||
evals
|
evals
|
||||||
advanced_agent_patterns
|
telemetry
|
||||||
safety
|
safety
|
||||||
```
|
```
|
||||||
|
|
|
@ -3,9 +3,9 @@
|
||||||
RAG enables your applications to reference and recall information from previous interactions or external documents.
|
RAG enables your applications to reference and recall information from previous interactions or external documents.
|
||||||
|
|
||||||
Llama Stack organizes the APIs that enable RAG into three layers:
|
Llama Stack organizes the APIs that enable RAG into three layers:
|
||||||
- the lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.)
|
1. The lowermost APIs deal with raw storage and retrieval. These include Vector IO, KeyValue IO (coming soon) and Relational IO (also coming soon.).
|
||||||
- next is the "Rag Tool", a first-class tool as part of the Tools API that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
|
2. The next is the "Rag Tool", a first-class tool as part of the [Tools API](tools.md) that allows you to ingest documents (from URLs, files, etc) with various chunking strategies and query them smartly.
|
||||||
- finally, it all comes together with the top-level "Agents" API that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
|
3. Finally, it all comes together with the top-level ["Agents" API](agent.md) that allows you to create agents that can use the tools to answer questions, perform tasks, and more.
|
||||||
|
|
||||||
<img src="rag.png" alt="RAG System" width="50%">
|
<img src="rag.png" alt="RAG System" width="50%">
|
||||||
|
|
||||||
|
@ -17,14 +17,19 @@ We may add more storage types like Graph IO in the future.
|
||||||
|
|
||||||
### Setting up Vector DBs
|
### Setting up Vector DBs
|
||||||
|
|
||||||
|
For this guide, we will use [Ollama](https://ollama.com/) as the inference provider.
|
||||||
|
Ollama is an LLM runtime that allows you to run Llama models locally.
|
||||||
|
|
||||||
Here's how to set up a vector database for RAG:
|
Here's how to set up a vector database for RAG:
|
||||||
|
|
||||||
```python
|
```python
|
||||||
# Create http client
|
# Create http client
|
||||||
|
import os
|
||||||
from llama_stack_client import LlamaStackClient
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")
|
client = LlamaStackClient(base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}")
|
||||||
|
|
||||||
|
|
||||||
# Register a vector db
|
# Register a vector db
|
||||||
vector_db_id = "my_documents"
|
vector_db_id = "my_documents"
|
||||||
response = client.vector_dbs.register(
|
response = client.vector_dbs.register(
|
||||||
|
@ -33,17 +38,27 @@ response = client.vector_dbs.register(
|
||||||
embedding_dimension=384,
|
embedding_dimension=384,
|
||||||
provider_id="faiss",
|
provider_id="faiss",
|
||||||
)
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Ingesting Documents
|
||||||
|
You can ingest documents into the vector database using two methods: directly inserting pre-chunked
|
||||||
|
documents or using the RAG Tool.
|
||||||
|
```python
|
||||||
# You can insert a pre-chunked document directly into the vector db
|
# You can insert a pre-chunked document directly into the vector db
|
||||||
chunks = [
|
chunks = [
|
||||||
{
|
{
|
||||||
"document_id": "doc1",
|
|
||||||
"content": "Your document text here",
|
"content": "Your document text here",
|
||||||
"mime_type": "text/plain",
|
"mime_type": "text/plain",
|
||||||
|
"metadata": {
|
||||||
|
"document_id": "doc1",
|
||||||
|
},
|
||||||
},
|
},
|
||||||
]
|
]
|
||||||
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
client.vector_io.insert(vector_db_id=vector_db_id, chunks=chunks)
|
||||||
|
```
|
||||||
|
### Retrieval
|
||||||
|
You can query the vector database to retrieve documents based on their embeddings.
|
||||||
|
```python
|
||||||
# You can then query for these chunks
|
# You can then query for these chunks
|
||||||
chunks_response = client.vector_io.query(
|
chunks_response = client.vector_io.query(
|
||||||
vector_db_id=vector_db_id, query="What do you know about..."
|
vector_db_id=vector_db_id, query="What do you know about..."
|
||||||
|
@ -52,7 +67,8 @@ chunks_response = client.vector_io.query(
|
||||||
|
|
||||||
### Using the RAG Tool
|
### Using the RAG Tool
|
||||||
|
|
||||||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. and automatically chunks them into smaller pieces.
|
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
||||||
|
and automatically chunks them into smaller pieces.
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import RAGDocument
|
from llama_stack_client import RAGDocument
|
||||||
|
|
|
@ -12,11 +12,12 @@
|
||||||
# -- Project information -----------------------------------------------------
|
# -- Project information -----------------------------------------------------
|
||||||
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
# https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information
|
||||||
|
|
||||||
from docutils import nodes
|
|
||||||
from pathlib import Path
|
|
||||||
import requests
|
|
||||||
import json
|
import json
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
import requests
|
||||||
|
from docutils import nodes
|
||||||
|
|
||||||
# Read version from pyproject.toml
|
# Read version from pyproject.toml
|
||||||
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") as f:
|
||||||
|
@ -25,7 +26,9 @@ with Path(__file__).parent.parent.parent.joinpath("pyproject.toml").open("rb") a
|
||||||
print(f"{version_tag=}")
|
print(f"{version_tag=}")
|
||||||
|
|
||||||
# generate the full link including text and url here
|
# generate the full link including text and url here
|
||||||
llama_stack_version_url = f"https://github.com/meta-llama/llama-stack/releases/tag/v{version_tag}"
|
llama_stack_version_url = (
|
||||||
|
f"https://github.com/meta-llama/llama-stack/releases/tag/v{version_tag}"
|
||||||
|
)
|
||||||
llama_stack_version_link = f"<a href='{llama_stack_version_url}'>release notes</a>"
|
llama_stack_version_link = f"<a href='{llama_stack_version_url}'>release notes</a>"
|
||||||
|
|
||||||
project = "llama-stack"
|
project = "llama-stack"
|
||||||
|
@ -37,11 +40,11 @@ author = "Meta"
|
||||||
|
|
||||||
extensions = [
|
extensions = [
|
||||||
"myst_parser",
|
"myst_parser",
|
||||||
|
"sphinx_copybutton",
|
||||||
|
"sphinx_design",
|
||||||
"sphinx_rtd_theme",
|
"sphinx_rtd_theme",
|
||||||
"sphinx_rtd_dark_mode",
|
"sphinx_rtd_dark_mode",
|
||||||
"sphinx_copybutton",
|
|
||||||
"sphinx_tabs.tabs",
|
"sphinx_tabs.tabs",
|
||||||
"sphinx_design",
|
|
||||||
"sphinxcontrib.redoc",
|
"sphinxcontrib.redoc",
|
||||||
"sphinxcontrib.mermaid",
|
"sphinxcontrib.mermaid",
|
||||||
"sphinxcontrib.video",
|
"sphinxcontrib.video",
|
||||||
|
@ -85,7 +88,7 @@ myst_substitutions = {
|
||||||
"llama_stack_version_link": llama_stack_version_link,
|
"llama_stack_version_link": llama_stack_version_link,
|
||||||
}
|
}
|
||||||
|
|
||||||
suppress_warnings = ['myst.header']
|
suppress_warnings = ["myst.header"]
|
||||||
|
|
||||||
# Copy button settings
|
# Copy button settings
|
||||||
copybutton_prompt_text = "$ " # for bash prompts
|
copybutton_prompt_text = "$ " # for bash prompts
|
||||||
|
@ -105,17 +108,21 @@ source_suffix = {
|
||||||
# html_theme = "alabaster"
|
# html_theme = "alabaster"
|
||||||
html_theme_options = {
|
html_theme_options = {
|
||||||
"canonical_url": "https://github.com/meta-llama/llama-stack",
|
"canonical_url": "https://github.com/meta-llama/llama-stack",
|
||||||
'collapse_navigation': False,
|
"collapse_navigation": False,
|
||||||
|
|
||||||
# "style_nav_header_background": "#c3c9d4",
|
# "style_nav_header_background": "#c3c9d4",
|
||||||
}
|
}
|
||||||
|
|
||||||
|
default_dark_mode = False
|
||||||
|
|
||||||
html_static_path = ["../_static"]
|
html_static_path = ["../_static"]
|
||||||
# html_logo = "../_static/llama-stack-logo.png"
|
# html_logo = "../_static/llama-stack-logo.png"
|
||||||
# html_style = "../_static/css/my_theme.css"
|
# html_style = "../_static/css/my_theme.css"
|
||||||
|
|
||||||
|
|
||||||
def setup(app):
|
def setup(app):
|
||||||
app.add_css_file("css/my_theme.css")
|
app.add_css_file("css/my_theme.css")
|
||||||
|
app.add_js_file("js/detect_theme.js")
|
||||||
|
|
||||||
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
|
def dockerhub_role(name, rawtext, text, lineno, inliner, options={}, content=[]):
|
||||||
url = f"https://hub.docker.com/r/llamastack/{text}"
|
url = f"https://hub.docker.com/r/llamastack/{text}"
|
||||||
node = nodes.reference(rawtext, text, refuri=url, **options)
|
node = nodes.reference(rawtext, text, refuri=url, **options)
|
||||||
|
|
|
@ -231,7 +231,7 @@ options:
|
||||||
-h, --help show this help message and exit
|
-h, --help show this help message and exit
|
||||||
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
--port PORT Port to run the server on. It can also be passed via the env var LLAMA_STACK_PORT. (default: 8321)
|
||||||
--image-name IMAGE_NAME
|
--image-name IMAGE_NAME
|
||||||
Name of the image to run. Defaults to the current conda environment (default: None)
|
Name of the image to run. Defaults to the current environment (default: None)
|
||||||
--disable-ipv6 Disable IPv6 support (default: False)
|
--disable-ipv6 Disable IPv6 support (default: False)
|
||||||
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
|
--env KEY=VALUE Environment variables to pass to the server in KEY=VALUE format. Can be specified multiple times. (default: [])
|
||||||
--tls-keyfile TLS_KEYFILE
|
--tls-keyfile TLS_KEYFILE
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
|
|
||||||
The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution:
|
The Llama Stack runtime configuration is specified as a YAML file. Here is a simplified version of an example configuration file for the Ollama distribution:
|
||||||
|
|
||||||
```{dropdown} Sample Configuration File
|
```{dropdown} 👋 Click here for a Sample Configuration File
|
||||||
|
|
||||||
```yaml
|
```yaml
|
||||||
version: 2
|
version: 2
|
||||||
|
|
|
@ -17,7 +17,7 @@ client = LlamaStackAsLibraryClient(
|
||||||
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
# provider_data is optional, but if you need to pass in any provider specific data, you can do so here.
|
||||||
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
provider_data={"tavily_search_api_key": os.environ["TAVILY_SEARCH_API_KEY"]},
|
||||||
)
|
)
|
||||||
await client.initialize()
|
client.initialize()
|
||||||
```
|
```
|
||||||
|
|
||||||
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
This will parse your config and set up any inline implementations and remote clients needed for your implementation.
|
||||||
|
|
|
@ -7,13 +7,18 @@ In this guide, we'll use a local [Kind](https://kind.sigs.k8s.io/) cluster and a
|
||||||
|
|
||||||
First, create a local Kubernetes cluster via Kind:
|
First, create a local Kubernetes cluster via Kind:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
kind create cluster --image kindest/node:v1.32.0 --name llama-stack-test
|
kind create cluster --image kindest/node:v1.32.0 --name llama-stack-test
|
||||||
```
|
```
|
||||||
|
|
||||||
First, create a Kubernetes PVC and Secret for downloading and storing Hugging Face model:
|
First set your hugging face token as an environment variable.
|
||||||
|
```
|
||||||
|
export HF_TOKEN=$(echo -n "your-hf-token" | base64)
|
||||||
|
```
|
||||||
|
|
||||||
```bash
|
Now create a Kubernetes PVC and Secret for downloading and storing Hugging Face model:
|
||||||
|
|
||||||
|
```
|
||||||
cat <<EOF |kubectl apply -f -
|
cat <<EOF |kubectl apply -f -
|
||||||
apiVersion: v1
|
apiVersion: v1
|
||||||
kind: PersistentVolumeClaim
|
kind: PersistentVolumeClaim
|
||||||
|
@ -33,13 +38,14 @@ metadata:
|
||||||
name: hf-token-secret
|
name: hf-token-secret
|
||||||
type: Opaque
|
type: Opaque
|
||||||
data:
|
data:
|
||||||
token: $(HF_TOKEN)
|
token: $HF_TOKEN
|
||||||
|
EOF
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
Next, start the vLLM server as a Kubernetes Deployment and Service:
|
Next, start the vLLM server as a Kubernetes Deployment and Service:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
cat <<EOF |kubectl apply -f -
|
cat <<EOF |kubectl apply -f -
|
||||||
apiVersion: apps/v1
|
apiVersion: apps/v1
|
||||||
kind: Deployment
|
kind: Deployment
|
||||||
|
@ -95,7 +101,7 @@ EOF
|
||||||
|
|
||||||
We can verify that the vLLM server has started successfully via the logs (this might take a couple of minutes to download the model):
|
We can verify that the vLLM server has started successfully via the logs (this might take a couple of minutes to download the model):
|
||||||
|
|
||||||
```bash
|
```
|
||||||
$ kubectl logs -l app.kubernetes.io/name=vllm
|
$ kubectl logs -l app.kubernetes.io/name=vllm
|
||||||
...
|
...
|
||||||
INFO: Started server process [1]
|
INFO: Started server process [1]
|
||||||
|
@ -119,8 +125,8 @@ providers:
|
||||||
|
|
||||||
Once we have defined the run configuration for Llama Stack, we can build an image with that configuration and the server source code:
|
Once we have defined the run configuration for Llama Stack, we can build an image with that configuration and the server source code:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
cat >/tmp/test-vllm-llama-stack/Containerfile.llama-stack-run-k8s <<EOF
|
tmp_dir=$(mktemp -d) && cat >$tmp_dir/Containerfile.llama-stack-run-k8s <<EOF
|
||||||
FROM distribution-myenv:dev
|
FROM distribution-myenv:dev
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y git
|
RUN apt-get update && apt-get install -y git
|
||||||
|
@ -128,14 +134,14 @@ RUN git clone https://github.com/meta-llama/llama-stack.git /app/llama-stack-sou
|
||||||
|
|
||||||
ADD ./vllm-llama-stack-run-k8s.yaml /app/config.yaml
|
ADD ./vllm-llama-stack-run-k8s.yaml /app/config.yaml
|
||||||
EOF
|
EOF
|
||||||
podman build -f /tmp/test-vllm-llama-stack/Containerfile.llama-stack-run-k8s -t llama-stack-run-k8s /tmp/test-vllm-llama-stack
|
podman build -f $tmp_dir/Containerfile.llama-stack-run-k8s -t llama-stack-run-k8s $tmp_dir
|
||||||
```
|
```
|
||||||
|
|
||||||
### Deploying Llama Stack Server in Kubernetes
|
### Deploying Llama Stack Server in Kubernetes
|
||||||
|
|
||||||
We can then start the Llama Stack server by deploying a Kubernetes Pod and Service:
|
We can then start the Llama Stack server by deploying a Kubernetes Pod and Service:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
cat <<EOF |kubectl apply -f -
|
cat <<EOF |kubectl apply -f -
|
||||||
apiVersion: v1
|
apiVersion: v1
|
||||||
kind: PersistentVolumeClaim
|
kind: PersistentVolumeClaim
|
||||||
|
@ -195,7 +201,7 @@ EOF
|
||||||
### Verifying the Deployment
|
### Verifying the Deployment
|
||||||
We can check that the LlamaStack server has started:
|
We can check that the LlamaStack server has started:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
$ kubectl logs -l app.kubernetes.io/name=llama-stack
|
$ kubectl logs -l app.kubernetes.io/name=llama-stack
|
||||||
...
|
...
|
||||||
INFO: Started server process [1]
|
INFO: Started server process [1]
|
||||||
|
@ -207,7 +213,7 @@ INFO: Uvicorn running on http://['::', '0.0.0.0']:5000 (Press CTRL+C to quit
|
||||||
|
|
||||||
Finally, we forward the Kubernetes service to a local port and test some inference requests against it via the Llama Stack Client:
|
Finally, we forward the Kubernetes service to a local port and test some inference requests against it via the Llama Stack Client:
|
||||||
|
|
||||||
```bash
|
```
|
||||||
kubectl port-forward service/llama-stack-service 5000:5000
|
kubectl port-forward service/llama-stack-service 5000:5000
|
||||||
llama-stack-client --endpoint http://localhost:5000 inference chat-completion --message "hello, what model are you?"
|
llama-stack-client --endpoint http://localhost:5000 inference chat-completion --message "hello, what model are you?"
|
||||||
```
|
```
|
||||||
|
|
|
@ -46,6 +46,8 @@ The following models are available by default:
|
||||||
- `accounts/fireworks/models/llama-v3p3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
- `accounts/fireworks/models/llama-v3p3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `accounts/fireworks/models/llama-guard-3-8b (aliases: meta-llama/Llama-Guard-3-8B)`
|
- `accounts/fireworks/models/llama-guard-3-8b (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
- `accounts/fireworks/models/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
- `accounts/fireworks/models/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
|
- `accounts/fireworks/models/llama4-scout-instruct-basic (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
|
- `accounts/fireworks/models/llama4-maverick-instruct-basic (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||||
- `nomic-ai/nomic-embed-text-v1.5 `
|
- `nomic-ai/nomic-embed-text-v1.5 `
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -42,6 +42,10 @@ The following models are available by default:
|
||||||
- `groq/llama3-70b-8192 (aliases: meta-llama/Llama-3-70B-Instruct)`
|
- `groq/llama3-70b-8192 (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||||
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
- `groq/llama-3.3-70b-versatile (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
- `groq/llama-3.2-3b-preview (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
- `groq/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
|
- `groq/meta-llama/llama-4-scout-17b-16e-instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
|
- `groq/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||||
|
- `groq/meta-llama/llama-4-maverick-17b-128e-instruct (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -1,3 +1,4 @@
|
||||||
|
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||||
# NVIDIA Distribution
|
# NVIDIA Distribution
|
||||||
|
|
||||||
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
|
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
|
||||||
|
@ -5,24 +6,49 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
| API | Provider(s) |
|
| API | Provider(s) |
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
|
| datasetio | `inline::localfs` |
|
||||||
|
| eval | `inline::meta-reference` |
|
||||||
| inference | `remote::nvidia` |
|
| inference | `remote::nvidia` |
|
||||||
| memory | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| post_training | `remote::nvidia` |
|
||||||
| safety | `inline::llama-guard` |
|
| safety | `remote::nvidia` |
|
||||||
|
| scoring | `inline::basic` |
|
||||||
| telemetry | `inline::meta-reference` |
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `inline::rag-runtime` |
|
||||||
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
|
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
||||||
|
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||||
|
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
||||||
|
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||||
|
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||||
|
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||||
|
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||||
|
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||||
|
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||||
|
|
||||||
### Models
|
### Models
|
||||||
|
|
||||||
The following models are available by default:
|
The following models are available by default:
|
||||||
|
|
||||||
- `${env.INFERENCE_MODEL} (None)`
|
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
||||||
|
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
||||||
|
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
|
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
|
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
||||||
|
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
|
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
|
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
|
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||||
|
- `nvidia/nv-embedqa-e5-v5 `
|
||||||
|
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||||
|
- `snowflake/arctic-embed-l `
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
@ -58,4 +84,5 @@ llama stack build --template nvidia --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port 8321 \
|
--port 8321 \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
```
|
```
|
||||||
|
|
|
@ -25,7 +25,7 @@ The `llamastack/distribution-remote-vllm` distribution consists of the following
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
||||||
|
|
||||||
|
|
||||||
You can use this distribution if you have GPUs and want to run an independent vLLM server container for running inference.
|
You can use this distribution if you want to run an independent vLLM server for inference.
|
||||||
|
|
||||||
### Environment Variables
|
### Environment Variables
|
||||||
|
|
||||||
|
@ -41,6 +41,83 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Setting up vLLM server
|
## Setting up vLLM server
|
||||||
|
|
||||||
|
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
|
||||||
|
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||||
|
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||||
|
that we only use GPUs here for demonstration purposes.
|
||||||
|
|
||||||
|
### Setting up vLLM server on AMD GPU
|
||||||
|
|
||||||
|
AMD provides two main vLLM container options:
|
||||||
|
- rocm/vllm: Production-ready container
|
||||||
|
- rocm/vllm-dev: Development container with the latest vLLM features
|
||||||
|
|
||||||
|
Please check the [Blog about ROCm vLLM Usage](https://rocm.blogs.amd.com/software-tools-optimization/vllm-container/README.html) to get more details.
|
||||||
|
|
||||||
|
Here is a sample script to start a ROCm vLLM server locally via Docker:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export INFERENCE_PORT=8000
|
||||||
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
export CUDA_VISIBLE_DEVICES=0
|
||||||
|
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--ipc=host \
|
||||||
|
--privileged \
|
||||||
|
--shm-size 16g \
|
||||||
|
--device=/dev/kfd \
|
||||||
|
--device=/dev/dri \
|
||||||
|
--group-add video \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--cap-add=CAP_SYS_ADMIN \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--security-opt apparmor=unconfined \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||||
|
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
$VLLM_DIMG \
|
||||||
|
python -m vllm.entrypoints.openai.api_server \
|
||||||
|
--model $INFERENCE_MODEL \
|
||||||
|
--port $INFERENCE_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
Note that you'll also need to set `--enable-auto-tool-choice` and `--tool-call-parser` to [enable tool calling in vLLM](https://docs.vllm.ai/en/latest/features/tool_calling.html).
|
||||||
|
|
||||||
|
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SAFETY_PORT=8081
|
||||||
|
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
export CUDA_VISIBLE_DEVICES=1
|
||||||
|
export VLLM_DIMG="rocm/vllm-dev:main"
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--ipc=host \
|
||||||
|
--privileged \
|
||||||
|
--shm-size 16g \
|
||||||
|
--device=/dev/kfd \
|
||||||
|
--device=/dev/dri \
|
||||||
|
--group-add video \
|
||||||
|
--cap-add=SYS_PTRACE \
|
||||||
|
--cap-add=CAP_SYS_ADMIN \
|
||||||
|
--security-opt seccomp=unconfined \
|
||||||
|
--security-opt apparmor=unconfined \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env "HIP_VISIBLE_DEVICES=$CUDA_VISIBLE_DEVICES" \
|
||||||
|
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
$VLLM_DIMG \
|
||||||
|
python -m vllm.entrypoints.openai.api_server \
|
||||||
|
--model $SAFETY_MODEL \
|
||||||
|
--port $SAFETY_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
### Setting up vLLM server on NVIDIA GPU
|
||||||
|
|
||||||
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
|
Please check the [vLLM Documentation](https://docs.vllm.ai/en/v0.5.5/serving/deploying_with_docker.html) to get a vLLM endpoint. Here is a sample script to start a vLLM server locally via Docker:
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
|
|
@ -43,6 +43,7 @@ The following models are available by default:
|
||||||
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
- `Llama-3.2-11B-Vision-Instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
- `Llama-3.2-90B-Vision-Instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
- `Meta-Llama-Guard-3-8B (aliases: meta-llama/Llama-Guard-3-8B)`
|
||||||
|
- `Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -48,6 +48,8 @@ The following models are available by default:
|
||||||
- `meta-llama/Llama-Guard-3-11B-Vision-Turbo (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
- `meta-llama/Llama-Guard-3-11B-Vision-Turbo (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
- `togethercomputer/m2-bert-80M-8k-retrieval `
|
- `togethercomputer/m2-bert-80M-8k-retrieval `
|
||||||
- `togethercomputer/m2-bert-80M-32k-retrieval `
|
- `togethercomputer/m2-bert-80M-32k-retrieval `
|
||||||
|
- `meta-llama/Llama-4-Scout-17B-16E-Instruct (aliases: meta-llama/Llama-4-Scout-17B-16E-Instruct, together/meta-llama/Llama-4-Scout-17B-16E-Instruct)`
|
||||||
|
- `meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 (aliases: meta-llama/Llama-4-Maverick-17B-128E-Instruct, together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8)`
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
### Prerequisite: API Keys
|
||||||
|
|
|
@ -2,22 +2,22 @@
|
||||||
|
|
||||||
You can run a Llama Stack server in one of the following ways:
|
You can run a Llama Stack server in one of the following ways:
|
||||||
|
|
||||||
**As a Library**:
|
## As a Library:
|
||||||
|
|
||||||
This is the simplest way to get started. Using Llama Stack as a library means you do not need to start a server. This is especially useful when you are not running inference locally and relying on an external inference service (eg. fireworks, together, groq, etc.) See [Using Llama Stack as a Library](importing_as_library)
|
This is the simplest way to get started. Using Llama Stack as a library means you do not need to start a server. This is especially useful when you are not running inference locally and relying on an external inference service (eg. fireworks, together, groq, etc.) See [Using Llama Stack as a Library](importing_as_library)
|
||||||
|
|
||||||
|
|
||||||
**Container**:
|
## Container:
|
||||||
|
|
||||||
Another simple way to start interacting with Llama Stack is to just spin up a container (via Docker or Podman) which is pre-built with all the providers you need. We provide a number of pre-built images so you can start a Llama Stack server instantly. You can also build your own custom container. Which distribution to choose depends on the hardware you have. See [Selection of a Distribution](selection) for more details.
|
Another simple way to start interacting with Llama Stack is to just spin up a container (via Docker or Podman) which is pre-built with all the providers you need. We provide a number of pre-built images so you can start a Llama Stack server instantly. You can also build your own custom container. Which distribution to choose depends on the hardware you have. See [Selection of a Distribution](selection) for more details.
|
||||||
|
|
||||||
|
|
||||||
**Conda**:
|
## Conda:
|
||||||
|
|
||||||
If you have a custom or an advanced setup or you are developing on Llama Stack you can also build a custom Llama Stack server. Using `llama stack build` and `llama stack run` you can build/run a custom Llama Stack server containing the exact combination of providers you wish. We have also provided various templates to make getting started easier. See [Building a Custom Distribution](building_distro) for more details.
|
If you have a custom or an advanced setup or you are developing on Llama Stack you can also build a custom Llama Stack server. Using `llama stack build` and `llama stack run` you can build/run a custom Llama Stack server containing the exact combination of providers you wish. We have also provided various templates to make getting started easier. See [Building a Custom Distribution](building_distro) for more details.
|
||||||
|
|
||||||
|
|
||||||
**Kubernetes**:
|
## Kubernetes:
|
||||||
|
|
||||||
If you have built a container image and want to deploy it in a Kubernetes cluster instead of starting the Llama Stack server locally. See [Kubernetes Deployment Guide](kubernetes_deployment) for more details.
|
If you have built a container image and want to deploy it in a Kubernetes cluster instead of starting the Llama Stack server locally. See [Kubernetes Deployment Guide](kubernetes_deployment) for more details.
|
||||||
|
|
||||||
|
|
541
docs/source/getting_started/detailed_tutorial.md
Normal file
541
docs/source/getting_started/detailed_tutorial.md
Normal file
|
@ -0,0 +1,541 @@
|
||||||
|
# Detailed Tutorial
|
||||||
|
|
||||||
|
In this guide, we'll walk through how you can use the Llama Stack (server and client SDK) to test a simple agent.
|
||||||
|
A Llama Stack agent is a simple integrated system that can perform tasks by combining a Llama model for reasoning with
|
||||||
|
tools (e.g., RAG, web search, code execution, etc.) for taking actions.
|
||||||
|
In Llama Stack, we provide a server exposing multiple APIs. These APIs are backed by implementations from different providers.
|
||||||
|
|
||||||
|
Llama Stack is a stateful service with REST APIs to support seamless transition of AI applications across different environments. The server can be run in a variety of ways, including as a standalone binary, Docker container, or hosted service. You can build and test using a local server first and deploy to a hosted endpoint for production.
|
||||||
|
|
||||||
|
In this guide, we'll walk through how to build a RAG agent locally using Llama Stack with [Ollama](https://ollama.com/)
|
||||||
|
as the inference [provider](../providers/index.md#inference) for a Llama Model.
|
||||||
|
|
||||||
|
## Step 1: Installation and Setup
|
||||||
|
|
||||||
|
Install Ollama by following the instructions on the [Ollama website](https://ollama.com/download), then
|
||||||
|
download Llama 3.2 3B model, and then start the Ollama service.
|
||||||
|
```bash
|
||||||
|
ollama pull llama3.2:3b
|
||||||
|
ollama run llama3.2:3b --keepalive 60m
|
||||||
|
```
|
||||||
|
|
||||||
|
Install [uv](https://docs.astral.sh/uv/) to setup your virtual environment
|
||||||
|
|
||||||
|
::::{tab-set}
|
||||||
|
|
||||||
|
:::{tab-item} macOS and Linux
|
||||||
|
Use `curl` to download the script and execute it with `sh`:
|
||||||
|
```console
|
||||||
|
curl -LsSf https://astral.sh/uv/install.sh | sh
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
|
||||||
|
:::{tab-item} Windows
|
||||||
|
Use `irm` to download the script and execute it with `iex`:
|
||||||
|
|
||||||
|
```console
|
||||||
|
powershell -ExecutionPolicy ByPass -c "irm https://astral.sh/uv/install.ps1 | iex"
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
::::
|
||||||
|
|
||||||
|
Setup your virtual environment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv venv --python 3.10
|
||||||
|
source .venv/bin/activate
|
||||||
|
```
|
||||||
|
## Step 2: Run Llama Stack
|
||||||
|
Llama Stack is a server that exposes multiple APIs, you connect with it using the Llama Stack client SDK.
|
||||||
|
|
||||||
|
::::{tab-set}
|
||||||
|
|
||||||
|
:::{tab-item} Using `venv`
|
||||||
|
You can use Python to build and run the Llama Stack server, which is useful for testing and development.
|
||||||
|
|
||||||
|
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
|
||||||
|
which defines the providers and their settings.
|
||||||
|
Now let's build and run the Llama Stack config for Ollama.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFERENCE_MODEL=llama3.2:3b llama stack build --template ollama --image-type venv --run
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
:::{tab-item} Using `conda`
|
||||||
|
You can use Python to build and run the Llama Stack server, which is useful for testing and development.
|
||||||
|
|
||||||
|
Llama Stack uses a [YAML configuration file](../distributions/configuration.md) to specify the stack setup,
|
||||||
|
which defines the providers and their settings.
|
||||||
|
Now let's build and run the Llama Stack config for Ollama.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFERENCE_MODEL=llama3.2:3b llama stack build --template ollama --image-type conda --image-name llama3-3b-conda --run
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
:::{tab-item} Using a Container
|
||||||
|
You can use a container image to run the Llama Stack server. We provide several container images for the server
|
||||||
|
component that works with different inference providers out of the box. For this guide, we will use
|
||||||
|
`llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the
|
||||||
|
configurations, please check out [this guide](../references/index.md).
|
||||||
|
First lets setup some environment variables and create a local directory to mount into the container’s file system.
|
||||||
|
```bash
|
||||||
|
export INFERENCE_MODEL="llama3.2:3b"
|
||||||
|
export LLAMA_STACK_PORT=8321
|
||||||
|
mkdir -p ~/.llama
|
||||||
|
```
|
||||||
|
Then start the server using the container tool of your choice. For example, if you are running Docker you can use the
|
||||||
|
following command:
|
||||||
|
```bash
|
||||||
|
docker run -it \
|
||||||
|
--pull always \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ~/.llama:/root/.llama \
|
||||||
|
llamastack/distribution-ollama \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
--env OLLAMA_URL=http://host.docker.internal:11434
|
||||||
|
```
|
||||||
|
Note to start the container with Podman, you can do the same but replace `docker` at the start of the command with
|
||||||
|
`podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL`
|
||||||
|
with `host.containers.internal`.
|
||||||
|
|
||||||
|
The configuration YAML for the Ollama distribution is available at `distributions/ollama/run.yaml`.
|
||||||
|
|
||||||
|
```{tip}
|
||||||
|
|
||||||
|
Docker containers run in their own isolated network namespaces on Linux. To allow the container to communicate with services running on the host via `localhost`, you need `--network=host`. This makes the container use the host’s network directly so it can connect to Ollama running on `localhost:11434`.
|
||||||
|
|
||||||
|
Linux users having issues running the above command should instead try the following:
|
||||||
|
```bash
|
||||||
|
docker run -it \
|
||||||
|
--pull always \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ~/.llama:/root/.llama \
|
||||||
|
--network=host \
|
||||||
|
llamastack/distribution-ollama \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
||||||
|
--env OLLAMA_URL=http://localhost:11434
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
::::
|
||||||
|
You will see output like below:
|
||||||
|
```
|
||||||
|
INFO: Application startup complete.
|
||||||
|
INFO: Uvicorn running on http://['::', '0.0.0.0']:8321 (Press CTRL+C to quit)
|
||||||
|
```
|
||||||
|
|
||||||
|
Now you can use the Llama Stack client to run inference and build agents!
|
||||||
|
|
||||||
|
You can reuse the server setup or use the [Llama Stack Client](https://github.com/meta-llama/llama-stack-client-python/).
|
||||||
|
Note that the client package is already included in the `llama-stack` package.
|
||||||
|
|
||||||
|
## Step 3: Run Client CLI
|
||||||
|
|
||||||
|
Open a new terminal and navigate to the same directory you started the server from. Then set up a new or activate your
|
||||||
|
existing server virtual environment.
|
||||||
|
|
||||||
|
::::{tab-set}
|
||||||
|
|
||||||
|
:::{tab-item} Reuse Server `venv`
|
||||||
|
```bash
|
||||||
|
# The client is included in the llama-stack package so we just activate the server venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
|
||||||
|
:::{tab-item} Install with `venv`
|
||||||
|
```bash
|
||||||
|
uv venv client --python 3.10
|
||||||
|
source client/bin/activate
|
||||||
|
pip install llama-stack-client
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
|
||||||
|
:::{tab-item} Install with `conda`
|
||||||
|
```bash
|
||||||
|
yes | conda create -n stack-client python=3.10
|
||||||
|
conda activate stack-client
|
||||||
|
pip install llama-stack-client
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
::::
|
||||||
|
|
||||||
|
Now let's use the `llama-stack-client` [CLI](../references/llama_stack_client_cli_reference.md) to check the
|
||||||
|
connectivity to the server.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama-stack-client configure --endpoint http://localhost:8321 --api-key none
|
||||||
|
```
|
||||||
|
You will see the below:
|
||||||
|
```
|
||||||
|
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
||||||
|
```
|
||||||
|
|
||||||
|
List the models
|
||||||
|
```bash
|
||||||
|
llama-stack-client models list
|
||||||
|
Available Models
|
||||||
|
|
||||||
|
┏━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━┓
|
||||||
|
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
|
||||||
|
┡━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━┩
|
||||||
|
│ embedding │ all-MiniLM-L6-v2 │ all-minilm:latest │ {'embedding_dimension': 384.0} │ ollama │
|
||||||
|
├─────────────────┼─────────────────────────────────────┼─────────────────────────────────────┼───────────────────────────────────────────┼─────────────────┤
|
||||||
|
│ llm │ llama3.2:3b │ llama3.2:3b │ │ ollama │
|
||||||
|
└─────────────────┴─────────────────────────────────────┴─────────────────────────────────────┴───────────────────────────────────────────┴─────────────────┘
|
||||||
|
|
||||||
|
Total models: 2
|
||||||
|
|
||||||
|
```
|
||||||
|
You can test basic Llama inference completion using the CLI.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama-stack-client inference chat-completion --message "tell me a joke"
|
||||||
|
```
|
||||||
|
Sample output:
|
||||||
|
```python
|
||||||
|
ChatCompletionResponse(
|
||||||
|
completion_message=CompletionMessage(
|
||||||
|
content="Here's one:\n\nWhat do you call a fake noodle?\n\nAn impasta!",
|
||||||
|
role="assistant",
|
||||||
|
stop_reason="end_of_turn",
|
||||||
|
tool_calls=[],
|
||||||
|
),
|
||||||
|
logprobs=None,
|
||||||
|
metrics=[
|
||||||
|
Metric(metric="prompt_tokens", value=14.0, unit=None),
|
||||||
|
Metric(metric="completion_tokens", value=27.0, unit=None),
|
||||||
|
Metric(metric="total_tokens", value=41.0, unit=None),
|
||||||
|
],
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
## Step 4: Run the Demos
|
||||||
|
|
||||||
|
Note that these demos show the [Python Client SDK](../references/python_sdk_reference/index.md).
|
||||||
|
Other SDKs are also available, please refer to the [Client SDK](../index.md#client-sdks) list for the complete options.
|
||||||
|
|
||||||
|
::::{tab-set}
|
||||||
|
|
||||||
|
:::{tab-item} Basic Inference
|
||||||
|
Now you can run inference using the Llama Stack client SDK.
|
||||||
|
|
||||||
|
### i. Create the Script
|
||||||
|
|
||||||
|
Create a file `inference.py` and add the following code:
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
# List available models
|
||||||
|
models = client.models.list()
|
||||||
|
|
||||||
|
# Select the first LLM
|
||||||
|
llm = next(m for m in models if m.model_type == "llm")
|
||||||
|
model_id = llm.identifier
|
||||||
|
|
||||||
|
print("Model:", model_id)
|
||||||
|
|
||||||
|
response = client.inference.chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages=[
|
||||||
|
{"role": "system", "content": "You are a helpful assistant."},
|
||||||
|
{"role": "user", "content": "Write a haiku about coding"},
|
||||||
|
],
|
||||||
|
)
|
||||||
|
print(response.completion_message.content)
|
||||||
|
```
|
||||||
|
|
||||||
|
### ii. Run the Script
|
||||||
|
Let's run the script using `uv`
|
||||||
|
```bash
|
||||||
|
uv run python inference.py
|
||||||
|
```
|
||||||
|
Which will output:
|
||||||
|
```
|
||||||
|
Model: llama3.2:3b
|
||||||
|
Here is a haiku about coding:
|
||||||
|
|
||||||
|
Lines of code unfold
|
||||||
|
Logic flows through digital night
|
||||||
|
Beauty in the bits
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
|
||||||
|
:::{tab-item} Build a Simple Agent
|
||||||
|
Next we can move beyond simple inference and build an agent that can perform tasks using the Llama Stack server.
|
||||||
|
### i. Create the Script
|
||||||
|
Create a file `agent.py` and add the following code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
from llama_stack_client import Agent, AgentEventLogger
|
||||||
|
from rich.pretty import pprint
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
client = LlamaStackClient(base_url=f"http://localhost:8321")
|
||||||
|
|
||||||
|
models = client.models.list()
|
||||||
|
llm = next(m for m in models if m.model_type == "llm")
|
||||||
|
model_id = llm.identifier
|
||||||
|
|
||||||
|
agent = Agent(client, model=model_id, instructions="You are a helpful assistant.")
|
||||||
|
|
||||||
|
s_id = agent.create_session(session_name=f"s{uuid.uuid4().hex}")
|
||||||
|
|
||||||
|
print("Non-streaming ...")
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": "Who are you?"}],
|
||||||
|
session_id=s_id,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
print("agent>", response.output_message.content)
|
||||||
|
|
||||||
|
print("Streaming ...")
|
||||||
|
stream = agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": "Who are you?"}], session_id=s_id, stream=True
|
||||||
|
)
|
||||||
|
for event in stream:
|
||||||
|
pprint(event)
|
||||||
|
|
||||||
|
print("Streaming with print helper...")
|
||||||
|
stream = agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": "Who are you?"}], session_id=s_id, stream=True
|
||||||
|
)
|
||||||
|
for event in AgentEventLogger().log(stream):
|
||||||
|
event.print()
|
||||||
|
```
|
||||||
|
### ii. Run the Script
|
||||||
|
Let's run the script using `uv`
|
||||||
|
```bash
|
||||||
|
uv run python agent.py
|
||||||
|
```
|
||||||
|
|
||||||
|
```{dropdown} 👋 Click here to see the sample output
|
||||||
|
Non-streaming ...
|
||||||
|
agent> I'm an artificial intelligence designed to assist and communicate with users like you. I don't have a personal identity, but I'm here to provide information, answer questions, and help with tasks to the best of my abilities.
|
||||||
|
|
||||||
|
I can be used for a wide range of purposes, such as:
|
||||||
|
|
||||||
|
* Providing definitions and explanations
|
||||||
|
* Offering suggestions and ideas
|
||||||
|
* Helping with language translation
|
||||||
|
* Assisting with writing and proofreading
|
||||||
|
* Generating text or responses to questions
|
||||||
|
* Playing simple games or chatting about topics of interest
|
||||||
|
|
||||||
|
I'm constantly learning and improving my abilities, so feel free to ask me anything, and I'll do my best to help!
|
||||||
|
|
||||||
|
Streaming ...
|
||||||
|
AgentTurnResponseStreamChunk(
|
||||||
|
│ event=TurnResponseEvent(
|
||||||
|
│ │ payload=AgentTurnResponseStepStartPayload(
|
||||||
|
│ │ │ event_type='step_start',
|
||||||
|
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
|
||||||
|
│ │ │ step_type='inference',
|
||||||
|
│ │ │ metadata={}
|
||||||
|
│ │ )
|
||||||
|
│ )
|
||||||
|
)
|
||||||
|
AgentTurnResponseStreamChunk(
|
||||||
|
│ event=TurnResponseEvent(
|
||||||
|
│ │ payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
│ │ │ delta=TextDelta(text='As', type='text'),
|
||||||
|
│ │ │ event_type='step_progress',
|
||||||
|
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
|
||||||
|
│ │ │ step_type='inference'
|
||||||
|
│ │ )
|
||||||
|
│ )
|
||||||
|
)
|
||||||
|
AgentTurnResponseStreamChunk(
|
||||||
|
│ event=TurnResponseEvent(
|
||||||
|
│ │ payload=AgentTurnResponseStepProgressPayload(
|
||||||
|
│ │ │ delta=TextDelta(text=' a', type='text'),
|
||||||
|
│ │ │ event_type='step_progress',
|
||||||
|
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
|
||||||
|
│ │ │ step_type='inference'
|
||||||
|
│ │ )
|
||||||
|
│ )
|
||||||
|
)
|
||||||
|
...
|
||||||
|
AgentTurnResponseStreamChunk(
|
||||||
|
│ event=TurnResponseEvent(
|
||||||
|
│ │ payload=AgentTurnResponseStepCompletePayload(
|
||||||
|
│ │ │ event_type='step_complete',
|
||||||
|
│ │ │ step_details=InferenceStep(
|
||||||
|
│ │ │ │ api_model_response=CompletionMessage(
|
||||||
|
│ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface – I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
|
||||||
|
│ │ │ │ │ role='assistant',
|
||||||
|
│ │ │ │ │ stop_reason='end_of_turn',
|
||||||
|
│ │ │ │ │ tool_calls=[]
|
||||||
|
│ │ │ │ ),
|
||||||
|
│ │ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
|
||||||
|
│ │ │ │ step_type='inference',
|
||||||
|
│ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
|
||||||
|
│ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 716174, tzinfo=TzInfo(UTC)),
|
||||||
|
│ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28823, tzinfo=TzInfo(UTC))
|
||||||
|
│ │ │ ),
|
||||||
|
│ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
|
||||||
|
│ │ │ step_type='inference'
|
||||||
|
│ │ )
|
||||||
|
│ )
|
||||||
|
)
|
||||||
|
AgentTurnResponseStreamChunk(
|
||||||
|
│ event=TurnResponseEvent(
|
||||||
|
│ │ payload=AgentTurnResponseTurnCompletePayload(
|
||||||
|
│ │ │ event_type='turn_complete',
|
||||||
|
│ │ │ turn=Turn(
|
||||||
|
│ │ │ │ input_messages=[UserMessage(content='Who are you?', role='user', context=None)],
|
||||||
|
│ │ │ │ output_message=CompletionMessage(
|
||||||
|
│ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface – I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
|
||||||
|
│ │ │ │ │ role='assistant',
|
||||||
|
│ │ │ │ │ stop_reason='end_of_turn',
|
||||||
|
│ │ │ │ │ tool_calls=[]
|
||||||
|
│ │ │ │ ),
|
||||||
|
│ │ │ │ session_id='abd4afea-4324-43f4-9513-cfe3970d92e8',
|
||||||
|
│ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28722, tzinfo=TzInfo(UTC)),
|
||||||
|
│ │ │ │ steps=[
|
||||||
|
│ │ │ │ │ InferenceStep(
|
||||||
|
│ │ │ │ │ │ api_model_response=CompletionMessage(
|
||||||
|
│ │ │ │ │ │ │ content='As a conversational AI, I don\'t have a personal identity in the classical sense. I exist as a program running on computer servers, designed to process and respond to text-based inputs.\n\nI\'m an instance of a type of artificial intelligence called a "language model," which is trained on vast amounts of text data to generate human-like responses. My primary function is to understand and respond to natural language inputs, like our conversation right now.\n\nThink of me as a virtual assistant, a chatbot, or a conversational interface – I\'m here to provide information, answer questions, and engage in conversation to the best of my abilities. I don\'t have feelings, emotions, or consciousness like humans do, but I\'m designed to simulate human-like interactions to make our conversations feel more natural and helpful.\n\nSo, that\'s me in a nutshell! What can I help you with today?',
|
||||||
|
│ │ │ │ │ │ │ role='assistant',
|
||||||
|
│ │ │ │ │ │ │ stop_reason='end_of_turn',
|
||||||
|
│ │ │ │ │ │ │ tool_calls=[]
|
||||||
|
│ │ │ │ │ │ ),
|
||||||
|
│ │ │ │ │ │ step_id='69831607-fa75-424a-949b-e2049e3129d1',
|
||||||
|
│ │ │ │ │ │ step_type='inference',
|
||||||
|
│ │ │ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
|
||||||
|
│ │ │ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 716174, tzinfo=TzInfo(UTC)),
|
||||||
|
│ │ │ │ │ │ started_at=datetime.datetime(2025, 4, 3, 1, 15, 14, 28823, tzinfo=TzInfo(UTC))
|
||||||
|
│ │ │ │ │ )
|
||||||
|
│ │ │ │ ],
|
||||||
|
│ │ │ │ turn_id='8b360202-f7cb-4786-baa9-166a1b46e2ca',
|
||||||
|
│ │ │ │ completed_at=datetime.datetime(2025, 4, 3, 1, 15, 21, 727364, tzinfo=TzInfo(UTC)),
|
||||||
|
│ │ │ │ output_attachments=[]
|
||||||
|
│ │ │ )
|
||||||
|
│ │ )
|
||||||
|
│ )
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
Streaming with print helper...
|
||||||
|
inference> Déjà vu!
|
||||||
|
|
||||||
|
As I mentioned earlier, I'm an artificial intelligence language model. I don't have a personal identity or consciousness like humans do. I exist solely to process and respond to text-based inputs, providing information and assistance on a wide range of topics.
|
||||||
|
|
||||||
|
I'm a computer program designed to simulate human-like conversations, using natural language processing (NLP) and machine learning algorithms to understand and generate responses. My purpose is to help users like you with their questions, provide information, and engage in conversation.
|
||||||
|
|
||||||
|
Think of me as a virtual companion, a helpful tool designed to make your interactions more efficient and enjoyable. I don't have personal opinions, emotions, or biases, but I'm here to provide accurate and informative responses to the best of my abilities.
|
||||||
|
|
||||||
|
So, who am I? I'm just a computer program designed to help you!
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
|
||||||
|
:::{tab-item} Build a RAG Agent
|
||||||
|
|
||||||
|
For our last demo, we can build a RAG agent that can answer questions about the Torchtune project using the documents
|
||||||
|
in a vector database.
|
||||||
|
### i. Create the Script
|
||||||
|
Create a file `rag_agent.py` and add the following code:
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
from llama_stack_client import Agent, AgentEventLogger
|
||||||
|
from llama_stack_client.types import Document
|
||||||
|
import uuid
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
|
# Create a vector database instance
|
||||||
|
embed_lm = next(m for m in client.models.list() if m.model_type == "embedding")
|
||||||
|
embedding_model = embed_lm.identifier
|
||||||
|
vector_db_id = f"v{uuid.uuid4().hex}"
|
||||||
|
client.vector_dbs.register(
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
embedding_model=embedding_model,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create Documents
|
||||||
|
urls = [
|
||||||
|
"memory_optimizations.rst",
|
||||||
|
"chat.rst",
|
||||||
|
"llama3.rst",
|
||||||
|
"datasets.rst",
|
||||||
|
"qat_finetune.rst",
|
||||||
|
"lora_finetune.rst",
|
||||||
|
]
|
||||||
|
documents = [
|
||||||
|
Document(
|
||||||
|
document_id=f"num-{i}",
|
||||||
|
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
||||||
|
mime_type="text/plain",
|
||||||
|
metadata={},
|
||||||
|
)
|
||||||
|
for i, url in enumerate(urls)
|
||||||
|
]
|
||||||
|
|
||||||
|
# Insert documents
|
||||||
|
client.tool_runtime.rag_tool.insert(
|
||||||
|
documents=documents,
|
||||||
|
vector_db_id=vector_db_id,
|
||||||
|
chunk_size_in_tokens=512,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Get the model being served
|
||||||
|
llm = next(m for m in client.models.list() if m.model_type == "llm")
|
||||||
|
model = llm.identifier
|
||||||
|
|
||||||
|
# Create the RAG agent
|
||||||
|
rag_agent = Agent(
|
||||||
|
client,
|
||||||
|
model=model,
|
||||||
|
instructions="You are a helpful assistant. Use the RAG tool to answer questions as needed.",
|
||||||
|
tools=[
|
||||||
|
{
|
||||||
|
"name": "builtin::rag/knowledge_search",
|
||||||
|
"args": {"vector_db_ids": [vector_db_id]},
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
|
||||||
|
session_id = rag_agent.create_session(session_name=f"s{uuid.uuid4().hex}")
|
||||||
|
|
||||||
|
turns = ["what is torchtune", "tell me about dora"]
|
||||||
|
|
||||||
|
for t in turns:
|
||||||
|
print("user>", t)
|
||||||
|
stream = rag_agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": t}], session_id=session_id, stream=True
|
||||||
|
)
|
||||||
|
for event in AgentEventLogger().log(stream):
|
||||||
|
event.print()
|
||||||
|
```
|
||||||
|
### ii. Run the Script
|
||||||
|
Let's run the script using `uv`
|
||||||
|
```bash
|
||||||
|
uv run python rag_agent.py
|
||||||
|
```
|
||||||
|
|
||||||
|
```{dropdown} 👋 Click here to see the sample output
|
||||||
|
user> what is torchtune
|
||||||
|
inference> [knowledge_search(query='TorchTune')]
|
||||||
|
tool_execution> Tool:knowledge_search Args:{'query': 'TorchTune'}
|
||||||
|
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text='Result 1:\nDocument_id:num-1\nContent: conversational data, :func:`~torchtune.datasets.chat_dataset` seems to be a good fit. ..., type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
|
||||||
|
inference> Here is a high-level overview of the text:
|
||||||
|
|
||||||
|
**LoRA Finetuning with PyTorch Tune**
|
||||||
|
|
||||||
|
PyTorch Tune provides a recipe for LoRA (Low-Rank Adaptation) finetuning, which is a technique to adapt pre-trained models to new tasks. The recipe uses the `lora_finetune_distributed` command.
|
||||||
|
...
|
||||||
|
Overall, DORA is a powerful reinforcement learning algorithm that can learn complex tasks from human demonstrations. However, it requires careful consideration of the challenges and limitations to achieve optimal results.
|
||||||
|
```
|
||||||
|
:::
|
||||||
|
|
||||||
|
::::
|
||||||
|
|
||||||
|
**You're Ready to Build Your Own Apps!**
|
||||||
|
|
||||||
|
Congrats! 🥳 Now you're ready to [build your own Llama Stack applications](../building_applications/index)! 🚀
|
|
@ -1,304 +1,121 @@
|
||||||
# Quick Start
|
# Quickstart
|
||||||
|
|
||||||
In this guide, we'll walk through how you can use the Llama Stack (server and client SDK) to build a simple [RAG (Retrieval Augmented Generation)](../building_applications/rag.md) agent.
|
Get started with Llama Stack in minutes!
|
||||||
|
|
||||||
A Llama Stack agent is a simple integrated system that can perform tasks by combining a Llama model for reasoning with tools (e.g., RAG, web search, code execution, etc.) for taking actions.
|
Llama Stack is a stateful service with REST APIs to support the seamless transition of AI applications across different
|
||||||
|
environments. You can build and test using a local server first and deploy to a hosted endpoint for production.
|
||||||
|
|
||||||
In Llama Stack, we provide a server exposing multiple APIs. These APIs are backed by implementations from different providers. For this guide, we will use [Ollama](https://ollama.com/) as the inference provider.
|
In this guide, we'll walk through how to build a RAG application locally using Llama Stack with [Ollama](https://ollama.com/)
|
||||||
Ollama is an LLM runtime that allows you to run Llama models locally.
|
as the inference [provider](../providers/index.md#inference) for a Llama Model.
|
||||||
|
|
||||||
|
|
||||||
### 1. Start Ollama
|
|
||||||
|
|
||||||
|
#### Step 1: Install and setup
|
||||||
|
1. Install [uv](https://docs.astral.sh/uv/)
|
||||||
|
2. Run inference on a Llama model with [Ollama](https://ollama.com/download)
|
||||||
```bash
|
```bash
|
||||||
ollama run llama3.2:3b-instruct-fp16 --keepalive 60m
|
ollama run llama3.2:3b --keepalive 60m
|
||||||
```
|
```
|
||||||
|
#### Step 2: Run the Llama Stack server
|
||||||
By default, Ollama keeps the model loaded in memory for 5 minutes which can be too short. We set the `--keepalive` flag to 60 minutes to ensure the model remains loaded for sometime.
|
We will use `uv` to run the Llama Stack server.
|
||||||
|
|
||||||
```{admonition} Note
|
|
||||||
:class: tip
|
|
||||||
|
|
||||||
If you do not have ollama, you can install it from [here](https://ollama.com/download).
|
|
||||||
```
|
|
||||||
|
|
||||||
|
|
||||||
### 2. Pick a client environment
|
|
||||||
|
|
||||||
Llama Stack has a service-oriented architecture, so every interaction with the Stack happens through a REST interface. You can interact with the Stack in two ways:
|
|
||||||
|
|
||||||
* Install the `llama-stack-client` PyPI package and point `LlamaStackClient` to a local or remote Llama Stack server.
|
|
||||||
* Or, install the `llama-stack` PyPI package and use the Stack as a library using `LlamaStackAsLibraryClient`.
|
|
||||||
|
|
||||||
```{admonition} Note
|
|
||||||
:class: tip
|
|
||||||
|
|
||||||
The API is **exactly identical** for both clients.
|
|
||||||
```
|
|
||||||
|
|
||||||
:::{dropdown} Starting up the Llama Stack server
|
|
||||||
The Llama Stack server can be configured flexibly so you can mix-and-match various providers for its individual API components -- beyond Inference, these include Vector IO, Agents, Telemetry, Evals, Post Training, etc.
|
|
||||||
|
|
||||||
To get started quickly, we provide various container images for the server component that work with different inference providers out of the box. For this guide, we will use `llamastack/distribution-ollama` as the container image. If you'd like to build your own image or customize the configurations, please check out [this guide](../references/index.md).
|
|
||||||
|
|
||||||
Lets setup some environment variables that we will use in the rest of the guide.
|
|
||||||
```bash
|
```bash
|
||||||
export INFERENCE_MODEL="meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template ollama --image-type venv --run
|
||||||
export LLAMA_STACK_PORT=8321
|
|
||||||
```
|
```
|
||||||
|
#### Step 3: Run the demo
|
||||||
|
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
|
||||||
|
|
||||||
Next you can create a local directory to mount into the container’s file system.
|
|
||||||
```bash
|
|
||||||
mkdir -p ~/.llama
|
|
||||||
```
|
|
||||||
|
|
||||||
Then you can start the server using the container tool of your choice. For example, if you are running Docker you can use the following command:
|
|
||||||
```bash
|
|
||||||
docker run -it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
llamastack/distribution-ollama \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env OLLAMA_URL=http://host.docker.internal:11434
|
|
||||||
```
|
|
||||||
|
|
||||||
As another example, to start the container with Podman, you can do the same but replace `docker` at the start of the command with `podman`. If you are using `podman` older than `4.7.0`, please also replace `host.docker.internal` in the `OLLAMA_URL` with `host.containers.internal`.
|
|
||||||
|
|
||||||
Configuration for this is available at `distributions/ollama/run.yaml`.
|
|
||||||
|
|
||||||
```{admonition} Note
|
|
||||||
:class: note
|
|
||||||
|
|
||||||
Docker containers run in their own isolated network namespaces on Linux. To allow the container to communicate with services running on the host via `localhost`, you need `--network=host`. This makes the container use the host’s network directly so it can connect to Ollama running on `localhost:11434`.
|
|
||||||
|
|
||||||
Linux users having issues running the above command should instead try the following:
|
|
||||||
```bash
|
|
||||||
docker run -it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
--network=host \
|
|
||||||
llamastack/distribution-ollama \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL \
|
|
||||||
--env OLLAMA_URL=http://localhost:11434
|
|
||||||
```
|
|
||||||
|
|
||||||
:::
|
|
||||||
|
|
||||||
|
|
||||||
:::{dropdown} Installing the Llama Stack client CLI and SDK
|
|
||||||
|
|
||||||
You can interact with the Llama Stack server using various client SDKs. Note that you must be using Python 3.10 or newer. We will use the Python SDK which you can install via `conda` or `virtualenv`.
|
|
||||||
|
|
||||||
For `conda`:
|
|
||||||
```bash
|
|
||||||
yes | conda create -n stack-client python=3.10
|
|
||||||
conda activate stack-client
|
|
||||||
pip install llama-stack-client
|
|
||||||
```
|
|
||||||
|
|
||||||
For `virtualenv`:
|
|
||||||
```bash
|
|
||||||
python -m venv stack-client
|
|
||||||
source stack-client/bin/activate
|
|
||||||
pip install llama-stack-client
|
|
||||||
```
|
|
||||||
|
|
||||||
Let's use the `llama-stack-client` CLI to check the connectivity to the server.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
$ llama-stack-client configure --endpoint http://localhost:$LLAMA_STACK_PORT
|
|
||||||
> Enter the API key (leave empty if no key is needed):
|
|
||||||
Done! You can now use the Llama Stack Client CLI with endpoint http://localhost:8321
|
|
||||||
|
|
||||||
$ llama-stack-client models list
|
|
||||||
|
|
||||||
Available Models
|
|
||||||
|
|
||||||
┏━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━━┳━━━━━━━━━━━━━┓
|
|
||||||
┃ model_type ┃ identifier ┃ provider_resource_id ┃ metadata ┃ provider_id ┃
|
|
||||||
┡━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━━╇━━━━━━━━━━━━━┩
|
|
||||||
│ llm │ meta-llama/Llama-3.2-3B-Instruct │ llama3.2:3b-instruct-fp16 │ │ ollama │
|
|
||||||
└──────────────┴──────────────────────────────────────┴──────────────────────────────┴───────────┴─────────────┘
|
|
||||||
|
|
||||||
Total models: 1
|
|
||||||
```
|
|
||||||
|
|
||||||
You can test basic Llama inference completion using the CLI too.
|
|
||||||
```bash
|
|
||||||
llama-stack-client \
|
|
||||||
inference chat-completion \
|
|
||||||
--message "hello, what model are you?"
|
|
||||||
```
|
|
||||||
:::
|
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
### 3. Run inference with Python SDK
|
|
||||||
|
|
||||||
Here is a simple example to perform chat completions using the SDK.
|
|
||||||
```python
|
```python
|
||||||
import os
|
from llama_stack_client import Agent, AgentEventLogger, RAGDocument, LlamaStackClient
|
||||||
import sys
|
|
||||||
|
|
||||||
|
vector_db_id = "my_demo_vector_db"
|
||||||
|
client = LlamaStackClient(base_url="http://localhost:8321")
|
||||||
|
|
||||||
def create_http_client():
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
return LlamaStackClient(
|
|
||||||
base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_library_client(template="ollama"):
|
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient(template)
|
|
||||||
if not client.initialize():
|
|
||||||
print("llama stack not built properly")
|
|
||||||
sys.exit(1)
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
client = (
|
|
||||||
create_library_client()
|
|
||||||
) # or create_http_client() depending on the environment you picked
|
|
||||||
|
|
||||||
# List available models
|
|
||||||
models = client.models.list()
|
models = client.models.list()
|
||||||
print("--- Available models: ---")
|
|
||||||
for m in models:
|
|
||||||
print(f"- {m.identifier}")
|
|
||||||
print()
|
|
||||||
|
|
||||||
response = client.inference.chat_completion(
|
# Select the first LLM and first embedding models
|
||||||
model_id=os.environ["INFERENCE_MODEL"],
|
model_id = next(m for m in models if m.model_type == "llm").identifier
|
||||||
messages=[
|
embedding_model_id = (
|
||||||
{"role": "system", "content": "You are a helpful assistant."},
|
em := next(m for m in models if m.model_type == "embedding")
|
||||||
{"role": "user", "content": "Write a haiku about coding"},
|
).identifier
|
||||||
],
|
embedding_dimension = em.metadata["embedding_dimension"]
|
||||||
)
|
|
||||||
print(response.completion_message.content)
|
|
||||||
```
|
|
||||||
|
|
||||||
To run the above example, put the code in a file called `inference.py`, ensure your `conda` or `virtualenv` environment is active, and run the following:
|
_ = client.vector_dbs.register(
|
||||||
```bash
|
|
||||||
pip install llama_stack
|
|
||||||
llama stack build --template ollama --image-type <conda|venv>
|
|
||||||
python inference.py
|
|
||||||
```
|
|
||||||
|
|
||||||
### 4. Your first RAG agent
|
|
||||||
|
|
||||||
Here is an example of a simple RAG (Retrieval Augmented Generation) chatbot agent which can answer questions about TorchTune documentation.
|
|
||||||
|
|
||||||
```python
|
|
||||||
import os
|
|
||||||
import uuid
|
|
||||||
from termcolor import cprint
|
|
||||||
|
|
||||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
|
||||||
|
|
||||||
|
|
||||||
def create_http_client():
|
|
||||||
from llama_stack_client import LlamaStackClient
|
|
||||||
|
|
||||||
return LlamaStackClient(
|
|
||||||
base_url=f"http://localhost:{os.environ['LLAMA_STACK_PORT']}"
|
|
||||||
)
|
|
||||||
|
|
||||||
|
|
||||||
def create_library_client(template="ollama"):
|
|
||||||
from llama_stack import LlamaStackAsLibraryClient
|
|
||||||
|
|
||||||
client = LlamaStackAsLibraryClient(template)
|
|
||||||
client.initialize()
|
|
||||||
return client
|
|
||||||
|
|
||||||
|
|
||||||
client = (
|
|
||||||
create_library_client()
|
|
||||||
) # or create_http_client() depending on the environment you picked
|
|
||||||
|
|
||||||
# Documents to be used for RAG
|
|
||||||
urls = ["chat.rst", "llama3.rst", "memory_optimizations.rst", "lora_finetune.rst"]
|
|
||||||
documents = [
|
|
||||||
RAGDocument(
|
|
||||||
document_id=f"num-{i}",
|
|
||||||
content=f"https://raw.githubusercontent.com/pytorch/torchtune/main/docs/source/tutorials/{url}",
|
|
||||||
mime_type="text/plain",
|
|
||||||
metadata={},
|
|
||||||
)
|
|
||||||
for i, url in enumerate(urls)
|
|
||||||
]
|
|
||||||
|
|
||||||
vector_providers = [
|
|
||||||
provider for provider in client.providers.list() if provider.api == "vector_io"
|
|
||||||
]
|
|
||||||
provider_id = vector_providers[0].provider_id # Use the first available vector provider
|
|
||||||
|
|
||||||
# Register a vector database
|
|
||||||
vector_db_id = f"test-vector-db-{uuid.uuid4().hex}"
|
|
||||||
client.vector_dbs.register(
|
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
provider_id=provider_id,
|
embedding_model=embedding_model_id,
|
||||||
embedding_model="all-MiniLM-L6-v2",
|
embedding_dimension=embedding_dimension,
|
||||||
embedding_dimension=384,
|
provider_id="faiss",
|
||||||
|
)
|
||||||
|
source = "https://www.paulgraham.com/greatwork.html"
|
||||||
|
print("rag_tool> Ingesting document:", source)
|
||||||
|
document = RAGDocument(
|
||||||
|
document_id="document_1",
|
||||||
|
content=source,
|
||||||
|
mime_type="text/html",
|
||||||
|
metadata={},
|
||||||
)
|
)
|
||||||
|
|
||||||
# Insert the documents into the vector database
|
|
||||||
client.tool_runtime.rag_tool.insert(
|
client.tool_runtime.rag_tool.insert(
|
||||||
documents=documents,
|
documents=[document],
|
||||||
vector_db_id=vector_db_id,
|
vector_db_id=vector_db_id,
|
||||||
chunk_size_in_tokens=512,
|
chunk_size_in_tokens=50,
|
||||||
)
|
)
|
||||||
|
agent = Agent(
|
||||||
rag_agent = Agent(
|
|
||||||
client,
|
client,
|
||||||
model=os.environ["INFERENCE_MODEL"],
|
model=model_id,
|
||||||
# Define instructions for the agent ( aka system prompt)
|
|
||||||
instructions="You are a helpful assistant",
|
instructions="You are a helpful assistant",
|
||||||
enable_session_persistence=False,
|
|
||||||
# Define tools available to the agent
|
|
||||||
tools=[
|
tools=[
|
||||||
{
|
{
|
||||||
"name": "builtin::rag/knowledge_search",
|
"name": "builtin::rag/knowledge_search",
|
||||||
"args": {
|
"args": {"vector_db_ids": [vector_db_id]},
|
||||||
"vector_db_ids": [vector_db_id],
|
|
||||||
},
|
|
||||||
}
|
}
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
session_id = rag_agent.create_session("test-session")
|
|
||||||
|
|
||||||
user_prompts = [
|
prompt = "How do you do great work?"
|
||||||
"How to optimize memory usage in torchtune? use the knowledge_search tool to get information.",
|
print("prompt>", prompt)
|
||||||
]
|
|
||||||
|
|
||||||
# Run the agent loop by calling the `create_turn` method
|
response = agent.create_turn(
|
||||||
for prompt in user_prompts:
|
messages=[{"role": "user", "content": prompt}],
|
||||||
cprint(f"User> {prompt}", "green")
|
session_id=agent.create_session("rag_session"),
|
||||||
response = rag_agent.create_turn(
|
stream=True,
|
||||||
messages=[{"role": "user", "content": prompt}],
|
)
|
||||||
session_id=session_id,
|
|
||||||
)
|
for log in AgentEventLogger().log(response):
|
||||||
for log in AgentEventLogger().log(response):
|
log.print()
|
||||||
log.print()
|
|
||||||
```
|
```
|
||||||
|
We will use `uv` to run the script
|
||||||
To run the above example, put the code in a file called `rag.py`, ensure your `conda` or `virtualenv` environment is active, and run the following:
|
|
||||||
```bash
|
|
||||||
pip install llama_stack
|
|
||||||
llama stack build --template ollama --image-type <conda|venv>
|
|
||||||
python rag.py
|
|
||||||
```
|
```
|
||||||
|
uv run --with llama-stack-client demo_script.py
|
||||||
|
```
|
||||||
|
And you should see output like below.
|
||||||
|
```
|
||||||
|
rag_tool> Ingesting document: https://www.paulgraham.com/greatwork.html
|
||||||
|
|
||||||
|
prompt> How do you do great work?
|
||||||
|
|
||||||
|
inference> [knowledge_search(query="What is the key to doing great work")]
|
||||||
|
|
||||||
|
tool_execution> Tool:knowledge_search Args:{'query': 'What is the key to doing great work'}
|
||||||
|
|
||||||
|
tool_execution> Tool:knowledge_search Response:[TextContentItem(text='knowledge_search tool found 5 chunks:\nBEGIN of knowledge_search tool results.\n', type='text'), TextContentItem(text="Result 1:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 2:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 3:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 4:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text="Result 5:\nDocument_id:docum\nContent: work. Doing great work means doing something important\nso well that you expand people's ideas of what's possible. But\nthere's no threshold for importance. It's a matter of degree, and\noften hard to judge at the time anyway.\n", type='text'), TextContentItem(text='END of knowledge_search tool results.\n', type='text')]
|
||||||
|
|
||||||
|
inference> Based on the search results, it seems that doing great work means doing something important so well that you expand people's ideas of what's possible. However, there is no clear threshold for importance, and it can be difficult to judge at the time.
|
||||||
|
|
||||||
|
To further clarify, I would suggest that doing great work involves:
|
||||||
|
|
||||||
|
* Completing tasks with high quality and attention to detail
|
||||||
|
* Expanding on existing knowledge or ideas
|
||||||
|
* Making a positive impact on others through your work
|
||||||
|
* Striving for excellence and continuous improvement
|
||||||
|
|
||||||
|
Ultimately, great work is about making a meaningful contribution and leaving a lasting impression.
|
||||||
|
```
|
||||||
|
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
|
||||||
|
|
||||||
## Next Steps
|
## Next Steps
|
||||||
|
|
||||||
- Learn more about Llama Stack [Concepts](../concepts/index.md)
|
Now you're ready to dive deeper into Llama Stack!
|
||||||
- Learn how to [Build Llama Stacks](../distributions/index.md)
|
- Explore the [Detailed Tutorial](./detailed_tutorial.md).
|
||||||
- See [References](../references/index.md) for more details about the llama CLI and Python SDK
|
- Try the [Getting Started Notebook](https://github.com/meta-llama/llama-stack/blob/main/docs/getting_started.ipynb).
|
||||||
- For example applications and more detailed tutorials, visit our [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository.
|
- Browse more [Notebooks on GitHub](https://github.com/meta-llama/llama-stack/tree/main/docs/notebooks).
|
||||||
|
- Learn about Llama Stack [Concepts](../concepts/index.md).
|
||||||
|
- Discover how to [Build Llama Stacks](../distributions/index.md).
|
||||||
|
- Refer to our [References](../references/index.md) for details on the Llama CLI and Python SDK.
|
||||||
|
- Check out the [llama-stack-apps](https://github.com/meta-llama/llama-stack-apps/tree/main/examples) repository for example applications and tutorials.
|
||||||
|
|
|
@ -1,10 +1,16 @@
|
||||||
|
# Llama Stack
|
||||||
|
Welcome to Llama Stack, the open-source framework for building generative AI applications.
|
||||||
|
```{admonition} Llama 4 is here!
|
||||||
|
:class: tip
|
||||||
|
|
||||||
|
Check out [Getting Started with Llama 4](https://colab.research.google.com/github/meta-llama/llama-stack/blob/main/docs/getting_started_llama4.ipynb)
|
||||||
|
```
|
||||||
```{admonition} News
|
```{admonition} News
|
||||||
:class: tip
|
:class: tip
|
||||||
|
|
||||||
Llama Stack {{ llama_stack_version }} is now available! See the {{ llama_stack_version_link }} for more details.
|
Llama Stack {{ llama_stack_version }} is now available! See the {{ llama_stack_version_link }} for more details.
|
||||||
```
|
```
|
||||||
|
|
||||||
# Llama Stack
|
|
||||||
|
|
||||||
## What is Llama Stack?
|
## What is Llama Stack?
|
||||||
|
|
||||||
|
@ -24,19 +30,17 @@ Llama Stack defines and standardizes the core building blocks needed to bring ge
|
||||||
Our goal is to provide pre-packaged implementations (aka "distributions") which can be run in a variety of deployment environments. LlamaStack can assist you in your entire app development lifecycle - start iterating on local, mobile or desktop and seamlessly transition to on-prem or public cloud deployments. At every point in this transition, the same set of APIs and the same developer experience is available.
|
Our goal is to provide pre-packaged implementations (aka "distributions") which can be run in a variety of deployment environments. LlamaStack can assist you in your entire app development lifecycle - start iterating on local, mobile or desktop and seamlessly transition to on-prem or public cloud deployments. At every point in this transition, the same set of APIs and the same developer experience is available.
|
||||||
|
|
||||||
## How does Llama Stack work?
|
## How does Llama Stack work?
|
||||||
Llama Stack consists of a [server](./distributions/index.md) (with multiple pluggable API [providers](./providers/index.md)) and [client SDKs](#available-sdks) meant to
|
Llama Stack consists of a [server](./distributions/index.md) (with multiple pluggable API [providers](./providers/index.md)) and Client SDKs (see below) meant to
|
||||||
be used in your applications. The server can be run in a variety of environments, including local (inline)
|
be used in your applications. The server can be run in a variety of environments, including local (inline)
|
||||||
development, on-premises, and cloud. The client SDKs are available for Python, Swift, Node, and
|
development, on-premises, and cloud. The client SDKs are available for Python, Swift, Node, and
|
||||||
Kotlin.
|
Kotlin.
|
||||||
|
|
||||||
## Quick Links
|
## Quick Links
|
||||||
|
|
||||||
- New to Llama Stack? Start with the [Introduction](introduction/index) to understand our motivation and vision.
|
|
||||||
- Ready to build? Check out the [Quick Start](getting_started/index) to get started.
|
- Ready to build? Check out the [Quick Start](getting_started/index) to get started.
|
||||||
- Need specific providers? Browse [Distributions](distributions/selection) to see all the options available.
|
|
||||||
- Want to contribute? See the [Contributing](contributing/index) guide.
|
- Want to contribute? See the [Contributing](contributing/index) guide.
|
||||||
|
|
||||||
## Available SDKs
|
## Client SDKs
|
||||||
|
|
||||||
We have a number of client-side SDKs available for different languages.
|
We have a number of client-side SDKs available for different languages.
|
||||||
|
|
||||||
|
@ -95,8 +99,9 @@ A number of "adapters" are available for some popular Inference and Vector Store
|
||||||
:maxdepth: 3
|
:maxdepth: 3
|
||||||
|
|
||||||
self
|
self
|
||||||
introduction/index
|
|
||||||
getting_started/index
|
getting_started/index
|
||||||
|
getting_started/detailed_tutorial
|
||||||
|
introduction/index
|
||||||
concepts/index
|
concepts/index
|
||||||
providers/index
|
providers/index
|
||||||
distributions/index
|
distributions/index
|
||||||
|
|
|
@ -103,7 +103,5 @@ llama stack run together
|
||||||
|
|
||||||
2. Start Streamlit UI
|
2. Start Streamlit UI
|
||||||
```bash
|
```bash
|
||||||
cd llama_stack/distribution/ui
|
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
|
||||||
pip install -r requirements.txt
|
|
||||||
streamlit run app.py
|
|
||||||
```
|
```
|
||||||
|
|
234
docs/source/providers/external.md
Normal file
234
docs/source/providers/external.md
Normal file
|
@ -0,0 +1,234 @@
|
||||||
|
# External Providers
|
||||||
|
|
||||||
|
Llama Stack supports external providers that live outside of the main codebase. This allows you to:
|
||||||
|
- Create and maintain your own providers independently
|
||||||
|
- Share providers with others without contributing to the main codebase
|
||||||
|
- Keep provider-specific code separate from the core Llama Stack code
|
||||||
|
|
||||||
|
## Configuration
|
||||||
|
|
||||||
|
To enable external providers, you need to configure the `external_providers_dir` in your Llama Stack configuration. This directory should contain your external provider specifications:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
external_providers_dir: /etc/llama-stack/providers.d/
|
||||||
|
```
|
||||||
|
|
||||||
|
## Directory Structure
|
||||||
|
|
||||||
|
The external providers directory should follow this structure:
|
||||||
|
|
||||||
|
```
|
||||||
|
providers.d/
|
||||||
|
remote/
|
||||||
|
inference/
|
||||||
|
custom_ollama.yaml
|
||||||
|
vllm.yaml
|
||||||
|
vector_io/
|
||||||
|
qdrant.yaml
|
||||||
|
safety/
|
||||||
|
llama-guard.yaml
|
||||||
|
inline/
|
||||||
|
inference/
|
||||||
|
custom_ollama.yaml
|
||||||
|
vllm.yaml
|
||||||
|
vector_io/
|
||||||
|
qdrant.yaml
|
||||||
|
safety/
|
||||||
|
llama-guard.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
Each YAML file in these directories defines a provider specification for that particular API.
|
||||||
|
|
||||||
|
## Provider Types
|
||||||
|
|
||||||
|
Llama Stack supports two types of external providers:
|
||||||
|
|
||||||
|
1. **Remote Providers**: Providers that communicate with external services (e.g., cloud APIs)
|
||||||
|
2. **Inline Providers**: Providers that run locally within the Llama Stack process
|
||||||
|
|
||||||
|
## Known External Providers
|
||||||
|
|
||||||
|
Here's a list of known external providers that you can use with Llama Stack:
|
||||||
|
|
||||||
|
| Type | Name | Description | Repository |
|
||||||
|
|------|------|-------------|------------|
|
||||||
|
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||||
|
|
||||||
|
### Remote Provider Specification
|
||||||
|
|
||||||
|
Remote providers are used when you need to communicate with external services. Here's an example for a custom Ollama provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
adapter:
|
||||||
|
adapter_type: custom_ollama
|
||||||
|
pip_packages:
|
||||||
|
- ollama
|
||||||
|
- aiohttp
|
||||||
|
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
|
||||||
|
module: llama_stack_ollama_provider
|
||||||
|
api_dependencies: []
|
||||||
|
optional_api_dependencies: []
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Adapter Configuration
|
||||||
|
|
||||||
|
The `adapter` section defines how to load and configure the provider:
|
||||||
|
|
||||||
|
- `adapter_type`: A unique identifier for this adapter
|
||||||
|
- `pip_packages`: List of Python packages required by the provider
|
||||||
|
- `config_class`: The full path to the configuration class
|
||||||
|
- `module`: The Python module containing the provider implementation
|
||||||
|
|
||||||
|
### Inline Provider Specification
|
||||||
|
|
||||||
|
Inline providers run locally within the Llama Stack process. Here's an example for a custom vector store provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
module: llama_stack_vector_provider
|
||||||
|
config_class: llama_stack_vector_provider.config.VectorStoreConfig
|
||||||
|
pip_packages:
|
||||||
|
- faiss-cpu
|
||||||
|
- numpy
|
||||||
|
api_dependencies:
|
||||||
|
- inference
|
||||||
|
optional_api_dependencies:
|
||||||
|
- vector_io
|
||||||
|
provider_data_validator: llama_stack_vector_provider.validator.VectorStoreValidator
|
||||||
|
container_image: custom-vector-store:latest # optional
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Inline Provider Fields
|
||||||
|
|
||||||
|
- `module`: The Python module containing the provider implementation
|
||||||
|
- `config_class`: The full path to the configuration class
|
||||||
|
- `pip_packages`: List of Python packages required by the provider
|
||||||
|
- `api_dependencies`: List of Llama Stack APIs that this provider depends on
|
||||||
|
- `optional_api_dependencies`: List of optional Llama Stack APIs that this provider can use
|
||||||
|
- `provider_data_validator`: Optional validator for provider data
|
||||||
|
- `container_image`: Optional container image to use instead of pip packages
|
||||||
|
|
||||||
|
## Required Implementation
|
||||||
|
|
||||||
|
### Remote Providers
|
||||||
|
|
||||||
|
Remote providers must expose a `get_adapter_impl()` function in their module that takes two arguments:
|
||||||
|
1. `config`: An instance of the provider's config class
|
||||||
|
2. `deps`: A dictionary of API dependencies
|
||||||
|
|
||||||
|
This function must return an instance of the provider's adapter class that implements the required protocol for the API.
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
async def get_adapter_impl(
|
||||||
|
config: OllamaImplConfig, deps: Dict[Api, Any]
|
||||||
|
) -> OllamaInferenceAdapter:
|
||||||
|
return OllamaInferenceAdapter(config)
|
||||||
|
```
|
||||||
|
|
||||||
|
### Inline Providers
|
||||||
|
|
||||||
|
Inline providers must expose a `get_provider_impl()` function in their module that takes two arguments:
|
||||||
|
1. `config`: An instance of the provider's config class
|
||||||
|
2. `deps`: A dictionary of API dependencies
|
||||||
|
|
||||||
|
Example:
|
||||||
|
```python
|
||||||
|
async def get_provider_impl(
|
||||||
|
config: VectorStoreConfig, deps: Dict[Api, Any]
|
||||||
|
) -> VectorStoreImpl:
|
||||||
|
impl = VectorStoreImpl(config, deps[Api.inference])
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
```
|
||||||
|
|
||||||
|
## Dependencies
|
||||||
|
|
||||||
|
The provider package must be installed on the system. For example:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
$ uv pip show llama-stack-ollama-provider
|
||||||
|
Name: llama-stack-ollama-provider
|
||||||
|
Version: 0.1.0
|
||||||
|
Location: /path/to/venv/lib/python3.10/site-packages
|
||||||
|
```
|
||||||
|
|
||||||
|
## Example: Custom Ollama Provider
|
||||||
|
|
||||||
|
Here's a complete example of creating and using a custom Ollama provider:
|
||||||
|
|
||||||
|
1. First, create the provider package:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
mkdir -p llama-stack-provider-ollama
|
||||||
|
cd llama-stack-provider-ollama
|
||||||
|
git init
|
||||||
|
uv init
|
||||||
|
```
|
||||||
|
|
||||||
|
2. Edit `pyproject.toml`:
|
||||||
|
|
||||||
|
```toml
|
||||||
|
[project]
|
||||||
|
name = "llama-stack-provider-ollama"
|
||||||
|
version = "0.1.0"
|
||||||
|
description = "Ollama provider for Llama Stack"
|
||||||
|
requires-python = ">=3.10"
|
||||||
|
dependencies = ["llama-stack", "pydantic", "ollama", "aiohttp"]
|
||||||
|
```
|
||||||
|
|
||||||
|
3. Create the provider specification:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# /etc/llama-stack/providers.d/remote/inference/custom_ollama.yaml
|
||||||
|
adapter:
|
||||||
|
adapter_type: custom_ollama
|
||||||
|
pip_packages: ["ollama", "aiohttp"]
|
||||||
|
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||||
|
module: llama_stack_provider_ollama
|
||||||
|
api_dependencies: []
|
||||||
|
optional_api_dependencies: []
|
||||||
|
```
|
||||||
|
|
||||||
|
4. Install the provider:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
uv pip install -e .
|
||||||
|
```
|
||||||
|
|
||||||
|
5. Configure Llama Stack to use external providers:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
external_providers_dir: /etc/llama-stack/providers.d/
|
||||||
|
```
|
||||||
|
|
||||||
|
The provider will now be available in Llama Stack with the type `remote::custom_ollama`.
|
||||||
|
|
||||||
|
## Best Practices
|
||||||
|
|
||||||
|
1. **Package Naming**: Use the prefix `llama-stack-provider-` for your provider packages to make them easily identifiable.
|
||||||
|
|
||||||
|
2. **Version Management**: Keep your provider package versioned and compatible with the Llama Stack version you're using.
|
||||||
|
|
||||||
|
3. **Dependencies**: Only include the minimum required dependencies in your provider package.
|
||||||
|
|
||||||
|
4. **Documentation**: Include clear documentation in your provider package about:
|
||||||
|
- Installation requirements
|
||||||
|
- Configuration options
|
||||||
|
- Usage examples
|
||||||
|
- Any limitations or known issues
|
||||||
|
|
||||||
|
5. **Testing**: Include tests in your provider package to ensure it works correctly with Llama Stack.
|
||||||
|
You can refer to the [integration tests
|
||||||
|
guide](https://github.com/meta-llama/llama-stack/blob/main/tests/integration/README.md) for more
|
||||||
|
information. Execute the test for the Provider type you are developing.
|
||||||
|
|
||||||
|
## Troubleshooting
|
||||||
|
|
||||||
|
If your external provider isn't being loaded:
|
||||||
|
|
||||||
|
1. Check that the `external_providers_dir` path is correct and accessible.
|
||||||
|
2. Verify that the YAML files are properly formatted.
|
||||||
|
3. Ensure all required Python packages are installed.
|
||||||
|
4. Check the Llama Stack server logs for any error messages - turn on debug logging to get more
|
||||||
|
information using `LLAMA_STACK_LOGGING=all=debug`.
|
||||||
|
5. Verify that the provider package is installed in your Python environment.
|
|
@ -1,8 +1,8 @@
|
||||||
# Providers Overview
|
# Providers Overview
|
||||||
|
|
||||||
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
The goal of Llama Stack is to build an ecosystem where users can easily swap out different implementations for the same API. Examples for these include:
|
||||||
- LLM inference providers (e.g., Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
- LLM inference providers (e.g., Ollama, Fireworks, Together, AWS Bedrock, Groq, Cerebras, SambaNova, vLLM, etc.),
|
||||||
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, etc.),
|
- Vector databases (e.g., ChromaDB, Weaviate, Qdrant, Milvus, FAISS, PGVector, SQLite-Vec, etc.),
|
||||||
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
- Safety providers (e.g., Meta's Llama Guard, AWS Bedrock Guardrails, etc.)
|
||||||
|
|
||||||
Providers come in two flavors:
|
Providers come in two flavors:
|
||||||
|
@ -11,6 +11,10 @@ Providers come in two flavors:
|
||||||
|
|
||||||
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
Importantly, Llama Stack always strives to provide at least one fully inline provider for each API so you can iterate on a fully featured environment locally.
|
||||||
|
|
||||||
|
## External Providers
|
||||||
|
|
||||||
|
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently. See the [External Providers Guide](external) for details.
|
||||||
|
|
||||||
## Agents
|
## Agents
|
||||||
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
Run multi-step agentic workflows with LLMs with tool usage, memory (RAG), etc.
|
||||||
|
|
||||||
|
@ -50,6 +54,7 @@ The following providers (i.e., databases) are available for Vector IO:
|
||||||
```{toctree}
|
```{toctree}
|
||||||
:maxdepth: 1
|
:maxdepth: 1
|
||||||
|
|
||||||
|
external
|
||||||
vector_io/faiss
|
vector_io/faiss
|
||||||
vector_io/sqlite-vec
|
vector_io/sqlite-vec
|
||||||
vector_io/chromadb
|
vector_io/chromadb
|
||||||
|
|
|
@ -6,11 +6,8 @@
|
||||||
|
|
||||||
from typing import List, Optional, Protocol, runtime_checkable
|
from typing import List, Optional, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from llama_stack.apis.common.job_types import Job
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionResponse,
|
|
||||||
CompletionResponse,
|
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
LogProbConfig,
|
LogProbConfig,
|
||||||
Message,
|
Message,
|
||||||
|
@ -20,41 +17,39 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import webmethod
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BatchCompletionResponse(BaseModel):
|
|
||||||
batch: List[CompletionResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class BatchChatCompletionResponse(BaseModel):
|
|
||||||
batch: List[ChatCompletionResponse]
|
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
class BatchInference(Protocol):
|
class BatchInference(Protocol):
|
||||||
|
"""Batch inference API for generating completions and chat completions.
|
||||||
|
|
||||||
|
This is an asynchronous API. If the request is successful, the response will be a job which can be polled for completion.
|
||||||
|
|
||||||
|
NOTE: This API is not yet implemented and is subject to change in concert with other asynchronous APIs
|
||||||
|
including (post-training, evals, etc).
|
||||||
|
"""
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/completion", method="POST")
|
@webmethod(route="/batch-inference/completion", method="POST")
|
||||||
async def batch_completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
content_batch: List[InterleavedContent],
|
content_batch: List[InterleavedContent],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchCompletionResponse: ...
|
) -> Job: ...
|
||||||
|
|
||||||
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
@webmethod(route="/batch-inference/chat-completion", method="POST")
|
||||||
async def batch_chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages_batch: List[List[Message]],
|
messages_batch: List[List[Message]],
|
||||||
sampling_params: Optional[SamplingParams] = None,
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
# zero-shot tool definitions as input to the model
|
# zero-shot tool definitions as input to the model
|
||||||
tools: Optional[List[ToolDefinition]] = list,
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
response_format: Optional[ResponseFormat] = None,
|
response_format: Optional[ResponseFormat] = None,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> BatchChatCompletionResponse: ...
|
) -> Job: ...
|
||||||
|
|
|
@ -18,22 +18,71 @@ from typing import (
|
||||||
)
|
)
|
||||||
|
|
||||||
from pydantic import BaseModel, Field, field_validator
|
from pydantic import BaseModel, Field, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated, TypedDict
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
from llama_stack.apis.common.content_types import ContentDelta, InterleavedContent, InterleavedContentItem
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
from llama_stack.apis.telemetry.telemetry import MetricResponseMixin
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.models.llama.datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
SamplingParams,
|
|
||||||
StopReason,
|
StopReason,
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
|
ToolParamDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
from llama_stack.schema_utils import json_schema_type, register_schema, webmethod
|
||||||
|
|
||||||
|
register_schema(ToolCall)
|
||||||
|
register_schema(ToolParamDefinition)
|
||||||
|
register_schema(ToolDefinition)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class GreedySamplingStrategy(BaseModel):
|
||||||
|
type: Literal["greedy"] = "greedy"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopPSamplingStrategy(BaseModel):
|
||||||
|
type: Literal["top_p"] = "top_p"
|
||||||
|
temperature: Optional[float] = Field(..., gt=0.0)
|
||||||
|
top_p: Optional[float] = 0.95
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class TopKSamplingStrategy(BaseModel):
|
||||||
|
type: Literal["top_k"] = "top_k"
|
||||||
|
top_k: int = Field(..., ge=1)
|
||||||
|
|
||||||
|
|
||||||
|
SamplingStrategy = Annotated[
|
||||||
|
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(SamplingStrategy, name="SamplingStrategy")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class SamplingParams(BaseModel):
|
||||||
|
"""Sampling parameters.
|
||||||
|
|
||||||
|
:param strategy: The sampling strategy.
|
||||||
|
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
||||||
|
your prompt plus max_tokens cannot exceed the model's context length.
|
||||||
|
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
||||||
|
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
||||||
|
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
||||||
|
The returned text will not contain the stop sequence.
|
||||||
|
"""
|
||||||
|
|
||||||
|
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
||||||
|
|
||||||
|
max_tokens: Optional[int] = 0
|
||||||
|
repetition_penalty: Optional[float] = 1.0
|
||||||
|
stop: Optional[List[str]] = None
|
||||||
|
|
||||||
|
|
||||||
class LogProbConfig(BaseModel):
|
class LogProbConfig(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -48,18 +97,18 @@ class QuantizationType(Enum):
|
||||||
"""Type of model quantization to run inference with.
|
"""Type of model quantization to run inference with.
|
||||||
|
|
||||||
:cvar bf16: BFloat16 typically this means _no_ quantization
|
:cvar bf16: BFloat16 typically this means _no_ quantization
|
||||||
:cvar fp8: 8-bit floating point quantization
|
:cvar fp8_mixed: 8-bit floating point quantization with mixed precision
|
||||||
:cvar int4: 4-bit integer quantization
|
:cvar int4_mixed: 4-bit integer quantization with mixed precision
|
||||||
"""
|
"""
|
||||||
|
|
||||||
bf16 = "bf16"
|
bf16 = "bf16"
|
||||||
fp8 = "fp8"
|
fp8_mixed = "fp8_mixed"
|
||||||
int4 = "int4"
|
int4_mixed = "int4_mixed"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class Fp8QuantizationConfig(BaseModel):
|
class Fp8QuantizationConfig(BaseModel):
|
||||||
type: Literal["fp8"] = "fp8"
|
type: Literal["fp8_mixed"] = "fp8_mixed"
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
@ -75,7 +124,7 @@ class Int4QuantizationConfig(BaseModel):
|
||||||
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
|
:param scheme: Quantization scheme to use. Defaults to "int4_weight_int8_dynamic_activation"
|
||||||
"""
|
"""
|
||||||
|
|
||||||
type: Literal["int4"] = "int4"
|
type: Literal["int4_mixed"] = "int4_mixed"
|
||||||
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
scheme: Optional[str] = "int4_weight_int8_dynamic_activation"
|
||||||
|
|
||||||
|
|
||||||
|
@ -393,6 +442,352 @@ class EmbeddingsResponse(BaseModel):
|
||||||
embeddings: List[List[float]]
|
embeddings: List[List[float]]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionContentPartTextParam(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
text: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIImageURL(BaseModel):
|
||||||
|
url: str
|
||||||
|
detail: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionContentPartImageParam(BaseModel):
|
||||||
|
type: Literal["image_url"] = "image_url"
|
||||||
|
image_url: OpenAIImageURL
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIChatCompletionContentPartParam = Annotated[
|
||||||
|
Union[
|
||||||
|
OpenAIChatCompletionContentPartTextParam,
|
||||||
|
OpenAIChatCompletionContentPartImageParam,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIChatCompletionContentPartParam, name="OpenAIChatCompletionContentPartParam")
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIChatCompletionMessageContent = Union[str, List[OpenAIChatCompletionContentPartParam]]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIUserMessageParam(BaseModel):
|
||||||
|
"""A message from the user in an OpenAI-compatible chat completion request.
|
||||||
|
|
||||||
|
:param role: Must be "user" to identify this as a user message
|
||||||
|
:param content: The content of the message, which can include text and other media
|
||||||
|
:param name: (Optional) The name of the user message participant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: Literal["user"] = "user"
|
||||||
|
content: OpenAIChatCompletionMessageContent
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAISystemMessageParam(BaseModel):
|
||||||
|
"""A system message providing instructions or context to the model.
|
||||||
|
|
||||||
|
:param role: Must be "system" to identify this as a system message
|
||||||
|
:param content: The content of the "system prompt". If multiple system messages are provided, they are concatenated. The underlying Llama Stack code may also add other system messages (for example, for formatting tool definitions).
|
||||||
|
:param name: (Optional) The name of the system message participant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: Literal["system"] = "system"
|
||||||
|
content: OpenAIChatCompletionMessageContent
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionToolCallFunction(BaseModel):
|
||||||
|
name: Optional[str] = None
|
||||||
|
arguments: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionToolCall(BaseModel):
|
||||||
|
index: Optional[int] = None
|
||||||
|
id: Optional[str] = None
|
||||||
|
type: Literal["function"] = "function"
|
||||||
|
function: Optional[OpenAIChatCompletionToolCallFunction] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIAssistantMessageParam(BaseModel):
|
||||||
|
"""A message containing the model's (assistant) response in an OpenAI-compatible chat completion request.
|
||||||
|
|
||||||
|
:param role: Must be "assistant" to identify this as the model's response
|
||||||
|
:param content: The content of the model's response
|
||||||
|
:param name: (Optional) The name of the assistant message participant.
|
||||||
|
:param tool_calls: List of tool calls. Each tool call is an OpenAIChatCompletionToolCall object.
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: Literal["assistant"] = "assistant"
|
||||||
|
content: OpenAIChatCompletionMessageContent
|
||||||
|
name: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIToolMessageParam(BaseModel):
|
||||||
|
"""A message representing the result of a tool invocation in an OpenAI-compatible chat completion request.
|
||||||
|
|
||||||
|
:param role: Must be "tool" to identify this as a tool response
|
||||||
|
:param tool_call_id: Unique identifier for the tool call this response is for
|
||||||
|
:param content: The response content from the tool
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: Literal["tool"] = "tool"
|
||||||
|
tool_call_id: str
|
||||||
|
content: OpenAIChatCompletionMessageContent
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIDeveloperMessageParam(BaseModel):
|
||||||
|
"""A message from the developer in an OpenAI-compatible chat completion request.
|
||||||
|
|
||||||
|
:param role: Must be "developer" to identify this as a developer message
|
||||||
|
:param content: The content of the developer message
|
||||||
|
:param name: (Optional) The name of the developer message participant.
|
||||||
|
"""
|
||||||
|
|
||||||
|
role: Literal["developer"] = "developer"
|
||||||
|
content: OpenAIChatCompletionMessageContent
|
||||||
|
name: Optional[str] = None
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIMessageParam = Annotated[
|
||||||
|
Union[
|
||||||
|
OpenAIUserMessageParam,
|
||||||
|
OpenAISystemMessageParam,
|
||||||
|
OpenAIAssistantMessageParam,
|
||||||
|
OpenAIToolMessageParam,
|
||||||
|
OpenAIDeveloperMessageParam,
|
||||||
|
],
|
||||||
|
Field(discriminator="role"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIMessageParam, name="OpenAIMessageParam")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseFormatText(BaseModel):
|
||||||
|
type: Literal["text"] = "text"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIJSONSchema(TypedDict, total=False):
|
||||||
|
name: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
strict: Optional[bool] = None
|
||||||
|
|
||||||
|
# Pydantic BaseModel cannot be used with a schema param, since it already
|
||||||
|
# has one. And, we don't want to alias here because then have to handle
|
||||||
|
# that alias when converting to OpenAI params. So, to support schema,
|
||||||
|
# we use a TypedDict.
|
||||||
|
schema: Optional[Dict[str, Any]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseFormatJSONSchema(BaseModel):
|
||||||
|
type: Literal["json_schema"] = "json_schema"
|
||||||
|
json_schema: OpenAIJSONSchema
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIResponseFormatJSONObject(BaseModel):
|
||||||
|
type: Literal["json_object"] = "json_object"
|
||||||
|
|
||||||
|
|
||||||
|
OpenAIResponseFormatParam = Annotated[
|
||||||
|
Union[
|
||||||
|
OpenAIResponseFormatText,
|
||||||
|
OpenAIResponseFormatJSONSchema,
|
||||||
|
OpenAIResponseFormatJSONObject,
|
||||||
|
],
|
||||||
|
Field(discriminator="type"),
|
||||||
|
]
|
||||||
|
register_schema(OpenAIResponseFormatParam, name="OpenAIResponseFormatParam")
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAITopLogProb(BaseModel):
|
||||||
|
"""The top log probability for a token from an OpenAI-compatible chat completion response.
|
||||||
|
|
||||||
|
:token: The token
|
||||||
|
:bytes: (Optional) The bytes for the token
|
||||||
|
:logprob: The log probability of the token
|
||||||
|
"""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
bytes: Optional[List[int]] = None
|
||||||
|
logprob: float
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAITokenLogProb(BaseModel):
|
||||||
|
"""The log probability for a token from an OpenAI-compatible chat completion response.
|
||||||
|
|
||||||
|
:token: The token
|
||||||
|
:bytes: (Optional) The bytes for the token
|
||||||
|
:logprob: The log probability of the token
|
||||||
|
:top_logprobs: The top log probabilities for the token
|
||||||
|
"""
|
||||||
|
|
||||||
|
token: str
|
||||||
|
bytes: Optional[List[int]] = None
|
||||||
|
logprob: float
|
||||||
|
top_logprobs: List[OpenAITopLogProb]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChoiceLogprobs(BaseModel):
|
||||||
|
"""The log probabilities for the tokens in the message from an OpenAI-compatible chat completion response.
|
||||||
|
|
||||||
|
:param content: (Optional) The log probabilities for the tokens in the message
|
||||||
|
:param refusal: (Optional) The log probabilities for the tokens in the message
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: Optional[List[OpenAITokenLogProb]] = None
|
||||||
|
refusal: Optional[List[OpenAITokenLogProb]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChoiceDelta(BaseModel):
|
||||||
|
"""A delta from an OpenAI-compatible chat completion streaming response.
|
||||||
|
|
||||||
|
:param content: (Optional) The content of the delta
|
||||||
|
:param refusal: (Optional) The refusal of the delta
|
||||||
|
:param role: (Optional) The role of the delta
|
||||||
|
:param tool_calls: (Optional) The tool calls of the delta
|
||||||
|
"""
|
||||||
|
|
||||||
|
content: Optional[str] = None
|
||||||
|
refusal: Optional[str] = None
|
||||||
|
role: Optional[str] = None
|
||||||
|
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChunkChoice(BaseModel):
|
||||||
|
"""A chunk choice from an OpenAI-compatible chat completion streaming response.
|
||||||
|
|
||||||
|
:param delta: The delta from the chunk
|
||||||
|
:param finish_reason: The reason the model stopped generating
|
||||||
|
:param index: The index of the choice
|
||||||
|
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||||
|
"""
|
||||||
|
|
||||||
|
delta: OpenAIChoiceDelta
|
||||||
|
finish_reason: str
|
||||||
|
index: int
|
||||||
|
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChoice(BaseModel):
|
||||||
|
"""A choice from an OpenAI-compatible chat completion response.
|
||||||
|
|
||||||
|
:param message: The message from the model
|
||||||
|
:param finish_reason: The reason the model stopped generating
|
||||||
|
:param index: The index of the choice
|
||||||
|
:param logprobs: (Optional) The log probabilities for the tokens in the message
|
||||||
|
"""
|
||||||
|
|
||||||
|
message: OpenAIMessageParam
|
||||||
|
finish_reason: str
|
||||||
|
index: int
|
||||||
|
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletion(BaseModel):
|
||||||
|
"""Response from an OpenAI-compatible chat completion request.
|
||||||
|
|
||||||
|
:param id: The ID of the chat completion
|
||||||
|
:param choices: List of choices
|
||||||
|
:param object: The object type, which will be "chat.completion"
|
||||||
|
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||||
|
:param model: The model that was used to generate the chat completion
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
choices: List[OpenAIChoice]
|
||||||
|
object: Literal["chat.completion"] = "chat.completion"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIChatCompletionChunk(BaseModel):
|
||||||
|
"""Chunk from a streaming response to an OpenAI-compatible chat completion request.
|
||||||
|
|
||||||
|
:param id: The ID of the chat completion
|
||||||
|
:param choices: List of choices
|
||||||
|
:param object: The object type, which will be "chat.completion.chunk"
|
||||||
|
:param created: The Unix timestamp in seconds when the chat completion was created
|
||||||
|
:param model: The model that was used to generate the chat completion
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
choices: List[OpenAIChunkChoice]
|
||||||
|
object: Literal["chat.completion.chunk"] = "chat.completion.chunk"
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAICompletionLogprobs(BaseModel):
|
||||||
|
"""The log probabilities for the tokens in the message from an OpenAI-compatible completion response.
|
||||||
|
|
||||||
|
:text_offset: (Optional) The offset of the token in the text
|
||||||
|
:token_logprobs: (Optional) The log probabilities for the tokens
|
||||||
|
:tokens: (Optional) The tokens
|
||||||
|
:top_logprobs: (Optional) The top log probabilities for the tokens
|
||||||
|
"""
|
||||||
|
|
||||||
|
text_offset: Optional[List[int]] = None
|
||||||
|
token_logprobs: Optional[List[float]] = None
|
||||||
|
tokens: Optional[List[str]] = None
|
||||||
|
top_logprobs: Optional[List[Dict[str, float]]] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAICompletionChoice(BaseModel):
|
||||||
|
"""A choice from an OpenAI-compatible completion response.
|
||||||
|
|
||||||
|
:finish_reason: The reason the model stopped generating
|
||||||
|
:text: The text of the choice
|
||||||
|
:index: The index of the choice
|
||||||
|
:logprobs: (Optional) The log probabilities for the tokens in the choice
|
||||||
|
"""
|
||||||
|
|
||||||
|
finish_reason: str
|
||||||
|
text: str
|
||||||
|
index: int
|
||||||
|
logprobs: Optional[OpenAIChoiceLogprobs] = None
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAICompletion(BaseModel):
|
||||||
|
"""Response from an OpenAI-compatible completion request.
|
||||||
|
|
||||||
|
:id: The ID of the completion
|
||||||
|
:choices: List of choices
|
||||||
|
:created: The Unix timestamp in seconds when the completion was created
|
||||||
|
:model: The model that was used to generate the completion
|
||||||
|
:object: The object type, which will be "text_completion"
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
choices: List[OpenAICompletionChoice]
|
||||||
|
created: int
|
||||||
|
model: str
|
||||||
|
object: Literal["text_completion"] = "text_completion"
|
||||||
|
|
||||||
|
|
||||||
class ModelStore(Protocol):
|
class ModelStore(Protocol):
|
||||||
async def get_model(self, identifier: str) -> Model: ...
|
async def get_model(self, identifier: str) -> Model: ...
|
||||||
|
|
||||||
|
@ -421,6 +816,16 @@ class EmbeddingTaskType(Enum):
|
||||||
document = "document"
|
document = "document"
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchCompletionResponse(BaseModel):
|
||||||
|
batch: List[CompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class BatchChatCompletionResponse(BaseModel):
|
||||||
|
batch: List[ChatCompletionResponse]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Inference(Protocol):
|
class Inference(Protocol):
|
||||||
|
@ -456,6 +861,17 @@ class Inference(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/batch-completion", method="POST", experimental=True)
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchCompletionResponse:
|
||||||
|
raise NotImplementedError("Batch completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/chat-completion", method="POST")
|
@webmethod(route="/inference/chat-completion", method="POST")
|
||||||
async def chat_completion(
|
async def chat_completion(
|
||||||
self,
|
self,
|
||||||
|
@ -496,6 +912,19 @@ class Inference(Protocol):
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/inference/batch-chat-completion", method="POST", experimental=True)
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchChatCompletionResponse:
|
||||||
|
raise NotImplementedError("Batch chat completion is not implemented")
|
||||||
|
|
||||||
@webmethod(route="/inference/embeddings", method="POST")
|
@webmethod(route="/inference/embeddings", method="POST")
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
|
@ -515,3 +944,105 @@ class Inference(Protocol):
|
||||||
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
:returns: An array of embeddings, one for each content. Each embedding is a list of floats. The dimensionality of the embedding is model-specific; you can check model metadata using /models/{model_id}
|
||||||
"""
|
"""
|
||||||
...
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/completions", method="POST")
|
||||||
|
async def openai_completion(
|
||||||
|
self,
|
||||||
|
# Standard OpenAI completion parameters
|
||||||
|
model: str,
|
||||||
|
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
echo: Optional[bool] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
# vLLM-specific parameters
|
||||||
|
guided_choice: Optional[List[str]] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None,
|
||||||
|
) -> OpenAICompletion:
|
||||||
|
"""Generate an OpenAI-compatible completion for the given prompt using the specified model.
|
||||||
|
|
||||||
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param prompt: The prompt to generate a completion for
|
||||||
|
:param best_of: (Optional) The number of completions to generate
|
||||||
|
:param echo: (Optional) Whether to echo the prompt
|
||||||
|
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||||
|
:param logit_bias: (Optional) The logit bias to use
|
||||||
|
:param logprobs: (Optional) The log probabilities to use
|
||||||
|
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||||
|
:param n: (Optional) The number of completions to generate
|
||||||
|
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||||
|
:param seed: (Optional) The seed to use
|
||||||
|
:param stop: (Optional) The stop tokens to use
|
||||||
|
:param stream: (Optional) Whether to stream the response
|
||||||
|
:param stream_options: (Optional) The stream options to use
|
||||||
|
:param temperature: (Optional) The temperature to use
|
||||||
|
:param top_p: (Optional) The top p to use
|
||||||
|
:param user: (Optional) The user to use
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/chat/completions", method="POST")
|
||||||
|
async def openai_chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[OpenAIMessageParam],
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
|
functions: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
max_completion_tokens: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
parallel_tool_calls: Optional[bool] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
|
"""Generate an OpenAI-compatible chat completion for the given messages using the specified model.
|
||||||
|
|
||||||
|
:param model: The identifier of the model to use. The model must be registered with Llama Stack and available via the /models endpoint.
|
||||||
|
:param messages: List of messages in the conversation
|
||||||
|
:param frequency_penalty: (Optional) The penalty for repeated tokens
|
||||||
|
:param function_call: (Optional) The function call to use
|
||||||
|
:param functions: (Optional) List of functions to use
|
||||||
|
:param logit_bias: (Optional) The logit bias to use
|
||||||
|
:param logprobs: (Optional) The log probabilities to use
|
||||||
|
:param max_completion_tokens: (Optional) The maximum number of tokens to generate
|
||||||
|
:param max_tokens: (Optional) The maximum number of tokens to generate
|
||||||
|
:param n: (Optional) The number of completions to generate
|
||||||
|
:param parallel_tool_calls: (Optional) Whether to parallelize tool calls
|
||||||
|
:param presence_penalty: (Optional) The penalty for repeated tokens
|
||||||
|
:param response_format: (Optional) The response format to use
|
||||||
|
:param seed: (Optional) The seed to use
|
||||||
|
:param stop: (Optional) The stop tokens to use
|
||||||
|
:param stream: (Optional) Whether to stream the response
|
||||||
|
:param stream_options: (Optional) The stream options to use
|
||||||
|
:param temperature: (Optional) The temperature to use
|
||||||
|
:param tool_choice: (Optional) The tool choice to use
|
||||||
|
:param tools: (Optional) The tools to use
|
||||||
|
:param top_logprobs: (Optional) The top log probabilities to use
|
||||||
|
:param top_p: (Optional) The top p to use
|
||||||
|
:param user: (Optional) The user to use
|
||||||
|
"""
|
||||||
|
...
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import List, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -20,8 +21,7 @@ class RouteInfo(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class HealthInfo(BaseModel):
|
class HealthInfo(BaseModel):
|
||||||
status: str
|
status: HealthStatus
|
||||||
# TODO: add a provider level status
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -56,12 +56,35 @@ class ListModelsResponse(BaseModel):
|
||||||
data: List[Model]
|
data: List[Model]
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class OpenAIModel(BaseModel):
|
||||||
|
"""A model from OpenAI.
|
||||||
|
|
||||||
|
:id: The ID of the model
|
||||||
|
:object: The object type, which will be "model"
|
||||||
|
:created: The Unix timestamp in seconds when the model was created
|
||||||
|
:owned_by: The owner of the model
|
||||||
|
"""
|
||||||
|
|
||||||
|
id: str
|
||||||
|
object: Literal["model"] = "model"
|
||||||
|
created: int
|
||||||
|
owned_by: str
|
||||||
|
|
||||||
|
|
||||||
|
class OpenAIListModelsResponse(BaseModel):
|
||||||
|
data: List[OpenAIModel]
|
||||||
|
|
||||||
|
|
||||||
@runtime_checkable
|
@runtime_checkable
|
||||||
@trace_protocol
|
@trace_protocol
|
||||||
class Models(Protocol):
|
class Models(Protocol):
|
||||||
@webmethod(route="/models", method="GET")
|
@webmethod(route="/models", method="GET")
|
||||||
async def list_models(self) -> ListModelsResponse: ...
|
async def list_models(self) -> ListModelsResponse: ...
|
||||||
|
|
||||||
|
@webmethod(route="/openai/v1/models", method="GET")
|
||||||
|
async def openai_list_models(self) -> OpenAIListModelsResponse: ...
|
||||||
|
|
||||||
@webmethod(route="/models/{model_id:path}", method="GET")
|
@webmethod(route="/models/{model_id:path}", method="GET")
|
||||||
async def get_model(
|
async def get_model(
|
||||||
self,
|
self,
|
||||||
|
|
|
@ -60,11 +60,11 @@ class EfficiencyConfig(BaseModel):
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class TrainingConfig(BaseModel):
|
class TrainingConfig(BaseModel):
|
||||||
n_epochs: int
|
n_epochs: int
|
||||||
max_steps_per_epoch: int
|
max_steps_per_epoch: int = 1
|
||||||
gradient_accumulation_steps: int
|
gradient_accumulation_steps: int = 1
|
||||||
max_validation_steps: int
|
max_validation_steps: Optional[int] = 1
|
||||||
data_config: DataConfig
|
data_config: Optional[DataConfig] = None
|
||||||
optimizer_config: OptimizerConfig
|
optimizer_config: Optional[OptimizerConfig] = None
|
||||||
efficiency_config: Optional[EfficiencyConfig] = None
|
efficiency_config: Optional[EfficiencyConfig] = None
|
||||||
dtype: Optional[str] = "bf16"
|
dtype: Optional[str] = "bf16"
|
||||||
|
|
||||||
|
@ -177,9 +177,9 @@ class PostTraining(Protocol):
|
||||||
training_config: TrainingConfig,
|
training_config: TrainingConfig,
|
||||||
hyperparam_search_config: Dict[str, Any],
|
hyperparam_search_config: Dict[str, Any],
|
||||||
logger_config: Dict[str, Any],
|
logger_config: Dict[str, Any],
|
||||||
model: str = Field(
|
model: Optional[str] = Field(
|
||||||
default="Llama3.2-3B-Instruct",
|
default=None,
|
||||||
description="Model descriptor from `llama model list`",
|
description="Model descriptor for training if not in provider config`",
|
||||||
),
|
),
|
||||||
checkpoint_dir: Optional[str] = None,
|
checkpoint_dir: Optional[str] = None,
|
||||||
algorithm_config: Optional[AlgorithmConfig] = None,
|
algorithm_config: Optional[AlgorithmConfig] = None,
|
||||||
|
|
|
@ -8,6 +8,7 @@ from typing import Any, Dict, List, Protocol, runtime_checkable
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
|
from llama_stack.providers.datatypes import HealthResponse
|
||||||
from llama_stack.schema_utils import json_schema_type, webmethod
|
from llama_stack.schema_utils import json_schema_type, webmethod
|
||||||
|
|
||||||
|
|
||||||
|
@ -17,6 +18,7 @@ class ProviderInfo(BaseModel):
|
||||||
provider_id: str
|
provider_id: str
|
||||||
provider_type: str
|
provider_type: str
|
||||||
config: Dict[str, Any]
|
config: Dict[str, Any]
|
||||||
|
health: HealthResponse
|
||||||
|
|
||||||
|
|
||||||
class ListProvidersResponse(BaseModel):
|
class ListProvidersResponse(BaseModel):
|
||||||
|
|
|
@ -29,8 +29,8 @@ from rich.progress import (
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.models.llama.datatypes import Model
|
|
||||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||||
|
from llama_stack.models.llama.sku_types import Model
|
||||||
|
|
||||||
|
|
||||||
class Download(Subcommand):
|
class Download(Subcommand):
|
||||||
|
@ -162,6 +162,10 @@ class ParallelDownloader:
|
||||||
raise last_exception
|
raise last_exception
|
||||||
|
|
||||||
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
async def get_file_info(self, client: httpx.AsyncClient, task: DownloadTask) -> None:
|
||||||
|
if task.total_size > 0:
|
||||||
|
self.progress.update(task.task_id, total=task.total_size)
|
||||||
|
return
|
||||||
|
|
||||||
async def _get_info():
|
async def _get_info():
|
||||||
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
|
response = await client.head(task.url, headers={"Accept-Encoding": "identity"}, **self.client_options)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
|
@ -282,7 +286,7 @@ class ParallelDownloader:
|
||||||
if not tasks:
|
if not tasks:
|
||||||
raise ValueError("No download tasks provided")
|
raise ValueError("No download tasks provided")
|
||||||
|
|
||||||
if not self.has_disk_space(tasks):
|
if not os.environ.get("LLAMA_DOWNLOAD_NO_SPACE_CHECK") and not self.has_disk_space(tasks):
|
||||||
raise DownloadError("Insufficient disk space for downloads")
|
raise DownloadError("Insufficient disk space for downloads")
|
||||||
|
|
||||||
failed_tasks = []
|
failed_tasks = []
|
||||||
|
|
|
@ -63,17 +63,6 @@ class ModelDescribe(Subcommand):
|
||||||
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
("Model params.json", json.dumps(model.arch_args, indent=4)),
|
||||||
]
|
]
|
||||||
|
|
||||||
if model.recommended_sampling_params is not None:
|
|
||||||
sampling_params = model.recommended_sampling_params.model_dump()
|
|
||||||
for k in ("max_tokens", "repetition_penalty"):
|
|
||||||
del sampling_params[k]
|
|
||||||
rows.append(
|
|
||||||
(
|
|
||||||
"Recommended sampling params",
|
|
||||||
json.dumps(sampling_params, indent=4),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
print_table(
|
print_table(
|
||||||
rows,
|
rows,
|
||||||
headers,
|
headers,
|
||||||
|
|
|
@ -11,7 +11,7 @@ from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.cli.subcommand import Subcommand
|
from llama_stack.cli.subcommand import Subcommand
|
||||||
from llama_stack.cli.table import print_table
|
from llama_stack.cli.table import print_table
|
||||||
from llama_stack.models.llama.datatypes import CoreModelId, ModelFamily, is_multimodal, model_family
|
from llama_stack.models.llama.sku_types import CoreModelId, ModelFamily, is_multimodal, model_family
|
||||||
|
|
||||||
ROOT_DIR = Path(__file__).parent.parent.parent
|
ROOT_DIR = Path(__file__).parent.parent.parent
|
||||||
|
|
||||||
|
|
|
@ -4,12 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel, ConfigDict, Field
|
from pydantic import BaseModel, ConfigDict, Field
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat, SamplingParams
|
|
||||||
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
from llama_stack.models.llama.sku_list import LlamaDownloadInfo
|
||||||
|
from llama_stack.models.llama.sku_types import CheckpointQuantizationFormat
|
||||||
|
|
||||||
|
|
||||||
class PromptGuardModel(BaseModel):
|
class PromptGuardModel(BaseModel):
|
||||||
|
@ -23,7 +23,6 @@ class PromptGuardModel(BaseModel):
|
||||||
is_instruct_model: bool = False
|
is_instruct_model: bool = False
|
||||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
||||||
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
arch_args: Dict[str, Any] = Field(default_factory=dict)
|
||||||
recommended_sampling_params: Optional[SamplingParams] = None
|
|
||||||
|
|
||||||
def descriptor(self) -> str:
|
def descriptor(self) -> str:
|
||||||
return self.model_id
|
return self.model_id
|
||||||
|
|
|
@ -89,6 +89,43 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
elif args.providers:
|
||||||
|
providers = dict()
|
||||||
|
for api_provider in args.providers.split(","):
|
||||||
|
if "=" not in api_provider:
|
||||||
|
cprint(
|
||||||
|
"Could not parse `--providers`. Please ensure the list is in the format api1=provider1,api2=provider2",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
api, provider = api_provider.split("=")
|
||||||
|
providers_for_api = get_provider_registry().get(Api(api), None)
|
||||||
|
if providers_for_api is None:
|
||||||
|
cprint(
|
||||||
|
f"{api} is not a valid API.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
if provider in providers_for_api:
|
||||||
|
providers.setdefault(api, []).append(provider)
|
||||||
|
else:
|
||||||
|
cprint(
|
||||||
|
f"{provider} is not a valid provider for the {api} API.",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
distribution_spec = DistributionSpec(
|
||||||
|
providers=providers,
|
||||||
|
description=",".join(args.providers),
|
||||||
|
)
|
||||||
|
if not args.image_type:
|
||||||
|
cprint(
|
||||||
|
f"Please specify a image-type (container | conda | venv) for {args.template}",
|
||||||
|
color="red",
|
||||||
|
)
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
build_config = BuildConfig(image_type=args.image_type, distribution_spec=distribution_spec)
|
||||||
elif not args.config and not args.template:
|
elif not args.config and not args.template:
|
||||||
name = prompt(
|
name = prompt(
|
||||||
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
"> Enter a name for your Llama Stack (e.g. my-local-stack): ",
|
||||||
|
|
|
@ -57,7 +57,7 @@ class StackBuild(Subcommand):
|
||||||
type=str,
|
type=str,
|
||||||
help=textwrap.dedent(
|
help=textwrap.dedent(
|
||||||
f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the conda or virtual environment to use for
|
f"""[for image-type={"|".join(e.value for e in ImageType)}] Name of the conda or virtual environment to use for
|
||||||
the build. If not specified, currently active Conda environment will be used if found.
|
the build. If not specified, currently active environment will be used if found.
|
||||||
"""
|
"""
|
||||||
),
|
),
|
||||||
default=None,
|
default=None,
|
||||||
|
@ -75,6 +75,12 @@ the build. If not specified, currently active Conda environment will be used if
|
||||||
default=False,
|
default=False,
|
||||||
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
help="Run the stack after building using the same image type, name, and other applicable arguments",
|
||||||
)
|
)
|
||||||
|
self.parser.add_argument(
|
||||||
|
"--providers",
|
||||||
|
type=str,
|
||||||
|
default=None,
|
||||||
|
help="Build a config for a list of providers and only those providers. This list is formatted like: api1=provider1,api2=provider2. Where there can be multiple providers per API.",
|
||||||
|
)
|
||||||
|
|
||||||
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
def _run_stack_build_command(self, args: argparse.Namespace) -> None:
|
||||||
# always keep implementation completely silo-ed away from CLI so CLI
|
# always keep implementation completely silo-ed away from CLI so CLI
|
||||||
|
|
|
@ -45,7 +45,7 @@ class StackRun(Subcommand):
|
||||||
"--image-name",
|
"--image-name",
|
||||||
type=str,
|
type=str,
|
||||||
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
default=os.environ.get("CONDA_DEFAULT_ENV"),
|
||||||
help="Name of the image to run. Defaults to the current conda environment",
|
help="Name of the image to run. Defaults to the current environment",
|
||||||
)
|
)
|
||||||
self.parser.add_argument(
|
self.parser.add_argument(
|
||||||
"--disable-ipv6",
|
"--disable-ipv6",
|
||||||
|
|
|
@ -312,6 +312,11 @@ a default SQLite store will be used.""",
|
||||||
description="Configuration for the HTTP(S) server",
|
description="Configuration for the HTTP(S) server",
|
||||||
)
|
)
|
||||||
|
|
||||||
|
external_providers_dir: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Path to directory containing external provider implementations. The providers code and dependencies must be installed on the system.",
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class BuildConfig(BaseModel):
|
class BuildConfig(BaseModel):
|
||||||
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
version: str = LLAMA_STACK_BUILD_CONFIG_VERSION
|
||||||
|
|
|
@ -4,12 +4,25 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import glob
|
||||||
import importlib
|
import importlib
|
||||||
from typing import Dict, List
|
import os
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import yaml
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, ProviderSpec
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import (
|
||||||
|
AdapterSpec,
|
||||||
|
Api,
|
||||||
|
InlineProviderSpec,
|
||||||
|
ProviderSpec,
|
||||||
|
remote_provider_spec,
|
||||||
|
)
|
||||||
|
|
||||||
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
|
||||||
def stack_apis() -> List[Api]:
|
def stack_apis() -> List[Api]:
|
||||||
|
@ -59,11 +72,116 @@ def providable_apis() -> List[Api]:
|
||||||
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
return [api for api in Api if api not in routing_table_apis and api != Api.inspect and api != Api.providers]
|
||||||
|
|
||||||
|
|
||||||
def get_provider_registry() -> Dict[Api, Dict[str, ProviderSpec]]:
|
def _load_remote_provider_spec(spec_data: Dict[str, Any], api: Api) -> ProviderSpec:
|
||||||
ret = {}
|
adapter = AdapterSpec(**spec_data["adapter"])
|
||||||
|
spec = remote_provider_spec(
|
||||||
|
api=api,
|
||||||
|
adapter=adapter,
|
||||||
|
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||||
|
)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_name: str) -> ProviderSpec:
|
||||||
|
spec = InlineProviderSpec(
|
||||||
|
api=api,
|
||||||
|
provider_type=f"inline::{provider_name}",
|
||||||
|
pip_packages=spec_data.get("pip_packages", []),
|
||||||
|
module=spec_data["module"],
|
||||||
|
config_class=spec_data["config_class"],
|
||||||
|
api_dependencies=[Api(dep) for dep in spec_data.get("api_dependencies", [])],
|
||||||
|
optional_api_dependencies=[Api(dep) for dep in spec_data.get("optional_api_dependencies", [])],
|
||||||
|
provider_data_validator=spec_data.get("provider_data_validator"),
|
||||||
|
container_image=spec_data.get("container_image"),
|
||||||
|
)
|
||||||
|
return spec
|
||||||
|
|
||||||
|
|
||||||
|
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
|
"""Get the provider registry, optionally including external providers.
|
||||||
|
|
||||||
|
This function loads both built-in providers and external providers from YAML files.
|
||||||
|
External providers are loaded from a directory structure like:
|
||||||
|
|
||||||
|
providers.d/
|
||||||
|
remote/
|
||||||
|
inference/
|
||||||
|
custom_ollama.yaml
|
||||||
|
vllm.yaml
|
||||||
|
vector_io/
|
||||||
|
qdrant.yaml
|
||||||
|
safety/
|
||||||
|
llama-guard.yaml
|
||||||
|
inline/
|
||||||
|
inference/
|
||||||
|
custom_ollama.yaml
|
||||||
|
vllm.yaml
|
||||||
|
vector_io/
|
||||||
|
qdrant.yaml
|
||||||
|
safety/
|
||||||
|
llama-guard.yaml
|
||||||
|
|
||||||
|
Args:
|
||||||
|
config: Optional StackRunConfig containing the external providers directory path
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
A dictionary mapping APIs to their available providers
|
||||||
|
|
||||||
|
Raises:
|
||||||
|
FileNotFoundError: If the external providers directory doesn't exist
|
||||||
|
ValueError: If any provider spec is invalid
|
||||||
|
"""
|
||||||
|
|
||||||
|
ret: Dict[Api, Dict[str, ProviderSpec]] = {}
|
||||||
for api in providable_apis():
|
for api in providable_apis():
|
||||||
name = api.name.lower()
|
name = api.name.lower()
|
||||||
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
logger.debug(f"Importing module {name}")
|
||||||
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
try:
|
||||||
|
module = importlib.import_module(f"llama_stack.providers.registry.{name}")
|
||||||
|
ret[api] = {a.provider_type: a for a in module.available_providers()}
|
||||||
|
except ImportError as e:
|
||||||
|
logger.warning(f"Failed to import module {name}: {e}")
|
||||||
|
|
||||||
|
if config and config.external_providers_dir:
|
||||||
|
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||||
|
if not os.path.exists(external_providers_dir):
|
||||||
|
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||||
|
logger.info(f"Loading external providers from {external_providers_dir}")
|
||||||
|
|
||||||
|
for api in providable_apis():
|
||||||
|
api_name = api.name.lower()
|
||||||
|
|
||||||
|
# Process both remote and inline providers
|
||||||
|
for provider_type in ["remote", "inline"]:
|
||||||
|
api_dir = os.path.join(external_providers_dir, provider_type, api_name)
|
||||||
|
if not os.path.exists(api_dir):
|
||||||
|
logger.debug(f"No {provider_type} provider directory found for {api_name}")
|
||||||
|
continue
|
||||||
|
|
||||||
|
# Look for provider spec files in the API directory
|
||||||
|
for spec_path in glob.glob(os.path.join(api_dir, "*.yaml")):
|
||||||
|
provider_name = os.path.splitext(os.path.basename(spec_path))[0]
|
||||||
|
logger.info(f"Loading {provider_type} provider spec from {spec_path}")
|
||||||
|
|
||||||
|
try:
|
||||||
|
with open(spec_path) as f:
|
||||||
|
spec_data = yaml.safe_load(f)
|
||||||
|
|
||||||
|
if provider_type == "remote":
|
||||||
|
spec = _load_remote_provider_spec(spec_data, api)
|
||||||
|
provider_type_key = f"remote::{provider_name}"
|
||||||
|
else:
|
||||||
|
spec = _load_inline_provider_spec(spec_data, api, provider_name)
|
||||||
|
provider_type_key = f"inline::{provider_name}"
|
||||||
|
|
||||||
|
logger.info(f"Loaded {provider_type} provider spec for {provider_type_key} from {spec_path}")
|
||||||
|
if provider_type_key in ret[api]:
|
||||||
|
logger.warning(f"Overriding already registered provider {provider_type_key} for {api.name}")
|
||||||
|
ret[api][provider_type_key] = spec
|
||||||
|
except yaml.YAMLError as yaml_err:
|
||||||
|
logger.error(f"Failed to parse YAML file {spec_path}: {yaml_err}")
|
||||||
|
raise yaml_err
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Failed to load provider spec from {spec_path}: {e}")
|
||||||
|
raise e
|
||||||
return ret
|
return ret
|
||||||
|
|
|
@ -17,6 +17,7 @@ from llama_stack.apis.inspect import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
from llama_stack.distribution.datatypes import StackRunConfig
|
||||||
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
from llama_stack.distribution.server.endpoints import get_all_api_endpoints
|
||||||
|
from llama_stack.providers.datatypes import HealthStatus
|
||||||
|
|
||||||
|
|
||||||
class DistributionInspectConfig(BaseModel):
|
class DistributionInspectConfig(BaseModel):
|
||||||
|
@ -58,7 +59,7 @@ class DistributionInspectImpl(Inspect):
|
||||||
return ListRoutesResponse(data=ret)
|
return ListRoutesResponse(data=ret)
|
||||||
|
|
||||||
async def health(self) -> HealthInfo:
|
async def health(self) -> HealthInfo:
|
||||||
return HealthInfo(status="OK")
|
return HealthInfo(status=HealthStatus.OK)
|
||||||
|
|
||||||
async def version(self) -> VersionInfo:
|
async def version(self) -> VersionInfo:
|
||||||
return VersionInfo(version=version("llama-stack"))
|
return VersionInfo(version=version("llama-stack"))
|
||||||
|
|
|
@ -43,9 +43,9 @@ from llama_stack.distribution.server.endpoints import (
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
get_stack_run_config_from_template,
|
get_stack_run_config_from_template,
|
||||||
redact_sensitive_fields,
|
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.distribution.utils.exec import in_notebook
|
from llama_stack.distribution.utils.exec import in_notebook
|
||||||
from llama_stack.providers.utils.telemetry.tracing import (
|
from llama_stack.providers.utils.telemetry.tracing import (
|
||||||
|
|
|
@ -4,14 +4,17 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
from llama_stack.apis.providers import ListProvidersResponse, ProviderInfo, Providers
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus
|
||||||
|
|
||||||
from .datatypes import StackRunConfig
|
from .datatypes import StackRunConfig
|
||||||
from .stack import redact_sensitive_fields
|
from .utils.config import redact_sensitive_fields
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
|
||||||
|
@ -41,19 +44,24 @@ class ProviderImpl(Providers):
|
||||||
async def list_providers(self) -> ListProvidersResponse:
|
async def list_providers(self) -> ListProvidersResponse:
|
||||||
run_config = self.config.run_config
|
run_config = self.config.run_config
|
||||||
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
safe_config = StackRunConfig(**redact_sensitive_fields(run_config.model_dump()))
|
||||||
|
providers_health = await self.get_providers_health()
|
||||||
ret = []
|
ret = []
|
||||||
for api, providers in safe_config.providers.items():
|
for api, providers in safe_config.providers.items():
|
||||||
ret.extend(
|
for p in providers:
|
||||||
[
|
ret.append(
|
||||||
ProviderInfo(
|
ProviderInfo(
|
||||||
api=api,
|
api=api,
|
||||||
provider_id=p.provider_id,
|
provider_id=p.provider_id,
|
||||||
provider_type=p.provider_type,
|
provider_type=p.provider_type,
|
||||||
config=p.config,
|
config=p.config,
|
||||||
|
health=providers_health.get(api, {}).get(
|
||||||
|
p.provider_id,
|
||||||
|
HealthResponse(
|
||||||
|
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||||
|
),
|
||||||
|
),
|
||||||
)
|
)
|
||||||
for p in providers
|
)
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
return ListProvidersResponse(data=ret)
|
return ListProvidersResponse(data=ret)
|
||||||
|
|
||||||
|
@ -64,3 +72,57 @@ class ProviderImpl(Providers):
|
||||||
return p
|
return p
|
||||||
|
|
||||||
raise ValueError(f"Provider {provider_id} not found")
|
raise ValueError(f"Provider {provider_id} not found")
|
||||||
|
|
||||||
|
async def get_providers_health(self) -> Dict[str, Dict[str, HealthResponse]]:
|
||||||
|
"""Get health status for all providers.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Dict[str, HealthResponse]]: A dictionary mapping API names to provider health statuses.
|
||||||
|
Each API maps to a dictionary of provider IDs to their health responses.
|
||||||
|
"""
|
||||||
|
providers_health: Dict[str, Dict[str, HealthResponse]] = {}
|
||||||
|
timeout = 1.0
|
||||||
|
|
||||||
|
async def check_provider_health(impl: Any) -> tuple[str, HealthResponse] | None:
|
||||||
|
# Skip special implementations (inspect/providers) that don't have provider specs
|
||||||
|
if not hasattr(impl, "__provider_spec__"):
|
||||||
|
return None
|
||||||
|
api_name = impl.__provider_spec__.api.name
|
||||||
|
if not hasattr(impl, "health"):
|
||||||
|
return (
|
||||||
|
api_name,
|
||||||
|
HealthResponse(
|
||||||
|
status=HealthStatus.NOT_IMPLEMENTED, message="Provider does not implement health check"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
|
||||||
|
try:
|
||||||
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||||
|
return api_name, health
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
return (
|
||||||
|
api_name,
|
||||||
|
HealthResponse(
|
||||||
|
status=HealthStatus.ERROR, message=f"Health check timed out after {timeout} seconds"
|
||||||
|
),
|
||||||
|
)
|
||||||
|
except Exception as e:
|
||||||
|
return (
|
||||||
|
api_name,
|
||||||
|
HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"),
|
||||||
|
)
|
||||||
|
|
||||||
|
# Create tasks for all providers
|
||||||
|
tasks = [check_provider_health(impl) for impl in self.deps.values()]
|
||||||
|
|
||||||
|
# Wait for all health checks to complete
|
||||||
|
results = await asyncio.gather(*tasks)
|
||||||
|
|
||||||
|
# Organize results by API and provider ID
|
||||||
|
for result in results:
|
||||||
|
if result is None: # Skip special implementations
|
||||||
|
continue
|
||||||
|
api_name, health_response = result
|
||||||
|
providers_health[api_name] = health_response
|
||||||
|
|
||||||
|
return providers_health
|
||||||
|
|
|
@ -41,7 +41,6 @@ from llama_stack.providers.datatypes import (
|
||||||
Api,
|
Api,
|
||||||
BenchmarksProtocolPrivate,
|
BenchmarksProtocolPrivate,
|
||||||
DatasetsProtocolPrivate,
|
DatasetsProtocolPrivate,
|
||||||
InlineProviderSpec,
|
|
||||||
ModelsProtocolPrivate,
|
ModelsProtocolPrivate,
|
||||||
ProviderSpec,
|
ProviderSpec,
|
||||||
RemoteProviderConfig,
|
RemoteProviderConfig,
|
||||||
|
@ -230,50 +229,9 @@ def sort_providers_by_deps(
|
||||||
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
{k: list(v.values()) for k, v in providers_with_specs.items()}
|
||||||
)
|
)
|
||||||
|
|
||||||
# Append built-in "inspect" provider
|
|
||||||
apis = [x[1].spec.api for x in sorted_providers]
|
|
||||||
sorted_providers.append(
|
|
||||||
(
|
|
||||||
"inspect",
|
|
||||||
ProviderWithSpec(
|
|
||||||
provider_id="__builtin__",
|
|
||||||
provider_type="__builtin__",
|
|
||||||
config={"run_config": run_config.model_dump()},
|
|
||||||
spec=InlineProviderSpec(
|
|
||||||
api=Api.inspect,
|
|
||||||
provider_type="__builtin__",
|
|
||||||
config_class="llama_stack.distribution.inspect.DistributionInspectConfig",
|
|
||||||
module="llama_stack.distribution.inspect",
|
|
||||||
api_dependencies=apis,
|
|
||||||
deps__=[x.value for x in apis],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
sorted_providers.append(
|
|
||||||
(
|
|
||||||
"providers",
|
|
||||||
ProviderWithSpec(
|
|
||||||
provider_id="__builtin__",
|
|
||||||
provider_type="__builtin__",
|
|
||||||
config={"run_config": run_config.model_dump()},
|
|
||||||
spec=InlineProviderSpec(
|
|
||||||
api=Api.providers,
|
|
||||||
provider_type="__builtin__",
|
|
||||||
config_class="llama_stack.distribution.providers.ProviderImplConfig",
|
|
||||||
module="llama_stack.distribution.providers",
|
|
||||||
api_dependencies=apis,
|
|
||||||
deps__=[x.value for x in apis],
|
|
||||||
),
|
|
||||||
),
|
|
||||||
)
|
|
||||||
)
|
|
||||||
|
|
||||||
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
logger.debug(f"Resolved {len(sorted_providers)} providers")
|
||||||
for api_str, provider in sorted_providers:
|
for api_str, provider in sorted_providers:
|
||||||
logger.debug(f" {api_str} => {provider.provider_id}")
|
logger.debug(f" {api_str} => {provider.provider_id}")
|
||||||
logger.debug("")
|
|
||||||
return sorted_providers
|
return sorted_providers
|
||||||
|
|
||||||
|
|
||||||
|
@ -351,6 +309,7 @@ async def instantiate_provider(
|
||||||
if not hasattr(provider_spec, "module"):
|
if not hasattr(provider_spec, "module"):
|
||||||
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
raise AttributeError(f"ProviderSpec of type {type(provider_spec)} does not have a 'module' attribute")
|
||||||
|
|
||||||
|
logger.debug(f"Instantiating provider {provider.provider_id} from {provider_spec.module}")
|
||||||
module = importlib.import_module(provider_spec.module)
|
module = importlib.import_module(provider_spec.module)
|
||||||
args = []
|
args = []
|
||||||
if isinstance(provider_spec, RemoteProviderSpec):
|
if isinstance(provider_spec, RemoteProviderSpec):
|
||||||
|
@ -399,6 +358,8 @@ def check_protocol_compliance(obj: Any, protocol: Any) -> None:
|
||||||
mro = type(obj).__mro__
|
mro = type(obj).__mro__
|
||||||
for name, value in inspect.getmembers(protocol):
|
for name, value in inspect.getmembers(protocol):
|
||||||
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
if inspect.isfunction(value) and hasattr(value, "__webmethod__"):
|
||||||
|
if value.__webmethod__.experimental:
|
||||||
|
continue
|
||||||
if not hasattr(obj, name):
|
if not hasattr(obj, name):
|
||||||
missing_methods.append((name, "missing"))
|
missing_methods.append((name, "missing"))
|
||||||
elif not callable(getattr(obj, name)):
|
elif not callable(getattr(obj, name)):
|
||||||
|
|
|
@ -4,6 +4,7 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
@ -17,6 +18,8 @@ from llama_stack.apis.datasetio import DatasetIO
|
||||||
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
from llama_stack.apis.datasets import DatasetPurpose, DataSource
|
||||||
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
from llama_stack.apis.eval import BenchmarkConfig, Eval, EvaluateResponse, Job
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
|
BatchChatCompletionResponse,
|
||||||
|
BatchCompletionResponse,
|
||||||
ChatCompletionResponse,
|
ChatCompletionResponse,
|
||||||
ChatCompletionResponseEventType,
|
ChatCompletionResponseEventType,
|
||||||
ChatCompletionResponseStreamChunk,
|
ChatCompletionResponseStreamChunk,
|
||||||
|
@ -35,6 +38,13 @@ from llama_stack.apis.inference import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAICompletion,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAIResponseFormatParam,
|
||||||
|
)
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.safety import RunShieldResponse, Safety
|
from llama_stack.apis.safety import RunShieldResponse, Safety
|
||||||
from llama_stack.apis.scoring import (
|
from llama_stack.apis.scoring import (
|
||||||
|
@ -57,7 +67,7 @@ from llama_stack.apis.vector_io import Chunk, QueryChunksResponse, VectorIO
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
from llama_stack.models.llama.llama3.chat_format import ChatFormat
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
from llama_stack.providers.datatypes import RoutingTable
|
from llama_stack.providers.datatypes import HealthResponse, HealthStatus, RoutingTable
|
||||||
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
from llama_stack.providers.utils.telemetry.tracing import get_current_span
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="core")
|
logger = get_logger(name=__name__, category="core")
|
||||||
|
@ -333,6 +343,30 @@ class InferenceRouter(Inference):
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def batch_chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages_batch: List[List[Message]],
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchChatCompletionResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"InferenceRouter.batch_chat_completion: {model_id=}, {len(messages_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
)
|
||||||
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
return await provider.batch_chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages_batch=messages_batch,
|
||||||
|
tools=tools,
|
||||||
|
tool_config=tool_config,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -397,6 +431,20 @@ class InferenceRouter(Inference):
|
||||||
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
response.metrics = metrics if response.metrics is None else response.metrics + metrics
|
||||||
return response
|
return response
|
||||||
|
|
||||||
|
async def batch_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content_batch: List[InterleavedContent],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> BatchCompletionResponse:
|
||||||
|
logger.debug(
|
||||||
|
f"InferenceRouter.batch_completion: {model_id=}, {len(content_batch)=}, {sampling_params=}, {response_format=}, {logprobs=}",
|
||||||
|
)
|
||||||
|
provider = self.routing_table.get_provider_impl(model_id)
|
||||||
|
return await provider.batch_completion(model_id, content_batch, sampling_params, response_format, logprobs)
|
||||||
|
|
||||||
async def embeddings(
|
async def embeddings(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -419,6 +467,149 @@ class InferenceRouter(Inference):
|
||||||
task_type=task_type,
|
task_type=task_type,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def openai_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
echo: Optional[bool] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
guided_choice: Optional[List[str]] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None,
|
||||||
|
) -> OpenAICompletion:
|
||||||
|
logger.debug(
|
||||||
|
f"InferenceRouter.openai_completion: {model=}, {stream=}, {prompt=}",
|
||||||
|
)
|
||||||
|
model_obj = await self.routing_table.get_model(model)
|
||||||
|
if model_obj is None:
|
||||||
|
raise ValueError(f"Model '{model}' not found")
|
||||||
|
if model_obj.model_type == ModelType.embedding:
|
||||||
|
raise ValueError(f"Model '{model}' is an embedding model and does not support completions")
|
||||||
|
|
||||||
|
params = dict(
|
||||||
|
model=model_obj.identifier,
|
||||||
|
prompt=prompt,
|
||||||
|
best_of=best_of,
|
||||||
|
echo=echo,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
guided_choice=guided_choice,
|
||||||
|
prompt_logprobs=prompt_logprobs,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
|
return await provider.openai_completion(**params)
|
||||||
|
|
||||||
|
async def openai_chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[OpenAIMessageParam],
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
|
functions: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
max_completion_tokens: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
parallel_tool_calls: Optional[bool] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
|
logger.debug(
|
||||||
|
f"InferenceRouter.openai_chat_completion: {model=}, {stream=}, {messages=}",
|
||||||
|
)
|
||||||
|
model_obj = await self.routing_table.get_model(model)
|
||||||
|
if model_obj is None:
|
||||||
|
raise ValueError(f"Model '{model}' not found")
|
||||||
|
if model_obj.model_type == ModelType.embedding:
|
||||||
|
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||||
|
|
||||||
|
params = dict(
|
||||||
|
model=model_obj.identifier,
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
function_call=function_call,
|
||||||
|
functions=functions,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
response_format=response_format,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tools=tools,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
|
provider = self.routing_table.get_provider_impl(model_obj.identifier)
|
||||||
|
return await provider.openai_chat_completion(**params)
|
||||||
|
|
||||||
|
async def health(self) -> Dict[str, HealthResponse]:
|
||||||
|
health_statuses = {}
|
||||||
|
timeout = 0.5
|
||||||
|
for provider_id, impl in self.routing_table.impls_by_provider_id.items():
|
||||||
|
try:
|
||||||
|
# check if the provider has a health method
|
||||||
|
if not hasattr(impl, "health"):
|
||||||
|
continue
|
||||||
|
health = await asyncio.wait_for(impl.health(), timeout=timeout)
|
||||||
|
health_statuses[provider_id] = health
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
health_statuses[provider_id] = HealthResponse(
|
||||||
|
status=HealthStatus.ERROR,
|
||||||
|
message=f"Health check timed out after {timeout} seconds",
|
||||||
|
)
|
||||||
|
except NotImplementedError:
|
||||||
|
health_statuses[provider_id] = HealthResponse(status=HealthStatus.NOT_IMPLEMENTED)
|
||||||
|
except Exception as e:
|
||||||
|
health_statuses[provider_id] = HealthResponse(
|
||||||
|
status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}"
|
||||||
|
)
|
||||||
|
return health_statuses
|
||||||
|
|
||||||
|
|
||||||
class SafetyRouter(Safety):
|
class SafetyRouter(Safety):
|
||||||
def __init__(
|
def __init__(
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional
|
||||||
|
|
||||||
|
@ -23,7 +24,7 @@ from llama_stack.apis.datasets import (
|
||||||
RowsDataSource,
|
RowsDataSource,
|
||||||
URIDataSource,
|
URIDataSource,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType
|
from llama_stack.apis.models import ListModelsResponse, Model, Models, ModelType, OpenAIListModelsResponse, OpenAIModel
|
||||||
from llama_stack.apis.resource import ResourceType
|
from llama_stack.apis.resource import ResourceType
|
||||||
from llama_stack.apis.scoring_functions import (
|
from llama_stack.apis.scoring_functions import (
|
||||||
ListScoringFunctionsResponse,
|
ListScoringFunctionsResponse,
|
||||||
|
@ -254,6 +255,19 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models):
|
||||||
async def list_models(self) -> ListModelsResponse:
|
async def list_models(self) -> ListModelsResponse:
|
||||||
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
return ListModelsResponse(data=await self.get_all_with_type("model"))
|
||||||
|
|
||||||
|
async def openai_list_models(self) -> OpenAIListModelsResponse:
|
||||||
|
models = await self.get_all_with_type("model")
|
||||||
|
openai_models = [
|
||||||
|
OpenAIModel(
|
||||||
|
id=model.identifier,
|
||||||
|
object="model",
|
||||||
|
created=int(time.time()),
|
||||||
|
owned_by="llama_stack",
|
||||||
|
)
|
||||||
|
for model in models
|
||||||
|
]
|
||||||
|
return OpenAIListModelsResponse(data=openai_models)
|
||||||
|
|
||||||
async def get_model(self, model_id: str) -> Model:
|
async def get_model(self, model_id: str) -> Model:
|
||||||
model = await self.get_object_by_identifier("model", model_id)
|
model = await self.get_object_by_identifier("model", model_id)
|
||||||
if model is None:
|
if model is None:
|
||||||
|
@ -608,8 +622,8 @@ class ToolGroupsRoutingTable(CommonRoutingTableImpl, ToolGroups):
|
||||||
tool_group = await self.get_tool_group(toolgroup_id)
|
tool_group = await self.get_tool_group(toolgroup_id)
|
||||||
if tool_group is None:
|
if tool_group is None:
|
||||||
raise ValueError(f"Tool group {toolgroup_id} not found")
|
raise ValueError(f"Tool group {toolgroup_id} not found")
|
||||||
tools = (await self.list_tools(toolgroup_id)).data
|
tools = await self.list_tools(toolgroup_id)
|
||||||
for tool in tools:
|
for tool in getattr(tools, "data", []):
|
||||||
await self.unregister_object(tool)
|
await self.unregister_object(tool)
|
||||||
await self.unregister_object(tool_group)
|
await self.unregister_object(tool_group)
|
||||||
|
|
||||||
|
|
|
@ -38,10 +38,10 @@ from llama_stack.distribution.server.endpoints import (
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.stack import (
|
from llama_stack.distribution.stack import (
|
||||||
construct_stack,
|
construct_stack,
|
||||||
redact_sensitive_fields,
|
|
||||||
replace_env_vars,
|
replace_env_vars,
|
||||||
validate_env_pair,
|
validate_env_pair,
|
||||||
)
|
)
|
||||||
|
from llama_stack.distribution.utils.config import redact_sensitive_fields
|
||||||
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
from llama_stack.distribution.utils.context import preserve_contexts_async_generator
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
@ -229,15 +229,30 @@ class TracingMiddleware:
|
||||||
def __init__(self, app, impls):
|
def __init__(self, app, impls):
|
||||||
self.app = app
|
self.app = app
|
||||||
self.impls = impls
|
self.impls = impls
|
||||||
|
# FastAPI built-in paths that should bypass custom routing
|
||||||
|
self.fastapi_paths = ("/docs", "/redoc", "/openapi.json", "/favicon.ico", "/static")
|
||||||
|
|
||||||
async def __call__(self, scope, receive, send):
|
async def __call__(self, scope, receive, send):
|
||||||
if scope.get("type") == "lifespan":
|
if scope.get("type") == "lifespan":
|
||||||
return await self.app(scope, receive, send)
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
path = scope.get("path", "")
|
path = scope.get("path", "")
|
||||||
|
|
||||||
|
# Check if the path is a FastAPI built-in path
|
||||||
|
if path.startswith(self.fastapi_paths):
|
||||||
|
# Pass through to FastAPI's built-in handlers
|
||||||
|
logger.debug(f"Bypassing custom routing for FastAPI built-in path: {path}")
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
if not hasattr(self, "endpoint_impls"):
|
if not hasattr(self, "endpoint_impls"):
|
||||||
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
self.endpoint_impls = initialize_endpoint_impls(self.impls)
|
||||||
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
|
||||||
|
try:
|
||||||
|
_, _, trace_path = find_matching_endpoint(scope.get("method", "GET"), path, self.endpoint_impls)
|
||||||
|
except ValueError:
|
||||||
|
# If no matching endpoint is found, pass through to FastAPI
|
||||||
|
logger.debug(f"No matching endpoint found for path: {path}, falling back to FastAPI")
|
||||||
|
return await self.app(scope, receive, send)
|
||||||
|
|
||||||
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
trace_context = await start_trace(trace_path, {"__location__": "server", "raw_path": path})
|
||||||
|
|
||||||
|
@ -388,7 +403,12 @@ def main(args: Optional[argparse.Namespace] = None):
|
||||||
safe_config = redact_sensitive_fields(config.model_dump())
|
safe_config = redact_sensitive_fields(config.model_dump())
|
||||||
logger.info(yaml.dump(safe_config, indent=2))
|
logger.info(yaml.dump(safe_config, indent=2))
|
||||||
|
|
||||||
app = FastAPI(lifespan=lifespan)
|
app = FastAPI(
|
||||||
|
lifespan=lifespan,
|
||||||
|
docs_url="/docs",
|
||||||
|
redoc_url="/redoc",
|
||||||
|
openapi_url="/openapi.json",
|
||||||
|
)
|
||||||
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"):
|
||||||
app.add_middleware(ClientVersionMiddleware)
|
app.add_middleware(ClientVersionMiddleware)
|
||||||
|
|
||||||
|
|
|
@ -35,6 +35,8 @@ from llama_stack.apis.vector_dbs import VectorDBs
|
||||||
from llama_stack.apis.vector_io import VectorIO
|
from llama_stack.apis.vector_io import VectorIO
|
||||||
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
from llama_stack.distribution.datatypes import Provider, StackRunConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
|
from llama_stack.distribution.inspect import DistributionInspectConfig, DistributionInspectImpl
|
||||||
|
from llama_stack.distribution.providers import ProviderImpl, ProviderImplConfig
|
||||||
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
from llama_stack.distribution.resolver import ProviderRegistry, resolve_impls
|
||||||
from llama_stack.distribution.store.registry import create_dist_registry
|
from llama_stack.distribution.store.registry import create_dist_registry
|
||||||
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
from llama_stack.distribution.utils.dynamic import instantiate_class_type
|
||||||
|
@ -96,7 +98,10 @@ async def register_resources(run_config: StackRunConfig, impls: Dict[Api, Any]):
|
||||||
|
|
||||||
method = getattr(impls[api], register_method)
|
method = getattr(impls[api], register_method)
|
||||||
for obj in objects:
|
for obj in objects:
|
||||||
await method(**obj.model_dump())
|
# we want to maintain the type information in arguments to method.
|
||||||
|
# instead of method(**obj.model_dump()), which may convert a typed attr to a dict,
|
||||||
|
# we use model_dump() to find all the attrs and then getattr to get the still typed value.
|
||||||
|
await method(**{k: getattr(obj, k) for k in obj.model_dump().keys()})
|
||||||
|
|
||||||
method = getattr(impls[api], list_method)
|
method = getattr(impls[api], list_method)
|
||||||
response = await method()
|
response = await method()
|
||||||
|
@ -116,26 +121,6 @@ class EnvVarError(Exception):
|
||||||
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
super().__init__(f"Environment variable '{var_name}' not set or empty{f' at {path}' if path else ''}")
|
||||||
|
|
||||||
|
|
||||||
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
"""Redact sensitive information from config before printing."""
|
|
||||||
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
|
||||||
|
|
||||||
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
|
||||||
result = {}
|
|
||||||
for k, v in d.items():
|
|
||||||
if isinstance(v, dict):
|
|
||||||
result[k] = _redact_dict(v)
|
|
||||||
elif isinstance(v, list):
|
|
||||||
result[k] = [_redact_dict(i) if isinstance(i, dict) else i for i in v]
|
|
||||||
elif any(pattern in k.lower() for pattern in sensitive_patterns):
|
|
||||||
result[k] = "********"
|
|
||||||
else:
|
|
||||||
result[k] = v
|
|
||||||
return result
|
|
||||||
|
|
||||||
return _redact_dict(data)
|
|
||||||
|
|
||||||
|
|
||||||
def replace_env_vars(config: Any, path: str = "") -> Any:
|
def replace_env_vars(config: Any, path: str = "") -> Any:
|
||||||
if isinstance(config, dict):
|
if isinstance(config, dict):
|
||||||
result = {}
|
result = {}
|
||||||
|
@ -212,13 +197,37 @@ def validate_env_pair(env_pair: str) -> tuple[str, str]:
|
||||||
) from e
|
) from e
|
||||||
|
|
||||||
|
|
||||||
|
def add_internal_implementations(impls: Dict[Api, Any], run_config: StackRunConfig) -> None:
|
||||||
|
"""Add internal implementations (inspect and providers) to the implementations dictionary.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
impls: Dictionary of API implementations
|
||||||
|
run_config: Stack run configuration
|
||||||
|
"""
|
||||||
|
inspect_impl = DistributionInspectImpl(
|
||||||
|
DistributionInspectConfig(run_config=run_config),
|
||||||
|
deps=impls,
|
||||||
|
)
|
||||||
|
impls[Api.inspect] = inspect_impl
|
||||||
|
|
||||||
|
providers_impl = ProviderImpl(
|
||||||
|
ProviderImplConfig(run_config=run_config),
|
||||||
|
deps=impls,
|
||||||
|
)
|
||||||
|
impls[Api.providers] = providers_impl
|
||||||
|
|
||||||
|
|
||||||
# Produces a stack of providers for the given run config. Not all APIs may be
|
# Produces a stack of providers for the given run config. Not all APIs may be
|
||||||
# asked for in the run config.
|
# asked for in the run config.
|
||||||
async def construct_stack(
|
async def construct_stack(
|
||||||
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
run_config: StackRunConfig, provider_registry: Optional[ProviderRegistry] = None
|
||||||
) -> Dict[Api, Any]:
|
) -> Dict[Api, Any]:
|
||||||
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
dist_registry, _ = await create_dist_registry(run_config.metadata_store, run_config.image_name)
|
||||||
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(), dist_registry)
|
impls = await resolve_impls(run_config, provider_registry or get_provider_registry(run_config), dist_registry)
|
||||||
|
|
||||||
|
# Add internal implementations after all other providers are resolved
|
||||||
|
add_internal_implementations(impls, run_config)
|
||||||
|
|
||||||
await register_resources(run_config, impls)
|
await register_resources(run_config, impls)
|
||||||
return impls
|
return impls
|
||||||
|
|
||||||
|
|
|
@ -18,6 +18,7 @@ VIRTUAL_ENV=${VIRTUAL_ENV:-}
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
|
GREEN='\033[0;32m'
|
||||||
NC='\033[0m' # No Color
|
NC='\033[0m' # No Color
|
||||||
|
|
||||||
error_handler() {
|
error_handler() {
|
||||||
|
@ -73,7 +74,7 @@ done
|
||||||
PYTHON_BINARY="python"
|
PYTHON_BINARY="python"
|
||||||
case "$env_type" in
|
case "$env_type" in
|
||||||
"venv")
|
"venv")
|
||||||
if [ -n "$VIRTUAL_ENV" && "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
|
if [ -n "$VIRTUAL_ENV" ] && [ "$VIRTUAL_ENV" == "$env_path_or_name" ]; then
|
||||||
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
|
echo -e "${GREEN}Virtual environment already activated${NC}" >&2
|
||||||
else
|
else
|
||||||
# Activate virtual environment
|
# Activate virtual environment
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
# More info on playground configuration can be found here:
|
# More info on playground configuration can be found here:
|
||||||
# https://llama-stack.readthedocs.io/en/latest/playground
|
# https://llama-stack.readthedocs.io/en/latest/playground
|
||||||
|
|
||||||
FROM python:3.9-slim
|
FROM python:3.12-slim
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
COPY . /app/
|
COPY . /app/
|
||||||
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
RUN /usr/local/bin/python -m pip install --upgrade pip && \
|
||||||
|
|
|
@ -36,9 +36,7 @@ llama-stack-client benchmarks register \
|
||||||
3. Start Streamlit UI
|
3. Start Streamlit UI
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
cd llama_stack/distribution/ui
|
uv run --with ".[ui]" streamlit run llama_stack/distribution/ui/app.py
|
||||||
pip install -r requirements.txt
|
|
||||||
streamlit run app.py
|
|
||||||
```
|
```
|
||||||
|
|
||||||
## Environment Variables
|
## Environment Variables
|
||||||
|
|
|
@ -24,6 +24,7 @@ def main():
|
||||||
# Playground pages
|
# Playground pages
|
||||||
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
chat_page = st.Page("page/playground/chat.py", title="Chat", icon="💬", default=True)
|
||||||
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
rag_page = st.Page("page/playground/rag.py", title="RAG", icon="💬", default=False)
|
||||||
|
tool_page = st.Page("page/playground/tools.py", title="Tools", icon="🛠", default=False)
|
||||||
|
|
||||||
# Distribution pages
|
# Distribution pages
|
||||||
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
resources_page = st.Page("page/distribution/resources.py", title="Resources", icon="🔍", default=False)
|
||||||
|
@ -39,6 +40,7 @@ def main():
|
||||||
"Playground": [
|
"Playground": [
|
||||||
chat_page,
|
chat_page,
|
||||||
rag_page,
|
rag_page,
|
||||||
|
tool_page,
|
||||||
application_evaluation_page,
|
application_evaluation_page,
|
||||||
native_evaluation_page,
|
native_evaluation_page,
|
||||||
],
|
],
|
||||||
|
|
|
@ -19,6 +19,7 @@ class LlamaStackApi:
|
||||||
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
|
"together_api_key": os.environ.get("TOGETHER_API_KEY", ""),
|
||||||
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
|
"sambanova_api_key": os.environ.get("SAMBANOVA_API_KEY", ""),
|
||||||
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
|
"openai_api_key": os.environ.get("OPENAI_API_KEY", ""),
|
||||||
|
"tavily_search_api_key": os.environ.get("TAVILY_SEARCH_API_KEY", ""),
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -4,9 +4,12 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
from llama_stack_client import Agent, AgentEventLogger, RAGDocument
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import ToolCallDelta
|
||||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||||
|
|
||||||
|
@ -14,9 +17,16 @@ from llama_stack.distribution.ui.modules.utils import data_url_from_file
|
||||||
def rag_chat_page():
|
def rag_chat_page():
|
||||||
st.title("🦙 RAG")
|
st.title("🦙 RAG")
|
||||||
|
|
||||||
|
def reset_agent_and_chat():
|
||||||
|
st.session_state.clear()
|
||||||
|
st.cache_resource.clear()
|
||||||
|
|
||||||
|
def should_disable_input():
|
||||||
|
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
# File/Directory Upload Section
|
# File/Directory Upload Section
|
||||||
st.subheader("Upload Documents")
|
st.subheader("Upload Documents", divider=True)
|
||||||
uploaded_files = st.file_uploader(
|
uploaded_files = st.file_uploader(
|
||||||
"Upload file(s) or directory",
|
"Upload file(s) or directory",
|
||||||
accept_multiple_files=True,
|
accept_multiple_files=True,
|
||||||
|
@ -27,11 +37,11 @@ def rag_chat_page():
|
||||||
st.success(f"Successfully uploaded {len(uploaded_files)} files")
|
st.success(f"Successfully uploaded {len(uploaded_files)} files")
|
||||||
# Add memory bank name input field
|
# Add memory bank name input field
|
||||||
vector_db_name = st.text_input(
|
vector_db_name = st.text_input(
|
||||||
"Vector Database Name",
|
"Document Collection Name",
|
||||||
value="rag_vector_db",
|
value="rag_vector_db",
|
||||||
help="Enter a unique identifier for this vector database",
|
help="Enter a unique identifier for this document collection",
|
||||||
)
|
)
|
||||||
if st.button("Create Vector Database"):
|
if st.button("Create Document Collection"):
|
||||||
documents = [
|
documents = [
|
||||||
RAGDocument(
|
RAGDocument(
|
||||||
document_id=uploaded_file.name,
|
document_id=uploaded_file.name,
|
||||||
|
@ -62,26 +72,45 @@ def rag_chat_page():
|
||||||
)
|
)
|
||||||
st.success("Vector database created successfully!")
|
st.success("Vector database created successfully!")
|
||||||
|
|
||||||
st.subheader("Configure Agent")
|
st.subheader("RAG Parameters", divider=True)
|
||||||
|
|
||||||
|
rag_mode = st.radio(
|
||||||
|
"RAG mode",
|
||||||
|
["Direct", "Agent-based"],
|
||||||
|
captions=[
|
||||||
|
"RAG is performed by directly retrieving the information and augmenting the user query",
|
||||||
|
"RAG is performed by an agent activating a dedicated knowledge search tool.",
|
||||||
|
],
|
||||||
|
on_change=reset_agent_and_chat,
|
||||||
|
disabled=should_disable_input(),
|
||||||
|
)
|
||||||
|
|
||||||
# select memory banks
|
# select memory banks
|
||||||
vector_dbs = llama_stack_api.client.vector_dbs.list()
|
vector_dbs = llama_stack_api.client.vector_dbs.list()
|
||||||
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
||||||
selected_vector_dbs = st.multiselect(
|
selected_vector_dbs = st.multiselect(
|
||||||
"Select Vector Databases",
|
label="Select Document Collections to use in RAG queries",
|
||||||
vector_dbs,
|
options=vector_dbs,
|
||||||
|
on_change=reset_agent_and_chat,
|
||||||
|
disabled=should_disable_input(),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
st.subheader("Inference Parameters", divider=True)
|
||||||
available_models = llama_stack_api.client.models.list()
|
available_models = llama_stack_api.client.models.list()
|
||||||
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
available_models = [model.identifier for model in available_models if model.model_type == "llm"]
|
||||||
selected_model = st.selectbox(
|
selected_model = st.selectbox(
|
||||||
"Choose a model",
|
label="Choose a model",
|
||||||
available_models,
|
options=available_models,
|
||||||
index=0,
|
index=0,
|
||||||
|
on_change=reset_agent_and_chat,
|
||||||
|
disabled=should_disable_input(),
|
||||||
)
|
)
|
||||||
system_prompt = st.text_area(
|
system_prompt = st.text_area(
|
||||||
"System Prompt",
|
"System Prompt",
|
||||||
value="You are a helpful assistant. ",
|
value="You are a helpful assistant. ",
|
||||||
help="Initial instructions given to the AI to set its behavior and context",
|
help="Initial instructions given to the AI to set its behavior and context",
|
||||||
|
on_change=reset_agent_and_chat,
|
||||||
|
disabled=should_disable_input(),
|
||||||
)
|
)
|
||||||
temperature = st.slider(
|
temperature = st.slider(
|
||||||
"Temperature",
|
"Temperature",
|
||||||
|
@ -90,6 +119,8 @@ def rag_chat_page():
|
||||||
value=0.0,
|
value=0.0,
|
||||||
step=0.1,
|
step=0.1,
|
||||||
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
help="Controls the randomness of the response. Higher values make the output more creative and unexpected, lower values make it more conservative and predictable",
|
||||||
|
on_change=reset_agent_and_chat,
|
||||||
|
disabled=should_disable_input(),
|
||||||
)
|
)
|
||||||
|
|
||||||
top_p = st.slider(
|
top_p = st.slider(
|
||||||
|
@ -98,19 +129,23 @@ def rag_chat_page():
|
||||||
max_value=1.0,
|
max_value=1.0,
|
||||||
value=0.95,
|
value=0.95,
|
||||||
step=0.1,
|
step=0.1,
|
||||||
|
on_change=reset_agent_and_chat,
|
||||||
|
disabled=should_disable_input(),
|
||||||
)
|
)
|
||||||
|
|
||||||
# Add clear chat button to sidebar
|
# Add clear chat button to sidebar
|
||||||
if st.button("Clear Chat", use_container_width=True):
|
if st.button("Clear Chat", use_container_width=True):
|
||||||
st.session_state.messages = []
|
reset_agent_and_chat()
|
||||||
st.rerun()
|
st.rerun()
|
||||||
|
|
||||||
# Chat Interface
|
# Chat Interface
|
||||||
if "messages" not in st.session_state:
|
if "messages" not in st.session_state:
|
||||||
st.session_state.messages = []
|
st.session_state.messages = []
|
||||||
|
if "displayed_messages" not in st.session_state:
|
||||||
|
st.session_state.displayed_messages = []
|
||||||
|
|
||||||
# Display chat history
|
# Display chat history
|
||||||
for message in st.session_state.messages:
|
for message in st.session_state.displayed_messages:
|
||||||
with st.chat_message(message["role"]):
|
with st.chat_message(message["role"]):
|
||||||
st.markdown(message["content"])
|
st.markdown(message["content"])
|
||||||
|
|
||||||
|
@ -123,33 +158,37 @@ def rag_chat_page():
|
||||||
else:
|
else:
|
||||||
strategy = {"type": "greedy"}
|
strategy = {"type": "greedy"}
|
||||||
|
|
||||||
agent = Agent(
|
@st.cache_resource
|
||||||
llama_stack_api.client,
|
def create_agent():
|
||||||
model=selected_model,
|
return Agent(
|
||||||
instructions=system_prompt,
|
llama_stack_api.client,
|
||||||
sampling_params={
|
model=selected_model,
|
||||||
"strategy": strategy,
|
instructions=system_prompt,
|
||||||
},
|
sampling_params={
|
||||||
tools=[
|
"strategy": strategy,
|
||||||
dict(
|
},
|
||||||
name="builtin::rag/knowledge_search",
|
tools=[
|
||||||
args={
|
dict(
|
||||||
"vector_db_ids": list(selected_vector_dbs),
|
name="builtin::rag/knowledge_search",
|
||||||
},
|
args={
|
||||||
)
|
"vector_db_ids": list(selected_vector_dbs),
|
||||||
],
|
},
|
||||||
)
|
)
|
||||||
session_id = agent.create_session("rag-session")
|
],
|
||||||
|
)
|
||||||
|
|
||||||
# Chat input
|
if rag_mode == "Agent-based":
|
||||||
if prompt := st.chat_input("Ask a question about your documents"):
|
agent = create_agent()
|
||||||
|
if "agent_session_id" not in st.session_state:
|
||||||
|
st.session_state["agent_session_id"] = agent.create_session(session_name=f"rag_demo_{uuid.uuid4()}")
|
||||||
|
|
||||||
|
session_id = st.session_state["agent_session_id"]
|
||||||
|
|
||||||
|
def agent_process_prompt(prompt):
|
||||||
# Add user message to chat history
|
# Add user message to chat history
|
||||||
st.session_state.messages.append({"role": "user", "content": prompt})
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
# Display user message
|
# Send the prompt to the agent
|
||||||
with st.chat_message("user"):
|
|
||||||
st.markdown(prompt)
|
|
||||||
|
|
||||||
response = agent.create_turn(
|
response = agent.create_turn(
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
@ -177,6 +216,79 @@ def rag_chat_page():
|
||||||
message_placeholder.markdown(full_response)
|
message_placeholder.markdown(full_response)
|
||||||
|
|
||||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||||
|
st.session_state.displayed_messages.append({"role": "assistant", "content": full_response})
|
||||||
|
|
||||||
|
def direct_process_prompt(prompt):
|
||||||
|
# Add the system prompt in the beginning of the conversation
|
||||||
|
if len(st.session_state.messages) == 0:
|
||||||
|
st.session_state.messages.append({"role": "system", "content": system_prompt})
|
||||||
|
|
||||||
|
# Query the vector DB
|
||||||
|
rag_response = llama_stack_api.client.tool_runtime.rag_tool.query(
|
||||||
|
content=prompt, vector_db_ids=list(selected_vector_dbs)
|
||||||
|
)
|
||||||
|
prompt_context = rag_response.content
|
||||||
|
|
||||||
|
with st.chat_message("assistant"):
|
||||||
|
retrieval_message_placeholder = st.empty()
|
||||||
|
message_placeholder = st.empty()
|
||||||
|
full_response = ""
|
||||||
|
retrieval_response = ""
|
||||||
|
|
||||||
|
# Display the retrieved content
|
||||||
|
retrieval_response += str(prompt_context)
|
||||||
|
retrieval_message_placeholder.info(retrieval_response)
|
||||||
|
|
||||||
|
# Construct the extended prompt
|
||||||
|
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
||||||
|
|
||||||
|
# Run inference directly
|
||||||
|
st.session_state.messages.append({"role": "user", "content": extended_prompt})
|
||||||
|
response = llama_stack_api.client.inference.chat_completion(
|
||||||
|
messages=st.session_state.messages,
|
||||||
|
model_id=selected_model,
|
||||||
|
sampling_params={
|
||||||
|
"strategy": strategy,
|
||||||
|
},
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Display assistant response
|
||||||
|
for chunk in response:
|
||||||
|
response_delta = chunk.event.delta
|
||||||
|
if isinstance(response_delta, ToolCallDelta):
|
||||||
|
retrieval_response += response_delta.tool_call.replace("====", "").strip()
|
||||||
|
retrieval_message_placeholder.info(retrieval_response)
|
||||||
|
else:
|
||||||
|
full_response += chunk.event.delta.text
|
||||||
|
message_placeholder.markdown(full_response + "▌")
|
||||||
|
message_placeholder.markdown(full_response)
|
||||||
|
|
||||||
|
response_dict = {"role": "assistant", "content": full_response, "stop_reason": "end_of_message"}
|
||||||
|
st.session_state.messages.append(response_dict)
|
||||||
|
st.session_state.displayed_messages.append(response_dict)
|
||||||
|
|
||||||
|
# Chat input
|
||||||
|
if prompt := st.chat_input("Ask a question about your documents"):
|
||||||
|
# Add user message to chat history
|
||||||
|
st.session_state.displayed_messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
# Display user message
|
||||||
|
with st.chat_message("user"):
|
||||||
|
st.markdown(prompt)
|
||||||
|
|
||||||
|
# store the prompt to process it after page refresh
|
||||||
|
st.session_state.prompt = prompt
|
||||||
|
|
||||||
|
# force page refresh to disable the settings widgets
|
||||||
|
st.rerun()
|
||||||
|
|
||||||
|
if "prompt" in st.session_state and st.session_state.prompt is not None:
|
||||||
|
if rag_mode == "Agent-based":
|
||||||
|
agent_process_prompt(st.session_state.prompt)
|
||||||
|
else: # rag_mode == "Direct"
|
||||||
|
direct_process_prompt(st.session_state.prompt)
|
||||||
|
st.session_state.prompt = None
|
||||||
|
|
||||||
|
|
||||||
rag_chat_page()
|
rag_chat_page()
|
||||||
|
|
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
116
llama_stack/distribution/ui/page/playground/tools.py
Normal file
|
@ -0,0 +1,116 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import uuid
|
||||||
|
|
||||||
|
import streamlit as st
|
||||||
|
from llama_stack_client import Agent
|
||||||
|
|
||||||
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
|
def tool_chat_page():
|
||||||
|
st.title("🛠 Tools")
|
||||||
|
|
||||||
|
client = llama_stack_api.client
|
||||||
|
models = client.models.list()
|
||||||
|
model_list = [model.identifier for model in models if model.api_model_type == "llm"]
|
||||||
|
|
||||||
|
tool_groups = client.toolgroups.list()
|
||||||
|
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||||
|
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||||
|
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||||
|
|
||||||
|
def reset_agent():
|
||||||
|
st.session_state.clear()
|
||||||
|
st.cache_resource.clear()
|
||||||
|
|
||||||
|
with st.sidebar:
|
||||||
|
st.subheader("Model")
|
||||||
|
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
|
||||||
|
|
||||||
|
st.subheader("Builtin Tools")
|
||||||
|
toolgroup_selection = st.pills(
|
||||||
|
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
st.subheader("MCP Servers")
|
||||||
|
mcp_selection = st.pills(
|
||||||
|
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
|
||||||
|
)
|
||||||
|
|
||||||
|
toolgroup_selection.extend(mcp_selection)
|
||||||
|
|
||||||
|
active_tool_list = []
|
||||||
|
for toolgroup_id in toolgroup_selection:
|
||||||
|
active_tool_list.extend(
|
||||||
|
[
|
||||||
|
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
||||||
|
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
||||||
|
]
|
||||||
|
)
|
||||||
|
|
||||||
|
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
||||||
|
st.json(active_tool_list)
|
||||||
|
|
||||||
|
@st.cache_resource
|
||||||
|
def create_agent():
|
||||||
|
return Agent(
|
||||||
|
client,
|
||||||
|
model=model,
|
||||||
|
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||||
|
tools=toolgroup_selection,
|
||||||
|
sampling_params={
|
||||||
|
"strategy": {"type": "greedy"},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = create_agent()
|
||||||
|
|
||||||
|
if "agent_session_id" not in st.session_state:
|
||||||
|
st.session_state["agent_session_id"] = agent.create_session(session_name=f"tool_demo_{uuid.uuid4()}")
|
||||||
|
|
||||||
|
session_id = st.session_state["agent_session_id"]
|
||||||
|
|
||||||
|
if "messages" not in st.session_state:
|
||||||
|
st.session_state["messages"] = [{"role": "assistant", "content": "How can I help you?"}]
|
||||||
|
|
||||||
|
for msg in st.session_state.messages:
|
||||||
|
with st.chat_message(msg["role"]):
|
||||||
|
st.markdown(msg["content"])
|
||||||
|
|
||||||
|
if prompt := st.chat_input(placeholder=""):
|
||||||
|
with st.chat_message("user"):
|
||||||
|
st.markdown(prompt)
|
||||||
|
|
||||||
|
st.session_state.messages.append({"role": "user", "content": prompt})
|
||||||
|
|
||||||
|
turn_response = agent.create_turn(
|
||||||
|
session_id=session_id,
|
||||||
|
messages=[{"role": "user", "content": prompt}],
|
||||||
|
stream=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
def response_generator(turn_response):
|
||||||
|
for response in turn_response:
|
||||||
|
if hasattr(response.event, "payload"):
|
||||||
|
print(response.event.payload)
|
||||||
|
if response.event.payload.event_type == "step_progress":
|
||||||
|
if hasattr(response.event.payload.delta, "text"):
|
||||||
|
yield response.event.payload.delta.text
|
||||||
|
if response.event.payload.event_type == "step_complete":
|
||||||
|
if response.event.payload.step_details.step_type == "tool_execution":
|
||||||
|
yield " 🛠 "
|
||||||
|
else:
|
||||||
|
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||||
|
|
||||||
|
with st.chat_message("assistant"):
|
||||||
|
response = st.write_stream(response_generator(turn_response))
|
||||||
|
|
||||||
|
st.session_state.messages.append({"role": "assistant", "content": response})
|
||||||
|
|
||||||
|
|
||||||
|
tool_chat_page()
|
|
@ -1,4 +1,5 @@
|
||||||
streamlit
|
streamlit
|
||||||
pandas
|
pandas
|
||||||
llama-stack-client>=0.0.55
|
llama-stack-client>=0.2.1
|
||||||
streamlit-option-menu
|
streamlit-option-menu
|
||||||
|
llama-stack>=0.2.1
|
||||||
|
|
30
llama_stack/distribution/utils/config.py
Normal file
30
llama_stack/distribution/utils/config.py
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
|
||||||
|
def redact_sensitive_fields(data: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
"""Redact sensitive information from config before printing."""
|
||||||
|
sensitive_patterns = ["api_key", "api_token", "password", "secret"]
|
||||||
|
|
||||||
|
def _redact_value(v: Any) -> Any:
|
||||||
|
if isinstance(v, dict):
|
||||||
|
return _redact_dict(v)
|
||||||
|
elif isinstance(v, list):
|
||||||
|
return [_redact_value(i) for i in v]
|
||||||
|
return v
|
||||||
|
|
||||||
|
def _redact_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
||||||
|
result = {}
|
||||||
|
for k, v in d.items():
|
||||||
|
if any(pattern in k.lower() for pattern in sensitive_patterns):
|
||||||
|
result[k] = "********"
|
||||||
|
else:
|
||||||
|
result[k] = _redact_value(v)
|
||||||
|
return result
|
||||||
|
|
||||||
|
return _redact_dict(data)
|
|
@ -29,6 +29,11 @@ def preserve_contexts_async_generator(
|
||||||
context_var.set(initial_context_values[context_var.name])
|
context_var.set(initial_context_values[context_var.name])
|
||||||
|
|
||||||
item = await gen.__anext__()
|
item = await gen.__anext__()
|
||||||
|
|
||||||
|
# Update our tracked values with any changes made during this iteration
|
||||||
|
for context_var in context_vars:
|
||||||
|
initial_context_values[context_var.name] = context_var.get()
|
||||||
|
|
||||||
yield item
|
yield item
|
||||||
|
|
||||||
except StopAsyncIteration:
|
except StopAsyncIteration:
|
||||||
|
|
164
llama_stack/models/llama/checkpoint.py
Normal file
164
llama_stack/models/llama/checkpoint.py
Normal file
|
@ -0,0 +1,164 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import concurrent.futures
|
||||||
|
import re
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Any, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
import numpy as np
|
||||||
|
import torch
|
||||||
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank, get_model_parallel_world_size
|
||||||
|
|
||||||
|
|
||||||
|
def map_mp_rank(old_mp_size: int, new_mp_size: int, new_mp_rank: int) -> List[int]:
|
||||||
|
"""Map a new MP rank to a list of old MP ranks given a change in MP size."""
|
||||||
|
if new_mp_size % old_mp_size == 0:
|
||||||
|
# Read old MP shard and split it into smaller ones
|
||||||
|
return [new_mp_rank * old_mp_size // new_mp_size]
|
||||||
|
elif old_mp_size % new_mp_size == 0:
|
||||||
|
# Merge old MP shards into a single one
|
||||||
|
mp_factor = old_mp_size // new_mp_size
|
||||||
|
return list(range(new_mp_rank * mp_factor, (new_mp_rank + 1) * mp_factor))
|
||||||
|
else:
|
||||||
|
raise ValueError(
|
||||||
|
f"Either old MP size or new MP size should be a multiple of the other: "
|
||||||
|
f"{old_mp_size} % {new_mp_size} != 0 and {new_mp_size} % {old_mp_size} != 0"
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def maybe_reshard_state_dict(
|
||||||
|
ckpt_paths: List[Path],
|
||||||
|
n_kv_heads: int,
|
||||||
|
moe_num_experts: Optional[int] = None,
|
||||||
|
map_location: Union[str, torch.device] = "cpu",
|
||||||
|
mmap: bool = True,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
if str(map_location) == "cpu":
|
||||||
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
|
||||||
|
ckpt_paths = np.array(sorted(ckpt_paths))
|
||||||
|
|
||||||
|
new_mp_size, new_mp_rank = get_model_parallel_world_size(), get_model_parallel_rank()
|
||||||
|
old_mp_size = len(ckpt_paths)
|
||||||
|
old_mp_ranks = map_mp_rank(old_mp_size, new_mp_size, new_mp_rank)
|
||||||
|
|
||||||
|
print(f"Loading checkpoint shards:\n{str(ckpt_paths[old_mp_ranks])}") # type: ignore
|
||||||
|
paths = ckpt_paths[old_mp_ranks] # type: ignore
|
||||||
|
state_dicts = [torch.load(str(p), map_location=map_location, mmap=mmap) for p in paths]
|
||||||
|
|
||||||
|
if new_mp_size == old_mp_size:
|
||||||
|
return state_dicts[0] # type: ignore
|
||||||
|
|
||||||
|
if moe_num_experts is not None:
|
||||||
|
state_dicts = [convert_moe_weights(d, moe_num_experts) for d in state_dicts]
|
||||||
|
|
||||||
|
print(f"Resharding {len(state_dicts)} state dicts from MP size {old_mp_size} to MP size {new_mp_size}")
|
||||||
|
return reshard_mp(
|
||||||
|
state_dicts,
|
||||||
|
size=max(new_mp_size // old_mp_size, 1),
|
||||||
|
rank=new_mp_rank % max(new_mp_size // old_mp_size, 1),
|
||||||
|
repeat_qk_qv=max(new_mp_size // n_kv_heads, 1),
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
_WEIGHT_ROW_KEY = {
|
||||||
|
"feed_forward.w2",
|
||||||
|
"feed_forward.mlp.fc2",
|
||||||
|
"attention.wo",
|
||||||
|
"feed_forward.mlp.fc2_weight",
|
||||||
|
"feed_forward.w_out_shared_DF.weight",
|
||||||
|
"attn.wo.weight",
|
||||||
|
"mlp.c_proj.weight",
|
||||||
|
}
|
||||||
|
_MOE_WEIGHT_ROW_KEY = {"feed_forward.experts.(moe_w_in_eD_F|moe_w_swiglu_eD_F)"}
|
||||||
|
|
||||||
|
_WEIGHT_COLUMN_KEY = {
|
||||||
|
"output",
|
||||||
|
"feed_forward.(w1|w3)",
|
||||||
|
"feed_forward.mlp.(fc1|fc3)",
|
||||||
|
"feed_forward.mlp.fc1_weight",
|
||||||
|
"attention.(wk|wq|wv|wqkv).weight",
|
||||||
|
"feed_forward.(w_in_shared_FD|w_swiglu_FD)",
|
||||||
|
"attn.(wk|wq|wv).weight",
|
||||||
|
"attn.(wk|wq|wv).bias",
|
||||||
|
"mlp.c_fc.weight",
|
||||||
|
"mlp.c_fc.bias",
|
||||||
|
"conv1._linear.weight",
|
||||||
|
"tok_embeddings.weight",
|
||||||
|
"vision_projection.weight",
|
||||||
|
}
|
||||||
|
_MOE_WEIGHT_COLUMN_KEY = {"feed_forward.experts.moe_w_out_eF_D"}
|
||||||
|
|
||||||
|
|
||||||
|
def reshard_mp(
|
||||||
|
state_dicts: List[Dict[str, torch.Tensor]],
|
||||||
|
size: int,
|
||||||
|
rank: int,
|
||||||
|
repeat_qk_qv: int = 1,
|
||||||
|
) -> Dict[str, torch.Tensor]:
|
||||||
|
"""
|
||||||
|
Reshard a list of state dicts into a single state dict given a change in MP size.
|
||||||
|
If the list has more than one state dict, we concatenate the values of the same
|
||||||
|
key across all state dicts. Otherwise, we just slice it for the current MP rank.
|
||||||
|
"""
|
||||||
|
|
||||||
|
def concat_or_chunk(tensors: List[torch.Tensor], dim: int) -> torch.Tensor:
|
||||||
|
if len(tensors) > 1:
|
||||||
|
return torch.cat(tensors, dim=dim)
|
||||||
|
return tensors[0].chunk(size, dim=dim)[rank].clone()
|
||||||
|
|
||||||
|
def process_key(key: str) -> torch.Tensor:
|
||||||
|
if row_regex.search(key):
|
||||||
|
return concat_or_chunk([s[key] for s in state_dicts], dim=-1)
|
||||||
|
elif column_regex.search(key):
|
||||||
|
if "w13" in key or "fc1_weight" in key:
|
||||||
|
dims = state_dicts[0][key].size()
|
||||||
|
values = [s[key].view(2, dims[0] // 2, *dims[1:]) for s in state_dicts]
|
||||||
|
return concat_or_chunk(values, dim=1).flatten(0, 1)
|
||||||
|
elif "qkv" in key:
|
||||||
|
q_dim = state_dicts[0][key.replace("qkv", "o")].size(1)
|
||||||
|
kv_dim = (state_dicts[0][key].size(0) - q_dim) // 2
|
||||||
|
values = [s[key].split((q_dim, kv_dim, kv_dim)) for s in state_dicts]
|
||||||
|
return torch.cat([concat_or_chunk(x, dim=0) for x in zip(*values, strict=False)]) # type: ignore
|
||||||
|
elif "wk.weight" in key or "wv.weight" in key:
|
||||||
|
# Support MP > #kv_head
|
||||||
|
return concat_or_chunk([s[key].repeat(repeat_qk_qv, 1) for s in state_dicts], dim=0)
|
||||||
|
elif key == "output.bias" or key == "fc.weight":
|
||||||
|
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||||
|
elif "w_" in key:
|
||||||
|
return concat_or_chunk([s[key] for s in state_dicts], dim=-2)
|
||||||
|
else:
|
||||||
|
return concat_or_chunk([s[key] for s in state_dicts], dim=0)
|
||||||
|
else:
|
||||||
|
return state_dicts[0][key].clone()
|
||||||
|
|
||||||
|
row_keys = _WEIGHT_ROW_KEY | _MOE_WEIGHT_ROW_KEY
|
||||||
|
column_keys = _WEIGHT_COLUMN_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||||
|
|
||||||
|
column_regex = re.compile("|".join(column_keys))
|
||||||
|
row_regex = re.compile("|".join(row_keys))
|
||||||
|
|
||||||
|
output: Dict[str, torch.Tensor] = {}
|
||||||
|
with concurrent.futures.ThreadPoolExecutor() as executor:
|
||||||
|
# Note: only processes keys in the first state dict.
|
||||||
|
# Assumes keys are the same across all state dicts.
|
||||||
|
mappings = {executor.submit(process_key, key): key for key in state_dicts[0]}
|
||||||
|
for future in concurrent.futures.as_completed(mappings):
|
||||||
|
output[mappings[future]] = future.result()
|
||||||
|
return output
|
||||||
|
|
||||||
|
|
||||||
|
def convert_moe_weights(state_dict: Dict[str, Any], num_experts: int) -> Dict[str, Any]:
|
||||||
|
routed_keys = _MOE_WEIGHT_ROW_KEY | _MOE_WEIGHT_COLUMN_KEY
|
||||||
|
routed_regex = re.compile("|".join(routed_keys))
|
||||||
|
keys = list(state_dict.keys())
|
||||||
|
for key in keys:
|
||||||
|
if routed_regex.search(key):
|
||||||
|
state_dict[key] = state_dict.pop(key).unflatten(0, (num_experts, -1)).squeeze(dim=0)
|
||||||
|
return state_dict
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from io import BytesIO
|
from io import BytesIO
|
||||||
|
@ -19,8 +12,6 @@ from typing import Any, Dict, List, Literal, Optional, Union
|
||||||
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
from pydantic import BaseModel, ConfigDict, Field, field_serializer, field_validator
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.schema_utils import json_schema_type, register_schema
|
|
||||||
|
|
||||||
# The goal is that these set of types are relevant for all Llama models.
|
# The goal is that these set of types are relevant for all Llama models.
|
||||||
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
# That isn't the current state yet -- e.g., BuiltinTool is somewhat specific to
|
||||||
# the llama3 series of models.
|
# the llama3 series of models.
|
||||||
|
@ -98,6 +89,29 @@ class StopReason(Enum):
|
||||||
out_of_tokens = "out_of_tokens"
|
out_of_tokens = "out_of_tokens"
|
||||||
|
|
||||||
|
|
||||||
|
class ToolParamDefinition(BaseModel):
|
||||||
|
param_type: str
|
||||||
|
description: Optional[str] = None
|
||||||
|
required: Optional[bool] = True
|
||||||
|
default: Optional[Any] = None
|
||||||
|
|
||||||
|
|
||||||
|
class ToolDefinition(BaseModel):
|
||||||
|
tool_name: Union[BuiltinTool, str]
|
||||||
|
description: Optional[str] = None
|
||||||
|
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
||||||
|
|
||||||
|
@field_validator("tool_name", mode="before")
|
||||||
|
@classmethod
|
||||||
|
def validate_field(cls, v):
|
||||||
|
if isinstance(v, str):
|
||||||
|
try:
|
||||||
|
return BuiltinTool(v)
|
||||||
|
except ValueError:
|
||||||
|
return v
|
||||||
|
return v
|
||||||
|
|
||||||
|
|
||||||
class RawMediaItem(BaseModel):
|
class RawMediaItem(BaseModel):
|
||||||
type: Literal["image"] = "image"
|
type: Literal["image"] = "image"
|
||||||
data: bytes | BytesIO
|
data: bytes | BytesIO
|
||||||
|
@ -140,267 +154,25 @@ class RawMessage(BaseModel):
|
||||||
tool_calls: List[ToolCall] = Field(default_factory=list)
|
tool_calls: List[ToolCall] = Field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
register_schema(ToolCall)
|
class GenerationResult(BaseModel):
|
||||||
|
token: int
|
||||||
|
text: str
|
||||||
|
logprobs: Optional[List[float]] = None
|
||||||
|
|
||||||
|
source: Literal["input"] | Literal["output"]
|
||||||
|
|
||||||
|
# index within the batch
|
||||||
|
batch_idx: int
|
||||||
|
# whether generation for this item is already finished. note that tokens can
|
||||||
|
# get returned even afterwards since other items in the batch can still be generating tokens
|
||||||
|
finished: bool
|
||||||
|
# because a batch is parallel processed, useful decoding for one item can correspond to processing
|
||||||
|
# pad tokens or tokens beyond EOS for other items. we could have decided to return None for this case
|
||||||
|
# but it's more convenient to return a list of GenerationResult and filter out the ignored tokens
|
||||||
|
ignore_token: bool
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
class QuantizationMode(str, Enum):
|
||||||
class ToolParamDefinition(BaseModel):
|
none = "none"
|
||||||
param_type: str
|
fp8_mixed = "fp8_mixed"
|
||||||
description: Optional[str] = None
|
int4_mixed = "int4_mixed"
|
||||||
required: Optional[bool] = True
|
|
||||||
default: Optional[Any] = None
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class ToolDefinition(BaseModel):
|
|
||||||
tool_name: Union[BuiltinTool, str]
|
|
||||||
description: Optional[str] = None
|
|
||||||
parameters: Optional[Dict[str, ToolParamDefinition]] = None
|
|
||||||
|
|
||||||
@field_validator("tool_name", mode="before")
|
|
||||||
@classmethod
|
|
||||||
def validate_field(cls, v):
|
|
||||||
if isinstance(v, str):
|
|
||||||
try:
|
|
||||||
return BuiltinTool(v)
|
|
||||||
except ValueError:
|
|
||||||
return v
|
|
||||||
return v
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class GreedySamplingStrategy(BaseModel):
|
|
||||||
type: Literal["greedy"] = "greedy"
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TopPSamplingStrategy(BaseModel):
|
|
||||||
type: Literal["top_p"] = "top_p"
|
|
||||||
temperature: Optional[float] = Field(..., gt=0.0)
|
|
||||||
top_p: Optional[float] = 0.95
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class TopKSamplingStrategy(BaseModel):
|
|
||||||
type: Literal["top_k"] = "top_k"
|
|
||||||
top_k: int = Field(..., ge=1)
|
|
||||||
|
|
||||||
|
|
||||||
SamplingStrategy = Annotated[
|
|
||||||
Union[GreedySamplingStrategy, TopPSamplingStrategy, TopKSamplingStrategy],
|
|
||||||
Field(discriminator="type"),
|
|
||||||
]
|
|
||||||
register_schema(SamplingStrategy, name="SamplingStrategy")
|
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
|
||||||
class SamplingParams(BaseModel):
|
|
||||||
"""Sampling parameters.
|
|
||||||
|
|
||||||
:param strategy: The sampling strategy.
|
|
||||||
:param max_tokens: The maximum number of tokens that can be generated in the completion. The token count of
|
|
||||||
your prompt plus max_tokens cannot exceed the model's context length.
|
|
||||||
:param repetition_penalty: Number between -2.0 and 2.0. Positive values penalize new tokens
|
|
||||||
based on whether they appear in the text so far, increasing the model's likelihood to talk about new topics.
|
|
||||||
:param stop: Up to 4 sequences where the API will stop generating further tokens.
|
|
||||||
The returned text will not contain the stop sequence.
|
|
||||||
"""
|
|
||||||
|
|
||||||
strategy: SamplingStrategy = Field(default_factory=GreedySamplingStrategy)
|
|
||||||
|
|
||||||
max_tokens: Optional[int] = 0
|
|
||||||
repetition_penalty: Optional[float] = 1.0
|
|
||||||
stop: Optional[List[str]] = None
|
|
||||||
|
|
||||||
|
|
||||||
class CheckpointQuantizationFormat(Enum):
|
|
||||||
# default format
|
|
||||||
bf16 = "bf16"
|
|
||||||
|
|
||||||
# used for enabling fp8_rowwise inference, some weights are bf16
|
|
||||||
fp8_mixed = "fp8-mixed"
|
|
||||||
|
|
||||||
int8 = "int8"
|
|
||||||
|
|
||||||
int4 = "int4"
|
|
||||||
|
|
||||||
|
|
||||||
class ModelFamily(Enum):
|
|
||||||
llama2 = "llama2"
|
|
||||||
llama3 = "llama3"
|
|
||||||
llama3_1 = "llama3_1"
|
|
||||||
llama3_2 = "llama3_2"
|
|
||||||
llama3_3 = "llama3_3"
|
|
||||||
safety = "safety"
|
|
||||||
|
|
||||||
|
|
||||||
class CoreModelId(Enum):
|
|
||||||
"""Each of these models is a unique "SKU". These root models can be served in various garbs (especially by quantizing them)"""
|
|
||||||
|
|
||||||
# Llama 2 family
|
|
||||||
llama2_7b = "Llama-2-7b"
|
|
||||||
llama2_13b = "Llama-2-13b"
|
|
||||||
llama2_70b = "Llama-2-70b"
|
|
||||||
llama2_7b_chat = "Llama-2-7b-chat"
|
|
||||||
llama2_13b_chat = "Llama-2-13b-chat"
|
|
||||||
llama2_70b_chat = "Llama-2-70b-chat"
|
|
||||||
|
|
||||||
# Llama 3 family
|
|
||||||
llama3_8b = "Llama-3-8B"
|
|
||||||
llama3_70b = "Llama-3-70B"
|
|
||||||
llama3_8b_instruct = "Llama-3-8B-Instruct"
|
|
||||||
llama3_70b_instruct = "Llama-3-70B-Instruct"
|
|
||||||
|
|
||||||
# Llama 3.1 family
|
|
||||||
llama3_1_8b = "Llama3.1-8B"
|
|
||||||
llama3_1_70b = "Llama3.1-70B"
|
|
||||||
llama3_1_405b = "Llama3.1-405B"
|
|
||||||
llama3_1_8b_instruct = "Llama3.1-8B-Instruct"
|
|
||||||
llama3_1_70b_instruct = "Llama3.1-70B-Instruct"
|
|
||||||
llama3_1_405b_instruct = "Llama3.1-405B-Instruct"
|
|
||||||
|
|
||||||
# Llama 3.2 family
|
|
||||||
llama3_2_1b = "Llama3.2-1B"
|
|
||||||
llama3_2_3b = "Llama3.2-3B"
|
|
||||||
llama3_2_1b_instruct = "Llama3.2-1B-Instruct"
|
|
||||||
llama3_2_3b_instruct = "Llama3.2-3B-Instruct"
|
|
||||||
llama3_2_11b_vision = "Llama3.2-11B-Vision"
|
|
||||||
llama3_2_90b_vision = "Llama3.2-90B-Vision"
|
|
||||||
llama3_2_11b_vision_instruct = "Llama3.2-11B-Vision-Instruct"
|
|
||||||
llama3_2_90b_vision_instruct = "Llama3.2-90B-Vision-Instruct"
|
|
||||||
|
|
||||||
# Llama 3.3 family
|
|
||||||
llama3_3_70b_instruct = "Llama3.3-70B-Instruct"
|
|
||||||
|
|
||||||
# Safety models
|
|
||||||
llama_guard_3_8b = "Llama-Guard-3-8B"
|
|
||||||
llama_guard_2_8b = "Llama-Guard-2-8B"
|
|
||||||
llama_guard_3_11b_vision = "Llama-Guard-3-11B-Vision"
|
|
||||||
llama_guard_3_1b = "Llama-Guard-3-1B"
|
|
||||||
|
|
||||||
|
|
||||||
def is_multimodal(model_id) -> bool:
|
|
||||||
if model_id in [
|
|
||||||
CoreModelId.llama3_2_11b_vision,
|
|
||||||
CoreModelId.llama3_2_90b_vision,
|
|
||||||
CoreModelId.llama3_2_11b_vision_instruct,
|
|
||||||
CoreModelId.llama3_2_90b_vision_instruct,
|
|
||||||
]:
|
|
||||||
return True
|
|
||||||
else:
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def model_family(model_id) -> ModelFamily:
|
|
||||||
if model_id in [
|
|
||||||
CoreModelId.llama2_7b,
|
|
||||||
CoreModelId.llama2_13b,
|
|
||||||
CoreModelId.llama2_70b,
|
|
||||||
CoreModelId.llama2_7b_chat,
|
|
||||||
CoreModelId.llama2_13b_chat,
|
|
||||||
CoreModelId.llama2_70b_chat,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama2
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama3_8b,
|
|
||||||
CoreModelId.llama3_70b,
|
|
||||||
CoreModelId.llama3_8b_instruct,
|
|
||||||
CoreModelId.llama3_70b_instruct,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama3
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama3_1_8b,
|
|
||||||
CoreModelId.llama3_1_70b,
|
|
||||||
CoreModelId.llama3_1_405b,
|
|
||||||
CoreModelId.llama3_1_8b_instruct,
|
|
||||||
CoreModelId.llama3_1_70b_instruct,
|
|
||||||
CoreModelId.llama3_1_405b_instruct,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama3_1
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama3_2_1b,
|
|
||||||
CoreModelId.llama3_2_3b,
|
|
||||||
CoreModelId.llama3_2_1b_instruct,
|
|
||||||
CoreModelId.llama3_2_3b_instruct,
|
|
||||||
CoreModelId.llama3_2_11b_vision,
|
|
||||||
CoreModelId.llama3_2_90b_vision,
|
|
||||||
CoreModelId.llama3_2_11b_vision_instruct,
|
|
||||||
CoreModelId.llama3_2_90b_vision_instruct,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama3_2
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama3_3_70b_instruct,
|
|
||||||
]:
|
|
||||||
return ModelFamily.llama3_3
|
|
||||||
elif model_id in [
|
|
||||||
CoreModelId.llama_guard_3_8b,
|
|
||||||
CoreModelId.llama_guard_2_8b,
|
|
||||||
CoreModelId.llama_guard_3_11b_vision,
|
|
||||||
CoreModelId.llama_guard_3_1b,
|
|
||||||
]:
|
|
||||||
return ModelFamily.safety
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown model family for {model_id}")
|
|
||||||
|
|
||||||
|
|
||||||
class Model(BaseModel):
|
|
||||||
core_model_id: CoreModelId
|
|
||||||
description: str
|
|
||||||
huggingface_repo: Optional[str] = None
|
|
||||||
recommended_sampling_params: Optional[SamplingParams] = None
|
|
||||||
arch_args: Dict[str, Any]
|
|
||||||
variant: str = ""
|
|
||||||
|
|
||||||
quantization_format: CheckpointQuantizationFormat = CheckpointQuantizationFormat.bf16
|
|
||||||
pth_file_count: int
|
|
||||||
metadata: Optional[Dict[str, Any]] = Field(default_factory=dict)
|
|
||||||
|
|
||||||
# silence pydantic until we remove the `model_` fields
|
|
||||||
model_config = ConfigDict(protected_namespaces=())
|
|
||||||
|
|
||||||
@property
|
|
||||||
def model_family(self) -> ModelFamily:
|
|
||||||
return model_family(self.core_model_id)
|
|
||||||
|
|
||||||
# The SKU is uniquely identified by (model_id, variant) combo
|
|
||||||
def descriptor(self, shorten_default_variant: bool = True) -> str:
|
|
||||||
if not self.variant:
|
|
||||||
return self.core_model_id.value
|
|
||||||
return f"{self.core_model_id.value}:{self.variant}"
|
|
||||||
|
|
||||||
@property
|
|
||||||
def is_instruct_model(self) -> bool:
|
|
||||||
return "instruct" in self.id.name
|
|
||||||
|
|
||||||
# Featured models are shown in the non-exhaustive model list
|
|
||||||
@property
|
|
||||||
def is_featured(self) -> bool:
|
|
||||||
return self.model_family in [
|
|
||||||
ModelFamily.llama3_1,
|
|
||||||
ModelFamily.llama3_2,
|
|
||||||
ModelFamily.llama3_3,
|
|
||||||
ModelFamily.safety,
|
|
||||||
]
|
|
||||||
|
|
||||||
@property
|
|
||||||
def max_seq_length(self) -> int:
|
|
||||||
if self.model_family == ModelFamily.llama2:
|
|
||||||
return 4096
|
|
||||||
elif self.core_model_id == CoreModelId.llama_guard_2_8b:
|
|
||||||
return 4096
|
|
||||||
elif self.model_family == ModelFamily.llama3:
|
|
||||||
return 8192
|
|
||||||
elif self.model_family in [ModelFamily.llama3_1, ModelFamily.llama3_3]:
|
|
||||||
return 131072
|
|
||||||
elif self.model_family == ModelFamily.llama3_2:
|
|
||||||
if self.quantization_format == CheckpointQuantizationFormat.int4:
|
|
||||||
return 8192
|
|
||||||
return 131072
|
|
||||||
elif self.core_model_id in [
|
|
||||||
CoreModelId.llama_guard_3_8b,
|
|
||||||
CoreModelId.llama_guard_3_11b_vision,
|
|
||||||
CoreModelId.llama_guard_3_1b,
|
|
||||||
]:
|
|
||||||
return 131072
|
|
||||||
else:
|
|
||||||
raise ValueError(f"Unknown max_seq_len for {self.core_model_id}")
|
|
||||||
|
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Optional
|
from typing import Optional
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
import io
|
import io
|
||||||
import json
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
@ -19,7 +12,7 @@ from typing import Dict, List, Optional, Tuple
|
||||||
|
|
||||||
from PIL import Image as PIL_Image
|
from PIL import Image as PIL_Image
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from ..datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawContent,
|
RawContent,
|
||||||
RawMediaItem,
|
RawMediaItem,
|
||||||
|
@ -30,7 +23,6 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from .tokenizer import Tokenizer
|
from .tokenizer import Tokenizer
|
||||||
from .tool_utils import ToolUtils
|
from .tool_utils import ToolUtils
|
||||||
|
|
||||||
|
@ -234,7 +226,6 @@ class ChatFormat:
|
||||||
arguments_json=json.dumps(tool_arguments),
|
arguments_json=json.dumps(tool_arguments),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
content = ""
|
|
||||||
|
|
||||||
return RawMessage(
|
return RawMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
|
371
llama_stack/models/llama/llama3/generation.py
Normal file
371
llama_stack/models/llama/llama3/generation.py
Normal file
|
@ -0,0 +1,371 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import json
|
||||||
|
import os
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
from pathlib import Path
|
||||||
|
from typing import Callable, Generator, List, Optional
|
||||||
|
|
||||||
|
import torch
|
||||||
|
import torch.nn.functional as F
|
||||||
|
from fairscale.nn.model_parallel.initialize import (
|
||||||
|
initialize_model_parallel,
|
||||||
|
model_parallel_is_initialized,
|
||||||
|
)
|
||||||
|
from termcolor import cprint
|
||||||
|
|
||||||
|
from ..checkpoint import maybe_reshard_state_dict
|
||||||
|
from ..datatypes import GenerationResult, QuantizationMode, RawContent, RawMessage, ToolPromptFormat
|
||||||
|
from .args import ModelArgs
|
||||||
|
from .chat_format import ChatFormat, LLMInput
|
||||||
|
from .model import Transformer
|
||||||
|
from .multimodal.model import CrossAttentionTransformer
|
||||||
|
from .tokenizer import Tokenizer
|
||||||
|
|
||||||
|
|
||||||
|
class Llama3:
|
||||||
|
@staticmethod
|
||||||
|
def build(
|
||||||
|
ckpt_dir: str,
|
||||||
|
max_seq_len: int,
|
||||||
|
max_batch_size: int,
|
||||||
|
world_size: Optional[int] = None,
|
||||||
|
quantization_mode: Optional[QuantizationMode] = None,
|
||||||
|
seed: int = 1,
|
||||||
|
device: str = "cuda",
|
||||||
|
):
|
||||||
|
device = torch.device(device)
|
||||||
|
if (
|
||||||
|
device.type == "cuda"
|
||||||
|
and not torch.cuda.is_available()
|
||||||
|
or device.type == "xpu"
|
||||||
|
and not torch.xpu.is_available()
|
||||||
|
):
|
||||||
|
raise RuntimeError(f"PyTorch backend for {device.type} device type is not available")
|
||||||
|
|
||||||
|
if not torch.distributed.is_initialized():
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.distributed.init_process_group("nccl")
|
||||||
|
else:
|
||||||
|
torch.distributed.init_process_group("gloo")
|
||||||
|
|
||||||
|
if not model_parallel_is_initialized():
|
||||||
|
if world_size is None:
|
||||||
|
world_size = int(os.environ.get("WORLD_SIZE", 1))
|
||||||
|
initialize_model_parallel(world_size)
|
||||||
|
|
||||||
|
local_rank = int(os.environ.get("LOCAL_RANK", 0))
|
||||||
|
if device.type == "cuda":
|
||||||
|
torch.cuda.set_device(local_rank)
|
||||||
|
elif device.type == "xpu":
|
||||||
|
torch.xpu.set_device(local_rank)
|
||||||
|
|
||||||
|
torch.manual_seed(seed)
|
||||||
|
|
||||||
|
if local_rank > 0:
|
||||||
|
sys.stdout = open(os.devnull, "w")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
|
||||||
|
ckpt_paths = sorted(Path(ckpt_dir).glob("*.pth"))
|
||||||
|
assert len(ckpt_paths) > 0, f"no checkpoint files found in {ckpt_dir}"
|
||||||
|
print(f"Loading a checkpoint (shards={len(ckpt_paths)}, current-mp-size={world_size})")
|
||||||
|
with open(Path(ckpt_dir) / "params.json", "r") as f:
|
||||||
|
params = json.loads(f.read())
|
||||||
|
|
||||||
|
model_args: ModelArgs = ModelArgs(
|
||||||
|
max_seq_len=max_seq_len,
|
||||||
|
max_batch_size=max_batch_size,
|
||||||
|
**params,
|
||||||
|
)
|
||||||
|
tokenizer = Tokenizer.get_instance()
|
||||||
|
|
||||||
|
state_dict = maybe_reshard_state_dict(
|
||||||
|
ckpt_paths,
|
||||||
|
n_kv_heads=model_args.n_kv_heads if model_args.n_kv_heads else model_args.n_heads,
|
||||||
|
)
|
||||||
|
|
||||||
|
assert model_args.vocab_size == tokenizer.n_words
|
||||||
|
|
||||||
|
def build_model():
|
||||||
|
if model_args.vision_chunk_size > 0:
|
||||||
|
model = CrossAttentionTransformer(model_args)
|
||||||
|
model.setup_cache(model_args.max_batch_size, device=device, dtype=torch.get_default_dtype())
|
||||||
|
else:
|
||||||
|
model = Transformer(model_args)
|
||||||
|
return model
|
||||||
|
|
||||||
|
if quantization_mode == QuantizationMode.fp8_mixed or quantization_mode == QuantizationMode.int4_mixed:
|
||||||
|
from .quantization.loader import convert_to_quantized_model
|
||||||
|
|
||||||
|
torch.set_default_tensor_type(torch.BFloat16Tensor)
|
||||||
|
model = build_model()
|
||||||
|
print("Loading state dict...")
|
||||||
|
model.load_state_dict(state_dict, strict=False)
|
||||||
|
print("Done...")
|
||||||
|
model = convert_to_quantized_model(model, ckpt_dir, quantization_mode, device=device)
|
||||||
|
torch.set_default_device(device)
|
||||||
|
else:
|
||||||
|
print(f"Setting default device to {device}")
|
||||||
|
if device.type == "cuda":
|
||||||
|
if torch.cuda.is_bf16_supported():
|
||||||
|
torch.set_default_tensor_type(torch.cuda.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.cuda.Float16Tensor)
|
||||||
|
elif device.type == "xpu":
|
||||||
|
if torch.xpu.is_bf16_supported():
|
||||||
|
torch.set_default_tensor_type(torch.xpu.BFloat16Tensor)
|
||||||
|
else:
|
||||||
|
torch.set_default_tensor_type(torch.xpu.Float16Tensor)
|
||||||
|
|
||||||
|
model = build_model()
|
||||||
|
print("Loading state dict...")
|
||||||
|
model.load_state_dict(state_dict, strict=True)
|
||||||
|
model.to(device)
|
||||||
|
print("Done...")
|
||||||
|
|
||||||
|
print(f"Loaded in {time.time() - start_time:.2f} seconds")
|
||||||
|
|
||||||
|
return Llama3(model, tokenizer, model_args)
|
||||||
|
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
model: Transformer | CrossAttentionTransformer,
|
||||||
|
tokenizer: Tokenizer,
|
||||||
|
args: ModelArgs,
|
||||||
|
):
|
||||||
|
self.args = args
|
||||||
|
self.model = model
|
||||||
|
self.tokenizer = tokenizer
|
||||||
|
self.formatter = ChatFormat(tokenizer)
|
||||||
|
|
||||||
|
@torch.inference_mode()
|
||||||
|
def generate(
|
||||||
|
self,
|
||||||
|
llm_inputs: List[LLMInput],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
echo: bool = False,
|
||||||
|
print_model_input: bool = False,
|
||||||
|
logits_processor: Optional[Callable[[torch.Tensor, torch.Tensor], torch.Tensor]] = None,
|
||||||
|
) -> Generator[List[GenerationResult], None, None]:
|
||||||
|
if max_gen_len is None or max_gen_len == 0 or max_gen_len >= self.args.max_seq_len:
|
||||||
|
max_gen_len = self.args.max_seq_len - 1
|
||||||
|
params = self.model.params
|
||||||
|
|
||||||
|
print_model_input = print_model_input or os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1"
|
||||||
|
if print_model_input:
|
||||||
|
for inp in llm_inputs:
|
||||||
|
tokens_to_print = [self.formatter.vision_token if t == 128256 else t for t in inp.tokens]
|
||||||
|
cprint(
|
||||||
|
"Input to model:\n" + self.tokenizer.decode(tokens_to_print) + "\n",
|
||||||
|
"red",
|
||||||
|
)
|
||||||
|
prompt_tokens = [inp.tokens for inp in llm_inputs]
|
||||||
|
|
||||||
|
bsz = len(llm_inputs)
|
||||||
|
assert bsz <= params.max_batch_size, (bsz, params.max_batch_size)
|
||||||
|
|
||||||
|
min_prompt_len = min(len(t) for t in prompt_tokens)
|
||||||
|
max_prompt_len = max(len(t) for t in prompt_tokens)
|
||||||
|
|
||||||
|
if max_prompt_len >= params.max_seq_len:
|
||||||
|
cprint(f"Out of token budget {max_prompt_len} vs {params.max_seq_len}", "red")
|
||||||
|
return
|
||||||
|
|
||||||
|
total_len = min(max_gen_len + max_prompt_len, params.max_seq_len)
|
||||||
|
|
||||||
|
pad_id = self.tokenizer.pad_id
|
||||||
|
tokens = torch.full((bsz, total_len), pad_id, dtype=torch.long)
|
||||||
|
for k, t in enumerate(prompt_tokens):
|
||||||
|
tokens[k, : len(t)] = torch.tensor(t, dtype=torch.long)
|
||||||
|
if logprobs:
|
||||||
|
token_logprobs = torch.zeros_like(tokens, dtype=torch.float)
|
||||||
|
|
||||||
|
is_vision = not isinstance(self.model, Transformer)
|
||||||
|
if is_vision:
|
||||||
|
images = [inp.vision.images if inp.vision is not None else [] for inp in llm_inputs]
|
||||||
|
mask = [inp.vision.mask if inp.vision is not None else [] for inp in llm_inputs]
|
||||||
|
|
||||||
|
xattn_caches, cross_attention_masks, full_text_row_masked_out_mask = self.model.compute_vision_tokens_masks(
|
||||||
|
batch_images=images,
|
||||||
|
batch_masks=mask,
|
||||||
|
total_len=total_len,
|
||||||
|
device=tokens.device,
|
||||||
|
)
|
||||||
|
|
||||||
|
eos_reached = torch.tensor([False] * bsz)
|
||||||
|
input_text_mask = tokens != pad_id
|
||||||
|
|
||||||
|
if echo:
|
||||||
|
for i in range(max_prompt_len):
|
||||||
|
results = []
|
||||||
|
for j, t in enumerate(tokens[:, i]):
|
||||||
|
results.append(
|
||||||
|
GenerationResult(
|
||||||
|
token=t.item(),
|
||||||
|
text=self.tokenizer.decode([t.item()]),
|
||||||
|
source="input",
|
||||||
|
logprobs=(token_logprobs[j, i : i + 1].tolist() if logprobs else None),
|
||||||
|
batch_idx=j,
|
||||||
|
finished=False,
|
||||||
|
ignore_token=t.item() == pad_id,
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield results
|
||||||
|
|
||||||
|
stop_tokens = torch.tensor(self.tokenizer.stop_tokens)
|
||||||
|
|
||||||
|
prev_pos = 0
|
||||||
|
for cur_pos in range(min_prompt_len, total_len):
|
||||||
|
if is_vision:
|
||||||
|
position_ids = torch.arange(prev_pos, cur_pos, dtype=torch.long)
|
||||||
|
text_only_inference = all(inp.vision is None for inp in llm_inputs)
|
||||||
|
logits = self.model.forward(
|
||||||
|
position_ids,
|
||||||
|
tokens,
|
||||||
|
cross_attention_masks,
|
||||||
|
full_text_row_masked_out_mask,
|
||||||
|
xattn_caches,
|
||||||
|
text_only_inference,
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
logits = self.model.forward(tokens[:, prev_pos:cur_pos], prev_pos)
|
||||||
|
|
||||||
|
if logits_processor is not None:
|
||||||
|
logits = logits_processor(tokens[:, :cur_pos], logits)
|
||||||
|
|
||||||
|
if temperature > 0:
|
||||||
|
probs = torch.softmax(logits[:, -1] / temperature, dim=-1)
|
||||||
|
next_token = sample_top_p(probs, top_p)
|
||||||
|
else:
|
||||||
|
next_token = torch.argmax(logits[:, -1], dim=-1)
|
||||||
|
|
||||||
|
next_token = next_token.reshape(-1)
|
||||||
|
# only replace token if prompt has already been generated
|
||||||
|
next_token = torch.where(input_text_mask[:, cur_pos], tokens[:, cur_pos], next_token)
|
||||||
|
tokens[:, cur_pos] = next_token
|
||||||
|
|
||||||
|
target = tokens[:, prev_pos + 1 : cur_pos + 1]
|
||||||
|
if is_vision:
|
||||||
|
# the logits space (num_classes) is designed to never contain a media_token
|
||||||
|
# however our input token stream does contain them. we need to nuke them here
|
||||||
|
# or else the CUDA kernels will crash with an illegal memory access
|
||||||
|
vision_tokens = [self.tokenizer.special_tokens["<|image|>"], 128256]
|
||||||
|
masks = [target.eq(t) for t in vision_tokens]
|
||||||
|
if len(masks) > 1:
|
||||||
|
mask = torch.logical_or(*masks)
|
||||||
|
else:
|
||||||
|
mask = masks[0]
|
||||||
|
target[mask] = 0
|
||||||
|
|
||||||
|
if logprobs:
|
||||||
|
token_logprobs[:, prev_pos + 1 : cur_pos + 1] = -F.cross_entropy(
|
||||||
|
input=logits.transpose(1, 2),
|
||||||
|
target=target,
|
||||||
|
reduction="none",
|
||||||
|
ignore_index=pad_id,
|
||||||
|
)
|
||||||
|
eos_reached |= (~input_text_mask[:, cur_pos]) & (torch.isin(next_token, stop_tokens))
|
||||||
|
results = []
|
||||||
|
for idx, t in enumerate(next_token):
|
||||||
|
results.append(
|
||||||
|
GenerationResult(
|
||||||
|
token=t.item(),
|
||||||
|
text=self.tokenizer.decode([t.item()]),
|
||||||
|
source="output",
|
||||||
|
logprobs=(token_logprobs[idx, cur_pos : cur_pos + 1].tolist() if logprobs else None),
|
||||||
|
batch_idx=idx,
|
||||||
|
finished=eos_reached[idx].item(),
|
||||||
|
ignore_token=cur_pos < len(prompt_tokens[idx]),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
yield results
|
||||||
|
|
||||||
|
prev_pos = cur_pos
|
||||||
|
if all(eos_reached):
|
||||||
|
break
|
||||||
|
|
||||||
|
def completion(
|
||||||
|
self,
|
||||||
|
contents: List[RawContent],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
echo: bool = False,
|
||||||
|
) -> Generator[List[GenerationResult], None, None]:
|
||||||
|
model_inputs = [self.formatter.encode_content(c) for c in contents]
|
||||||
|
for result in self.generate(
|
||||||
|
model_inputs=model_inputs,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
logprobs=logprobs,
|
||||||
|
echo=echo,
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
if all(r.finished for r in result):
|
||||||
|
break
|
||||||
|
|
||||||
|
def chat_completion(
|
||||||
|
self,
|
||||||
|
messages_batch: List[List[RawMessage]],
|
||||||
|
temperature: float = 0.6,
|
||||||
|
top_p: float = 0.9,
|
||||||
|
max_gen_len: Optional[int] = None,
|
||||||
|
logprobs: bool = False,
|
||||||
|
tool_prompt_format: ToolPromptFormat = ToolPromptFormat.json,
|
||||||
|
echo: bool = False,
|
||||||
|
) -> Generator[List[GenerationResult], None, None]:
|
||||||
|
model_inputs = [self.formatter.encode_dialog_prompt(messages) for messages in messages_batch]
|
||||||
|
for result in self.generate(
|
||||||
|
model_inputs=model_inputs,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
max_gen_len=max_gen_len,
|
||||||
|
logprobs=logprobs,
|
||||||
|
echo=echo,
|
||||||
|
):
|
||||||
|
yield result
|
||||||
|
if all(r.finished for r in result):
|
||||||
|
break
|
||||||
|
|
||||||
|
|
||||||
|
def sample_top_p(probs, p):
|
||||||
|
"""
|
||||||
|
Perform top-p (nucleus) sampling on a probability distribution.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
probs (torch.Tensor): Probability distribution tensor.
|
||||||
|
p (float): Probability threshold for top-p sampling.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
torch.Tensor: Sampled token indices.
|
||||||
|
|
||||||
|
Note:
|
||||||
|
Top-p sampling selects the smallest set of tokens whose cumulative probability mass
|
||||||
|
exceeds the threshold p. The distribution is renormalized based on the selected tokens.
|
||||||
|
"""
|
||||||
|
probs_sort, probs_idx = torch.sort(probs, dim=-1, descending=True)
|
||||||
|
probs_sum = torch.cumsum(probs_sort, dim=-1)
|
||||||
|
mask = probs_sum - probs_sort > p
|
||||||
|
probs_sort[mask] = 0.0
|
||||||
|
probs_sort.div_(probs_sort.sum(dim=-1, keepdim=True))
|
||||||
|
next_token = torch.multinomial(probs_sort, num_samples=1)
|
||||||
|
next_token = torch.gather(probs_idx, -1, next_token)
|
||||||
|
return next_token
|
|
@ -16,7 +16,7 @@ from typing import List, Optional
|
||||||
|
|
||||||
from termcolor import colored
|
from termcolor import colored
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from ..datatypes import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
RawMessage,
|
RawMessage,
|
||||||
StopReason,
|
StopReason,
|
||||||
|
@ -24,7 +24,6 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
|
||||||
from . import template_data
|
from . import template_data
|
||||||
from .chat_format import ChatFormat
|
from .chat_format import ChatFormat
|
||||||
from .prompt_templates import (
|
from .prompt_templates import (
|
||||||
|
|
|
@ -4,16 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
||||||
|
|
||||||
import math
|
import math
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
@ -29,6 +19,10 @@ from torch import nn
|
||||||
|
|
||||||
from .args import ModelArgs
|
from .args import ModelArgs
|
||||||
|
|
||||||
|
# **NOTE**: This code is not runnable without installing `torch` and `fairscale`
|
||||||
|
# dependencies. These dependencies are not part of the default dependencies
|
||||||
|
# (requirements.txt) of the `llama-models` package.
|
||||||
|
|
||||||
|
|
||||||
class RMSNorm(torch.nn.Module):
|
class RMSNorm(torch.nn.Module):
|
||||||
def __init__(self, dim: int, eps: float = 1e-6):
|
def __init__(self, dim: int, eps: float = 1e-6):
|
||||||
|
@ -111,9 +105,9 @@ class Attention(nn.Module):
|
||||||
def __init__(self, args: ModelArgs):
|
def __init__(self, args: ModelArgs):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
world_size = fs_init.get_model_parallel_world_size()
|
||||||
self.n_local_heads = args.n_heads // model_parallel_size
|
self.n_local_heads = args.n_heads // world_size
|
||||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
self.head_dim = args.dim // args.n_heads
|
self.head_dim = args.dim // args.n_heads
|
||||||
|
|
|
@ -4,16 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# This software may be used and distributed according to the terms of the Llama 2 Community License Agreement.
|
|
||||||
|
|
||||||
import logging
|
import logging
|
||||||
import math
|
import math
|
||||||
from functools import partial
|
from functools import partial
|
||||||
|
@ -180,14 +170,14 @@ class ImageAttention(nn.Module):
|
||||||
n_heads,
|
n_heads,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
world_size = fs_init.get_model_parallel_world_size()
|
||||||
qkvo_replication = 1
|
qkvo_replication = 1
|
||||||
if model_parallel_size > 16:
|
if world_size > 16:
|
||||||
qkvo_replication = model_parallel_size // 8
|
qkvo_replication = world_size // 8
|
||||||
|
|
||||||
self.n_kv_heads = n_heads
|
self.n_kv_heads = n_heads
|
||||||
self.n_local_heads = n_heads * qkvo_replication // model_parallel_size
|
self.n_local_heads = n_heads * qkvo_replication // world_size
|
||||||
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads * qkvo_replication // world_size
|
||||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
self.head_dim = dim // n_heads
|
self.head_dim = dim // n_heads
|
||||||
|
|
||||||
|
@ -536,16 +526,16 @@ class Attention(nn.Module):
|
||||||
cache_v (torch.Tensor): Cached values for attention.
|
cache_v (torch.Tensor): Cached values for attention.
|
||||||
"""
|
"""
|
||||||
super().__init__()
|
super().__init__()
|
||||||
model_parallel_size = fs_init.get_model_parallel_world_size()
|
world_size = fs_init.get_model_parallel_world_size()
|
||||||
replication_factor = 1
|
replication_factor = 1
|
||||||
if model_parallel_size > 8:
|
if world_size > 8:
|
||||||
replication_factor = model_parallel_size // MP_SCALE
|
replication_factor = world_size // MP_SCALE
|
||||||
|
|
||||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||||
self.n_kv_heads *= replication_factor
|
self.n_kv_heads *= replication_factor
|
||||||
|
|
||||||
self.n_local_heads = args.n_heads // model_parallel_size
|
self.n_local_heads = args.n_heads // world_size
|
||||||
self.n_local_kv_heads = self.n_kv_heads // model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads // world_size
|
||||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
self.head_dim = args.dim // args.n_heads
|
self.head_dim = args.dim // args.n_heads
|
||||||
self.max_seq_len = args.max_seq_len
|
self.max_seq_len = args.max_seq_len
|
||||||
|
@ -587,13 +577,11 @@ class Attention(nn.Module):
|
||||||
self.n_local_kv_heads,
|
self.n_local_kv_heads,
|
||||||
self.head_dim,
|
self.head_dim,
|
||||||
)
|
)
|
||||||
device = next(self.parameters()).device
|
|
||||||
self.register_buffer(
|
self.register_buffer(
|
||||||
"key_cache",
|
"key_cache",
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
cache_shape,
|
cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
|
||||||
),
|
),
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
|
@ -602,7 +590,6 @@ class Attention(nn.Module):
|
||||||
torch.zeros(
|
torch.zeros(
|
||||||
cache_shape,
|
cache_shape,
|
||||||
dtype=dtype,
|
dtype=dtype,
|
||||||
device=device,
|
|
||||||
),
|
),
|
||||||
persistent=False,
|
persistent=False,
|
||||||
)
|
)
|
||||||
|
@ -614,6 +601,9 @@ class Attention(nn.Module):
|
||||||
freqs_cis: torch.Tensor,
|
freqs_cis: torch.Tensor,
|
||||||
position_ids: torch.LongTensor,
|
position_ids: torch.LongTensor,
|
||||||
):
|
):
|
||||||
|
self.key_cache = self.key_cache.to(x.device)
|
||||||
|
self.value_cache = self.value_cache.to(x.device)
|
||||||
|
|
||||||
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
|
xq, xk, xv = [F.linear(x, w) for w in [self.wq.weight, self.wk.weight, self.wv.weight]]
|
||||||
|
|
||||||
bs, slen, _ = xq.shape
|
bs, slen, _ = xq.shape
|
||||||
|
@ -832,10 +822,10 @@ class CrossAttention(torch.nn.Module):
|
||||||
norm_eps: float,
|
norm_eps: float,
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_parallel_size = fs_init.get_model_parallel_world_size()
|
self.world_size = fs_init.get_model_parallel_world_size()
|
||||||
replication_factor = 1
|
replication_factor = 1
|
||||||
if self.model_parallel_size > 8:
|
if self.world_size > 8:
|
||||||
replication_factor = self.model_parallel_size // MP_SCALE
|
replication_factor = self.world_size // MP_SCALE
|
||||||
n_kv_heads *= replication_factor
|
n_kv_heads *= replication_factor
|
||||||
|
|
||||||
assert n_heads % n_kv_heads == 0
|
assert n_heads % n_kv_heads == 0
|
||||||
|
@ -889,10 +879,10 @@ class CrossAttention(torch.nn.Module):
|
||||||
# trunk LLM (i.e., group query attention) -- @dubeya
|
# trunk LLM (i.e., group query attention) -- @dubeya
|
||||||
# local heads
|
# local heads
|
||||||
assert self.n_heads % self.n_kv_heads == 0
|
assert self.n_heads % self.n_kv_heads == 0
|
||||||
assert self.n_heads % self.model_parallel_size == 0
|
assert self.n_heads % self.world_size == 0
|
||||||
assert self.n_kv_heads % self.model_parallel_size == 0
|
assert self.n_kv_heads % self.world_size == 0
|
||||||
self.n_local_heads = self.n_heads // self.model_parallel_size
|
self.n_local_heads = self.n_heads // self.world_size
|
||||||
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
||||||
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
self.n_rep = self.n_local_heads // self.n_local_kv_heads
|
||||||
|
|
||||||
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
|
def _compute_xattn_kv_cache(self, xattn_tokens: torch.Tensor) -> torch.Tensor:
|
||||||
|
@ -1041,7 +1031,7 @@ class CrossAttentionTransformerVision(torch.nn.Module):
|
||||||
self.image_res = args.vision_chunk_size
|
self.image_res = args.vision_chunk_size
|
||||||
self.max_num_chunks = args.vision_max_num_chunks
|
self.max_num_chunks = args.vision_max_num_chunks
|
||||||
if return_intermediate is not None:
|
if return_intermediate is not None:
|
||||||
return_intermediate = [int(level) for level in return_intermediate.split(",")]
|
return_intermediate = [int(layer) for layer in return_intermediate.split(",")]
|
||||||
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
|
self.vision_input_dim = (len(return_intermediate) + 1) * self.vision_input_dim
|
||||||
self.patch_size = 14
|
self.patch_size = 14
|
||||||
self.vision_encoder = VisionEncoder(
|
self.vision_encoder = VisionEncoder(
|
||||||
|
@ -1076,15 +1066,15 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
|
|
||||||
def __init__(self, args: ModelArgs) -> None:
|
def __init__(self, args: ModelArgs) -> None:
|
||||||
super().__init__()
|
super().__init__()
|
||||||
self.model_parallel_size = fs_init.get_model_parallel_world_size()
|
self.world_size = fs_init.get_model_parallel_world_size()
|
||||||
assert args.vocab_size > 0
|
assert args.vocab_size > 0
|
||||||
self.vocab_size = args.vocab_size
|
self.vocab_size = args.vocab_size
|
||||||
self.n_layers = args.n_layers
|
self.n_layers = args.n_layers
|
||||||
self.dim = args.dim
|
self.dim = args.dim
|
||||||
self.head_dim = args.dim // args.n_heads
|
self.head_dim = args.dim // args.n_heads
|
||||||
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
self.n_kv_heads = args.n_heads if args.n_kv_heads is None else args.n_kv_heads
|
||||||
self.n_local_kv_heads = self.n_kv_heads // self.model_parallel_size
|
self.n_local_kv_heads = self.n_kv_heads // self.world_size
|
||||||
assert self.vocab_size % self.model_parallel_size == 0
|
assert self.vocab_size % self.world_size == 0
|
||||||
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
self.tok_embeddings = VocabParallelEmbedding(args.vocab_size, args.dim, init_method=lambda x: x)
|
||||||
self.pos_embeddings = None
|
self.pos_embeddings = None
|
||||||
# final norm layer (not necessary for post-norm)
|
# final norm layer (not necessary for post-norm)
|
||||||
|
@ -1184,6 +1174,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
text_only_inference: bool = False,
|
text_only_inference: bool = False,
|
||||||
):
|
):
|
||||||
assert self.cache_is_setup, "Please set up cache before calling forward"
|
assert self.cache_is_setup, "Please set up cache before calling forward"
|
||||||
|
self.mask_cache = self.mask_cache.to(h.device)
|
||||||
|
self.freqs_cis = self.freqs_cis.to(h.device)
|
||||||
mask = self.mask_cache.index_select(2, position_ids)
|
mask = self.mask_cache.index_select(2, position_ids)
|
||||||
freqs_cis = self.freqs_cis.index_select(0, position_ids)
|
freqs_cis = self.freqs_cis.index_select(0, position_ids)
|
||||||
|
|
||||||
|
@ -1212,9 +1204,8 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
output = gather_from_tensor_model_parallel_region(output)
|
output = gather_from_tensor_model_parallel_region(output)
|
||||||
return output.float()
|
return output.float()
|
||||||
|
|
||||||
def setup_cache(self, max_batch_size: int, dtype=torch.bfloat16):
|
def setup_cache(self, max_batch_size: int, device: torch.device, dtype=torch.bfloat16):
|
||||||
# Set up the text kv caches
|
# Set up the text kv caches
|
||||||
device = next(self.parameters()).device
|
|
||||||
ones = torch.ones(
|
ones = torch.ones(
|
||||||
(self.max_seq_len, self.max_seq_len),
|
(self.max_seq_len, self.max_seq_len),
|
||||||
dtype=torch.bool,
|
dtype=torch.bool,
|
||||||
|
@ -1265,7 +1256,7 @@ class CrossAttentionTransformerText(torch.nn.Module):
|
||||||
|
|
||||||
return (
|
return (
|
||||||
cross_attention_masks.to(device=text_device, dtype=text_dtype),
|
cross_attention_masks.to(device=text_device, dtype=text_dtype),
|
||||||
full_text_row_masked_out_mask,
|
full_text_row_masked_out_mask.to(device=text_device),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
@ -1284,14 +1275,15 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
max_num_chunks=args.vision_max_num_chunks,
|
max_num_chunks=args.vision_max_num_chunks,
|
||||||
)
|
)
|
||||||
|
|
||||||
def setup_cache(self, max_batch_size: int, dtype: torch.dtype):
|
def setup_cache(self, max_batch_size: int, device: torch.device, dtype: torch.dtype):
|
||||||
self.text_model.setup_cache(max_batch_size, dtype)
|
self.text_model.setup_cache(max_batch_size, device, dtype)
|
||||||
|
|
||||||
def compute_vision_tokens_masks(
|
def compute_vision_tokens_masks(
|
||||||
self,
|
self,
|
||||||
batch_images: List[List[PIL_Image.Image]],
|
batch_images: List[List[PIL_Image.Image]],
|
||||||
batch_masks: List[List[List[int]]],
|
batch_masks: List[List[List[int]]],
|
||||||
total_len: int,
|
total_len: int,
|
||||||
|
device: torch.device,
|
||||||
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
) -> Tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
||||||
skip_vision_encoder = False
|
skip_vision_encoder = False
|
||||||
|
|
||||||
|
@ -1318,6 +1310,7 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
image_res=self.params.vision_chunk_size,
|
image_res=self.params.vision_chunk_size,
|
||||||
max_num_images=max_num_images,
|
max_num_images=max_num_images,
|
||||||
)
|
)
|
||||||
|
stacked_images = stacked_images.to(device=device)
|
||||||
|
|
||||||
if skip_vision_encoder:
|
if skip_vision_encoder:
|
||||||
vision_tokens = torch.zeros(
|
vision_tokens = torch.zeros(
|
||||||
|
@ -1330,7 +1323,7 @@ class CrossAttentionTransformer(torch.nn.Module):
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
vision_tokens = self.vision_model(stacked_images, aspect_ratios)
|
vision_tokens = self.vision_model(stacked_images, aspect_ratios).to(device=device)
|
||||||
|
|
||||||
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
|
bsz, nimg, nchunk, ntok, image_token_dim = tuple(vision_tokens.shape)
|
||||||
xattn_caches = torch.stack(
|
xattn_caches = torch.stack(
|
|
@ -15,7 +15,7 @@ import textwrap
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Any, List, Optional
|
from typing import Any, List, Optional
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import (
|
from llama_stack.apis.inference import (
|
||||||
BuiltinTool,
|
BuiltinTool,
|
||||||
ToolDefinition,
|
ToolDefinition,
|
||||||
ToolParamDefinition,
|
ToolParamDefinition,
|
||||||
|
@ -229,6 +229,11 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
You are an expert in composing functions. You are given a question and a set of possible functions.
|
||||||
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
|
Based on the question, you may or may not need to make one function/tool call to achieve the purpose.
|
||||||
|
|
||||||
|
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
|
If you decide to invoke a function, you SHOULD NOT include any other text in the response. besides the function call in the above format.
|
||||||
|
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
|
||||||
|
|
||||||
|
|
||||||
{{ function_description }}
|
{{ function_description }}
|
||||||
""".strip("\n")
|
""".strip("\n")
|
||||||
)
|
)
|
||||||
|
@ -243,10 +248,6 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
template_str = textwrap.dedent(
|
template_str = textwrap.dedent(
|
||||||
"""
|
"""
|
||||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
|
||||||
For a boolean parameter, be sure to use `True` or `False` (capitalized) for the value.
|
|
||||||
You SHOULD NOT include any other text in the response.
|
|
||||||
|
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
[
|
[
|
||||||
|
@ -279,6 +280,10 @@ class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
{% endif -%}
|
{% endif -%}
|
||||||
{%- endfor %}
|
{%- endfor %}
|
||||||
]
|
]
|
||||||
|
|
||||||
|
You can answer general questions or invoke tools when necessary.
|
||||||
|
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||||
|
|
||||||
"""
|
"""
|
||||||
)
|
)
|
||||||
return PromptTemplate(
|
return PromptTemplate(
|
||||||
|
|
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
5
llama_stack/models/llama/llama3/quantization/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
|
@ -4,12 +4,9 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
# type: ignore
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
||||||
|
|
||||||
import logging
|
|
||||||
import os
|
import os
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Dict, List, Optional, cast
|
||||||
|
|
||||||
import torch
|
import torch
|
||||||
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
from fairscale.nn.model_parallel.initialize import get_model_parallel_rank
|
||||||
|
@ -18,52 +15,53 @@ from fairscale.nn.model_parallel.mappings import reduce_from_model_parallel_regi
|
||||||
from torch import Tensor, nn
|
from torch import Tensor, nn
|
||||||
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
from torchao.quantization.GPTQ import Int8DynActInt4WeightLinear
|
||||||
|
|
||||||
from llama_stack.apis.inference import QuantizationType
|
from ...datatypes import QuantizationMode
|
||||||
from llama_stack.models.llama.datatypes import CheckpointQuantizationFormat
|
from ...quantize_impls import (
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
Fp8ScaledWeights,
|
||||||
|
ffn_swiglu,
|
||||||
from ...llama3.args import ModelArgs
|
load_fp8,
|
||||||
from ...llama3.model import Transformer, TransformerBlock
|
quantize_fp8,
|
||||||
from ..config import MetaReferenceQuantizedInferenceConfig
|
)
|
||||||
|
from ..model import Transformer, TransformerBlock
|
||||||
log = logging.getLogger(__name__)
|
from ..multimodal.model import CrossAttentionTransformer
|
||||||
|
|
||||||
|
|
||||||
def swiglu_wrapper(
|
def swiglu_wrapper(
|
||||||
self,
|
self,
|
||||||
x: Tensor,
|
x: Tensor,
|
||||||
):
|
):
|
||||||
from .fp8_impls import ffn_swiglu
|
|
||||||
|
|
||||||
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
out = ffn_swiglu(x, self.w1.weight, self.w3.weight, self.w2.weight)
|
||||||
return reduce_from_model_parallel_region(out)
|
return reduce_from_model_parallel_region(out)
|
||||||
|
|
||||||
|
|
||||||
|
def convert_to_quantized_model(
|
||||||
|
model: Transformer | CrossAttentionTransformer,
|
||||||
|
checkpoint_dir: str,
|
||||||
|
quantization_mode: Optional[str] = None,
|
||||||
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
|
) -> Transformer | CrossAttentionTransformer:
|
||||||
|
if quantization_mode == QuantizationMode.fp8_mixed:
|
||||||
|
return convert_to_fp8_quantized_model(model, checkpoint_dir, fp8_activation_scale_ub, device)
|
||||||
|
elif quantization_mode == QuantizationMode.int4_mixed:
|
||||||
|
return convert_to_int4_quantized_model(model, checkpoint_dir, device)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unsupported quantization mode: {quantization_mode}")
|
||||||
|
|
||||||
|
|
||||||
def convert_to_fp8_quantized_model(
|
def convert_to_fp8_quantized_model(
|
||||||
model: Transformer,
|
model: Transformer,
|
||||||
config: MetaReferenceQuantizedInferenceConfig,
|
|
||||||
checkpoint_dir: str,
|
checkpoint_dir: str,
|
||||||
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
fp8_activation_scale_ub: Optional[float] = 1200.0,
|
||||||
|
device: Optional[torch.device] = None,
|
||||||
) -> Transformer:
|
) -> Transformer:
|
||||||
if config.quantization.type == QuantizationType.bf16.value:
|
|
||||||
return model
|
|
||||||
|
|
||||||
elif config.quantization.type != QuantizationType.fp8.value:
|
|
||||||
raise ValueError("Only FP8 quantization is supported")
|
|
||||||
|
|
||||||
from .fp8_impls import Fp8ScaledWeights, load_fp8, quantize_fp8
|
|
||||||
|
|
||||||
llama_model = resolve_model(config.model)
|
|
||||||
assert llama_model is not None, f"Model {config.model} not found"
|
|
||||||
|
|
||||||
# Move weights to GPU with quantization
|
# Move weights to GPU with quantization
|
||||||
if llama_model.quantization_format == CheckpointQuantizationFormat.fp8_mixed.value:
|
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
||||||
log.info("Loading fp8 scales...")
|
if os.path.isfile(fp8_scales_path):
|
||||||
fp8_scales_path = os.path.join(checkpoint_dir, f"fp8_scales_{get_model_parallel_rank()}.pt")
|
print("Loading fp8 scales...")
|
||||||
assert os.path.isfile(fp8_scales_path), f"fp8_scales_path not found for rank {get_model_parallel_rank()}"
|
|
||||||
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
fp8_scales = torch.load(fp8_scales_path, weights_only=True)
|
||||||
|
|
||||||
for block in model.layers:
|
for _, block in model.named_modules():
|
||||||
if isinstance(block, TransformerBlock):
|
if isinstance(block, TransformerBlock):
|
||||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
continue
|
continue
|
||||||
|
@ -77,23 +75,23 @@ def convert_to_fp8_quantized_model(
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
)
|
)
|
||||||
else:
|
else:
|
||||||
log.info("Quantizing fp8 weights from bf16...")
|
print("Quantizing fp8 weights from bf16...")
|
||||||
for block in model.layers:
|
for _, block in model.named_modules():
|
||||||
if isinstance(block, TransformerBlock):
|
if isinstance(block, TransformerBlock):
|
||||||
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
if block.layer_id == 0 or block.layer_id == (model.n_layers - 1):
|
||||||
continue
|
continue
|
||||||
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward)
|
block.feed_forward.forward = swiglu_wrapper.__get__(block.feed_forward) # type: ignore
|
||||||
for key in ("w1", "w3", "w2"):
|
for key in ("w1", "w3", "w2"):
|
||||||
param = getattr(block.feed_forward, key)
|
param = getattr(block.feed_forward, key)
|
||||||
param.weight = quantize_fp8(
|
param.weight = quantize_fp8(
|
||||||
param.weight,
|
param.weight,
|
||||||
fp8_activation_scale_ub,
|
fp8_activation_scale_ub,
|
||||||
output_device=torch.device("cuda"),
|
output_device=device,
|
||||||
)
|
)
|
||||||
|
|
||||||
for _, parameter in model.named_parameters():
|
for _, parameter in model.named_parameters():
|
||||||
if not isinstance(parameter, Fp8ScaledWeights):
|
if not isinstance(parameter, Fp8ScaledWeights):
|
||||||
parameter.data = parameter.to(device="cuda")
|
parameter.data = parameter.to(device=device)
|
||||||
return model
|
return model
|
||||||
|
|
||||||
|
|
||||||
|
@ -136,6 +134,8 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||||
precision=precision,
|
precision=precision,
|
||||||
scales_precision=scales_precision,
|
scales_precision=scales_precision,
|
||||||
)
|
)
|
||||||
|
self.lora_scale: Optional[float] = None
|
||||||
|
self.adaptor: Optional[nn.Sequential] = None
|
||||||
if lora_rank is not None:
|
if lora_rank is not None:
|
||||||
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
assert lora_scale is not None, "Please specify lora scale for LoRA."
|
||||||
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
# Low-rank adaptation. See paper for more details: https://arxiv.org/abs/2106.09685
|
||||||
|
@ -143,9 +143,6 @@ class Int8DynActInt4WeightLinearLoRA(Int8DynActInt4WeightLinear):
|
||||||
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
self.adaptor.add_module("A", nn.Linear(in_features, lora_rank, bias=False))
|
||||||
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
self.adaptor.add_module("B", nn.Linear(lora_rank, out_features, bias=False))
|
||||||
self.lora_scale = lora_scale
|
self.lora_scale = lora_scale
|
||||||
else:
|
|
||||||
self.adaptor = None
|
|
||||||
self.lora_scale = None
|
|
||||||
self._register_load_state_dict_pre_hook(self.load_hook)
|
self._register_load_state_dict_pre_hook(self.load_hook)
|
||||||
|
|
||||||
def load_hook(
|
def load_hook(
|
||||||
|
@ -287,16 +284,16 @@ def _prepare_model_int4_weight_int8_dynamic_activation(
|
||||||
|
|
||||||
|
|
||||||
def convert_to_int4_quantized_model(
|
def convert_to_int4_quantized_model(
|
||||||
model: Transformer,
|
model: Transformer | CrossAttentionTransformer,
|
||||||
model_args: ModelArgs,
|
checkpoint_dir: str,
|
||||||
config: MetaReferenceQuantizedInferenceConfig,
|
device: Optional[torch.device] = None,
|
||||||
) -> Transformer:
|
) -> Transformer | CrossAttentionTransformer:
|
||||||
"""Convert the model to int4 quantized model."""
|
"""Convert the model to int4 quantized model."""
|
||||||
|
model_args = model.params
|
||||||
if model_args.quantization_args is None:
|
assert model_args.quantization_args is not None, "Quantization args must be specified."
|
||||||
raise ValueError("'quantization_args' cannot be None. Please specify it.")
|
|
||||||
|
|
||||||
quantization_args = model_args.quantization_args
|
quantization_args = model_args.quantization_args
|
||||||
|
if quantization_args.scheme is None:
|
||||||
|
raise ValueError("Quantization scheme must be specified in 'quantization_args'.")
|
||||||
|
|
||||||
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
if quantization_args.scheme.value != "int4_weight_int8_dynamic_activation":
|
||||||
raise NotImplementedError(
|
raise NotImplementedError(
|
||||||
|
@ -316,5 +313,4 @@ def convert_to_int4_quantized_model(
|
||||||
lora_scale = model_args.lora_args.scale
|
lora_scale = model_args.lora_args.scale
|
||||||
|
|
||||||
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
_prepare_model_int4_weight_int8_dynamic_activation(model, group_size, lora_rank, lora_scale)
|
||||||
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
return cast(Transformer | CrossAttentionTransformer, model.to(device=device))
|
||||||
return model.to(device)
|
|
|
@ -12,8 +12,7 @@
|
||||||
# the top-level of this source tree.
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall
|
from ..datatypes import BuiltinTool, StopReason, ToolCall
|
||||||
|
|
||||||
from .prompt_templates import (
|
from .prompt_templates import (
|
||||||
BuiltinToolGenerator,
|
BuiltinToolGenerator,
|
||||||
JsonCustomToolGenerator,
|
JsonCustomToolGenerator,
|
||||||
|
|
|
@ -4,16 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# This software may be used and distributed in accordance with the terms of the Llama 3 Community License Agreement.
|
|
||||||
|
|
||||||
import os
|
import os
|
||||||
from logging import getLogger
|
from logging import getLogger
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
|
@ -4,19 +4,13 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
import ast
|
|
||||||
import json
|
import json
|
||||||
import re
|
import re
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.models.llama.datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
|
||||||
|
from ..datatypes import BuiltinTool, RecursiveType, ToolCall, ToolPromptFormat
|
||||||
|
|
||||||
logger = get_logger(name=__name__, category="inference")
|
logger = get_logger(name=__name__, category="inference")
|
||||||
|
|
||||||
|
@ -34,80 +28,141 @@ def is_json(s):
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
def is_valid_python_list(input_string):
|
def parse_llama_tool_call_format(input_string):
|
||||||
"""Check if the input string is a valid Python list of function calls"""
|
|
||||||
try:
|
|
||||||
# Try to parse the string
|
|
||||||
tree = ast.parse(input_string)
|
|
||||||
|
|
||||||
# Check if it's a single expression
|
|
||||||
if len(tree.body) != 1 or not isinstance(tree.body[0], ast.Expr):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if the expression is a list
|
|
||||||
expr = tree.body[0].value
|
|
||||||
if not isinstance(expr, ast.List):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if the list is empty
|
|
||||||
if len(expr.elts) == 0:
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if all elements in the list are function calls
|
|
||||||
for element in expr.elts:
|
|
||||||
if not isinstance(element, ast.Call):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if the function call has a valid name
|
|
||||||
if not isinstance(element.func, ast.Name):
|
|
||||||
return False
|
|
||||||
|
|
||||||
# Check if all arguments are keyword arguments
|
|
||||||
if element.args or not all(isinstance(arg, ast.keyword) for arg in element.keywords):
|
|
||||||
return False
|
|
||||||
|
|
||||||
return True
|
|
||||||
|
|
||||||
except SyntaxError:
|
|
||||||
# If parsing fails, it's not a valid Python expression
|
|
||||||
return False
|
|
||||||
|
|
||||||
|
|
||||||
def parse_python_list_for_function_calls(input_string):
|
|
||||||
"""
|
"""
|
||||||
Parse a Python list of function calls and
|
Parse tool calls in the format:
|
||||||
return a list of tuples containing the function name and arguments
|
[func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
||||||
"""
|
|
||||||
# Parse the string into an AST
|
|
||||||
tree = ast.parse(input_string)
|
|
||||||
|
|
||||||
# Ensure the input is a list
|
Returns a list of (function_name, arguments_dict) tuples or None if parsing fails.
|
||||||
if not isinstance(tree.body[0], ast.Expr) or not isinstance(tree.body[0].value, ast.List):
|
"""
|
||||||
raise ValueError("Input must be a list of function calls")
|
# Strip outer brackets and whitespace
|
||||||
|
input_string = input_string.strip()
|
||||||
|
if not (input_string.startswith("[") and input_string.endswith("]")):
|
||||||
|
return None
|
||||||
|
|
||||||
|
content = input_string[1:-1].strip()
|
||||||
|
if not content:
|
||||||
|
return None
|
||||||
|
|
||||||
result = []
|
result = []
|
||||||
|
|
||||||
# Iterate through each function call in the list
|
# State variables for parsing
|
||||||
for node in tree.body[0].value.elts:
|
pos = 0
|
||||||
if isinstance(node, ast.Call):
|
length = len(content)
|
||||||
function_name = node.func.id
|
|
||||||
function_args = {}
|
|
||||||
|
|
||||||
# Extract keyword arguments
|
while pos < length:
|
||||||
for keyword in node.keywords:
|
# Find function name
|
||||||
try:
|
name_end = content.find("(", pos)
|
||||||
function_args[keyword.arg] = ast.literal_eval(keyword.value)
|
if name_end == -1:
|
||||||
except ValueError as e:
|
break
|
||||||
logger.error(
|
|
||||||
f"Error parsing tool call argument '{keyword.arg}': {e}, full input string: '{input_string}'"
|
|
||||||
)
|
|
||||||
raise ValueError(
|
|
||||||
f"Error parsing tool call argument '{keyword.arg}', full input string: '{input_string}'"
|
|
||||||
) from e
|
|
||||||
|
|
||||||
result.append((function_name, function_args))
|
func_name = content[pos:name_end].strip()
|
||||||
|
|
||||||
return result
|
# Find closing parenthesis for this function call
|
||||||
|
paren_level = 1
|
||||||
|
args_start = name_end + 1
|
||||||
|
args_end = args_start
|
||||||
|
|
||||||
|
while args_end < length and paren_level > 0:
|
||||||
|
if content[args_end] == "(":
|
||||||
|
paren_level += 1
|
||||||
|
elif content[args_end] == ")":
|
||||||
|
paren_level -= 1
|
||||||
|
args_end += 1
|
||||||
|
|
||||||
|
if paren_level != 0:
|
||||||
|
# Unmatched parentheses
|
||||||
|
return None
|
||||||
|
|
||||||
|
# Parse arguments
|
||||||
|
args_str = content[args_start : args_end - 1].strip()
|
||||||
|
args_dict = {}
|
||||||
|
|
||||||
|
if args_str:
|
||||||
|
# Split by commas, but respect nested structures
|
||||||
|
parts = []
|
||||||
|
part_start = 0
|
||||||
|
in_quotes = False
|
||||||
|
quote_char = None
|
||||||
|
nested_level = 0
|
||||||
|
|
||||||
|
for i, char in enumerate(args_str):
|
||||||
|
if char in ('"', "'") and (i == 0 or args_str[i - 1] != "\\"):
|
||||||
|
if not in_quotes:
|
||||||
|
in_quotes = True
|
||||||
|
quote_char = char
|
||||||
|
elif char == quote_char:
|
||||||
|
in_quotes = False
|
||||||
|
quote_char = None
|
||||||
|
elif not in_quotes:
|
||||||
|
if char in ("{", "["):
|
||||||
|
nested_level += 1
|
||||||
|
elif char in ("}", "]"):
|
||||||
|
nested_level -= 1
|
||||||
|
elif char == "," and nested_level == 0:
|
||||||
|
parts.append(args_str[part_start:i].strip())
|
||||||
|
part_start = i + 1
|
||||||
|
|
||||||
|
parts.append(args_str[part_start:].strip())
|
||||||
|
|
||||||
|
# Process each key=value pair
|
||||||
|
for part in parts:
|
||||||
|
if "=" in part:
|
||||||
|
key, value = part.split("=", 1)
|
||||||
|
key = key.strip()
|
||||||
|
value = value.strip()
|
||||||
|
|
||||||
|
# Try to convert value to appropriate Python type
|
||||||
|
if (value.startswith('"') and value.endswith('"')) or (
|
||||||
|
value.startswith("'") and value.endswith("'")
|
||||||
|
):
|
||||||
|
# String
|
||||||
|
value = value[1:-1]
|
||||||
|
elif value.lower() == "true":
|
||||||
|
value = True
|
||||||
|
elif value.lower() == "false":
|
||||||
|
value = False
|
||||||
|
elif value.lower() == "none":
|
||||||
|
value = None
|
||||||
|
elif value.startswith("{") and value.endswith("}"):
|
||||||
|
# This is a nested dictionary
|
||||||
|
try:
|
||||||
|
# Try to parse as JSON
|
||||||
|
value = json.loads(value.replace("'", '"'))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Keep as string if parsing fails
|
||||||
|
pass
|
||||||
|
elif value.startswith("[") and value.endswith("]"):
|
||||||
|
# This is a nested list
|
||||||
|
try:
|
||||||
|
# Try to parse as JSON
|
||||||
|
value = json.loads(value.replace("'", '"'))
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
# Keep as string if parsing fails
|
||||||
|
pass
|
||||||
|
else:
|
||||||
|
# Try to convert to number
|
||||||
|
try:
|
||||||
|
if "." in value:
|
||||||
|
value = float(value)
|
||||||
|
else:
|
||||||
|
value = int(value)
|
||||||
|
except ValueError:
|
||||||
|
# Keep as string if not a valid number
|
||||||
|
pass
|
||||||
|
|
||||||
|
args_dict[key] = value
|
||||||
|
|
||||||
|
result.append((func_name, args_dict))
|
||||||
|
|
||||||
|
# Move to the next function call
|
||||||
|
pos = args_end
|
||||||
|
|
||||||
|
# Skip the comma between function calls if present
|
||||||
|
if pos < length and content[pos] == ",":
|
||||||
|
pos += 1
|
||||||
|
|
||||||
|
return result if result else None
|
||||||
|
|
||||||
|
|
||||||
class ToolUtils:
|
class ToolUtils:
|
||||||
|
@ -149,17 +204,19 @@ class ToolUtils:
|
||||||
return None
|
return None
|
||||||
elif is_json(message_body):
|
elif is_json(message_body):
|
||||||
response = json.loads(message_body)
|
response = json.loads(message_body)
|
||||||
if ("type" in response and response["type"] == "function") or ("name" in response):
|
if ("type" in response and response["type"] == "function") or (
|
||||||
|
"name" in response and "parameters" in response
|
||||||
|
):
|
||||||
function_name = response["name"]
|
function_name = response["name"]
|
||||||
args = response["parameters"]
|
args = response["parameters"]
|
||||||
return function_name, args
|
return function_name, args
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
elif is_valid_python_list(message_body):
|
elif function_calls := parse_llama_tool_call_format(message_body):
|
||||||
res = parse_python_list_for_function_calls(message_body)
|
|
||||||
# FIXME: Enable multiple tool calls
|
# FIXME: Enable multiple tool calls
|
||||||
return res[0]
|
return function_calls[0]
|
||||||
else:
|
else:
|
||||||
|
logger.debug(f"Did not parse tool call from message body: {message_body}")
|
||||||
return None
|
return None
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
|
|
|
@ -21,8 +21,7 @@ from llama_stack.models.llama.datatypes import (
|
||||||
ToolCall,
|
ToolCall,
|
||||||
ToolPromptFormat,
|
ToolPromptFormat,
|
||||||
)
|
)
|
||||||
|
from llama_stack.models.llama.prompt_format import (
|
||||||
from ..prompt_format import (
|
|
||||||
# llama3_1_e2e_tool_call_dialog,
|
# llama3_1_e2e_tool_call_dialog,
|
||||||
TextCompletionContent,
|
TextCompletionContent,
|
||||||
UseCase,
|
UseCase,
|
||||||
|
|
|
@ -3,10 +3,3 @@
|
||||||
#
|
#
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
|
@ -4,12 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
import json
|
import json
|
||||||
import textwrap
|
import textwrap
|
||||||
|
|
||||||
|
|
|
@ -4,13 +4,6 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
|
||||||
# All rights reserved.
|
|
||||||
#
|
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
|
||||||
# top-level folder for each specific model found within the models/ directory at
|
|
||||||
# the top-level of this source tree.
|
|
||||||
|
|
||||||
import textwrap
|
import textwrap
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue