mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-07-20 11:47:00 +00:00
Merge branch 'main' into nvidia-e2e-notebook
This commit is contained in:
commit
73275f07b7
123 changed files with 6946 additions and 2220 deletions
6
.coveragerc
Normal file
6
.coveragerc
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
[run]
|
||||||
|
omit =
|
||||||
|
*/tests/*
|
||||||
|
*/llama_stack/providers/*
|
||||||
|
*/llama_stack/templates/*
|
||||||
|
.venv/*
|
26
.github/workflows/install-script-ci.yml
vendored
Normal file
26
.github/workflows/install-script-ci.yml
vendored
Normal file
|
@ -0,0 +1,26 @@
|
||||||
|
name: Installer CI
|
||||||
|
|
||||||
|
on:
|
||||||
|
pull_request:
|
||||||
|
paths:
|
||||||
|
- 'install.sh'
|
||||||
|
push:
|
||||||
|
paths:
|
||||||
|
- 'install.sh'
|
||||||
|
schedule:
|
||||||
|
- cron: '0 2 * * *' # every day at 02:00 UTC
|
||||||
|
|
||||||
|
jobs:
|
||||||
|
lint:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||||
|
- name: Run ShellCheck on install.sh
|
||||||
|
run: shellcheck install.sh
|
||||||
|
smoke-test:
|
||||||
|
needs: lint
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # 4.2.2
|
||||||
|
- name: Run installer end-to-end
|
||||||
|
run: ./install.sh
|
1
.github/workflows/integration-tests.yml
vendored
1
.github/workflows/integration-tests.yml
vendored
|
@ -6,7 +6,6 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- 'distributions/**'
|
|
||||||
- 'llama_stack/**'
|
- 'llama_stack/**'
|
||||||
- 'tests/integration/**'
|
- 'tests/integration/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
|
|
2
.github/workflows/pre-commit.yml
vendored
2
.github/workflows/pre-commit.yml
vendored
|
@ -18,7 +18,7 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
cache: pip
|
cache: pip
|
||||||
|
|
46
.github/workflows/providers-build.yml
vendored
46
.github/workflows/providers-build.yml
vendored
|
@ -51,7 +51,7 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
|
@ -86,15 +86,15 @@ jobs:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@v5
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.10'
|
python-version: '3.10'
|
||||||
|
|
||||||
- name: Install uv
|
- name: Install uv
|
||||||
uses: astral-sh/setup-uv@v5
|
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
with:
|
with:
|
||||||
python-version: "3.10"
|
python-version: "3.10"
|
||||||
|
|
||||||
|
@ -107,3 +107,41 @@ jobs:
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama
|
||||||
|
|
||||||
|
build-custom-container-distribution:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
- name: Build a single provider
|
||||||
|
run: |
|
||||||
|
yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml
|
||||||
|
yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml
|
||||||
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
||||||
|
|
||||||
|
- name: Inspect the container image entrypoint
|
||||||
|
run: |
|
||||||
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
|
echo "Entrypoint: $entrypoint"
|
||||||
|
if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then
|
||||||
|
echo "Entrypoint is not correct"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
30
.github/workflows/test-external-providers.yml
vendored
30
.github/workflows/test-external-providers.yml
vendored
|
@ -5,10 +5,22 @@ on:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'llama_stack/**'
|
||||||
|
- 'tests/integration/**'
|
||||||
|
- 'uv.lock'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/test-external-providers.yml' # This workflow
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-external-providers:
|
test-external-providers:
|
||||||
runs-on: ubuntu-latest
|
runs-on: ubuntu-latest
|
||||||
|
strategy:
|
||||||
|
matrix:
|
||||||
|
image-type: [venv]
|
||||||
|
# We don't do container yet, it's tricky to install a package from the host into the
|
||||||
|
# container and point 'uv pip install' to the correct path...
|
||||||
steps:
|
steps:
|
||||||
- name: Checkout repository
|
- name: Checkout repository
|
||||||
uses: actions/checkout@v4
|
uses: actions/checkout@v4
|
||||||
|
@ -35,17 +47,25 @@ jobs:
|
||||||
uv sync --extra dev --extra test
|
uv sync --extra dev --extra test
|
||||||
uv pip install -e .
|
uv pip install -e .
|
||||||
|
|
||||||
- name: Install Ollama custom provider
|
- name: Apply image type to config file
|
||||||
|
run: |
|
||||||
|
yq -i '.image_type = "${{ matrix.image-type }}"' tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||||
|
cat tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||||
|
|
||||||
|
- name: Setup directory for Ollama custom provider
|
||||||
run: |
|
run: |
|
||||||
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
mkdir -p tests/external-provider/llama-stack-provider-ollama/src/
|
||||||
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
cp -a llama_stack/providers/remote/inference/ollama/ tests/external-provider/llama-stack-provider-ollama/src/llama_stack_provider_ollama
|
||||||
uv pip install tests/external-provider/llama-stack-provider-ollama
|
|
||||||
|
|
||||||
- name: Create provider configuration
|
- name: Create provider configuration
|
||||||
run: |
|
run: |
|
||||||
mkdir -p /tmp/providers.d/remote/inference
|
mkdir -p /tmp/providers.d/remote/inference
|
||||||
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
cp tests/external-provider/llama-stack-provider-ollama/custom_ollama.yaml /tmp/providers.d/remote/inference/custom_ollama.yaml
|
||||||
|
|
||||||
|
- name: Build distro from config file
|
||||||
|
run: |
|
||||||
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config tests/external-provider/llama-stack-provider-ollama/custom-distro.yaml
|
||||||
|
|
||||||
- name: Wait for Ollama to start
|
- name: Wait for Ollama to start
|
||||||
run: |
|
run: |
|
||||||
echo "Waiting for Ollama..."
|
echo "Waiting for Ollama..."
|
||||||
|
@ -62,11 +82,13 @@ jobs:
|
||||||
exit 1
|
exit 1
|
||||||
|
|
||||||
- name: Start Llama Stack server in background
|
- name: Start Llama Stack server in background
|
||||||
|
if: ${{ matrix.image-type }} == 'venv'
|
||||||
env:
|
env:
|
||||||
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
INFERENCE_MODEL: "meta-llama/Llama-3.2-3B-Instruct"
|
||||||
run: |
|
run: |
|
||||||
source .venv/bin/activate
|
source ci-test/bin/activate
|
||||||
nohup uv run llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type venv > server.log 2>&1 &
|
uv run pip list
|
||||||
|
nohup uv run --active llama stack run tests/external-provider/llama-stack-provider-ollama/run.yaml --image-type ${{ matrix.image-type }} > server.log 2>&1 &
|
||||||
|
|
||||||
- name: Wait for Llama Stack server to be ready
|
- name: Wait for Llama Stack server to be ready
|
||||||
run: |
|
run: |
|
||||||
|
|
3
.github/workflows/unit-tests.yml
vendored
3
.github/workflows/unit-tests.yml
vendored
|
@ -6,7 +6,6 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- 'distributions/**'
|
|
||||||
- 'llama_stack/**'
|
- 'llama_stack/**'
|
||||||
- 'tests/unit/**'
|
- 'tests/unit/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
|
@ -34,7 +33,7 @@ jobs:
|
||||||
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
- uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python ${{ matrix.python }}
|
- name: Set up Python ${{ matrix.python }}
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: ${{ matrix.python }}
|
python-version: ${{ matrix.python }}
|
||||||
|
|
||||||
|
|
2
.github/workflows/update-readthedocs.yml
vendored
2
.github/workflows/update-readthedocs.yml
vendored
|
@ -36,7 +36,7 @@ jobs:
|
||||||
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
- name: Set up Python
|
- name: Set up Python
|
||||||
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
uses: actions/setup-python@a26af69be951a213d495a4c3e4e4022e16d87065 # v5.6.0
|
||||||
with:
|
with:
|
||||||
python-version: '3.11'
|
python-version: '3.11'
|
||||||
|
|
||||||
|
|
28
CHANGELOG.md
28
CHANGELOG.md
|
@ -1,5 +1,33 @@
|
||||||
# Changelog
|
# Changelog
|
||||||
|
|
||||||
|
# v0.2.3
|
||||||
|
Published on: 2025-04-25T22:46:21Z
|
||||||
|
|
||||||
|
## Highlights
|
||||||
|
|
||||||
|
* OpenAI compatible inference endpoints and client-SDK support. `client.chat.completions.create()` now works.
|
||||||
|
* significant improvements and functionality added to the nVIDIA distribution
|
||||||
|
* many improvements to the test verification suite.
|
||||||
|
* new inference providers: Ramalama, IBM WatsonX
|
||||||
|
* many improvements to the Playground UI
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
|
# v0.2.2
|
||||||
|
Published on: 2025-04-13T01:19:49Z
|
||||||
|
|
||||||
|
## Main changes
|
||||||
|
|
||||||
|
- Bring Your Own Provider (@leseb) - use out-of-tree provider code to execute the distribution server
|
||||||
|
- OpenAI compatible inference API in progress (@bbrowning)
|
||||||
|
- Provider verifications (@ehhuang)
|
||||||
|
- Many updates and fixes to playground
|
||||||
|
- Several llama4 related fixes
|
||||||
|
|
||||||
|
|
||||||
|
---
|
||||||
|
|
||||||
# v0.2.1
|
# v0.2.1
|
||||||
Published on: 2025-04-05T23:13:00Z
|
Published on: 2025-04-05T23:13:00Z
|
||||||
|
|
||||||
|
|
|
@ -70,6 +70,13 @@ As more providers start supporting Llama 4, you can use them in Llama Stack as w
|
||||||
|
|
||||||
</details>
|
</details>
|
||||||
|
|
||||||
|
### 🚀 One-Line Installer 🚀
|
||||||
|
|
||||||
|
To try Llama Stack locally, run:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
curl -LsSf https://github.com/meta-llama/llama-stack/raw/main/install.sh | sh
|
||||||
|
```
|
||||||
|
|
||||||
### Overview
|
### Overview
|
||||||
|
|
||||||
|
@ -119,6 +126,7 @@ Here is a list of the various API providers and available distributions that can
|
||||||
| OpenAI | Hosted | | ✅ | | | |
|
| OpenAI | Hosted | | ✅ | | | |
|
||||||
| Anthropic | Hosted | | ✅ | | | |
|
| Anthropic | Hosted | | ✅ | | | |
|
||||||
| Gemini | Hosted | | ✅ | | | |
|
| Gemini | Hosted | | ✅ | | | |
|
||||||
|
| watsonx | Hosted | | ✅ | | | |
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
@ -128,7 +136,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
|
||||||
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
||||||
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
||||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
|
||||||
| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) |
|
| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) |
|
||||||
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
|
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
|
||||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
||||||
|
|
29
docs/_static/js/detect_theme.js
vendored
29
docs/_static/js/detect_theme.js
vendored
|
@ -1,9 +1,32 @@
|
||||||
document.addEventListener("DOMContentLoaded", function () {
|
document.addEventListener("DOMContentLoaded", function () {
|
||||||
const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
|
const prefersDark = window.matchMedia("(prefers-color-scheme: dark)").matches;
|
||||||
const htmlElement = document.documentElement;
|
const htmlElement = document.documentElement;
|
||||||
if (prefersDark) {
|
|
||||||
htmlElement.setAttribute("data-theme", "dark");
|
// Check if theme is saved in localStorage
|
||||||
|
const savedTheme = localStorage.getItem("sphinx-rtd-theme");
|
||||||
|
|
||||||
|
if (savedTheme) {
|
||||||
|
// Use the saved theme preference
|
||||||
|
htmlElement.setAttribute("data-theme", savedTheme);
|
||||||
|
document.body.classList.toggle("dark", savedTheme === "dark");
|
||||||
} else {
|
} else {
|
||||||
htmlElement.setAttribute("data-theme", "light");
|
// Fall back to system preference
|
||||||
|
const theme = prefersDark ? "dark" : "light";
|
||||||
|
htmlElement.setAttribute("data-theme", theme);
|
||||||
|
document.body.classList.toggle("dark", theme === "dark");
|
||||||
|
// Save initial preference
|
||||||
|
localStorage.setItem("sphinx-rtd-theme", theme);
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// Listen for theme changes from the existing toggle
|
||||||
|
const observer = new MutationObserver(function(mutations) {
|
||||||
|
mutations.forEach(function(mutation) {
|
||||||
|
if (mutation.attributeName === "data-theme") {
|
||||||
|
const currentTheme = htmlElement.getAttribute("data-theme");
|
||||||
|
localStorage.setItem("sphinx-rtd-theme", currentTheme);
|
||||||
|
}
|
||||||
|
});
|
||||||
|
});
|
||||||
|
|
||||||
|
observer.observe(htmlElement, { attributes: true });
|
||||||
});
|
});
|
||||||
|
|
22
docs/_static/llama-stack-spec.html
vendored
22
docs/_static/llama-stack-spec.html
vendored
|
@ -5221,17 +5221,25 @@
|
||||||
"default": 10
|
"default": 10
|
||||||
},
|
},
|
||||||
"model": {
|
"model": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"description": "The model identifier to use for the agent"
|
||||||
},
|
},
|
||||||
"instructions": {
|
"instructions": {
|
||||||
"type": "string"
|
"type": "string",
|
||||||
|
"description": "The system instructions for the agent"
|
||||||
|
},
|
||||||
|
"name": {
|
||||||
|
"type": "string",
|
||||||
|
"description": "Optional name for the agent, used in telemetry and identification"
|
||||||
},
|
},
|
||||||
"enable_session_persistence": {
|
"enable_session_persistence": {
|
||||||
"type": "boolean",
|
"type": "boolean",
|
||||||
"default": false
|
"default": false,
|
||||||
|
"description": "Optional flag indicating whether session data has to be persisted"
|
||||||
},
|
},
|
||||||
"response_format": {
|
"response_format": {
|
||||||
"$ref": "#/components/schemas/ResponseFormat"
|
"$ref": "#/components/schemas/ResponseFormat",
|
||||||
|
"description": "Optional response format configuration"
|
||||||
}
|
}
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
|
@ -5239,7 +5247,8 @@
|
||||||
"model",
|
"model",
|
||||||
"instructions"
|
"instructions"
|
||||||
],
|
],
|
||||||
"title": "AgentConfig"
|
"title": "AgentConfig",
|
||||||
|
"description": "Configuration for an agent."
|
||||||
},
|
},
|
||||||
"AgentTool": {
|
"AgentTool": {
|
||||||
"oneOf": [
|
"oneOf": [
|
||||||
|
@ -8891,8 +8900,7 @@
|
||||||
},
|
},
|
||||||
"additionalProperties": false,
|
"additionalProperties": false,
|
||||||
"required": [
|
"required": [
|
||||||
"role",
|
"role"
|
||||||
"content"
|
|
||||||
],
|
],
|
||||||
"title": "OpenAIAssistantMessageParam",
|
"title": "OpenAIAssistantMessageParam",
|
||||||
"description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."
|
"description": "A message containing the model's (assistant) response in an OpenAI-compatible chat completion request."
|
||||||
|
|
12
docs/_static/llama-stack-spec.yaml
vendored
12
docs/_static/llama-stack-spec.yaml
vendored
|
@ -3686,18 +3686,29 @@ components:
|
||||||
default: 10
|
default: 10
|
||||||
model:
|
model:
|
||||||
type: string
|
type: string
|
||||||
|
description: >-
|
||||||
|
The model identifier to use for the agent
|
||||||
instructions:
|
instructions:
|
||||||
type: string
|
type: string
|
||||||
|
description: The system instructions for the agent
|
||||||
|
name:
|
||||||
|
type: string
|
||||||
|
description: >-
|
||||||
|
Optional name for the agent, used in telemetry and identification
|
||||||
enable_session_persistence:
|
enable_session_persistence:
|
||||||
type: boolean
|
type: boolean
|
||||||
default: false
|
default: false
|
||||||
|
description: >-
|
||||||
|
Optional flag indicating whether session data has to be persisted
|
||||||
response_format:
|
response_format:
|
||||||
$ref: '#/components/schemas/ResponseFormat'
|
$ref: '#/components/schemas/ResponseFormat'
|
||||||
|
description: Optional response format configuration
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- model
|
- model
|
||||||
- instructions
|
- instructions
|
||||||
title: AgentConfig
|
title: AgentConfig
|
||||||
|
description: Configuration for an agent.
|
||||||
AgentTool:
|
AgentTool:
|
||||||
oneOf:
|
oneOf:
|
||||||
- type: string
|
- type: string
|
||||||
|
@ -6097,7 +6108,6 @@ components:
|
||||||
additionalProperties: false
|
additionalProperties: false
|
||||||
required:
|
required:
|
||||||
- role
|
- role
|
||||||
- content
|
|
||||||
title: OpenAIAssistantMessageParam
|
title: OpenAIAssistantMessageParam
|
||||||
description: >-
|
description: >-
|
||||||
A message containing the model's (assistant) response in an OpenAI-compatible
|
A message containing the model's (assistant) response in an OpenAI-compatible
|
||||||
|
|
|
@ -68,7 +68,8 @@ chunks_response = client.vector_io.query(
|
||||||
### Using the RAG Tool
|
### Using the RAG Tool
|
||||||
|
|
||||||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
||||||
and automatically chunks them into smaller pieces.
|
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
|
||||||
|
[appendix](#more-ragdocument-examples).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import RAGDocument
|
from llama_stack_client import RAGDocument
|
||||||
|
@ -178,3 +179,38 @@ for vector_db_id in client.vector_dbs.list():
|
||||||
print(f"Unregistering vector database: {vector_db_id.identifier}")
|
print(f"Unregistering vector database: {vector_db_id.identifier}")
|
||||||
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
|
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Appendix
|
||||||
|
|
||||||
|
#### More RAGDocument Examples
|
||||||
|
```python
|
||||||
|
from llama_stack_client import RAGDocument
|
||||||
|
import base64
|
||||||
|
|
||||||
|
RAGDocument(document_id="num-0", content={"uri": "file://path/to/file"})
|
||||||
|
RAGDocument(document_id="num-1", content="plain text")
|
||||||
|
RAGDocument(
|
||||||
|
document_id="num-2",
|
||||||
|
content={
|
||||||
|
"type": "text",
|
||||||
|
"text": "plain text input",
|
||||||
|
}, # for inputs that should be treated as text explicitly
|
||||||
|
)
|
||||||
|
RAGDocument(
|
||||||
|
document_id="num-3",
|
||||||
|
content={
|
||||||
|
"type": "image",
|
||||||
|
"image": {"url": {"uri": "https://mywebsite.com/image.jpg"}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
B64_ENCODED_IMAGE = base64.b64encode(
|
||||||
|
requests.get(
|
||||||
|
"https://raw.githubusercontent.com/meta-llama/llama-stack/refs/heads/main/docs/_static/llama-stack.png"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
RAGDocuemnt(
|
||||||
|
document_id="num-4",
|
||||||
|
content={"type": "image", "image": {"data": B64_ENCODED_IMAGE}},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
for more strongly typed interaction use the typed dicts found [here](https://github.com/meta-llama/llama-stack-client-python/blob/38cd91c9e396f2be0bec1ee96a19771582ba6f17/src/llama_stack_client/types/shared_params/document.py).
|
||||||
|
|
|
@ -41,7 +41,7 @@ client.toolgroups.register(
|
||||||
|
|
||||||
The tool requires an API key which can be provided either in the configuration or through the request header `X-LlamaStack-Provider-Data`. The format of the header is `{"<provider_name>_api_key": <your api key>}`.
|
The tool requires an API key which can be provided either in the configuration or through the request header `X-LlamaStack-Provider-Data`. The format of the header is `{"<provider_name>_api_key": <your api key>}`.
|
||||||
|
|
||||||
|
> **NOTE:** When using Tavily Search and Bing Search, the inference output will still display "Brave Search." This is because Llama models have been trained with Brave Search as a built-in tool. Tavily and bing is just being used in lieu of Brave search.
|
||||||
|
|
||||||
#### Code Interpreter
|
#### Code Interpreter
|
||||||
|
|
||||||
|
@ -214,3 +214,69 @@ response = agent.create_turn(
|
||||||
session_id=session_id,
|
session_id=session_id,
|
||||||
)
|
)
|
||||||
```
|
```
|
||||||
|
## Simple Example 2: Using an Agent with the Web Search Tool
|
||||||
|
1. Start by registering a Tavily API key at [Tavily](https://tavily.com/).
|
||||||
|
2. [Optional] Provide the API key directly to the Llama Stack server
|
||||||
|
```bash
|
||||||
|
export TAVILY_SEARCH_API_KEY="your key"
|
||||||
|
```
|
||||||
|
```bash
|
||||||
|
--env TAVILY_SEARCH_API_KEY=${TAVILY_SEARCH_API_KEY}
|
||||||
|
```
|
||||||
|
3. Run the following script.
|
||||||
|
```python
|
||||||
|
from llama_stack_client.lib.agents.agent import Agent
|
||||||
|
from llama_stack_client.types.agent_create_params import AgentConfig
|
||||||
|
from llama_stack_client.lib.agents.event_logger import EventLogger
|
||||||
|
from llama_stack_client import LlamaStackClient
|
||||||
|
|
||||||
|
client = LlamaStackClient(
|
||||||
|
base_url=f"http://localhost:8321",
|
||||||
|
provider_data={
|
||||||
|
"tavily_search_api_key": "your_TAVILY_SEARCH_API_KEY"
|
||||||
|
}, # Set this from the client side. No need to provide it if it has already been configured on the Llama Stack server.
|
||||||
|
)
|
||||||
|
|
||||||
|
agent = Agent(
|
||||||
|
client,
|
||||||
|
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
instructions=(
|
||||||
|
"You are a web search assistant, must use websearch tool to look up the most current and precise information available. "
|
||||||
|
),
|
||||||
|
tools=["builtin::websearch"],
|
||||||
|
)
|
||||||
|
|
||||||
|
session_id = agent.create_session("websearch-session")
|
||||||
|
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{"role": "user", "content": "How did the USA perform in the last Olympics?"}
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
for log in EventLogger().log(response):
|
||||||
|
log.print()
|
||||||
|
```
|
||||||
|
|
||||||
|
## Simple Example3: Using an Agent with the WolframAlpha Tool
|
||||||
|
1. Start by registering for a WolframAlpha API key at [WolframAlpha Developer Portal](https://developer.wolframalpha.com/access).
|
||||||
|
2. Provide the API key either when starting the Llama Stack server:
|
||||||
|
```bash
|
||||||
|
--env WOLFRAM_ALPHA_API_KEY=${WOLFRAM_ALPHA_API_KEY}
|
||||||
|
```
|
||||||
|
or from the client side:
|
||||||
|
```python
|
||||||
|
client = LlamaStackClient(
|
||||||
|
base_url="http://localhost:8321",
|
||||||
|
provider_data={"wolfram_alpha_api_key": wolfram_api_key},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
3. Configure the tools in the Agent by setting `tools=["builtin::wolfram_alpha"]`.
|
||||||
|
4. Example user query:
|
||||||
|
```python
|
||||||
|
response = agent.create_turn(
|
||||||
|
messages=[{"role": "user", "content": "Solve x^2 + 2x + 1 = 0 using WolframAlpha"}],
|
||||||
|
session_id=session_id,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
```
|
||||||
|
|
|
@ -109,8 +109,6 @@ llama stack build --list-templates
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| nvidia | Use NVIDIA NIM for running LLM inference |
|
| nvidia | Use NVIDIA NIM for running LLM inference |
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| meta-reference-quantized-gpu | Use Meta Reference with fp8, int4 quantization for running LLM inference |
|
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
|
||||||
| cerebras | Use Cerebras for running LLM inference |
|
| cerebras | Use Cerebras for running LLM inference |
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| ollama | Use (an external) Ollama server for running LLM inference |
|
| ollama | Use (an external) Ollama server for running LLM inference |
|
||||||
|
@ -176,7 +174,11 @@ distribution_spec:
|
||||||
safety: inline::llama-guard
|
safety: inline::llama-guard
|
||||||
agents: inline::meta-reference
|
agents: inline::meta-reference
|
||||||
telemetry: inline::meta-reference
|
telemetry: inline::meta-reference
|
||||||
|
image_name: ollama
|
||||||
image_type: conda
|
image_type: conda
|
||||||
|
|
||||||
|
# If some providers are external, you can specify the path to the implementation
|
||||||
|
external_providers_dir: /etc/llama-stack/providers.d
|
||||||
```
|
```
|
||||||
|
|
||||||
```
|
```
|
||||||
|
@ -184,6 +186,57 @@ llama stack build --config llama_stack/templates/ollama/build.yaml
|
||||||
```
|
```
|
||||||
:::
|
:::
|
||||||
|
|
||||||
|
:::{tab-item} Building with External Providers
|
||||||
|
|
||||||
|
Llama Stack supports external providers that live outside of the main codebase. This allows you to create and maintain your own providers independently or use community-provided providers.
|
||||||
|
|
||||||
|
To build a distribution with external providers, you need to:
|
||||||
|
|
||||||
|
1. Configure the `external_providers_dir` in your build configuration file:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
# Example my-external-stack.yaml with external providers
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Custom distro for CI tests
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::custom_ollama
|
||||||
|
# Add more providers as needed
|
||||||
|
image_type: container
|
||||||
|
image_name: ci-test
|
||||||
|
# Path to external provider implementations
|
||||||
|
external_providers_dir: /etc/llama-stack/providers.d
|
||||||
|
```
|
||||||
|
|
||||||
|
Here's an example for a custom Ollama provider:
|
||||||
|
|
||||||
|
```yaml
|
||||||
|
adapter:
|
||||||
|
adapter_type: custom_ollama
|
||||||
|
pip_packages:
|
||||||
|
- ollama
|
||||||
|
- aiohttp
|
||||||
|
- llama-stack-provider-ollama # This is the provider package
|
||||||
|
config_class: llama_stack_ollama_provider.config.OllamaImplConfig
|
||||||
|
module: llama_stack_ollama_provider
|
||||||
|
api_dependencies: []
|
||||||
|
optional_api_dependencies: []
|
||||||
|
```
|
||||||
|
|
||||||
|
The `pip_packages` section lists the Python packages required by the provider, as well as the
|
||||||
|
provider package itself. The package must be available on PyPI or can be provided from a local
|
||||||
|
directory or a git repository (git must be installed on the build environment).
|
||||||
|
|
||||||
|
2. Build your distribution using the config file:
|
||||||
|
|
||||||
|
```
|
||||||
|
llama stack build --config my-external-stack.yaml
|
||||||
|
```
|
||||||
|
|
||||||
|
For more information on external providers, including directory structure, provider types, and implementation requirements, see the [External Providers documentation](../providers/external.md).
|
||||||
|
:::
|
||||||
|
|
||||||
:::{tab-item} Building Container
|
:::{tab-item} Building Container
|
||||||
|
|
||||||
```{admonition} Podman Alternative
|
```{admonition} Podman Alternative
|
||||||
|
|
|
@ -24,7 +24,7 @@ The key files in the app are `ExampleLlamaStackLocalInference.kt`, `ExampleLlama
|
||||||
Add the following dependency in your `build.gradle.kts` file:
|
Add the following dependency in your `build.gradle.kts` file:
|
||||||
```
|
```
|
||||||
dependencies {
|
dependencies {
|
||||||
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.1.4.2")
|
implementation("com.llama.llamastack:llama-stack-client-kotlin:0.2.2")
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
|
This will download jar files in your gradle cache in a directory like `~/.gradle/caches/modules-2/files-2.1/com.llama.llamastack/`
|
||||||
|
@ -37,11 +37,7 @@ For local inferencing, it is required to include the ExecuTorch library into you
|
||||||
|
|
||||||
Include the ExecuTorch library by:
|
Include the ExecuTorch library by:
|
||||||
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
|
1. Download the `download-prebuilt-et-lib.sh` script file from the [llama-stack-client-kotlin-client-local](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/llama-stack-client-kotlin-client-local/download-prebuilt-et-lib.sh) directory to your local machine.
|
||||||
2. Move the script to the top level of your Android app where the app directory resides:
|
2. Move the script to the top level of your Android app where the `app` directory resides.
|
||||||
<p align="center">
|
|
||||||
<img src="https://github.com/meta-llama/llama-stack-client-kotlin/blob/latest-release/doc/img/example_android_app_directory.png" style="width:300px">
|
|
||||||
</p>
|
|
||||||
|
|
||||||
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate.
|
3. Run `sh download-prebuilt-et-lib.sh` to create an `app/libs` directory and download the `executorch.aar` in that path. This generates an ExecuTorch library for the XNNPACK delegate.
|
||||||
4. Add the `executorch.aar` dependency in your `build.gradle.kts` file:
|
4. Add the `executorch.aar` dependency in your `build.gradle.kts` file:
|
||||||
```
|
```
|
||||||
|
@ -52,6 +48,8 @@ dependencies {
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
See other dependencies for the local RAG in Android app [README](https://github.com/meta-llama/llama-stack-client-kotlin/tree/latest-release/examples/android_app#quick-start).
|
||||||
|
|
||||||
## Llama Stack APIs in Your Android App
|
## Llama Stack APIs in Your Android App
|
||||||
Breaking down the demo app, this section will show the core pieces that are used to initialize and run inference with Llama Stack using the Kotlin library.
|
Breaking down the demo app, this section will show the core pieces that are used to initialize and run inference with Llama Stack using the Kotlin library.
|
||||||
|
|
||||||
|
@ -60,7 +58,7 @@ Start a Llama Stack server on localhost. Here is an example of how you can do th
|
||||||
```
|
```
|
||||||
conda create -n stack-fireworks python=3.10
|
conda create -n stack-fireworks python=3.10
|
||||||
conda activate stack-fireworks
|
conda activate stack-fireworks
|
||||||
pip install --no-cache llama-stack==0.1.4
|
pip install --no-cache llama-stack==0.2.2
|
||||||
llama stack build --template fireworks --image-type conda
|
llama stack build --template fireworks --image-type conda
|
||||||
export FIREWORKS_API_KEY=<SOME_KEY>
|
export FIREWORKS_API_KEY=<SOME_KEY>
|
||||||
llama stack run fireworks --port 5050
|
llama stack run fireworks --port 5050
|
||||||
|
|
|
@ -1,89 +0,0 @@
|
||||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
|
||||||
# NVIDIA Distribution
|
|
||||||
|
|
||||||
The `llamastack/distribution-nvidia` distribution consists of the following provider configurations.
|
|
||||||
|
|
||||||
| API | Provider(s) |
|
|
||||||
|-----|-------------|
|
|
||||||
| agents | `inline::meta-reference` |
|
|
||||||
| datasetio | `inline::localfs`, `remote::nvidia` |
|
|
||||||
| eval | `remote::nvidia` |
|
|
||||||
| inference | `remote::nvidia` |
|
|
||||||
| post_training | `remote::nvidia` |
|
|
||||||
| safety | `remote::nvidia` |
|
|
||||||
| scoring | `inline::basic` |
|
|
||||||
| telemetry | `inline::meta-reference` |
|
|
||||||
| tool_runtime | `inline::rag-runtime` |
|
|
||||||
| vector_io | `inline::faiss` |
|
|
||||||
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
The following environment variables can be configured:
|
|
||||||
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
|
||||||
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
|
||||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
|
||||||
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
|
||||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
|
||||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
|
||||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
|
||||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
|
||||||
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
|
|
||||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
|
||||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
|
||||||
|
|
||||||
### Models
|
|
||||||
|
|
||||||
The following models are available by default:
|
|
||||||
|
|
||||||
- `meta/llama3-8b-instruct (aliases: meta-llama/Llama-3-8B-Instruct)`
|
|
||||||
- `meta/llama3-70b-instruct (aliases: meta-llama/Llama-3-70B-Instruct)`
|
|
||||||
- `meta/llama-3.1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
|
||||||
- `meta/llama-3.1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
|
||||||
- `meta/llama-3.1-405b-instruct (aliases: meta-llama/Llama-3.1-405B-Instruct-FP8)`
|
|
||||||
- `meta/llama-3.2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
|
||||||
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
|
||||||
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
|
||||||
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
|
||||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
|
||||||
- `nvidia/nv-embedqa-e5-v5 `
|
|
||||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
|
||||||
- `snowflake/arctic-embed-l `
|
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
|
||||||
|
|
||||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
|
|
||||||
|
|
||||||
|
|
||||||
## Running Llama Stack with NVIDIA
|
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
|
||||||
|
|
||||||
### Via Docker
|
|
||||||
|
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
LLAMA_STACK_PORT=8321
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ./run.yaml:/root/my-run.yaml \
|
|
||||||
llamastack/distribution-nvidia \
|
|
||||||
--yaml-config /root/my-run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
|
||||||
```
|
|
||||||
|
|
||||||
### Via Conda
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template nvidia --image-type conda
|
|
||||||
llama stack run ./run.yaml \
|
|
||||||
--port 8321 \
|
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
|
||||||
```
|
|
88
docs/source/distributions/remote_hosted_distro/watsonx.md
Normal file
88
docs/source/distributions/remote_hosted_distro/watsonx.md
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||||
|
# watsonx Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
|
The `llamastack/distribution-watsonx` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
| API | Provider(s) |
|
||||||
|
|-----|-------------|
|
||||||
|
| agents | `inline::meta-reference` |
|
||||||
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
|
| eval | `inline::meta-reference` |
|
||||||
|
| inference | `remote::watsonx` |
|
||||||
|
| safety | `inline::llama-guard` |
|
||||||
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
|
- `WATSONX_API_KEY`: watsonx API Key (default: ``)
|
||||||
|
- `WATSONX_PROJECT_ID`: watsonx Project ID (default: ``)
|
||||||
|
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
- `meta-llama/llama-3-3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
|
- `meta-llama/llama-2-13b-chat (aliases: meta-llama/Llama-2-13b)`
|
||||||
|
- `meta-llama/llama-3-1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
|
- `meta-llama/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
|
||||||
|
|
||||||
|
|
||||||
|
## Running Llama Stack with watsonx
|
||||||
|
|
||||||
|
You can do this via Conda (build code), venv or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-watsonx \
|
||||||
|
--yaml-config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template watsonx --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
|
||||||
|
```
|
|
@ -81,6 +81,7 @@ LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
@ -94,6 +95,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
|
|
@ -1,123 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
|
||||||
# Meta Reference Quantized Distribution
|
|
||||||
|
|
||||||
```{toctree}
|
|
||||||
:maxdepth: 2
|
|
||||||
:hidden:
|
|
||||||
|
|
||||||
self
|
|
||||||
```
|
|
||||||
|
|
||||||
The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations:
|
|
||||||
|
|
||||||
| API | Provider(s) |
|
|
||||||
|-----|-------------|
|
|
||||||
| agents | `inline::meta-reference` |
|
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
|
||||||
| eval | `inline::meta-reference` |
|
|
||||||
| inference | `inline::meta-reference-quantized` |
|
|
||||||
| safety | `inline::llama-guard` |
|
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
|
||||||
| telemetry | `inline::meta-reference` |
|
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
|
||||||
|
|
||||||
|
|
||||||
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
|
||||||
|
|
||||||
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
The following environment variables can be configured:
|
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
|
||||||
|
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
|
||||||
|
|
||||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
|
||||||
|
|
||||||
```
|
|
||||||
$ llama model list --downloaded
|
|
||||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
|
||||||
┃ Model ┃ Size ┃ Modified Time ┃
|
|
||||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
|
||||||
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
|
||||||
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running the Distribution
|
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
|
||||||
|
|
||||||
### Via Docker
|
|
||||||
|
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
LLAMA_STACK_PORT=8321
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
|
||||||
|
|
||||||
### Via Conda
|
|
||||||
|
|
||||||
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template meta-reference-quantized-gpu --image-type conda
|
|
||||||
llama stack run distributions/meta-reference-quantized-gpu/run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack run distributions/meta-reference-quantized-gpu/run-with-safety.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
|
|
@ -22,10 +22,8 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
|
||||||
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||||
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
|
||||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||||
|
@ -48,20 +46,91 @@ The following models are available by default:
|
||||||
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
- `meta/llama-3.2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
- `meta/llama-3.2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
- `meta/llama-3.2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
|
- `meta/llama-3.3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
- `nvidia/llama-3.2-nv-embedqa-1b-v2 `
|
||||||
- `nvidia/nv-embedqa-e5-v5 `
|
- `nvidia/nv-embedqa-e5-v5 `
|
||||||
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
- `nvidia/nv-embedqa-mistral-7b-v2 `
|
||||||
- `snowflake/arctic-embed-l `
|
- `snowflake/arctic-embed-l `
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
## Prerequisites
|
||||||
|
### NVIDIA API Keys
|
||||||
|
|
||||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
|
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
|
||||||
|
|
||||||
|
### Deploy NeMo Microservices Platform
|
||||||
|
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
|
||||||
|
|
||||||
|
## Supported Services
|
||||||
|
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
|
||||||
|
|
||||||
|
### Inference: NVIDIA NIM
|
||||||
|
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
|
||||||
|
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
|
||||||
|
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
|
||||||
|
|
||||||
|
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
|
||||||
|
|
||||||
|
### Datasetio API: NeMo Data Store
|
||||||
|
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||||
|
|
||||||
|
See the [NVIDIA Datasetio docs](/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
|
||||||
|
|
||||||
|
### Eval API: NeMo Evaluator
|
||||||
|
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the [NVIDIA Eval docs](/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
|
||||||
|
|
||||||
|
### Post-Training API: NeMo Customizer
|
||||||
|
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the [NVIDIA Post-Training docs](/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
|
||||||
|
|
||||||
|
### Safety API: NeMo Guardrails
|
||||||
|
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the NVIDIA Safety docs for supported features and example usage.
|
||||||
|
|
||||||
|
## Deploying models
|
||||||
|
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||||
|
|
||||||
|
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
|
||||||
|
```sh
|
||||||
|
# URL to NeMo NIM Proxy service
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
|
||||||
|
-H 'accept: application/json' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"name": "llama-3.2-1b-instruct",
|
||||||
|
"namespace": "meta",
|
||||||
|
"config": {
|
||||||
|
"model": "meta/llama-3.2-1b-instruct",
|
||||||
|
"nim_deployment": {
|
||||||
|
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
|
||||||
|
"image_tag": "1.8.3",
|
||||||
|
"pvc_size": "25Gi",
|
||||||
|
"gpu": 1,
|
||||||
|
"additional_envs": {
|
||||||
|
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
|
||||||
|
|
||||||
|
You can also remove a deployed NIM to free up GPU resources, if needed.
|
||||||
|
```sh
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
|
||||||
|
```
|
||||||
|
|
||||||
## Running Llama Stack with NVIDIA
|
## Running Llama Stack with NVIDIA
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||||
|
|
||||||
### Via Docker
|
### Via Docker
|
||||||
|
|
||||||
|
@ -83,9 +152,23 @@ docker run \
|
||||||
### Via Conda
|
### Via Conda
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
llama stack build --template nvidia --image-type conda
|
llama stack build --template nvidia --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port 8321 \
|
--port 8321 \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via venv
|
||||||
|
|
||||||
|
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
|
llama stack build --template nvidia --image-type venv
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port 8321 \
|
||||||
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
```
|
```
|
||||||
|
|
|
@ -41,10 +41,10 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Setting up vLLM server
|
## Setting up vLLM server
|
||||||
|
|
||||||
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
|
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
|
||||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||||
that we only use GPUs here for demonstration purposes.
|
that we only use GPUs here for demonstration purposes. Note that if you run into issues, you can include the environment variable `--env VLLM_DEBUG_LOG_API_SERVER_RESPONSE=true` (available in vLLM v0.8.3 and above) in the `docker run` command to enable log response from API server for debugging.
|
||||||
|
|
||||||
### Setting up vLLM server on AMD GPU
|
### Setting up vLLM server on AMD GPU
|
||||||
|
|
||||||
|
@ -162,6 +162,55 @@ docker run \
|
||||||
--port $SAFETY_PORT
|
--port $SAFETY_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Setting up vLLM server on Intel GPU
|
||||||
|
|
||||||
|
Refer to [vLLM Documentation for XPU](https://docs.vllm.ai/en/v0.8.2/getting_started/installation/gpu.html?device=xpu) to get a vLLM endpoint. In addition to vLLM side setup which guides towards installing vLLM from sources orself-building vLLM Docker container, Intel provides prebuilt vLLM container to use on systems with Intel GPUs supported by PyTorch XPU backend:
|
||||||
|
- [intel/vllm](https://hub.docker.com/r/intel/vllm)
|
||||||
|
|
||||||
|
Here is a sample script to start a vLLM server locally via Docker using Intel provided container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export INFERENCE_PORT=8000
|
||||||
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
export ZE_AFFINITY_MASK=0
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||||
|
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||||
|
--ipc=host \
|
||||||
|
intel/vllm:xpu \
|
||||||
|
--gpu-memory-utilization 0.7 \
|
||||||
|
--model $INFERENCE_MODEL \
|
||||||
|
--port $INFERENCE_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SAFETY_PORT=8081
|
||||||
|
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
export ZE_AFFINITY_MASK=1
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||||
|
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||||
|
--ipc=host \
|
||||||
|
intel/vllm:xpu \
|
||||||
|
--gpu-memory-utilization 0.7 \
|
||||||
|
--model $SAFETY_MODEL \
|
||||||
|
--port $SAFETY_PORT
|
||||||
|
```
|
||||||
|
|
||||||
## Running Llama Stack
|
## Running Llama Stack
|
||||||
|
|
||||||
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
|
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
|
@ -50,9 +50,10 @@ Llama Stack supports two types of external providers:
|
||||||
|
|
||||||
Here's a list of known external providers that you can use with Llama Stack:
|
Here's a list of known external providers that you can use with Llama Stack:
|
||||||
|
|
||||||
| Type | Name | Description | Repository |
|
| Name | Description | API | Type | Repository |
|
||||||
|------|------|-------------|------------|
|
|------|-------------|-----|------|------------|
|
||||||
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||||
|
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
||||||
|
|
||||||
### Remote Provider Specification
|
### Remote Provider Specification
|
||||||
|
|
||||||
|
|
|
@ -389,5 +389,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -256,5 +256,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -301,5 +301,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -200,5 +200,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -355,5 +355,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -398,5 +398,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,5 +132,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.10"
|
"version": "3.11.10"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,5 +188,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
86
install.sh
Executable file
86
install.sh
Executable file
|
@ -0,0 +1,86 @@
|
||||||
|
#!/usr/bin/env bash
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
set -Eeuo pipefail
|
||||||
|
|
||||||
|
PORT=8321
|
||||||
|
OLLAMA_PORT=11434
|
||||||
|
MODEL_ALIAS="llama3.2:3b"
|
||||||
|
SERVER_IMAGE="llamastack/distribution-ollama:0.2.2"
|
||||||
|
WAIT_TIMEOUT=300
|
||||||
|
|
||||||
|
log(){ printf "\e[1;32m%s\e[0m\n" "$*"; }
|
||||||
|
die(){ printf "\e[1;31m❌ %s\e[0m\n" "$*" >&2; exit 1; }
|
||||||
|
|
||||||
|
if command -v docker &> /dev/null; then
|
||||||
|
ENGINE="docker"
|
||||||
|
HOST_DNS="host.docker.internal"
|
||||||
|
elif command -v podman &> /dev/null; then
|
||||||
|
ENGINE="podman"
|
||||||
|
HOST_DNS="host.containers.internal"
|
||||||
|
else
|
||||||
|
die "Docker or Podman is required. Install Docker: https://docs.docker.com/get-docker/ or Podman: https://podman.io/getting-started/installation"
|
||||||
|
fi
|
||||||
|
|
||||||
|
# Clean up any leftovers from earlier runs
|
||||||
|
for name in ollama-server llama-stack; do
|
||||||
|
ids=$($ENGINE ps -aq --filter "name=^${name}$")
|
||||||
|
if [ -n "$ids" ]; then
|
||||||
|
log "⚠️ Found existing container(s) for '${name}', removing..."
|
||||||
|
$ENGINE rm -f "$ids"
|
||||||
|
fi
|
||||||
|
done
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# 1. Ollama
|
||||||
|
###############################################################################
|
||||||
|
log "🦙 Starting Ollama…"
|
||||||
|
$ENGINE run -d --name ollama-server \
|
||||||
|
-p "${OLLAMA_PORT}:11434" \
|
||||||
|
ollama/ollama > /dev/null 2>&1
|
||||||
|
|
||||||
|
log "⏳ Waiting for Ollama daemon…"
|
||||||
|
if ! timeout "$WAIT_TIMEOUT" bash -c \
|
||||||
|
"until curl -fsS http://localhost:${OLLAMA_PORT}/ 2>/dev/null | grep -q 'Ollama'; do sleep 1; done"; then
|
||||||
|
log "❌ Ollama daemon did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||||
|
$ENGINE logs ollama-server --tail=200
|
||||||
|
die "Ollama startup failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
log "📦 Ensuring model is pulled: ${MODEL_ALIAS}..."
|
||||||
|
$ENGINE exec ollama-server ollama pull "${MODEL_ALIAS}" > /dev/null 2>&1
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# 2. Llama‑Stack
|
||||||
|
###############################################################################
|
||||||
|
log "🦙📦 Starting Llama‑Stack…"
|
||||||
|
$ENGINE run -d --name llama-stack \
|
||||||
|
-p "${PORT}:${PORT}" \
|
||||||
|
--add-host="${HOST_DNS}:host-gateway" \
|
||||||
|
"${SERVER_IMAGE}" \
|
||||||
|
--port "${PORT}" \
|
||||||
|
--env INFERENCE_MODEL="${MODEL_ALIAS}" \
|
||||||
|
--env OLLAMA_URL="http://${HOST_DNS}:${OLLAMA_PORT}" > /dev/null 2>&1
|
||||||
|
|
||||||
|
log "⏳ Waiting for Llama-Stack API…"
|
||||||
|
if ! timeout "$WAIT_TIMEOUT" bash -c \
|
||||||
|
"until curl -fsS http://localhost:${PORT}/v1/health 2>/dev/null | grep -q 'OK'; do sleep 1; done"; then
|
||||||
|
log "❌ Llama-Stack did not become ready in ${WAIT_TIMEOUT}s; dumping container logs:"
|
||||||
|
$ENGINE logs llama-stack --tail=200
|
||||||
|
die "Llama-Stack startup failed"
|
||||||
|
fi
|
||||||
|
|
||||||
|
###############################################################################
|
||||||
|
# Done
|
||||||
|
###############################################################################
|
||||||
|
log ""
|
||||||
|
log "🎉 Llama‑Stack is ready!"
|
||||||
|
log "👉 API endpoint: http://localhost:${PORT}"
|
||||||
|
log "📖 Documentation: https://llama-stack.readthedocs.io/en/latest/references/index.html"
|
||||||
|
log "💻 To access the llama‑stack CLI, exec into the container:"
|
||||||
|
log " $ENGINE exec -ti llama-stack bash"
|
||||||
|
log ""
|
|
@ -225,8 +225,18 @@ class AgentConfigCommon(BaseModel):
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
class AgentConfig(AgentConfigCommon):
|
class AgentConfig(AgentConfigCommon):
|
||||||
|
"""Configuration for an agent.
|
||||||
|
|
||||||
|
:param model: The model identifier to use for the agent
|
||||||
|
:param instructions: The system instructions for the agent
|
||||||
|
:param name: Optional name for the agent, used in telemetry and identification
|
||||||
|
:param enable_session_persistence: Optional flag indicating whether session data has to be persisted
|
||||||
|
:param response_format: Optional response format configuration
|
||||||
|
"""
|
||||||
|
|
||||||
model: str
|
model: str
|
||||||
instructions: str
|
instructions: str
|
||||||
|
name: Optional[str] = None
|
||||||
enable_session_persistence: Optional[bool] = False
|
enable_session_persistence: Optional[bool] = False
|
||||||
response_format: Optional[ResponseFormat] = None
|
response_format: Optional[ResponseFormat] = None
|
||||||
|
|
||||||
|
|
|
@ -526,9 +526,9 @@ class OpenAIAssistantMessageParam(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
||||||
role: Literal["assistant"] = "assistant"
|
role: Literal["assistant"] = "assistant"
|
||||||
content: OpenAIChatCompletionMessageContent
|
content: Optional[OpenAIChatCompletionMessageContent] = None
|
||||||
name: Optional[str] = None
|
name: Optional[str] = None
|
||||||
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = Field(default_factory=list)
|
tool_calls: Optional[List[OpenAIChatCompletionToolCall]] = None
|
||||||
|
|
||||||
|
|
||||||
@json_schema_type
|
@json_schema_type
|
||||||
|
|
|
@ -136,12 +136,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
image_type = prompt(
|
image_type = prompt(
|
||||||
f"> Enter the image type you want your Llama Stack to be built as ({' or '.join(e.value for e in ImageType)}): ",
|
"> Enter the image type you want your Llama Stack to be built as (use <TAB> to see options): ",
|
||||||
|
completer=WordCompleter([e.value for e in ImageType]),
|
||||||
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
lambda x: x in [e.value for e in ImageType],
|
lambda x: x in [e.value for e in ImageType],
|
||||||
error_message=f"Invalid image type, please enter {' or '.join(e.value for e in ImageType)}",
|
error_message="Invalid image type. Use <TAB> to see options",
|
||||||
),
|
),
|
||||||
default=ImageType.CONDA.value,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_type == ImageType.CONDA.value:
|
if image_type == ImageType.CONDA.value:
|
||||||
|
@ -210,16 +211,9 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
|
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not args.image_name:
|
|
||||||
cprint(
|
|
||||||
"Please specify --image-name when building a container from a config file",
|
|
||||||
color="red",
|
|
||||||
)
|
|
||||||
sys.exit(1)
|
|
||||||
|
|
||||||
if args.print_deps_only:
|
if args.print_deps_only:
|
||||||
print(f"# Dependencies for {args.template or args.config or image_name}")
|
print(f"# Dependencies for {args.template or args.config or image_name}")
|
||||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||||
normal_deps += SERVER_DEPENDENCIES
|
normal_deps += SERVER_DEPENDENCIES
|
||||||
print(f"uv pip install {' '.join(normal_deps)}")
|
print(f"uv pip install {' '.join(normal_deps)}")
|
||||||
for special_dep in special_deps:
|
for special_dep in special_deps:
|
||||||
|
@ -235,10 +229,14 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
except (Exception, RuntimeError) as exc:
|
except (Exception, RuntimeError) as exc:
|
||||||
|
import traceback
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Error building stack: {exc}",
|
f"Error building stack: {exc}",
|
||||||
color="red",
|
color="red",
|
||||||
)
|
)
|
||||||
|
cprint("Stack trace:", color="red")
|
||||||
|
traceback.print_exc()
|
||||||
sys.exit(1)
|
sys.exit(1)
|
||||||
if run_config is None:
|
if run_config is None:
|
||||||
cprint(
|
cprint(
|
||||||
|
@ -270,9 +268,10 @@ def _generate_run_config(
|
||||||
image_name=image_name,
|
image_name=image_name,
|
||||||
apis=apis,
|
apis=apis,
|
||||||
providers={},
|
providers={},
|
||||||
|
external_providers_dir=build_config.external_providers_dir if build_config.external_providers_dir else None,
|
||||||
)
|
)
|
||||||
# build providers dict
|
# build providers dict
|
||||||
provider_registry = get_provider_registry()
|
provider_registry = get_provider_registry(build_config)
|
||||||
for api in apis:
|
for api in apis:
|
||||||
run_config.providers[api] = []
|
run_config.providers[api] = []
|
||||||
provider_types = build_config.distribution_spec.providers[api]
|
provider_types = build_config.distribution_spec.providers[api]
|
||||||
|
@ -286,8 +285,22 @@ def _generate_run_config(
|
||||||
if p.deprecation_error:
|
if p.deprecation_error:
|
||||||
raise InvalidProviderError(p.deprecation_error)
|
raise InvalidProviderError(p.deprecation_error)
|
||||||
|
|
||||||
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
try:
|
||||||
if hasattr(config_type, "sample_run_config"):
|
config_type = instantiate_class_type(provider_registry[Api(api)][provider_type].config_class)
|
||||||
|
except ModuleNotFoundError:
|
||||||
|
# HACK ALERT:
|
||||||
|
# This code executes after building is done, the import cannot work since the
|
||||||
|
# package is either available in the venv or container - not available on the host.
|
||||||
|
# TODO: use a "is_external" flag in ProviderSpec to check if the provider is
|
||||||
|
# external
|
||||||
|
cprint(
|
||||||
|
f"Failed to import provider {provider_type} for API {api} - assuming it's external, skipping",
|
||||||
|
color="yellow",
|
||||||
|
)
|
||||||
|
# Set config_type to None to avoid UnboundLocalError
|
||||||
|
config_type = None
|
||||||
|
|
||||||
|
if config_type is not None and hasattr(config_type, "sample_run_config"):
|
||||||
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
config = config_type.sample_run_config(__distro_dir__=f"~/.llama/distributions/{image_name}")
|
||||||
else:
|
else:
|
||||||
config = {}
|
config = {}
|
||||||
|
@ -305,11 +318,15 @@ def _generate_run_config(
|
||||||
to_write = json.loads(run_config.model_dump_json())
|
to_write = json.loads(run_config.model_dump_json())
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
# this path is only invoked when no template is provided
|
# Only print this message for non-container builds since it will be displayed before the
|
||||||
cprint(
|
# container is built
|
||||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||||
color="green",
|
# makes sense to display this message
|
||||||
)
|
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||||
|
cprint(
|
||||||
|
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||||
|
color="green",
|
||||||
|
)
|
||||||
return run_config_file
|
return run_config_file
|
||||||
|
|
||||||
|
|
||||||
|
@ -319,6 +336,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
template_name: Optional[str] = None,
|
template_name: Optional[str] = None,
|
||||||
config_path: Optional[str] = None,
|
config_path: Optional[str] = None,
|
||||||
) -> str:
|
) -> str:
|
||||||
|
image_name = image_name or build_config.image_name
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
if template_name:
|
if template_name:
|
||||||
image_name = f"distribution-{template_name}"
|
image_name = f"distribution-{template_name}"
|
||||||
|
@ -342,6 +360,13 @@ def _run_stack_build_command_from_build_config(
|
||||||
build_file_path = build_dir / f"{image_name}-build.yaml"
|
build_file_path = build_dir / f"{image_name}-build.yaml"
|
||||||
|
|
||||||
os.makedirs(build_dir, exist_ok=True)
|
os.makedirs(build_dir, exist_ok=True)
|
||||||
|
run_config_file = None
|
||||||
|
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||||
|
# Only do this if we're building a container image and we're not using a template
|
||||||
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||||
|
cprint("Generating run.yaml file", color="green")
|
||||||
|
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
with open(build_file_path, "w") as f:
|
with open(build_file_path, "w") as f:
|
||||||
to_write = json.loads(build_config.model_dump_json())
|
to_write = json.loads(build_config.model_dump_json())
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
@ -350,7 +375,8 @@ def _run_stack_build_command_from_build_config(
|
||||||
build_config,
|
build_config,
|
||||||
build_file_path,
|
build_file_path,
|
||||||
image_name,
|
image_name,
|
||||||
template_or_config=template_name or config_path,
|
template_or_config=template_name or config_path or str(build_file_path),
|
||||||
|
run_config=run_config_file,
|
||||||
)
|
)
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
raise RuntimeError(f"Failed to build image {image_name}")
|
raise RuntimeError(f"Failed to build image {image_name}")
|
||||||
|
|
|
@ -7,16 +7,16 @@
|
||||||
import importlib.resources
|
import importlib.resources
|
||||||
import logging
|
import logging
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
from termcolor import cprint
|
from termcolor import cprint
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import BuildConfig, Provider
|
from llama_stack.distribution.datatypes import BuildConfig
|
||||||
from llama_stack.distribution.distribution import get_provider_registry
|
from llama_stack.distribution.distribution import get_provider_registry
|
||||||
from llama_stack.distribution.utils.exec import run_command
|
from llama_stack.distribution.utils.exec import run_command
|
||||||
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||||
from llama_stack.providers.datatypes import Api
|
from llama_stack.providers.datatypes import Api
|
||||||
|
from llama_stack.templates.template import DistributionTemplate
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
@ -37,19 +37,24 @@ class ApiInput(BaseModel):
|
||||||
|
|
||||||
|
|
||||||
def get_provider_dependencies(
|
def get_provider_dependencies(
|
||||||
config_providers: Dict[str, List[Provider]],
|
config: BuildConfig | DistributionTemplate,
|
||||||
) -> tuple[list[str], list[str]]:
|
) -> tuple[list[str], list[str]]:
|
||||||
"""Get normal and special dependencies from provider configuration."""
|
"""Get normal and special dependencies from provider configuration."""
|
||||||
all_providers = get_provider_registry()
|
# Extract providers based on config type
|
||||||
|
if isinstance(config, DistributionTemplate):
|
||||||
|
providers = config.providers
|
||||||
|
elif isinstance(config, BuildConfig):
|
||||||
|
providers = config.distribution_spec.providers
|
||||||
deps = []
|
deps = []
|
||||||
|
registry = get_provider_registry(config)
|
||||||
|
|
||||||
for api_str, provider_or_providers in config_providers.items():
|
for api_str, provider_or_providers in providers.items():
|
||||||
providers_for_api = all_providers[Api(api_str)]
|
providers_for_api = registry[Api(api_str)]
|
||||||
|
|
||||||
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
providers = provider_or_providers if isinstance(provider_or_providers, list) else [provider_or_providers]
|
||||||
|
|
||||||
for provider in providers:
|
for provider in providers:
|
||||||
# Providers from BuildConfig and RunConfig are subtly different – not great
|
# Providers from BuildConfig and RunConfig are subtly different – not great
|
||||||
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
provider_type = provider if isinstance(provider, str) else provider.provider_type
|
||||||
|
|
||||||
if provider_type not in providers_for_api:
|
if provider_type not in providers_for_api:
|
||||||
|
@ -71,8 +76,8 @@ def get_provider_dependencies(
|
||||||
return list(set(normal_deps)), list(set(special_deps))
|
return list(set(normal_deps)), list(set(special_deps))
|
||||||
|
|
||||||
|
|
||||||
def print_pip_install_help(providers: Dict[str, List[Provider]]):
|
def print_pip_install_help(config: BuildConfig):
|
||||||
normal_deps, special_deps = get_provider_dependencies(providers)
|
normal_deps, special_deps = get_provider_dependencies(config)
|
||||||
|
|
||||||
cprint(
|
cprint(
|
||||||
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
f"Please install needed dependencies using the following commands:\n\nuv pip install {' '.join(normal_deps)}",
|
||||||
|
@ -88,10 +93,11 @@ def build_image(
|
||||||
build_file_path: Path,
|
build_file_path: Path,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
template_or_config: str,
|
template_or_config: str,
|
||||||
|
run_config: str | None = None,
|
||||||
):
|
):
|
||||||
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
||||||
|
|
||||||
normal_deps, special_deps = get_provider_dependencies(build_config.distribution_spec.providers)
|
normal_deps, special_deps = get_provider_dependencies(build_config)
|
||||||
normal_deps += SERVER_DEPENDENCIES
|
normal_deps += SERVER_DEPENDENCIES
|
||||||
|
|
||||||
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value:
|
||||||
|
@ -103,6 +109,11 @@ def build_image(
|
||||||
container_base,
|
container_base,
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# When building from a config file (not a template), include the run config path in the
|
||||||
|
# build arguments
|
||||||
|
if run_config is not None:
|
||||||
|
args.append(run_config)
|
||||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||||
args = [
|
args = [
|
||||||
|
|
|
@ -19,12 +19,16 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
# mounting is not supported by docker buildx, so we use COPY instead
|
# mounting is not supported by docker buildx, so we use COPY instead
|
||||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||||
|
|
||||||
|
# Path to the run.yaml file in the container
|
||||||
|
RUN_CONFIG_PATH=/app/run.yaml
|
||||||
|
|
||||||
|
BUILD_CONTEXT_DIR=$(pwd)
|
||||||
|
|
||||||
if [ "$#" -lt 4 ]; then
|
if [ "$#" -lt 4 ]; then
|
||||||
# This only works for templates
|
# This only works for templates
|
||||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
template_or_config="$1"
|
template_or_config="$1"
|
||||||
|
@ -35,8 +39,27 @@ container_base="$1"
|
||||||
shift
|
shift
|
||||||
pip_dependencies="$1"
|
pip_dependencies="$1"
|
||||||
shift
|
shift
|
||||||
special_pip_deps="${1:-}"
|
|
||||||
|
|
||||||
|
# Handle optional arguments
|
||||||
|
run_config=""
|
||||||
|
special_pip_deps=""
|
||||||
|
|
||||||
|
# Check if there are more arguments
|
||||||
|
# The logics is becoming cumbersom, we should refactor it if we can do better
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
# Check if the argument ends with .yaml
|
||||||
|
if [[ "$1" == *.yaml ]]; then
|
||||||
|
run_config="$1"
|
||||||
|
shift
|
||||||
|
# If there's another argument after .yaml, it must be special_pip_deps
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
special_pip_deps="$1"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
# If it's not .yaml, it must be special_pip_deps
|
||||||
|
special_pip_deps="$1"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
|
@ -72,9 +95,13 @@ if [[ $container_base == *"registry.access.redhat.com/ubi9"* ]]; then
|
||||||
FROM $container_base
|
FROM $container_base
|
||||||
WORKDIR /app
|
WORKDIR /app
|
||||||
|
|
||||||
RUN dnf -y update && dnf install -y iputils net-tools wget \
|
# We install the Python 3.11 dev headers and build tools so that any
|
||||||
|
# C‑extension wheels (e.g. polyleven, faiss‑cpu) can compile successfully.
|
||||||
|
|
||||||
|
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||||
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||||
python3.11-setuptools && ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
python3.11-setuptools python3.11-devel gcc make && \
|
||||||
|
ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||||
|
|
||||||
ENV UV_SYSTEM_PYTHON=1
|
ENV UV_SYSTEM_PYTHON=1
|
||||||
RUN pip install uv
|
RUN pip install uv
|
||||||
|
@ -86,7 +113,7 @@ WORKDIR /app
|
||||||
|
|
||||||
RUN apt-get update && apt-get install -y \
|
RUN apt-get update && apt-get install -y \
|
||||||
iputils-ping net-tools iproute2 dnsutils telnet \
|
iputils-ping net-tools iproute2 dnsutils telnet \
|
||||||
curl wget telnet \
|
curl wget telnet git\
|
||||||
procps psmisc lsof \
|
procps psmisc lsof \
|
||||||
traceroute \
|
traceroute \
|
||||||
bubblewrap \
|
bubblewrap \
|
||||||
|
@ -115,6 +142,45 @@ EOF
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Function to get Python command
|
||||||
|
get_python_cmd() {
|
||||||
|
if is_command_available python; then
|
||||||
|
echo "python"
|
||||||
|
elif is_command_available python3; then
|
||||||
|
echo "python3"
|
||||||
|
else
|
||||||
|
echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ -n "$run_config" ]; then
|
||||||
|
# Copy the run config to the build context since it's an absolute path
|
||||||
|
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
add_to_container << EOF
|
||||||
|
COPY run.yaml $RUN_CONFIG_PATH
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Parse the run.yaml configuration to identify external provider directories
|
||||||
|
# If external providers are specified, copy their directory to the container
|
||||||
|
# and update the configuration to reference the new container path
|
||||||
|
python_cmd=$(get_python_cmd)
|
||||||
|
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
||||||
|
if [ -n "$external_providers_dir" ]; then
|
||||||
|
echo "Copying external providers directory: $external_providers_dir"
|
||||||
|
add_to_container << EOF
|
||||||
|
COPY $external_providers_dir /app/providers.d
|
||||||
|
EOF
|
||||||
|
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
|
||||||
|
if [ "$(uname)" = "Darwin" ]; then
|
||||||
|
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
||||||
|
else
|
||||||
|
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
stack_mount="/app/llama-stack-source"
|
stack_mount="/app/llama-stack-source"
|
||||||
client_mount="/app/llama-stack-client-source"
|
client_mount="/app/llama-stack-client-source"
|
||||||
|
|
||||||
|
@ -174,15 +240,16 @@ fi
|
||||||
RUN pip uninstall -y uv
|
RUN pip uninstall -y uv
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
# If a run config is provided, we use the --config flag
|
||||||
if [[ "$template_or_config" != *.yaml ]]; then
|
if [[ -n "$run_config" ]]; then
|
||||||
|
add_to_container << EOF
|
||||||
|
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"]
|
||||||
|
EOF
|
||||||
|
# If a template is provided (not a yaml file), we use the --template flag
|
||||||
|
elif [[ "$template_or_config" != *.yaml ]]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
|
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
|
||||||
EOF
|
EOF
|
||||||
else
|
|
||||||
add_to_container << EOF
|
|
||||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
|
||||||
EOF
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Add other require item commands genearic to all containers
|
# Add other require item commands genearic to all containers
|
||||||
|
@ -254,9 +321,10 @@ $CONTAINER_BINARY build \
|
||||||
"${CLI_ARGS[@]}" \
|
"${CLI_ARGS[@]}" \
|
||||||
-t "$image_tag" \
|
-t "$image_tag" \
|
||||||
-f "$TEMP_DIR/Containerfile" \
|
-f "$TEMP_DIR/Containerfile" \
|
||||||
"."
|
"$BUILD_CONTEXT_DIR"
|
||||||
|
|
||||||
# clean up tmp/configs
|
# clean up tmp/configs
|
||||||
|
rm -f "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
set +x
|
set +x
|
||||||
|
|
||||||
echo "Success!"
|
echo "Success!"
|
||||||
|
|
|
@ -326,3 +326,12 @@ class BuildConfig(BaseModel):
|
||||||
default="conda",
|
default="conda",
|
||||||
description="Type of package to build (conda | container | venv)",
|
description="Type of package to build (conda | container | venv)",
|
||||||
)
|
)
|
||||||
|
image_name: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Name of the distribution to build",
|
||||||
|
)
|
||||||
|
external_providers_dir: Optional[str] = Field(
|
||||||
|
default=None,
|
||||||
|
description="Path to directory containing external provider implementations. The providers packages will be resolved from this directory. "
|
||||||
|
"pip_packages MUST contain the provider package name.",
|
||||||
|
)
|
||||||
|
|
|
@ -12,7 +12,6 @@ from typing import Any, Dict, List
|
||||||
import yaml
|
import yaml
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import StackRunConfig
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
from llama_stack.providers.datatypes import (
|
from llama_stack.providers.datatypes import (
|
||||||
AdapterSpec,
|
AdapterSpec,
|
||||||
|
@ -97,7 +96,9 @@ def _load_inline_provider_spec(spec_data: Dict[str, Any], api: Api, provider_nam
|
||||||
return spec
|
return spec
|
||||||
|
|
||||||
|
|
||||||
def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dict[str, ProviderSpec]]:
|
def get_provider_registry(
|
||||||
|
config=None,
|
||||||
|
) -> Dict[Api, Dict[str, ProviderSpec]]:
|
||||||
"""Get the provider registry, optionally including external providers.
|
"""Get the provider registry, optionally including external providers.
|
||||||
|
|
||||||
This function loads both built-in providers and external providers from YAML files.
|
This function loads both built-in providers and external providers from YAML files.
|
||||||
|
@ -122,7 +123,7 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
|
||||||
llama-guard.yaml
|
llama-guard.yaml
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
config: Optional StackRunConfig containing the external providers directory path
|
config: Optional object containing the external providers directory path
|
||||||
|
|
||||||
Returns:
|
Returns:
|
||||||
A dictionary mapping APIs to their available providers
|
A dictionary mapping APIs to their available providers
|
||||||
|
@ -142,7 +143,8 @@ def get_provider_registry(config: StackRunConfig | None = None) -> Dict[Api, Dic
|
||||||
except ImportError as e:
|
except ImportError as e:
|
||||||
logger.warning(f"Failed to import module {name}: {e}")
|
logger.warning(f"Failed to import module {name}: {e}")
|
||||||
|
|
||||||
if config and config.external_providers_dir:
|
# Check if config has the external_providers_dir attribute
|
||||||
|
if config and hasattr(config, "external_providers_dir") and config.external_providers_dir:
|
||||||
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
external_providers_dir = os.path.abspath(config.external_providers_dir)
|
||||||
if not os.path.exists(external_providers_dir):
|
if not os.path.exists(external_providers_dir):
|
||||||
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
raise FileNotFoundError(f"External providers directory not found: {external_providers_dir}")
|
||||||
|
|
|
@ -8,6 +8,11 @@ import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||||
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||||
|
from pydantic import Field, TypeAdapter
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -526,7 +531,7 @@ class InferenceRouter(Inference):
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[OpenAIMessageParam],
|
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
functions: Optional[List[Dict[str, Any]]] = None,
|
functions: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
@ -558,6 +563,16 @@ class InferenceRouter(Inference):
|
||||||
if model_obj.model_type == ModelType.embedding:
|
if model_obj.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||||
|
|
||||||
|
# Use the OpenAI client for a bit of extra input validation without
|
||||||
|
# exposing the OpenAI client itself as part of our API surface
|
||||||
|
if tool_choice:
|
||||||
|
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||||
|
if tools is None:
|
||||||
|
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model=model_obj.identifier,
|
model=model_obj.identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -22,6 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
@ -92,7 +93,7 @@ async def global_exception_handler(request: Request, exc: Exception):
|
||||||
|
|
||||||
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidationError]:
|
||||||
if isinstance(exc, ValidationError):
|
if isinstance(exc, ValidationError):
|
||||||
exc = RequestValidationError(exc.raw_errors)
|
exc = RequestValidationError(exc.errors())
|
||||||
|
|
||||||
if isinstance(exc, RequestValidationError):
|
if isinstance(exc, RequestValidationError):
|
||||||
return HTTPException(
|
return HTTPException(
|
||||||
|
@ -110,6 +111,8 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
||||||
)
|
)
|
||||||
elif isinstance(exc, ValueError):
|
elif isinstance(exc, ValueError):
|
||||||
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
||||||
|
elif isinstance(exc, BadRequestError):
|
||||||
|
return HTTPException(status_code=400, detail=str(exc))
|
||||||
elif isinstance(exc, PermissionError):
|
elif isinstance(exc, PermissionError):
|
||||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||||
elif isinstance(exc, TimeoutError):
|
elif isinstance(exc, TimeoutError):
|
||||||
|
@ -162,14 +165,17 @@ async def maybe_await(value):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen_coroutine):
|
||||||
|
event_gen = None
|
||||||
try:
|
try:
|
||||||
async for item in await event_gen:
|
event_gen = await event_gen_coroutine
|
||||||
|
async for item in event_gen:
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Generator cancelled")
|
logger.info("Generator cancelled")
|
||||||
await event_gen.aclose()
|
if event_gen:
|
||||||
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error in sse_generator")
|
logger.exception("Error in sse_generator")
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
|
@ -455,6 +461,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
||||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
|
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
|
|
|
@ -24,6 +24,13 @@ def rag_chat_page():
|
||||||
def should_disable_input():
|
def should_disable_input():
|
||||||
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
||||||
|
|
||||||
|
def log_message(message):
|
||||||
|
with st.chat_message(message["role"]):
|
||||||
|
if "tool_output" in message and message["tool_output"]:
|
||||||
|
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
|
||||||
|
st.write(message["tool_output"])
|
||||||
|
st.markdown(message["content"])
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
# File/Directory Upload Section
|
# File/Directory Upload Section
|
||||||
st.subheader("Upload Documents", divider=True)
|
st.subheader("Upload Documents", divider=True)
|
||||||
|
@ -146,8 +153,7 @@ def rag_chat_page():
|
||||||
|
|
||||||
# Display chat history
|
# Display chat history
|
||||||
for message in st.session_state.displayed_messages:
|
for message in st.session_state.displayed_messages:
|
||||||
with st.chat_message(message["role"]):
|
log_message(message)
|
||||||
st.markdown(message["content"])
|
|
||||||
|
|
||||||
if temperature > 0.0:
|
if temperature > 0.0:
|
||||||
strategy = {
|
strategy = {
|
||||||
|
@ -201,7 +207,7 @@ def rag_chat_page():
|
||||||
|
|
||||||
# Display assistant response
|
# Display assistant response
|
||||||
with st.chat_message("assistant"):
|
with st.chat_message("assistant"):
|
||||||
retrieval_message_placeholder = st.empty()
|
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
|
||||||
message_placeholder = st.empty()
|
message_placeholder = st.empty()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
retrieval_response = ""
|
retrieval_response = ""
|
||||||
|
@ -209,14 +215,16 @@ def rag_chat_page():
|
||||||
log.print()
|
log.print()
|
||||||
if log.role == "tool_execution":
|
if log.role == "tool_execution":
|
||||||
retrieval_response += log.content.replace("====", "").strip()
|
retrieval_response += log.content.replace("====", "").strip()
|
||||||
retrieval_message_placeholder.info(retrieval_response)
|
retrieval_message_placeholder.write(retrieval_response)
|
||||||
else:
|
else:
|
||||||
full_response += log.content
|
full_response += log.content
|
||||||
message_placeholder.markdown(full_response + "▌")
|
message_placeholder.markdown(full_response + "▌")
|
||||||
message_placeholder.markdown(full_response)
|
message_placeholder.markdown(full_response)
|
||||||
|
|
||||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||||
st.session_state.displayed_messages.append({"role": "assistant", "content": full_response})
|
st.session_state.displayed_messages.append(
|
||||||
|
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
|
||||||
|
)
|
||||||
|
|
||||||
def direct_process_prompt(prompt):
|
def direct_process_prompt(prompt):
|
||||||
# Add the system prompt in the beginning of the conversation
|
# Add the system prompt in the beginning of the conversation
|
||||||
|
@ -230,15 +238,14 @@ def rag_chat_page():
|
||||||
prompt_context = rag_response.content
|
prompt_context = rag_response.content
|
||||||
|
|
||||||
with st.chat_message("assistant"):
|
with st.chat_message("assistant"):
|
||||||
|
with st.expander(label="Retrieval Output", expanded=False):
|
||||||
|
st.write(prompt_context)
|
||||||
|
|
||||||
retrieval_message_placeholder = st.empty()
|
retrieval_message_placeholder = st.empty()
|
||||||
message_placeholder = st.empty()
|
message_placeholder = st.empty()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
retrieval_response = ""
|
retrieval_response = ""
|
||||||
|
|
||||||
# Display the retrieved content
|
|
||||||
retrieval_response += str(prompt_context)
|
|
||||||
retrieval_message_placeholder.info(retrieval_response)
|
|
||||||
|
|
||||||
# Construct the extended prompt
|
# Construct the extended prompt
|
||||||
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
||||||
|
|
||||||
|
|
|
@ -4,14 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from llama_stack_client import Agent
|
from llama_stack_client import Agent
|
||||||
|
from llama_stack_client.lib.agents.react.agent import ReActAgent
|
||||||
|
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
|
||||||
|
|
||||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(enum.Enum):
|
||||||
|
REGULAR = "Regular"
|
||||||
|
REACT = "ReAct"
|
||||||
|
|
||||||
|
|
||||||
def tool_chat_page():
|
def tool_chat_page():
|
||||||
st.title("🛠 Tools")
|
st.title("🛠 Tools")
|
||||||
|
|
||||||
|
@ -23,50 +32,117 @@ def tool_chat_page():
|
||||||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||||
|
selected_vector_dbs = []
|
||||||
|
|
||||||
def reset_agent():
|
def reset_agent():
|
||||||
st.session_state.clear()
|
st.session_state.clear()
|
||||||
st.cache_resource.clear()
|
st.cache_resource.clear()
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
|
st.title("Configuration")
|
||||||
st.subheader("Model")
|
st.subheader("Model")
|
||||||
model = st.selectbox(label="models", options=model_list, on_change=reset_agent)
|
model = st.selectbox(label="Model", options=model_list, on_change=reset_agent, label_visibility="collapsed")
|
||||||
|
|
||||||
|
st.subheader("Available ToolGroups")
|
||||||
|
|
||||||
st.subheader("Builtin Tools")
|
|
||||||
toolgroup_selection = st.pills(
|
toolgroup_selection = st.pills(
|
||||||
label="Available ToolGroups", options=builtin_tools_list, selection_mode="multi", on_change=reset_agent
|
label="Built-in tools",
|
||||||
|
options=builtin_tools_list,
|
||||||
|
selection_mode="multi",
|
||||||
|
on_change=reset_agent,
|
||||||
|
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
||||||
|
help="List of built-in tools from your llama stack server.",
|
||||||
)
|
)
|
||||||
|
|
||||||
st.subheader("MCP Servers")
|
if "builtin::rag" in toolgroup_selection:
|
||||||
|
vector_dbs = llama_stack_api.client.vector_dbs.list() or []
|
||||||
|
if not vector_dbs:
|
||||||
|
st.info("No vector databases available for selection.")
|
||||||
|
vector_dbs = [vector_db.identifier for vector_db in vector_dbs]
|
||||||
|
selected_vector_dbs = st.multiselect(
|
||||||
|
label="Select Document Collections to use in RAG queries",
|
||||||
|
options=vector_dbs,
|
||||||
|
on_change=reset_agent,
|
||||||
|
)
|
||||||
|
|
||||||
mcp_selection = st.pills(
|
mcp_selection = st.pills(
|
||||||
label="Available MCP Servers", options=mcp_tools_list, selection_mode="multi", on_change=reset_agent
|
label="MCP Servers",
|
||||||
|
options=mcp_tools_list,
|
||||||
|
selection_mode="multi",
|
||||||
|
on_change=reset_agent,
|
||||||
|
format_func=lambda tool: "".join(tool.split("::")[1:]),
|
||||||
|
help="List of MCP servers registered to your llama stack server.",
|
||||||
)
|
)
|
||||||
|
|
||||||
toolgroup_selection.extend(mcp_selection)
|
toolgroup_selection.extend(mcp_selection)
|
||||||
|
|
||||||
active_tool_list = []
|
grouped_tools = {}
|
||||||
for toolgroup_id in toolgroup_selection:
|
total_tools = 0
|
||||||
active_tool_list.extend(
|
|
||||||
[
|
|
||||||
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
|
||||||
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
st.subheader(f"Active Tools: 🛠 {len(active_tool_list)}")
|
for toolgroup_id in toolgroup_selection:
|
||||||
st.json(active_tool_list)
|
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||||
|
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||||
|
total_tools += len(tools)
|
||||||
|
|
||||||
|
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||||
|
|
||||||
|
for group_id, tools in grouped_tools.items():
|
||||||
|
with st.expander(f"🔧 Tools from `{group_id}`"):
|
||||||
|
for idx, tool in enumerate(tools, start=1):
|
||||||
|
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
|
||||||
|
|
||||||
|
st.subheader("Agent Configurations")
|
||||||
|
st.subheader("Agent Type")
|
||||||
|
agent_type = st.radio(
|
||||||
|
"Select Agent Type",
|
||||||
|
[AgentType.REGULAR, AgentType.REACT],
|
||||||
|
format_func=lambda x: x.value,
|
||||||
|
on_change=reset_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
max_tokens = st.slider(
|
||||||
|
"Max Tokens",
|
||||||
|
min_value=0,
|
||||||
|
max_value=4096,
|
||||||
|
value=512,
|
||||||
|
step=64,
|
||||||
|
help="The maximum number of tokens to generate",
|
||||||
|
on_change=reset_agent,
|
||||||
|
)
|
||||||
|
|
||||||
|
for i, tool_name in enumerate(toolgroup_selection):
|
||||||
|
if tool_name == "builtin::rag":
|
||||||
|
tool_dict = dict(
|
||||||
|
name="builtin::rag",
|
||||||
|
args={
|
||||||
|
"vector_db_ids": list(selected_vector_dbs),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
toolgroup_selection[i] = tool_dict
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def create_agent():
|
def create_agent():
|
||||||
return Agent(
|
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||||
client,
|
return ReActAgent(
|
||||||
model=model,
|
client=client,
|
||||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
model=model,
|
||||||
tools=toolgroup_selection,
|
tools=toolgroup_selection,
|
||||||
sampling_params={
|
response_format={
|
||||||
"strategy": {"type": "greedy"},
|
"type": "json_schema",
|
||||||
},
|
"json_schema": ReActOutput.model_json_schema(),
|
||||||
)
|
},
|
||||||
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Agent(
|
||||||
|
client,
|
||||||
|
model=model,
|
||||||
|
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||||
|
tools=toolgroup_selection,
|
||||||
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
|
)
|
||||||
|
|
||||||
|
st.session_state.agent_type = agent_type
|
||||||
|
|
||||||
agent = create_agent()
|
agent = create_agent()
|
||||||
|
|
||||||
|
@ -95,6 +171,158 @@ def tool_chat_page():
|
||||||
)
|
)
|
||||||
|
|
||||||
def response_generator(turn_response):
|
def response_generator(turn_response):
|
||||||
|
if st.session_state.get("agent_type") == AgentType.REACT:
|
||||||
|
return _handle_react_response(turn_response)
|
||||||
|
else:
|
||||||
|
return _handle_regular_response(turn_response)
|
||||||
|
|
||||||
|
def _handle_react_response(turn_response):
|
||||||
|
current_step_content = ""
|
||||||
|
final_answer = None
|
||||||
|
tool_results = []
|
||||||
|
|
||||||
|
for response in turn_response:
|
||||||
|
if not hasattr(response.event, "payload"):
|
||||||
|
yield (
|
||||||
|
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
|
||||||
|
"The response received is missing an expected `payload` attribute.\n"
|
||||||
|
"This could indicate a malformed response or an internal issue within the server.\n\n"
|
||||||
|
f"Error details: {response}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = response.event.payload
|
||||||
|
|
||||||
|
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
|
||||||
|
current_step_content += payload.delta.text
|
||||||
|
continue
|
||||||
|
|
||||||
|
if payload.event_type == "step_complete":
|
||||||
|
step_details = payload.step_details
|
||||||
|
|
||||||
|
if step_details.step_type == "inference":
|
||||||
|
yield from _process_inference_step(current_step_content, tool_results, final_answer)
|
||||||
|
current_step_content = ""
|
||||||
|
elif step_details.step_type == "tool_execution":
|
||||||
|
tool_results = _process_tool_execution(step_details, tool_results)
|
||||||
|
current_step_content = ""
|
||||||
|
else:
|
||||||
|
current_step_content = ""
|
||||||
|
|
||||||
|
if not final_answer and tool_results:
|
||||||
|
yield from _format_tool_results_summary(tool_results)
|
||||||
|
|
||||||
|
def _process_inference_step(current_step_content, tool_results, final_answer):
|
||||||
|
try:
|
||||||
|
react_output_data = json.loads(current_step_content)
|
||||||
|
thought = react_output_data.get("thought")
|
||||||
|
action = react_output_data.get("action")
|
||||||
|
answer = react_output_data.get("answer")
|
||||||
|
|
||||||
|
if answer and answer != "null" and answer is not None:
|
||||||
|
final_answer = answer
|
||||||
|
|
||||||
|
if thought:
|
||||||
|
with st.expander("🤔 Thinking...", expanded=False):
|
||||||
|
st.markdown(f":grey[__{thought}__]")
|
||||||
|
|
||||||
|
if action and isinstance(action, dict):
|
||||||
|
tool_name = action.get("tool_name")
|
||||||
|
tool_params = action.get("tool_params")
|
||||||
|
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
|
||||||
|
st.json(tool_params)
|
||||||
|
|
||||||
|
if answer and answer != "null" and answer is not None:
|
||||||
|
yield f"\n\n✅ **Final Answer:**\n{answer}"
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
|
||||||
|
except Exception as e:
|
||||||
|
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
|
||||||
|
|
||||||
|
return final_answer
|
||||||
|
|
||||||
|
def _process_tool_execution(step_details, tool_results):
|
||||||
|
try:
|
||||||
|
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
|
||||||
|
for tool_response in step_details.tool_responses:
|
||||||
|
tool_name = tool_response.tool_name
|
||||||
|
content = tool_response.content
|
||||||
|
tool_results.append((tool_name, content))
|
||||||
|
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
st.json(parsed_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
st.code(content, language=None)
|
||||||
|
else:
|
||||||
|
with st.expander("⚙️ Observation", expanded=False):
|
||||||
|
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
|
||||||
|
except Exception as e:
|
||||||
|
with st.expander("⚙️ Error in Tool Execution", expanded=False):
|
||||||
|
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
|
||||||
|
|
||||||
|
return tool_results
|
||||||
|
|
||||||
|
def _format_tool_results_summary(tool_results):
|
||||||
|
yield "\n\n**Here's what I found:**\n"
|
||||||
|
for tool_name, content in tool_results:
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
|
||||||
|
if tool_name == "web_search" and "top_k" in parsed_content:
|
||||||
|
yield from _format_web_search_results(parsed_content)
|
||||||
|
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
|
||||||
|
yield from _format_results_list(parsed_content["results"])
|
||||||
|
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
|
||||||
|
yield from _format_dict_results(parsed_content)
|
||||||
|
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
|
||||||
|
yield from _format_list_results(parsed_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
|
||||||
|
except (TypeError, AttributeError, KeyError, IndexError) as e:
|
||||||
|
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
|
||||||
|
|
||||||
|
def _format_web_search_results(parsed_content):
|
||||||
|
for i, result in enumerate(parsed_content["top_k"], 1):
|
||||||
|
if i <= 3:
|
||||||
|
title = result.get("title", "Untitled")
|
||||||
|
url = result.get("url", "")
|
||||||
|
content_text = result.get("content", "").strip()
|
||||||
|
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
|
||||||
|
|
||||||
|
def _format_results_list(results):
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
if i <= 3:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
name = result.get("name", result.get("title", "Result " + str(i)))
|
||||||
|
description = result.get("description", result.get("content", result.get("summary", "")))
|
||||||
|
yield f"\n- **{name}**\n {description}\n"
|
||||||
|
else:
|
||||||
|
yield f"\n- {result}\n"
|
||||||
|
|
||||||
|
def _format_dict_results(parsed_content):
|
||||||
|
yield "\n```\n"
|
||||||
|
for key, value in list(parsed_content.items())[:5]:
|
||||||
|
if isinstance(value, str) and len(value) < 100:
|
||||||
|
yield f"{key}: {value}\n"
|
||||||
|
else:
|
||||||
|
yield f"{key}: [Complex data]\n"
|
||||||
|
yield "```\n"
|
||||||
|
|
||||||
|
def _format_list_results(parsed_content):
|
||||||
|
yield "\n"
|
||||||
|
for _, item in enumerate(parsed_content[:3], 1):
|
||||||
|
if isinstance(item, str):
|
||||||
|
yield f"- {item}\n"
|
||||||
|
elif isinstance(item, dict) and "text" in item:
|
||||||
|
yield f"- {item['text']}\n"
|
||||||
|
elif isinstance(item, dict) and len(item) > 0:
|
||||||
|
first_value = next(iter(item.values()))
|
||||||
|
if isinstance(first_value, str) and len(first_value) < 100:
|
||||||
|
yield f"- {first_value}\n"
|
||||||
|
|
||||||
|
def _handle_regular_response(turn_response):
|
||||||
for response in turn_response:
|
for response in turn_response:
|
||||||
if hasattr(response.event, "payload"):
|
if hasattr(response.event, "payload"):
|
||||||
print(response.event.payload)
|
print(response.event.payload)
|
||||||
|
@ -103,14 +331,18 @@ def tool_chat_page():
|
||||||
yield response.event.payload.delta.text
|
yield response.event.payload.delta.text
|
||||||
if response.event.payload.event_type == "step_complete":
|
if response.event.payload.event_type == "step_complete":
|
||||||
if response.event.payload.step_details.step_type == "tool_execution":
|
if response.event.payload.step_details.step_type == "tool_execution":
|
||||||
yield " 🛠 "
|
if response.event.payload.step_details.tool_calls:
|
||||||
|
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
|
||||||
|
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
|
||||||
|
else:
|
||||||
|
yield "No tool_calls present in step_details"
|
||||||
else:
|
else:
|
||||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||||
|
|
||||||
with st.chat_message("assistant"):
|
with st.chat_message("assistant"):
|
||||||
response = st.write_stream(response_generator(turn_response))
|
response_content = st.write_stream(response_generator(turn_response))
|
||||||
|
|
||||||
st.session_state.messages.append({"role": "assistant", "content": response})
|
st.session_state.messages.append({"role": "assistant", "content": response_content})
|
||||||
|
|
||||||
|
|
||||||
tool_chat_page()
|
tool_chat_page()
|
||||||
|
|
|
@ -5,6 +5,7 @@
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
import io
|
import io
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Dict, List, Optional, Tuple
|
from typing import Dict, List, Optional, Tuple
|
||||||
|
@ -299,8 +300,10 @@ class ChatFormat:
|
||||||
call_id=call_id,
|
call_id=call_id,
|
||||||
tool_name=tool_name,
|
tool_name=tool_name,
|
||||||
arguments=tool_arguments,
|
arguments=tool_arguments,
|
||||||
|
arguments_json=json.dumps(tool_arguments),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
content = ""
|
||||||
|
|
||||||
return RawMessage(
|
return RawMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
|
|
@ -64,7 +64,7 @@ This example passes an image that is smaller than the tile size, to show the til
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
The image depicts a dog standing on a skateboard, with its front paws positioned on the board and its back paws hanging off the back. The dog has a distinctive coat pattern, featuring a white face, brown and black fur, and white paws, and is standing on a skateboard with red wheels, set against a blurred background of a street or alleyway with a teal door and beige wall.<|eot|>
|
The image depicts a dog standing on a skateboard, positioned centrally and facing the camera directly. The dog has a distinctive coat pattern featuring white, black, and brown fur, with floppy ears and a black nose, and is standing on a skateboard with red wheels.<|eot|>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ Here is an example of how to pass an image to the model
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
This image shows a dog standing on a skateboard, with its front paws positioned near the front of the board and its back paws near the back. The dog has a white, black, and orange coat, and is standing on a gray skateboard with red wheels, in front of a blurred background that appears to be a street or alleyway.<|eot|>
|
The image depicts a dog standing on a skateboard, with the dog positioned centrally and facing forward. The dog has a distinctive coat featuring a mix of white, brown, and black fur, and is wearing a collar as it stands on the skateboard, which has red wheels.<|eot|>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ Here is an example of how to pass an image to the model
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
The first image shows a dog standing on a skateboard, while the second image shows a plate of spaghetti with tomato sauce, parmesan cheese, and parsley. The two images are unrelated, with the first image featuring a dog and the second image featuring a food dish, and they do not share any common elements or themes.<|eot|>
|
The first image features a dog standing on a skateboard, while the second image showcases a plate of spaghetti with tomato sauce and cheese. The two images appear to be unrelated, with one depicting a playful scene of a dog on a skateboard and the other presenting a classic Italian dish.<|eom|>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,13 +135,44 @@ We are continuing the format for zero shot function calling used in previous ver
|
||||||
```
|
```
|
||||||
<|begin_of_text|><|header_start|>system<|header_end|>
|
<|begin_of_text|><|header_start|>system<|header_end|>
|
||||||
|
|
||||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
|
||||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
1. FUNCTION CALLS:
|
||||||
also point it out. You should only return the function call in tools call sections.
|
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||||
|
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||||
|
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||||
|
Examples:
|
||||||
|
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||||
|
INCORRECT: get_weather(location="New York")
|
||||||
|
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||||
|
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||||
|
|
||||||
|
2. RESPONSE RULES:
|
||||||
|
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||||
|
- For knowledge questions: ONLY output text
|
||||||
|
- For missing parameters: ONLY request the specific missing parameters
|
||||||
|
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||||
|
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||||
|
- NEVER combine text and function calls in the same response
|
||||||
|
- NEVER suggest alternative functions when the requested service is unavailable
|
||||||
|
- NEVER create or invent new functions not listed below
|
||||||
|
|
||||||
|
3. STRICT BOUNDARIES:
|
||||||
|
- ONLY use functions from the list below - no exceptions
|
||||||
|
- NEVER use a function as an alternative to unavailable information
|
||||||
|
- NEVER call functions not present in the function list
|
||||||
|
- NEVER add explanatory text to function calls
|
||||||
|
- NEVER respond with empty brackets
|
||||||
|
- Use proper Python/JSON syntax for function calls
|
||||||
|
- Check the function list carefully before responding
|
||||||
|
|
||||||
|
4. TOOL RESPONSE HANDLING:
|
||||||
|
- When receiving tool responses: provide concise, natural language responses
|
||||||
|
- Don't repeat tool response verbatim
|
||||||
|
- Don't add supplementary information
|
||||||
|
|
||||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
|
||||||
You SHOULD NOT include any other text in the response.
|
|
||||||
|
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
@ -151,9 +182,7 @@ Here is a list of functions in JSON format that you can invoke.
|
||||||
"description": "Get weather info for places",
|
"description": "Get weather info for places",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "dict",
|
"type": "dict",
|
||||||
"required": [
|
"required": ["city"],
|
||||||
"city"
|
|
||||||
],
|
|
||||||
"properties": {
|
"properties": {
|
||||||
"city": {
|
"city": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -167,7 +196,10 @@ Here is a list of functions in JSON format that you can invoke.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
<|eot|><|header_start|>user<|header_end|>
|
]
|
||||||
|
|
||||||
|
You can answer general questions or invoke tools when necessary.
|
||||||
|
In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|>
|
||||||
|
|
||||||
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>
|
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>
|
||||||
|
|
||||||
|
@ -176,7 +208,7 @@ What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_e
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
[get_weather(city='SF'), get_weather(city='Seattle')]<|eot|>
|
[get_weather(city="San Francisco"), get_weather(city="Seattle")]<|eot|>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -273,5 +305,5 @@ Use tools to get latest trending songs<|eot|><|header_start|>assistant<|header_e
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
<function=trending_songs>{"n": "10"}</function><|eot|>
|
<function=trending_songs>{"n": 10}</function><|eot|>
|
||||||
```
|
```
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||||
|
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||||
|
PromptTemplate,
|
||||||
|
PromptTemplateGeneratorBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
|
DEFAULT_PROMPT = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||||
|
|
||||||
|
1. FUNCTION CALLS:
|
||||||
|
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||||
|
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||||
|
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||||
|
Examples:
|
||||||
|
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||||
|
INCORRECT: get_weather(location="New York")
|
||||||
|
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||||
|
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||||
|
|
||||||
|
2. RESPONSE RULES:
|
||||||
|
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||||
|
- For knowledge questions: ONLY output text
|
||||||
|
- For missing parameters: ONLY request the specific missing parameters
|
||||||
|
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||||
|
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||||
|
- NEVER combine text and function calls in the same response
|
||||||
|
- NEVER suggest alternative functions when the requested service is unavailable
|
||||||
|
- NEVER create or invent new functions not listed below
|
||||||
|
|
||||||
|
3. STRICT BOUNDARIES:
|
||||||
|
- ONLY use functions from the list below - no exceptions
|
||||||
|
- NEVER use a function as an alternative to unavailable information
|
||||||
|
- NEVER call functions not present in the function list
|
||||||
|
- NEVER add explanatory text to function calls
|
||||||
|
- NEVER respond with empty brackets
|
||||||
|
- Use proper Python/JSON syntax for function calls
|
||||||
|
- Check the function list carefully before responding
|
||||||
|
|
||||||
|
4. TOOL RESPONSE HANDLING:
|
||||||
|
- When receiving tool responses: provide concise, natural language responses
|
||||||
|
- Don't repeat tool response verbatim
|
||||||
|
- Don't add supplementary information
|
||||||
|
|
||||||
|
|
||||||
|
{{ function_description }}
|
||||||
|
""".strip("\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||||
|
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||||
|
return PromptTemplate(
|
||||||
|
system_prompt,
|
||||||
|
{"function_description": self._gen_function_description(custom_tools)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{% for t in tools -%}
|
||||||
|
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set tparams = t.parameters -%}
|
||||||
|
{%- set required_params = [] -%}
|
||||||
|
{%- for name, param in tparams.items() if param.required == true -%}
|
||||||
|
{%- set _ = required_params.append(name) -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{
|
||||||
|
"name": "{{tname}}",
|
||||||
|
"description": "{{tdesc}}",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": {{ required_params | tojson }},
|
||||||
|
"properties": {
|
||||||
|
{%- for name, param in tparams.items() %}
|
||||||
|
"{{name}}": {
|
||||||
|
"type": "{{param.param_type}}",
|
||||||
|
"description": "{{param.description}}"{% if param.default %},
|
||||||
|
"default": "{{param.default}}"{% endif %}
|
||||||
|
}{% if not loop.last %},{% endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}{% if not loop.last %},
|
||||||
|
{% endif -%}
|
||||||
|
{%- endfor %}
|
||||||
|
]
|
||||||
|
|
||||||
|
You can answer general questions or invoke tools when necessary.
|
||||||
|
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.strip("\n"),
|
||||||
|
{"tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
).render()
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get weather info for places",
|
||||||
|
parameters={
|
||||||
|
"city": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The name of the city to get the weather for",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"metric": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
required=False,
|
||||||
|
default="celsius",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
|
@ -9,6 +9,10 @@ from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||||
|
PythonListCustomToolGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||||
from ..prompt_format import (
|
from ..prompt_format import (
|
||||||
Llama4UseCase,
|
Llama4UseCase,
|
||||||
|
@ -177,39 +181,9 @@ def usecases(base_model: bool = False) -> List[UseCase | str]:
|
||||||
[
|
[
|
||||||
RawMessage(
|
RawMessage(
|
||||||
role="system",
|
role="system",
|
||||||
content="""You are an expert in composing functions. You are given a question and a set of possible functions.
|
content=PythonListCustomToolGenerator()
|
||||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
.gen(PythonListCustomToolGenerator().data_examples()[0])
|
||||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
.render(),
|
||||||
also point it out. You should only return the function call in tools call sections.
|
|
||||||
|
|
||||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
|
||||||
You SHOULD NOT include any other text in the response.
|
|
||||||
|
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get weather info for places",
|
|
||||||
"parameters": {
|
|
||||||
"type": "dict",
|
|
||||||
"required": [
|
|
||||||
"city"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The name of the city to get the weather for"
|
|
||||||
},
|
|
||||||
"metric": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
|
||||||
"default": "celsius"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
),
|
),
|
||||||
RawMessage(
|
RawMessage(
|
||||||
role="user",
|
role="user",
|
||||||
|
|
|
@ -178,6 +178,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("request", request.model_dump_json())
|
||||||
turn_id = str(uuid.uuid4())
|
turn_id = str(uuid.uuid4())
|
||||||
span.set_attribute("turn_id", turn_id)
|
span.set_attribute("turn_id", turn_id)
|
||||||
|
if self.agent_config.name:
|
||||||
|
span.set_attribute("agent_name", self.agent_config.name)
|
||||||
|
|
||||||
await self._initialize_tools(request.toolgroups)
|
await self._initialize_tools(request.toolgroups)
|
||||||
async for chunk in self._run_turn(request, turn_id):
|
async for chunk in self._run_turn(request, turn_id):
|
||||||
|
@ -190,6 +192,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
span.set_attribute("session_id", request.session_id)
|
span.set_attribute("session_id", request.session_id)
|
||||||
span.set_attribute("request", request.model_dump_json())
|
span.set_attribute("request", request.model_dump_json())
|
||||||
span.set_attribute("turn_id", request.turn_id)
|
span.set_attribute("turn_id", request.turn_id)
|
||||||
|
if self.agent_config.name:
|
||||||
|
span.set_attribute("agent_name", self.agent_config.name)
|
||||||
|
|
||||||
await self._initialize_tools()
|
await self._initialize_tools()
|
||||||
async for chunk in self._run_turn(request):
|
async for chunk in self._run_turn(request):
|
||||||
|
@ -498,6 +502,8 @@ class ChatAgent(ShieldRunnerMixin):
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
async with tracing.span("inference") as span:
|
async with tracing.span("inference") as span:
|
||||||
|
if self.agent_config.name:
|
||||||
|
span.set_attribute("agent_name", self.agent_config.name)
|
||||||
async for chunk in await self.inference_api.chat_completion(
|
async for chunk in await self.inference_api.chat_completion(
|
||||||
self.agent_config.model,
|
self.agent_config.model,
|
||||||
input_messages,
|
input_messages,
|
||||||
|
|
|
@ -253,7 +253,8 @@ class MetaReferenceInferenceImpl(
|
||||||
def impl():
|
def impl():
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for token_result in self.generator.completion(request):
|
for token_results in self.generator.completion([request]):
|
||||||
|
token_result = token_results[0]
|
||||||
if token_result.token == tokenizer.eot_id:
|
if token_result.token == tokenizer.eot_id:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
text = ""
|
text = ""
|
||||||
|
@ -515,7 +516,8 @@ class MetaReferenceInferenceImpl(
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
ipython = False
|
ipython = False
|
||||||
|
|
||||||
for token_result in self.generator.chat_completion(request):
|
for token_results in self.generator.chat_completion([request]):
|
||||||
|
token_result = token_results[0]
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "1":
|
||||||
cprint(token_result.text, "cyan", end="")
|
cprint(token_result.text, "cyan", end="")
|
||||||
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
if os.environ.get("LLAMA_MODELS_DEBUG", "0") == "2":
|
||||||
|
|
|
@ -69,7 +69,10 @@ class CancelSentinel(BaseModel):
|
||||||
|
|
||||||
class TaskRequest(BaseModel):
|
class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||||
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
|
task: Tuple[
|
||||||
|
str,
|
||||||
|
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
|
@ -231,10 +234,10 @@ def worker_process_entrypoint(
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
task = req_gen.send(result)
|
task = req_gen.send(result)
|
||||||
if isinstance(task, str) and task == EndSentinel():
|
if isinstance(task, EndSentinel):
|
||||||
break
|
break
|
||||||
|
|
||||||
assert isinstance(task, TaskRequest)
|
assert isinstance(task, TaskRequest), task
|
||||||
result = model(task.task)
|
result = model(task.task)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
@ -331,7 +334,10 @@ class ModelParallelProcessGroup:
|
||||||
|
|
||||||
def run_inference(
|
def run_inference(
|
||||||
self,
|
self,
|
||||||
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
|
req: Tuple[
|
||||||
|
str,
|
||||||
|
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||||
|
],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
content_from_doc,
|
content_from_doc,
|
||||||
make_overlapped_chunks,
|
make_overlapped_chunks,
|
||||||
|
@ -153,6 +154,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
|
picked.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return RAGQueryResult(
|
return RAGQueryResult(
|
||||||
content=picked,
|
content=picked,
|
||||||
|
|
|
@ -288,4 +288,14 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="watsonx",
|
||||||
|
pip_packages=["ibm_watson_machine_learning"],
|
||||||
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -77,7 +77,7 @@ POST /eval/benchmarks/{benchmark_id}/jobs
|
||||||
"benchmark_config": {
|
"benchmark_config": {
|
||||||
"eval_candidate": {
|
"eval_candidate": {
|
||||||
"type": "model",
|
"type": "model",
|
||||||
"model": "meta/llama-3.1-8b-instruct",
|
"model": "meta-llama/Llama3.1-8B-Instruct",
|
||||||
"sampling_params": {
|
"sampling_params": {
|
||||||
"max_tokens": 100,
|
"max_tokens": 100,
|
||||||
"temperature": 0.7
|
"temperature": 0.7
|
||||||
|
@ -91,7 +91,7 @@ POST /eval/benchmarks/{benchmark_id}/jobs
|
||||||
Response example:
|
Response example:
|
||||||
```json
|
```json
|
||||||
{
|
{
|
||||||
"job_id": "1234",
|
"job_id": "eval-1234",
|
||||||
"status": "in_progress"
|
"status": "in_progress"
|
||||||
}
|
}
|
||||||
```
|
```
|
||||||
|
@ -101,6 +101,14 @@ Response example:
|
||||||
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||||
```
|
```
|
||||||
|
|
||||||
|
Response example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"job_id": "eval-1234",
|
||||||
|
"status": "in_progress"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
### Example for cancelling a job
|
### Example for cancelling a job
|
||||||
```
|
```
|
||||||
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
|
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
|
||||||
|
|
|
@ -14,10 +14,10 @@ class NVIDIAEvalConfig(BaseModel):
|
||||||
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
|
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
|
||||||
|
|
||||||
Attributes:
|
Attributes:
|
||||||
evaluator_service_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
|
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
evaluator_service_url: str = Field(
|
evaluator_url: str = Field(
|
||||||
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
|
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
|
||||||
description="The url for accessing the evaluator service",
|
description="The url for accessing the evaluator service",
|
||||||
)
|
)
|
||||||
|
@ -25,5 +25,5 @@ class NVIDIAEvalConfig(BaseModel):
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"evaluator_service_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
|
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
|
||||||
}
|
}
|
||||||
|
|
|
@ -53,13 +53,13 @@ class NVIDIAEvalImpl(
|
||||||
|
|
||||||
async def _evaluator_get(self, path):
|
async def _evaluator_get(self, path):
|
||||||
"""Helper for making GET requests to the evaluator service."""
|
"""Helper for making GET requests to the evaluator service."""
|
||||||
response = requests.get(url=f"{self.config.evaluator_service_url}{path}")
|
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
async def _evaluator_post(self, path, data):
|
async def _evaluator_post(self, path, data):
|
||||||
"""Helper for making POST requests to the evaluator service."""
|
"""Helper for making POST requests to the evaluator service."""
|
||||||
response = requests.post(url=f"{self.config.evaluator_service_url}{path}", json=data)
|
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
|
|
|
@ -362,6 +362,39 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
model_obj = await self.model_store.get_model(model)
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
|
||||||
|
# Divert Llama Models through Llama Stack inference APIs because
|
||||||
|
# Fireworks chat completions OpenAI-compatible API does not support
|
||||||
|
# tool calls properly.
|
||||||
|
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
||||||
|
if llama_model:
|
||||||
|
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(
|
||||||
|
self,
|
||||||
|
model=model,
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
function_call=function_call,
|
||||||
|
functions=functions,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
response_format=response_format,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tools=tools,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
messages=messages,
|
messages=messages,
|
||||||
frequency_penalty=frequency_penalty,
|
frequency_penalty=frequency_penalty,
|
||||||
|
@ -387,11 +420,4 @@ class FireworksInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProv
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
|
|
||||||
# Divert Llama Models through Llama Stack inference APIs because
|
|
||||||
# Fireworks chat completions OpenAI-compatible API does not support
|
|
||||||
# tool calls properly.
|
|
||||||
llama_model = self.get_llama_model(model_obj.provider_resource_id)
|
|
||||||
if llama_model:
|
|
||||||
return await OpenAIChatCompletionToLlamaStackMixin.openai_chat_completion(self, model=model, **params)
|
|
||||||
|
|
||||||
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
return await self._get_openai_client().chat.completions.create(model=model_obj.provider_resource_id, **params)
|
||||||
|
|
85
llama_stack/providers/remote/inference/nvidia/NVIDIA.md
Normal file
85
llama_stack/providers/remote/inference/nvidia/NVIDIA.md
Normal file
|
@ -0,0 +1,85 @@
|
||||||
|
# NVIDIA Inference Provider for LlamaStack
|
||||||
|
|
||||||
|
This provider enables running inference using NVIDIA NIM.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
- Endpoints for completions, chat completions, and embeddings for registered models
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- LlamaStack with NVIDIA configuration
|
||||||
|
- Access to NVIDIA NIM deployment
|
||||||
|
- NIM for model to use for inference is deployed
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
Build the NVIDIA environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template nvidia --image-type conda
|
||||||
|
```
|
||||||
|
|
||||||
|
### Basic Usage using the LlamaStack Python Client
|
||||||
|
|
||||||
|
#### Initialize the client
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["NVIDIA_API_KEY"] = (
|
||||||
|
"" # Required if using hosted NIM endpoint. If self-hosted, not required.
|
||||||
|
)
|
||||||
|
os.environ["NVIDIA_BASE_URL"] = "http://nim.test" # NIM URL
|
||||||
|
|
||||||
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
client.initialize()
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create Completion
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = client.completion(
|
||||||
|
model_id="meta-llama/Llama-3.1-8b-Instruct",
|
||||||
|
content="Complete the sentence using one word: Roses are red, violets are :",
|
||||||
|
stream=False,
|
||||||
|
sampling_params={
|
||||||
|
"max_tokens": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(f"Response: {response.content}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create Chat Completion
|
||||||
|
|
||||||
|
```python
|
||||||
|
response = client.chat_completion(
|
||||||
|
model_id="meta-llama/Llama-3.1-8b-Instruct",
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "system",
|
||||||
|
"content": "You must respond to each message with only one word",
|
||||||
|
},
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Complete the sentence using one word: Roses are red, violets are:",
|
||||||
|
},
|
||||||
|
],
|
||||||
|
stream=False,
|
||||||
|
sampling_params={
|
||||||
|
"max_tokens": 50,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
print(f"Response: {response.completion_message.content}")
|
||||||
|
```
|
||||||
|
|
||||||
|
### Create Embeddings
|
||||||
|
```python
|
||||||
|
response = client.embeddings(
|
||||||
|
model_id="meta-llama/Llama-3.1-8b-Instruct", contents=["foo", "bar", "baz"]
|
||||||
|
)
|
||||||
|
print(f"Embeddings: {response.embeddings}")
|
||||||
|
```
|
|
@ -48,6 +48,10 @@ MODEL_ENTRIES = [
|
||||||
"meta/llama-3.2-90b-vision-instruct",
|
"meta/llama-3.2-90b-vision-instruct",
|
||||||
CoreModelId.llama3_2_90b_vision_instruct.value,
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
),
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.3-70b-instruct",
|
||||||
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
),
|
||||||
# NeMo Retriever Text Embedding models -
|
# NeMo Retriever Text Embedding models -
|
||||||
#
|
#
|
||||||
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
# https://docs.nvidia.com/nim/nemo-retriever/text-embedding/latest/support-matrix.html
|
||||||
|
|
|
@ -129,6 +129,14 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
base_url = special_model_urls[provider_model_id]
|
base_url = special_model_urls[provider_model_id]
|
||||||
return _get_client_for_base_url(base_url)
|
return _get_client_for_base_url(base_url)
|
||||||
|
|
||||||
|
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||||
|
if not self.model_store:
|
||||||
|
raise RuntimeError("Model store is not set")
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
if model is None:
|
||||||
|
raise ValueError(f"Model {model_id} is unknown")
|
||||||
|
return model.provider_model_id
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -147,7 +155,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
# removing this health check as NeMo customizer endpoint health check is returning 404
|
# removing this health check as NeMo customizer endpoint health check is returning 404
|
||||||
# await check_health(self._config) # this raises errors
|
# await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
provider_model_id = self.get_provider_model_id(model_id)
|
provider_model_id = await self._get_provider_model_id(model_id)
|
||||||
request = convert_completion_request(
|
request = convert_completion_request(
|
||||||
request=CompletionRequest(
|
request=CompletionRequest(
|
||||||
model=provider_model_id,
|
model=provider_model_id,
|
||||||
|
@ -191,7 +199,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
#
|
#
|
||||||
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
flat_contents = [content.text if isinstance(content, TextContentItem) else content for content in contents]
|
||||||
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
input = [content.text if isinstance(content, TextContentItem) else content for content in flat_contents]
|
||||||
model = self.get_provider_model_id(model_id)
|
provider_model_id = await self._get_provider_model_id(model_id)
|
||||||
|
|
||||||
extra_body = {}
|
extra_body = {}
|
||||||
|
|
||||||
|
@ -214,8 +222,8 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
extra_body["input_type"] = task_type_options[task_type]
|
extra_body["input_type"] = task_type_options[task_type]
|
||||||
|
|
||||||
try:
|
try:
|
||||||
response = await self._get_client(model).embeddings.create(
|
response = await self._get_client(provider_model_id).embeddings.create(
|
||||||
model=model,
|
model=provider_model_id,
|
||||||
input=input,
|
input=input,
|
||||||
extra_body=extra_body,
|
extra_body=extra_body,
|
||||||
)
|
)
|
||||||
|
@ -249,11 +257,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
|
||||||
# await check_health(self._config) # this raises errors
|
# await check_health(self._config) # this raises errors
|
||||||
|
|
||||||
provider_model_id = self.get_provider_model_id(model_id)
|
provider_model_id = await self._get_provider_model_id(model_id)
|
||||||
print(f"provider_model_id: {provider_model_id}")
|
|
||||||
request = await convert_chat_completion_request(
|
request = await convert_chat_completion_request(
|
||||||
request=ChatCompletionRequest(
|
request=ChatCompletionRequest(
|
||||||
model=self.get_provider_model_id(model_id),
|
model=provider_model_id,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
sampling_params=sampling_params,
|
sampling_params=sampling_params,
|
||||||
response_format=response_format,
|
response_format=response_format,
|
||||||
|
@ -298,7 +305,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: Optional[List[str]] = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: Optional[int] = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
provider_model_id = self.get_provider_model_id(model)
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=provider_model_id,
|
model=provider_model_id,
|
||||||
|
@ -351,7 +358,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
provider_model_id = self.get_provider_model_id(model)
|
provider_model_id = await self._get_provider_model_id(model)
|
||||||
|
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=provider_model_id,
|
model=provider_model_id,
|
||||||
|
|
|
@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self._client:
|
if self._client:
|
||||||
await self._client.close()
|
# Together client has no close method, so just set to None
|
||||||
self._client = None
|
self._client = None
|
||||||
|
if self._openai_client:
|
||||||
|
await self._openai_client.close()
|
||||||
|
self._openai_client = None
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
if params.get("stream", True):
|
if params.get("stream", False):
|
||||||
return self._stream_openai_chat_completion(params)
|
return self._stream_openai_chat_completion(params)
|
||||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
pass
|
||||||
self.client = AsyncOpenAI(
|
|
||||||
base_url=self.config.url,
|
|
||||||
api_key=self.config.api_token,
|
|
||||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
raise ValueError("Model store not set")
|
raise ValueError("Model store not set")
|
||||||
return await self.model_store.get_model(model_id)
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
def _lazy_initialize_client(self):
|
||||||
|
if self.client is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||||
|
self.client = self._create_client()
|
||||||
|
|
||||||
|
def _create_client(self):
|
||||||
|
return AsyncOpenAI(
|
||||||
|
base_url=self.config.url,
|
||||||
|
api_key=self.config.api_token,
|
||||||
|
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||||
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
|
self._lazy_initialize_client()
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
|
self._lazy_initialize_client()
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
@ -357,9 +368,12 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
assert self.client is not None
|
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
|
||||||
|
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
||||||
|
# Changing this may lead to unpredictable behavior.
|
||||||
|
client = self._create_client() if self.client is None else self.client
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
res = await self.client.models.list()
|
res = await client.models.list()
|
||||||
available_models = [m.id async for m in res]
|
available_models = [m.id async for m in res]
|
||||||
if model.provider_resource_id not in available_models:
|
if model.provider_resource_id not in available_models:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -374,7 +388,8 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
options["max_tokens"] = self.config.max_tokens
|
options["max_tokens"] = self.config.max_tokens
|
||||||
|
|
||||||
input_dict: dict[str, Any] = {}
|
input_dict: dict[str, Any] = {}
|
||||||
if isinstance(request, ChatCompletionRequest) and request.tools is not None:
|
# Only include the 'tools' param if there is any. It can break things if an empty list is sent to the vLLM.
|
||||||
|
if isinstance(request, ChatCompletionRequest) and request.tools:
|
||||||
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
input_dict = {"tools": _convert_to_vllm_tools_in_request(request.tools)}
|
||||||
|
|
||||||
if isinstance(request, ChatCompletionRequest):
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
@ -409,6 +424,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
self._lazy_initialize_client()
|
||||||
assert self.client is not None
|
assert self.client is not None
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
|
||||||
|
@ -448,6 +464,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: Optional[List[str]] = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: Optional[int] = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
|
self._lazy_initialize_client()
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
|
|
||||||
extra_body: Dict[str, Any] = {}
|
extra_body: Dict[str, Any] = {}
|
||||||
|
@ -504,6 +521,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
|
self._lazy_initialize_client()
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=model_obj.provider_resource_id,
|
model=model_obj.provider_resource_id,
|
||||||
|
|
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import WatsonXConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||||
|
from .watsonx import WatsonXInferenceAdapter
|
||||||
|
|
||||||
|
if not isinstance(config, WatsonXConfig):
|
||||||
|
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||||
|
adapter = WatsonXInferenceAdapter(config)
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXProviderDataValidator(BaseModel):
|
||||||
|
url: str
|
||||||
|
api_key: str
|
||||||
|
project_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class WatsonXConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
|
description="A base url for accessing the watsonx.ai",
|
||||||
|
)
|
||||||
|
api_key: Optional[SecretStr] = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||||
|
description="The watsonx API key, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
project_id: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||||
|
description="The Project ID key, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
timeout: int = Field(
|
||||||
|
default=60,
|
||||||
|
description="Timeout for the HTTP requests",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
|
||||||
|
"api_key": "${env.WATSONX_API_KEY:}",
|
||||||
|
"project_id": "${env.WATSONX_PROJECT_ID:}",
|
||||||
|
}
|
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
||||||
|
|
||||||
|
MODEL_ENTRIES = [
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-3-70b-instruct",
|
||||||
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-2-13b-chat",
|
||||||
|
CoreModelId.llama2_13b.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-1-70b-instruct",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-1-8b-instruct",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-1b-instruct",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-3b-instruct",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-guard-3-11b-vision",
|
||||||
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
|
),
|
||||||
|
]
|
378
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
378
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
|
@ -0,0 +1,378 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from ibm_watson_machine_learning.foundation_models import Model
|
||||||
|
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
GreedySamplingStrategy,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAICompletion,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAIResponseFormatParam,
|
||||||
|
TopKSamplingStrategy,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
prepare_openai_completion_params,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
process_completion_response,
|
||||||
|
process_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
completion_request_to_prompt,
|
||||||
|
request_has_media,
|
||||||
|
)
|
||||||
|
|
||||||
|
from . import WatsonXConfig
|
||||||
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
def __init__(self, config: WatsonXConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||||
|
|
||||||
|
print(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||||
|
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
self._project_id = self._config.project_id
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedContent,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
def _get_client(self, model_id) -> Model:
|
||||||
|
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||||
|
config_url = self._config.url
|
||||||
|
project_id = self._config.project_id
|
||||||
|
credentials = {"url": config_url, "apikey": config_api_key}
|
||||||
|
|
||||||
|
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
||||||
|
|
||||||
|
def _get_openai_client(self) -> AsyncOpenAI:
|
||||||
|
if not self._openai_client:
|
||||||
|
self._openai_client = AsyncOpenAI(
|
||||||
|
base_url=f"{self._config.url}/openai/v1",
|
||||||
|
api_key=self._config.api_key,
|
||||||
|
)
|
||||||
|
return self._openai_client
|
||||||
|
|
||||||
|
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
r = self._get_client(request.model).generate(**params)
|
||||||
|
choices = []
|
||||||
|
if "results" in r:
|
||||||
|
for result in r["results"]:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||||
|
text=result["generated_text"],
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=choices,
|
||||||
|
)
|
||||||
|
return process_completion_response(response)
|
||||||
|
|
||||||
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
|
||||||
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
s = self._get_client(request.model).generate_text_stream(**params)
|
||||||
|
for chunk in s:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
text=chunk,
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
|
async for chunk in process_completion_stream_response(stream):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
r = self._get_client(request.model).generate(**params)
|
||||||
|
choices = []
|
||||||
|
if "results" in r:
|
||||||
|
for result in r["results"]:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||||
|
text=result["generated_text"],
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=choices,
|
||||||
|
)
|
||||||
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
model_id = request.model
|
||||||
|
|
||||||
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
|
async def _to_async_generator():
|
||||||
|
s = self._get_client(model_id).generate_text_stream(**params)
|
||||||
|
for chunk in s:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
text=chunk,
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
|
input_dict = {"params": {}}
|
||||||
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
|
else:
|
||||||
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
if request.sampling_params:
|
||||||
|
if request.sampling_params.strategy:
|
||||||
|
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
||||||
|
if request.sampling_params.max_tokens:
|
||||||
|
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||||
|
if request.sampling_params.repetition_penalty:
|
||||||
|
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||||
|
|
||||||
|
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||||
|
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
||||||
|
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
||||||
|
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||||
|
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
||||||
|
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||||
|
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
||||||
|
|
||||||
|
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
|
||||||
|
|
||||||
|
params = {
|
||||||
|
**input_dict,
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError("embedding is not supported for watsonx")
|
||||||
|
|
||||||
|
async def openai_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
echo: Optional[bool] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
guided_choice: Optional[List[str]] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None,
|
||||||
|
) -> OpenAICompletion:
|
||||||
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
params = await prepare_openai_completion_params(
|
||||||
|
model=model_obj.provider_resource_id,
|
||||||
|
prompt=prompt,
|
||||||
|
best_of=best_of,
|
||||||
|
echo=echo,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
return await self._get_openai_client().completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
async def openai_chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[OpenAIMessageParam],
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
|
functions: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
max_completion_tokens: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
parallel_tool_calls: Optional[bool] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
params = await prepare_openai_completion_params(
|
||||||
|
model=model_obj.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
function_call=function_call,
|
||||||
|
functions=functions,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
response_format=response_format,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tools=tools,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if params.get("stream", False):
|
||||||
|
return self._stream_openai_chat_completion(params)
|
||||||
|
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||||
|
# watsonx.ai sometimes adds usage data to the stream
|
||||||
|
include_usage = False
|
||||||
|
if params.get("stream_options", None):
|
||||||
|
include_usage = params["stream_options"].get("include_usage", False)
|
||||||
|
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||||
|
|
||||||
|
seen_finish_reason = False
|
||||||
|
async for chunk in stream:
|
||||||
|
# Final usage chunk with no choices that the user didn't request, so discard
|
||||||
|
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
for choice in chunk.choices:
|
||||||
|
if choice.finish_reason:
|
||||||
|
seen_finish_reason = True
|
||||||
|
break
|
|
@ -36,7 +36,6 @@ import os
|
||||||
|
|
||||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
|
|
||||||
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||||
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||||
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
||||||
|
@ -128,13 +127,14 @@ client.post_training.job.cancel(job_uuid="your-job-id")
|
||||||
#### 1. Register the model
|
#### 1. Register the model
|
||||||
|
|
||||||
```python
|
```python
|
||||||
model = Model(
|
from llama_stack.apis.models import Model, ModelType
|
||||||
identifier="test-example-model@v1",
|
|
||||||
|
client.models.register(
|
||||||
|
model_id="test-example-model@v1",
|
||||||
provider_id="nvidia",
|
provider_id="nvidia",
|
||||||
provider_model_id="test-example-model@v1",
|
provider_model_id="test-example-model@v1",
|
||||||
model_type=ModelType.llm,
|
model_type=ModelType.llm,
|
||||||
)
|
)
|
||||||
client.register_model(model)
|
|
||||||
```
|
```
|
||||||
|
|
||||||
#### 2. Inference with the fine-tuned model
|
#### 2. Inference with the fine-tuned model
|
||||||
|
|
|
@ -16,7 +16,11 @@ _MODEL_ENTRIES = [
|
||||||
build_hf_repo_model_entry(
|
build_hf_repo_model_entry(
|
||||||
"meta/llama-3.1-8b-instruct",
|
"meta/llama-3.1-8b-instruct",
|
||||||
CoreModelId.llama3_1_8b_instruct.value,
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
)
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta/llama-3.2-1b-instruct",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -67,13 +67,18 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||||
# TODO: filter by available models based on /config endpoint
|
# TODO: filter by available models based on /config endpoint
|
||||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
self.session = None
|
||||||
self.customizer_url = config.customizer_url
|
|
||||||
|
|
||||||
|
self.customizer_url = config.customizer_url
|
||||||
if not self.customizer_url:
|
if not self.customizer_url:
|
||||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||||
self.customizer_url = "http://nemo.test"
|
self.customizer_url = "http://nemo.test"
|
||||||
|
|
||||||
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self.session is None or self.session.closed:
|
||||||
|
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||||
|
return self.session
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
|
@ -94,11 +99,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
if json and "Content-Type" not in request_headers:
|
if json and "Content-Type" not in request_headers:
|
||||||
request_headers["Content-Type"] = "application/json"
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
|
session = await self._get_session()
|
||||||
for _ in range(self.config.max_retries):
|
for _ in range(self.config.max_retries):
|
||||||
# TODO: Remove `verify_ssl=False`. Added for testing purposes to call NMP int environment from `docs/notebooks/nvidia/`
|
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||||
async with self.session.request(
|
|
||||||
method, url, params=params, json=json, verify_ssl=False, **kwargs
|
|
||||||
) as response:
|
|
||||||
if response.status >= 400:
|
if response.status >= 400:
|
||||||
error_data = await response.json()
|
error_data = await response.json()
|
||||||
raise Exception(f"API request failed: {error_data}")
|
raise Exception(f"API request failed: {error_data}")
|
||||||
|
@ -125,8 +128,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
jobs = []
|
jobs = []
|
||||||
for job in response.get("data", []):
|
for job in response.get("data", []):
|
||||||
job_id = job.pop("id")
|
job_id = job.pop("id")
|
||||||
job_status = job.pop("status", "unknown").lower()
|
job_status = job.pop("status", "scheduled").lower()
|
||||||
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
mapped_status = STATUS_MAPPING.get(job_status, "scheduled")
|
||||||
|
|
||||||
# Convert string timestamps to datetime objects
|
# Convert string timestamps to datetime objects
|
||||||
created_at = (
|
created_at = (
|
||||||
|
@ -180,7 +183,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
api_status = response.pop("status").lower()
|
api_status = response.pop("status").lower()
|
||||||
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
mapped_status = STATUS_MAPPING.get(api_status, "scheduled")
|
||||||
|
|
||||||
return NvidiaPostTrainingJobStatusResponse(
|
return NvidiaPostTrainingJobStatusResponse(
|
||||||
status=JobStatus(mapped_status),
|
status=JobStatus(mapped_status),
|
||||||
|
@ -242,6 +245,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
Supported models:
|
Supported models:
|
||||||
- meta/llama-3.1-8b-instruct
|
- meta/llama-3.1-8b-instruct
|
||||||
|
- meta/llama-3.2-1b-instruct
|
||||||
|
|
||||||
Supported algorithm configs:
|
Supported algorithm configs:
|
||||||
- LoRA, SFT
|
- LoRA, SFT
|
||||||
|
@ -287,10 +291,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
- LoRA config:
|
- LoRA config:
|
||||||
## NeMo customizer specific LoRA parameters
|
## NeMo customizer specific LoRA parameters
|
||||||
- adapter_dim: int - Adapter dimension
|
|
||||||
Default: 8 (supports powers of 2)
|
|
||||||
- adapter_dropout: float - Adapter dropout
|
|
||||||
Default: None (0.0-1.0)
|
|
||||||
- alpha: int - Scaling factor for the LoRA update
|
- alpha: int - Scaling factor for the LoRA update
|
||||||
Default: 16
|
Default: 16
|
||||||
Note:
|
Note:
|
||||||
|
@ -300,7 +300,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
User is informed about unsupported parameters via warnings.
|
User is informed about unsupported parameters via warnings.
|
||||||
"""
|
"""
|
||||||
# Map model to nvidia model name
|
# Map model to nvidia model name
|
||||||
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
|
# See `_MODEL_ENTRIES` for supported models
|
||||||
nvidia_model = self.get_provider_model_id(model)
|
nvidia_model = self.get_provider_model_id(model)
|
||||||
|
|
||||||
# Check for unsupported method parameters
|
# Check for unsupported method parameters
|
||||||
|
@ -333,7 +333,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
},
|
},
|
||||||
"data_config": {"dataset_id", "batch_size"},
|
"data_config": {"dataset_id", "batch_size"},
|
||||||
"optimizer_config": {"lr", "weight_decay"},
|
"optimizer_config": {"lr", "weight_decay"},
|
||||||
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
|
"lora_config": {"type", "alpha"},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Validate all parameters at once
|
# Validate all parameters at once
|
||||||
|
@ -392,17 +392,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
# Handle LoRA-specific configuration
|
# Handle LoRA-specific configuration
|
||||||
if algorithm_config:
|
if algorithm_config:
|
||||||
algorithm_config_dict = algorithm_config.model_dump()
|
if algorithm_config.type == "LoRA":
|
||||||
if algorithm_config_dict.get("type") == "LoRA":
|
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
||||||
warn_unsupported_params(algorithm_config_dict, supported_params["lora_config"], "LoRA config")
|
|
||||||
job_config["hyperparameters"]["lora"] = {
|
job_config["hyperparameters"]["lora"] = {
|
||||||
k: v
|
k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
|
||||||
for k, v in {
|
|
||||||
"adapter_dim": algorithm_config_dict.get("adapter_dim"),
|
|
||||||
"alpha": algorithm_config_dict.get("alpha"),
|
|
||||||
"adapter_dropout": algorithm_config_dict.get("adapter_dropout"),
|
|
||||||
}.items()
|
|
||||||
if v is not None
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||||
|
|
77
llama_stack/providers/remote/safety/nvidia/README.md
Normal file
77
llama_stack/providers/remote/safety/nvidia/README.md
Normal file
|
@ -0,0 +1,77 @@
|
||||||
|
# NVIDIA Safety Provider for LlamaStack
|
||||||
|
|
||||||
|
This provider enables safety checks and guardrails for LLM interactions using NVIDIA's NeMo Guardrails service.
|
||||||
|
|
||||||
|
## Features
|
||||||
|
|
||||||
|
- Run safety checks for messages
|
||||||
|
|
||||||
|
## Getting Started
|
||||||
|
|
||||||
|
### Prerequisites
|
||||||
|
|
||||||
|
- LlamaStack with NVIDIA configuration
|
||||||
|
- Access to NVIDIA NeMo Guardrails service
|
||||||
|
- NIM for model to use for safety check is deployed
|
||||||
|
|
||||||
|
### Setup
|
||||||
|
|
||||||
|
Build the NVIDIA environment:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template nvidia --image-type conda
|
||||||
|
```
|
||||||
|
|
||||||
|
### Basic Usage using the LlamaStack Python Client
|
||||||
|
|
||||||
|
#### Initialize the client
|
||||||
|
|
||||||
|
```python
|
||||||
|
import os
|
||||||
|
|
||||||
|
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||||
|
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://guardrails.test"
|
||||||
|
|
||||||
|
from llama_stack.distribution.library_client import LlamaStackAsLibraryClient
|
||||||
|
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
client.initialize()
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Create a safety shield
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.safety import Shield
|
||||||
|
from llama_stack.apis.inference import Message
|
||||||
|
|
||||||
|
# Create a safety shield
|
||||||
|
shield = Shield(
|
||||||
|
shield_id="your-shield-id",
|
||||||
|
provider_resource_id="safety-model-id", # The model to use for safety checks
|
||||||
|
description="Safety checks for content moderation",
|
||||||
|
)
|
||||||
|
|
||||||
|
# Register the shield
|
||||||
|
await client.safety.register_shield(shield)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### Run safety checks
|
||||||
|
|
||||||
|
```python
|
||||||
|
# Messages to check
|
||||||
|
messages = [Message(role="user", content="Your message to check")]
|
||||||
|
|
||||||
|
# Run safety check
|
||||||
|
response = await client.safety.run_shield(
|
||||||
|
shield_id="your-shield-id",
|
||||||
|
messages=messages,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Check for violations
|
||||||
|
if response.violation:
|
||||||
|
print(f"Safety violation detected: {response.violation.user_message}")
|
||||||
|
print(f"Violation level: {response.violation.violation_level}")
|
||||||
|
print(f"Metadata: {response.violation.metadata}")
|
||||||
|
else:
|
||||||
|
print("No safety violations detected")
|
||||||
|
```
|
|
@ -8,7 +8,17 @@ import logging
|
||||||
import time
|
import time
|
||||||
import uuid
|
import uuid
|
||||||
import warnings
|
import warnings
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Awaitable, Dict, Iterable, List, Optional, Union
|
from typing import (
|
||||||
|
Any,
|
||||||
|
AsyncGenerator,
|
||||||
|
AsyncIterator,
|
||||||
|
Awaitable,
|
||||||
|
Dict,
|
||||||
|
Iterable,
|
||||||
|
List,
|
||||||
|
Optional,
|
||||||
|
Union,
|
||||||
|
)
|
||||||
|
|
||||||
from openai import AsyncStream
|
from openai import AsyncStream
|
||||||
from openai.types.chat import (
|
from openai.types.chat import (
|
||||||
|
@ -78,6 +88,7 @@ from llama_stack.apis.common.content_types import (
|
||||||
TextDelta,
|
TextDelta,
|
||||||
ToolCallDelta,
|
ToolCallDelta,
|
||||||
ToolCallParseStatus,
|
ToolCallParseStatus,
|
||||||
|
_URLOrData,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
@ -93,6 +104,7 @@ from llama_stack.apis.inference import (
|
||||||
SamplingParams,
|
SamplingParams,
|
||||||
SystemMessage,
|
SystemMessage,
|
||||||
TokenLogProbs,
|
TokenLogProbs,
|
||||||
|
ToolChoice,
|
||||||
ToolResponseMessage,
|
ToolResponseMessage,
|
||||||
TopKSamplingStrategy,
|
TopKSamplingStrategy,
|
||||||
TopPSamplingStrategy,
|
TopPSamplingStrategy,
|
||||||
|
@ -103,7 +115,6 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
OpenAICompletion,
|
OpenAICompletion,
|
||||||
OpenAICompletionChoice,
|
OpenAICompletionChoice,
|
||||||
OpenAIMessageParam,
|
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
)
|
)
|
||||||
|
@ -513,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
||||||
else:
|
else:
|
||||||
content = [await _convert_content(message.content)]
|
content = [await _convert_content(message.content)]
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
result["tool_calls"] = []
|
||||||
|
for tc in message.tool_calls:
|
||||||
|
result["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"id": tc.call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.tool_name,
|
||||||
|
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class UnparseableToolCall(BaseModel):
|
class UnparseableToolCall(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
@ -612,13 +638,10 @@ async def convert_message_to_openai_dict_new(
|
||||||
)
|
)
|
||||||
for tool in message.tool_calls
|
for tool in message.tool_calls
|
||||||
]
|
]
|
||||||
params = {}
|
|
||||||
if tool_calls:
|
|
||||||
params = {"tool_calls": tool_calls}
|
|
||||||
out = OpenAIChatCompletionAssistantMessage(
|
out = OpenAIChatCompletionAssistantMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
content=await _convert_message_content(message.content),
|
content=await _convert_message_content(message.content),
|
||||||
**params,
|
tool_calls=tool_calls or None,
|
||||||
)
|
)
|
||||||
elif isinstance(message, ToolResponseMessage):
|
elif isinstance(message, ToolResponseMessage):
|
||||||
out = OpenAIChatCompletionToolMessage(
|
out = OpenAIChatCompletionToolMessage(
|
||||||
|
@ -695,7 +718,10 @@ def to_openai_param_type(param_type: str) -> dict:
|
||||||
if param_type.startswith("list[") and param_type.endswith("]"):
|
if param_type.startswith("list[") and param_type.endswith("]"):
|
||||||
inner_type = param_type[5:-1]
|
inner_type = param_type[5:-1]
|
||||||
if inner_type in basic_types:
|
if inner_type in basic_types:
|
||||||
return {"type": "array", "items": {"type": basic_types.get(inner_type, inner_type)}}
|
return {
|
||||||
|
"type": "array",
|
||||||
|
"items": {"type": basic_types.get(inner_type, inner_type)},
|
||||||
|
}
|
||||||
|
|
||||||
return {"type": param_type}
|
return {"type": param_type}
|
||||||
|
|
||||||
|
@ -815,6 +841,10 @@ def _convert_openai_finish_reason(finish_reason: str) -> StopReason:
|
||||||
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
def _convert_openai_request_tool_config(tool_choice: Optional[Union[str, Dict[str, Any]]] = None) -> ToolConfig:
|
||||||
tool_config = ToolConfig()
|
tool_config = ToolConfig()
|
||||||
if tool_choice:
|
if tool_choice:
|
||||||
|
try:
|
||||||
|
tool_choice = ToolChoice(tool_choice)
|
||||||
|
except ValueError:
|
||||||
|
pass
|
||||||
tool_config.tool_choice = tool_choice
|
tool_config.tool_choice = tool_choice
|
||||||
return tool_config
|
return tool_config
|
||||||
|
|
||||||
|
@ -849,7 +879,9 @@ def _convert_openai_request_tools(tools: Optional[List[Dict[str, Any]]] = None)
|
||||||
return lls_tools
|
return lls_tools
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_request_response_format(response_format: OpenAIResponseFormatParam = None):
|
def _convert_openai_request_response_format(
|
||||||
|
response_format: OpenAIResponseFormatParam = None,
|
||||||
|
):
|
||||||
if not response_format:
|
if not response_format:
|
||||||
return None
|
return None
|
||||||
# response_format can be a dict or a pydantic model
|
# response_format can be a dict or a pydantic model
|
||||||
|
@ -957,38 +989,50 @@ def _convert_openai_sampling_params(
|
||||||
return sampling_params
|
return sampling_params
|
||||||
|
|
||||||
|
|
||||||
def _convert_openai_request_messages(messages: List[OpenAIMessageParam]):
|
def openai_messages_to_messages(
|
||||||
# Llama Stack messages and OpenAI messages are similar, but not identical.
|
messages: List[OpenAIChatCompletionMessage],
|
||||||
lls_messages = []
|
) -> List[Message]:
|
||||||
|
"""
|
||||||
|
Convert a list of OpenAIChatCompletionMessage into a list of Message.
|
||||||
|
"""
|
||||||
|
converted_messages = []
|
||||||
for message in messages:
|
for message in messages:
|
||||||
lls_message = dict(message)
|
if message.role == "system":
|
||||||
|
converted_message = SystemMessage(content=message.content)
|
||||||
|
elif message.role == "user":
|
||||||
|
converted_message = UserMessage(content=openai_content_to_content(message.content))
|
||||||
|
elif message.role == "assistant":
|
||||||
|
converted_message = CompletionMessage(
|
||||||
|
content=message.content,
|
||||||
|
tool_calls=_convert_openai_tool_calls(message.tool_calls),
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
elif message.role == "tool":
|
||||||
|
converted_message = ToolResponseMessage(
|
||||||
|
role="tool",
|
||||||
|
call_id=message.tool_call_id,
|
||||||
|
content=openai_content_to_content(message.content),
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
raise ValueError(f"Unknown role {message.role}")
|
||||||
|
converted_messages.append(converted_message)
|
||||||
|
return converted_messages
|
||||||
|
|
||||||
# Llama Stack expects `call_id` but OpenAI uses `tool_call_id`
|
|
||||||
tool_call_id = lls_message.pop("tool_call_id", None)
|
|
||||||
if tool_call_id:
|
|
||||||
lls_message["call_id"] = tool_call_id
|
|
||||||
|
|
||||||
content = lls_message.get("content", None)
|
def openai_content_to_content(content: Union[str, Iterable[OpenAIChatCompletionContentPartParam]]):
|
||||||
if isinstance(content, list):
|
if isinstance(content, str):
|
||||||
lls_content = []
|
return content
|
||||||
for item in content:
|
elif isinstance(content, list):
|
||||||
# items can either by pydantic models or dicts here...
|
return [openai_content_to_content(c) for c in content]
|
||||||
item = dict(item)
|
elif hasattr(content, "type"):
|
||||||
if item.get("type", "") == "image_url":
|
if content.type == "text":
|
||||||
lls_item = ImageContentItem(
|
return TextContentItem(type="text", text=content.text)
|
||||||
type="image",
|
elif content.type == "image_url":
|
||||||
image=URL(uri=item.get("image_url", {}).get("url", "")),
|
return ImageContentItem(type="image", image=_URLOrData(url=URL(uri=content.image_url.url)))
|
||||||
)
|
else:
|
||||||
elif item.get("type", "") == "text":
|
raise ValueError(f"Unknown content type: {content.type}")
|
||||||
lls_item = TextContentItem(
|
else:
|
||||||
type="text",
|
raise ValueError(f"Unknown content type: {content}")
|
||||||
text=item.get("text", ""),
|
|
||||||
)
|
|
||||||
lls_content.append(lls_item)
|
|
||||||
lls_message["content"] = lls_content
|
|
||||||
lls_messages.append(lls_message)
|
|
||||||
|
|
||||||
return lls_messages
|
|
||||||
|
|
||||||
|
|
||||||
def convert_openai_chat_completion_choice(
|
def convert_openai_chat_completion_choice(
|
||||||
|
@ -1313,7 +1357,7 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
messages = _convert_openai_request_messages(messages)
|
messages = openai_messages_to_messages(messages)
|
||||||
response_format = _convert_openai_request_response_format(response_format)
|
response_format = _convert_openai_request_response_format(response_format)
|
||||||
sampling_params = _convert_openai_sampling_params(
|
sampling_params = _convert_openai_sampling_params(
|
||||||
max_tokens=max_tokens,
|
max_tokens=max_tokens,
|
||||||
|
@ -1321,7 +1365,10 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
)
|
)
|
||||||
tool_config = _convert_openai_request_tool_config(tool_choice)
|
tool_config = _convert_openai_request_tool_config(tool_choice)
|
||||||
|
|
||||||
tools = _convert_openai_request_tools(tools)
|
tools = _convert_openai_request_tools(tools)
|
||||||
|
if tool_config.tool_choice == ToolChoice.none:
|
||||||
|
tools = []
|
||||||
|
|
||||||
outstanding_responses = []
|
outstanding_responses = []
|
||||||
# "n" is the number of completions to generate per prompt
|
# "n" is the number of completions to generate per prompt
|
||||||
|
@ -1346,7 +1393,9 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _process_stream_response(
|
async def _process_stream_response(
|
||||||
self, model: str, outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]]
|
self,
|
||||||
|
model: str,
|
||||||
|
outstanding_responses: List[Awaitable[AsyncIterator[ChatCompletionResponseStreamChunk]]],
|
||||||
):
|
):
|
||||||
id = f"chatcmpl-{uuid.uuid4()}"
|
id = f"chatcmpl-{uuid.uuid4()}"
|
||||||
for outstanding_response in outstanding_responses:
|
for outstanding_response in outstanding_responses:
|
||||||
|
@ -1369,11 +1418,31 @@ class OpenAIChatCompletionToLlamaStackMixin:
|
||||||
elif isinstance(event.delta, ToolCallDelta):
|
elif isinstance(event.delta, ToolCallDelta):
|
||||||
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
if event.delta.parse_status == ToolCallParseStatus.succeeded:
|
||||||
tool_call = event.delta.tool_call
|
tool_call = event.delta.tool_call
|
||||||
|
|
||||||
|
# First chunk includes full structure
|
||||||
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||||
index=0,
|
index=0,
|
||||||
id=tool_call.call_id,
|
id=tool_call.call_id,
|
||||||
function=OpenAIChoiceDeltaToolCallFunction(
|
function=OpenAIChoiceDeltaToolCallFunction(
|
||||||
name=tool_call.tool_name, arguments=tool_call.arguments_json
|
name=tool_call.tool_name,
|
||||||
|
arguments="",
|
||||||
|
),
|
||||||
|
)
|
||||||
|
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||||
|
yield OpenAIChatCompletionChunk(
|
||||||
|
id=id,
|
||||||
|
choices=[
|
||||||
|
OpenAIChatCompletionChunkChoice(index=i, finish_reason=finish_reason, delta=delta)
|
||||||
|
],
|
||||||
|
created=int(time.time()),
|
||||||
|
model=model,
|
||||||
|
object="chat.completion.chunk",
|
||||||
|
)
|
||||||
|
# arguments
|
||||||
|
openai_tool_call = OpenAIChoiceDeltaToolCall(
|
||||||
|
index=0,
|
||||||
|
function=OpenAIChoiceDeltaToolCallFunction(
|
||||||
|
arguments=tool_call.arguments_json,
|
||||||
),
|
),
|
||||||
)
|
)
|
||||||
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
delta = OpenAIChoiceDelta(tool_calls=[openai_tool_call])
|
||||||
|
|
|
@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
SystemDefaultGenerator,
|
SystemDefaultGenerator,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
|
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||||
|
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
|
||||||
elif model.model_family in (
|
elif model.model_family in (
|
||||||
ModelFamily.llama3_2,
|
ModelFamily.llama3_2,
|
||||||
ModelFamily.llama3_3,
|
ModelFamily.llama3_3,
|
||||||
ModelFamily.llama4,
|
|
||||||
):
|
):
|
||||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
# llama3.2, llama3.3 follow the same tool prompt format
|
||||||
messages = augment_messages_for_tools_llama_3_2(request)
|
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
||||||
|
elif model.model_family == ModelFamily.llama4:
|
||||||
|
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
||||||
else:
|
else:
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
|
|
||||||
|
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools_llama_3_2(
|
def augment_messages_for_tools_llama(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
|
custom_tool_prompt_generator,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
existing_messages = request.messages
|
existing_messages = request.messages
|
||||||
existing_system_message = None
|
existing_system_message = None
|
||||||
|
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||||
system_prompt = existing_system_message.content
|
system_prompt = existing_system_message.content
|
||||||
|
|
||||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
sys_content += tool_template.render()
|
||||||
sys_content += "\n"
|
sys_content += "\n"
|
||||||
|
|
|
@ -756,5 +756,41 @@
|
||||||
"vllm",
|
"vllm",
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
|
"watsonx": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"datasets",
|
||||||
|
"emoji",
|
||||||
|
"faiss-cpu",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"ibm_watson_machine_learning",
|
||||||
|
"langdetect",
|
||||||
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pymongo",
|
||||||
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"tree_sitter",
|
||||||
|
"uvicorn"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,7 @@ LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
|
@ -82,6 +83,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
|
|
|
@ -25,14 +25,84 @@ The following models are available by default:
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
|
||||||
|
|
||||||
### Prerequisite: API Keys
|
## Prerequisites
|
||||||
|
### NVIDIA API Keys
|
||||||
|
|
||||||
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/).
|
Make sure you have access to a NVIDIA API Key. You can get one by visiting [https://build.nvidia.com/](https://build.nvidia.com/). Use this key for the `NVIDIA_API_KEY` environment variable.
|
||||||
|
|
||||||
|
### Deploy NeMo Microservices Platform
|
||||||
|
The NVIDIA NeMo microservices platform supports end-to-end microservice deployment of a complete AI flywheel on your Kubernetes cluster through the NeMo Microservices Helm Chart. Please reference the [NVIDIA NeMo Microservices documentation](https://docs.nvidia.com/nemo/microservices/latest/about/index.html) for platform prerequisites and instructions to install and deploy the platform.
|
||||||
|
|
||||||
|
## Supported Services
|
||||||
|
Each Llama Stack API corresponds to a specific NeMo microservice. The core microservices (Customizer, Evaluator, Guardrails) are exposed by the same endpoint. The platform components (Data Store) are each exposed by separate endpoints.
|
||||||
|
|
||||||
|
### Inference: NVIDIA NIM
|
||||||
|
NVIDIA NIM is used for running inference with registered models. There are two ways to access NVIDIA NIMs:
|
||||||
|
1. Hosted (default): Preview APIs hosted at https://integrate.api.nvidia.com (Requires an API key)
|
||||||
|
2. Self-hosted: NVIDIA NIMs that run on your own infrastructure.
|
||||||
|
|
||||||
|
The deployed platform includes the NIM Proxy microservice, which is the service that provides to access your NIMs (for example, to run inference on a model). Set the `NVIDIA_BASE_URL` environment variable to use your NVIDIA NIM Proxy deployment.
|
||||||
|
|
||||||
|
### Datasetio API: NeMo Data Store
|
||||||
|
The NeMo Data Store microservice serves as the default file storage solution for the NeMo microservices platform. It exposts APIs compatible with the Hugging Face Hub client (`HfApi`), so you can use the client to interact with Data Store. The `NVIDIA_DATASETS_URL` environment variable should point to your NeMo Data Store endpoint.
|
||||||
|
|
||||||
|
See the [NVIDIA Datasetio docs](/llama_stack/providers/remote/datasetio/nvidia/README.md) for supported features and example usage.
|
||||||
|
|
||||||
|
### Eval API: NeMo Evaluator
|
||||||
|
The NeMo Evaluator microservice supports evaluation of LLMs. Launching an Evaluation job with NeMo Evaluator requires an Evaluation Config (an object that contains metadata needed by the job). A Llama Stack Benchmark maps to an Evaluation Config, so registering a Benchmark creates an Evaluation Config in NeMo Evaluator. The `NVIDIA_EVALUATOR_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the [NVIDIA Eval docs](/llama_stack/providers/remote/eval/nvidia/README.md) for supported features and example usage.
|
||||||
|
|
||||||
|
### Post-Training API: NeMo Customizer
|
||||||
|
The NeMo Customizer microservice supports fine-tuning models. You can reference [this list of supported models](/llama_stack/providers/remote/post_training/nvidia/models.py) that can be fine-tuned using Llama Stack. The `NVIDIA_CUSTOMIZER_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the [NVIDIA Post-Training docs](/llama_stack/providers/remote/post_training/nvidia/README.md) for supported features and example usage.
|
||||||
|
|
||||||
|
### Safety API: NeMo Guardrails
|
||||||
|
The NeMo Guardrails microservice sits between your application and the LLM, and adds checks and content moderation to a model. The `GUARDRAILS_SERVICE_URL` environment variable should point to your NeMo Microservices endpoint.
|
||||||
|
|
||||||
|
See the NVIDIA Safety docs for supported features and example usage.
|
||||||
|
|
||||||
|
## Deploying models
|
||||||
|
In order to use a registered model with the Llama Stack APIs, ensure the corresponding NIM is deployed to your environment. For example, you can use the NIM Proxy microservice to deploy `meta/llama-3.2-1b-instruct`.
|
||||||
|
|
||||||
|
Note: For improved inference speeds, we need to use NIM with `fast_outlines` guided decoding system (specified in the request body). This is the default if you deployed the platform with the NeMo Microservices Helm Chart.
|
||||||
|
```sh
|
||||||
|
# URL to NeMo NIM Proxy service
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl --location "$NEMO_URL/v1/deployment/model-deployments" \
|
||||||
|
-H 'accept: application/json' \
|
||||||
|
-H 'Content-Type: application/json' \
|
||||||
|
-d '{
|
||||||
|
"name": "llama-3.2-1b-instruct",
|
||||||
|
"namespace": "meta",
|
||||||
|
"config": {
|
||||||
|
"model": "meta/llama-3.2-1b-instruct",
|
||||||
|
"nim_deployment": {
|
||||||
|
"image_name": "nvcr.io/nim/meta/llama-3.2-1b-instruct",
|
||||||
|
"image_tag": "1.8.3",
|
||||||
|
"pvc_size": "25Gi",
|
||||||
|
"gpu": 1,
|
||||||
|
"additional_envs": {
|
||||||
|
"NIM_GUIDED_DECODING_BACKEND": "fast_outlines"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}'
|
||||||
|
```
|
||||||
|
This NIM deployment should take approximately 10 minutes to go live. [See the docs](https://docs.nvidia.com/nemo/microservices/latest/get-started/tutorials/deploy-nims.html) for more information on how to deploy a NIM and verify it's available for inference.
|
||||||
|
|
||||||
|
You can also remove a deployed NIM to free up GPU resources, if needed.
|
||||||
|
```sh
|
||||||
|
export NEMO_URL="http://nemo.test"
|
||||||
|
|
||||||
|
curl -X DELETE "$NEMO_URL/v1/deployment/model-deployments/meta/llama-3.1-8b-instruct"
|
||||||
|
```
|
||||||
|
|
||||||
## Running Llama Stack with NVIDIA
|
## Running Llama Stack with NVIDIA
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
You can do this via Conda or venv (build code), or Docker which has a pre-built image.
|
||||||
|
|
||||||
### Via Docker
|
### Via Docker
|
||||||
|
|
||||||
|
@ -54,9 +124,23 @@ docker run \
|
||||||
### Via Conda
|
### Via Conda
|
||||||
|
|
||||||
```bash
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
llama stack build --template nvidia --image-type conda
|
llama stack build --template nvidia --image-type conda
|
||||||
llama stack run ./run.yaml \
|
llama stack run ./run.yaml \
|
||||||
--port 8321 \
|
--port 8321 \
|
||||||
--env NVIDIA_API_KEY=$NVIDIA_API_KEY
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via venv
|
||||||
|
|
||||||
|
If you've set up your local development environment, you can also build the image using your local virtual environment.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
INFERENCE_MODEL=meta-llama/Llama-3.1-8b-Instruct
|
||||||
|
llama stack build --template nvidia --image-type venv
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port 8321 \
|
||||||
|
--env NVIDIA_API_KEY=$NVIDIA_API_KEY \
|
||||||
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
--env INFERENCE_MODEL=$INFERENCE_MODEL
|
||||||
```
|
```
|
||||||
|
|
|
@ -98,23 +98,15 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"",
|
"",
|
||||||
"NVIDIA API Key",
|
"NVIDIA API Key",
|
||||||
),
|
),
|
||||||
## Nemo Customizer related variables
|
|
||||||
"NVIDIA_USER_ID": (
|
|
||||||
"llama-stack-user",
|
|
||||||
"NVIDIA User ID",
|
|
||||||
),
|
|
||||||
"NVIDIA_APPEND_API_VERSION": (
|
"NVIDIA_APPEND_API_VERSION": (
|
||||||
"True",
|
"True",
|
||||||
"Whether to append the API version to the base_url",
|
"Whether to append the API version to the base_url",
|
||||||
),
|
),
|
||||||
|
## Nemo Customizer related variables
|
||||||
"NVIDIA_DATASET_NAMESPACE": (
|
"NVIDIA_DATASET_NAMESPACE": (
|
||||||
"default",
|
"default",
|
||||||
"NVIDIA Dataset Namespace",
|
"NVIDIA Dataset Namespace",
|
||||||
),
|
),
|
||||||
"NVIDIA_ACCESS_POLICIES": (
|
|
||||||
"{}",
|
|
||||||
"NVIDIA Access Policies",
|
|
||||||
),
|
|
||||||
"NVIDIA_PROJECT_ID": (
|
"NVIDIA_PROJECT_ID": (
|
||||||
"test-project",
|
"test-project",
|
||||||
"NVIDIA Project ID",
|
"NVIDIA Project ID",
|
||||||
|
|
|
@ -57,7 +57,7 @@ providers:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
evaluator_service_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
|
|
|
@ -52,7 +52,7 @@ providers:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
evaluator_service_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
|
@ -178,6 +178,16 @@ models:
|
||||||
provider_id: nvidia
|
provider_id: nvidia
|
||||||
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
provider_model_id: meta/llama-3.2-90b-vision-instruct
|
||||||
model_type: llm
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta/llama-3.3-70b-instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: nvidia
|
||||||
|
provider_model_id: meta/llama-3.3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
- metadata:
|
- metadata:
|
||||||
embedding_dimension: 2048
|
embedding_dimension: 2048
|
||||||
context_length: 8192
|
context_length: 8192
|
||||||
|
|
|
@ -28,10 +28,10 @@ The following environment variables can be configured:
|
||||||
|
|
||||||
## Setting up vLLM server
|
## Setting up vLLM server
|
||||||
|
|
||||||
In the following sections, we'll use either AMD and NVIDIA GPUs to serve as hardware accelerators for the vLLM
|
In the following sections, we'll use AMD, NVIDIA or Intel GPUs to serve as hardware accelerators for the vLLM
|
||||||
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
server, which acts as both the LLM inference provider and the safety provider. Note that vLLM also
|
||||||
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
[supports many other hardware accelerators](https://docs.vllm.ai/en/latest/getting_started/installation.html) and
|
||||||
that we only use GPUs here for demonstration purposes.
|
that we only use GPUs here for demonstration purposes. Note that if you run into issues, you can include the environment variable `--env VLLM_DEBUG_LOG_API_SERVER_RESPONSE=true` (available in vLLM v0.8.3 and above) in the `docker run` command to enable log response from API server for debugging.
|
||||||
|
|
||||||
### Setting up vLLM server on AMD GPU
|
### Setting up vLLM server on AMD GPU
|
||||||
|
|
||||||
|
@ -149,6 +149,55 @@ docker run \
|
||||||
--port $SAFETY_PORT
|
--port $SAFETY_PORT
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Setting up vLLM server on Intel GPU
|
||||||
|
|
||||||
|
Refer to [vLLM Documentation for XPU](https://docs.vllm.ai/en/v0.8.2/getting_started/installation/gpu.html?device=xpu) to get a vLLM endpoint. In addition to vLLM side setup which guides towards installing vLLM from sources orself-building vLLM Docker container, Intel provides prebuilt vLLM container to use on systems with Intel GPUs supported by PyTorch XPU backend:
|
||||||
|
- [intel/vllm](https://hub.docker.com/r/intel/vllm)
|
||||||
|
|
||||||
|
Here is a sample script to start a vLLM server locally via Docker using Intel provided container:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export INFERENCE_PORT=8000
|
||||||
|
export INFERENCE_MODEL=meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
export ZE_AFFINITY_MASK=0
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||||
|
-p $INFERENCE_PORT:$INFERENCE_PORT \
|
||||||
|
--ipc=host \
|
||||||
|
intel/vllm:xpu \
|
||||||
|
--gpu-memory-utilization 0.7 \
|
||||||
|
--model $INFERENCE_MODEL \
|
||||||
|
--port $INFERENCE_PORT
|
||||||
|
```
|
||||||
|
|
||||||
|
If you are using Llama Stack Safety / Shield APIs, then you will need to also run another instance of a vLLM with a corresponding safety model like `meta-llama/Llama-Guard-3-1B` using a script like:
|
||||||
|
|
||||||
|
```bash
|
||||||
|
export SAFETY_PORT=8081
|
||||||
|
export SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
||||||
|
export ZE_AFFINITY_MASK=1
|
||||||
|
|
||||||
|
docker run \
|
||||||
|
--pull always \
|
||||||
|
--device /dev/dri \
|
||||||
|
-v /dev/dri/by-path:/dev/dri/by-path \
|
||||||
|
-v ~/.cache/huggingface:/root/.cache/huggingface \
|
||||||
|
--env "HUGGING_FACE_HUB_TOKEN=$HF_TOKEN" \
|
||||||
|
--env ZE_AFFINITY_MASK=$ZE_AFFINITY_MASK \
|
||||||
|
-p $SAFETY_PORT:$SAFETY_PORT \
|
||||||
|
--ipc=host \
|
||||||
|
intel/vllm:xpu \
|
||||||
|
--gpu-memory-utilization 0.7 \
|
||||||
|
--model $SAFETY_MODEL \
|
||||||
|
--port $SAFETY_PORT
|
||||||
|
```
|
||||||
|
|
||||||
## Running Llama Stack
|
## Running Llama Stack
|
||||||
|
|
||||||
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
|
Now you are ready to run Llama Stack with vLLM as the inference provider. You can do this via Conda (build code) or Docker which has a pre-built image.
|
||||||
|
|
7
llama_stack/templates/watsonx/__init__.py
Normal file
7
llama_stack/templates/watsonx/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .watsonx import get_distribution_template # noqa: F401
|
30
llama_stack/templates/watsonx/build.yaml
Normal file
30
llama_stack/templates/watsonx/build.yaml
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Use watsonx for running LLM inference
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::watsonx
|
||||||
|
vector_io:
|
||||||
|
- inline::faiss
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
eval:
|
||||||
|
- inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
- inline::localfs
|
||||||
|
scoring:
|
||||||
|
- inline::basic
|
||||||
|
- inline::llm-as-judge
|
||||||
|
- inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::rag-runtime
|
||||||
|
- remote::model-context-protocol
|
||||||
|
image_type: conda
|
74
llama_stack/templates/watsonx/doc_template.md
Normal file
74
llama_stack/templates/watsonx/doc_template.md
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# watsonx Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
|
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
{{ providers_table }}
|
||||||
|
|
||||||
|
{% if run_config_env_vars %}
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||||
|
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if default_models %}
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
{% for model in default_models %}
|
||||||
|
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
|
||||||
|
|
||||||
|
|
||||||
|
## Running Llama Stack with watsonx
|
||||||
|
|
||||||
|
You can do this via Conda (build code), venv or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-{{ name }} \
|
||||||
|
--yaml-config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template watsonx --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
|
||||||
|
```
|
210
llama_stack/templates/watsonx/run.yaml
Normal file
210
llama_stack/templates/watsonx/run.yaml
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: watsonx
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: watsonx
|
||||||
|
provider_type: remote::watsonx
|
||||||
|
config:
|
||||||
|
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
|
||||||
|
api_key: ${env.WATSONX_API_KEY:}
|
||||||
|
project_id: ${env.WATSONX_PROJECT_ID:}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/faiss_store.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/watsonx/trace_store.db}
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/meta_reference_eval.db
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/huggingface_datasetio.db
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/localfs_datasetio.db
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
- provider_id: llm-as-judge
|
||||||
|
provider_type: inline::llm-as-judge
|
||||||
|
config: {}
|
||||||
|
- provider_id: braintrust
|
||||||
|
provider_type: inline::braintrust
|
||||||
|
config:
|
||||||
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: code-interpreter
|
||||||
|
provider_type: inline::code-interpreter
|
||||||
|
config: {}
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-2-13b-chat
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-2-13b-chat
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-2-13b
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-2-13b-chat
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-1b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-3b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
shields: []
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
- toolgroup_id: builtin::code_interpreter
|
||||||
|
provider_id: code-interpreter
|
||||||
|
server:
|
||||||
|
port: 8321
|
90
llama_stack/templates/watsonx/watsonx.py
Normal file
90
llama_stack/templates/watsonx/watsonx.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||||
|
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||||
|
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::watsonx"],
|
||||||
|
"vector_io": ["inline::faiss"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"eval": ["inline::meta-reference"],
|
||||||
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::rag-runtime",
|
||||||
|
"remote::model-context-protocol",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="watsonx",
|
||||||
|
provider_type="remote::watsonx",
|
||||||
|
config=WatsonXConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
available_models = {
|
||||||
|
"watsonx": MODEL_ENTRIES,
|
||||||
|
}
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
default_models = get_model_registry(available_models)
|
||||||
|
return DistributionTemplate(
|
||||||
|
name="watsonx",
|
||||||
|
distro_type="remote_hosted",
|
||||||
|
description="Use watsonx for running LLM inference",
|
||||||
|
container_image=None,
|
||||||
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
|
providers=providers,
|
||||||
|
available_models_by_provider=available_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMASTACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"WATSONX_API_KEY": (
|
||||||
|
"",
|
||||||
|
"watsonx API Key",
|
||||||
|
),
|
||||||
|
"WATSONX_PROJECT_ID": (
|
||||||
|
"",
|
||||||
|
"watsonx Project ID",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
|
@ -38,6 +38,7 @@ dependencies = [
|
||||||
"termcolor",
|
"termcolor",
|
||||||
"tiktoken",
|
"tiktoken",
|
||||||
"pillow",
|
"pillow",
|
||||||
|
"h11>=0.16.0",
|
||||||
]
|
]
|
||||||
|
|
||||||
[project.optional-dependencies]
|
[project.optional-dependencies]
|
||||||
|
@ -46,6 +47,7 @@ dev = [
|
||||||
"pytest-asyncio",
|
"pytest-asyncio",
|
||||||
"pytest-cov",
|
"pytest-cov",
|
||||||
"pytest-html",
|
"pytest-html",
|
||||||
|
"pytest-json-report",
|
||||||
"nbval", # For notebook testing
|
"nbval", # For notebook testing
|
||||||
"black",
|
"black",
|
||||||
"ruff",
|
"ruff",
|
||||||
|
@ -57,7 +59,16 @@ dev = [
|
||||||
"ruamel.yaml", # needed for openapi generator
|
"ruamel.yaml", # needed for openapi generator
|
||||||
]
|
]
|
||||||
# These are the dependencies required for running unit tests.
|
# These are the dependencies required for running unit tests.
|
||||||
unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"]
|
unit = [
|
||||||
|
"sqlite-vec",
|
||||||
|
"openai",
|
||||||
|
"aiosqlite",
|
||||||
|
"aiohttp",
|
||||||
|
"pypdf",
|
||||||
|
"chardet",
|
||||||
|
"qdrant-client",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http"
|
||||||
|
]
|
||||||
# These are the core dependencies required for running integration tests. They are shared across all
|
# These are the core dependencies required for running integration tests. They are shared across all
|
||||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||||
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
||||||
|
@ -265,6 +276,7 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/inference/sample/",
|
"^llama_stack/providers/remote/inference/sample/",
|
||||||
"^llama_stack/providers/remote/inference/tgi/",
|
"^llama_stack/providers/remote/inference/tgi/",
|
||||||
"^llama_stack/providers/remote/inference/together/",
|
"^llama_stack/providers/remote/inference/together/",
|
||||||
|
"^llama_stack/providers/remote/inference/watsonx/",
|
||||||
"^llama_stack/providers/remote/safety/bedrock/",
|
"^llama_stack/providers/remote/safety/bedrock/",
|
||||||
"^llama_stack/providers/remote/safety/nvidia/",
|
"^llama_stack/providers/remote/safety/nvidia/",
|
||||||
"^llama_stack/providers/remote/safety/sample/",
|
"^llama_stack/providers/remote/safety/sample/",
|
||||||
|
|
|
@ -13,8 +13,8 @@ exceptiongroup==1.2.2 ; python_full_version < '3.11'
|
||||||
filelock==3.17.0
|
filelock==3.17.0
|
||||||
fire==0.7.0
|
fire==0.7.0
|
||||||
fsspec==2024.12.0
|
fsspec==2024.12.0
|
||||||
h11==0.14.0
|
h11==0.16.0
|
||||||
httpcore==1.0.7
|
httpcore==1.0.9
|
||||||
httpx==0.28.1
|
httpx==0.28.1
|
||||||
huggingface-hub==0.29.0
|
huggingface-hub==0.29.0
|
||||||
idna==3.10
|
idna==3.10
|
||||||
|
|
|
@ -98,7 +98,7 @@ def collect_template_dependencies(template_dir: Path) -> tuple[str | None, list[
|
||||||
|
|
||||||
if template_func := getattr(module, "get_distribution_template", None):
|
if template_func := getattr(module, "get_distribution_template", None):
|
||||||
template = template_func()
|
template = template_func()
|
||||||
normal_deps, special_deps = get_provider_dependencies(template.providers)
|
normal_deps, special_deps = get_provider_dependencies(template)
|
||||||
# Combine all dependencies in order: normal deps, special deps, server deps
|
# Combine all dependencies in order: normal deps, special deps, server deps
|
||||||
all_deps = sorted(set(normal_deps + SERVER_DEPENDENCIES)) + sorted(set(special_deps))
|
all_deps = sorted(set(normal_deps + SERVER_DEPENDENCIES)) + sorted(set(special_deps))
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,9 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Custom distro for CI tests
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::custom_ollama
|
||||||
|
image_type: container
|
||||||
|
image_name: ci-test
|
||||||
|
external_providers_dir: /tmp/providers.d
|
|
@ -1,6 +1,6 @@
|
||||||
adapter:
|
adapter:
|
||||||
adapter_type: custom_ollama
|
adapter_type: custom_ollama
|
||||||
pip_packages: ["ollama", "aiohttp"]
|
pip_packages: ["ollama", "aiohttp", "tests/external-provider/llama-stack-provider-ollama"]
|
||||||
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
config_class: llama_stack_provider_ollama.config.OllamaImplConfig
|
||||||
module: llama_stack_provider_ollama
|
module: llama_stack_provider_ollama
|
||||||
api_dependencies: []
|
api_dependencies: []
|
||||||
|
|
|
@ -1,14 +1,10 @@
|
||||||
version: '2'
|
version: '2'
|
||||||
image_name: ollama
|
image_name: ollama
|
||||||
apis:
|
apis:
|
||||||
- agents
|
|
||||||
- datasetio
|
|
||||||
- eval
|
|
||||||
- inference
|
- inference
|
||||||
- safety
|
|
||||||
- scoring
|
|
||||||
- telemetry
|
- telemetry
|
||||||
- tool_runtime
|
- tool_runtime
|
||||||
|
- datasetio
|
||||||
- vector_io
|
- vector_io
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
|
@ -24,19 +20,6 @@ providers:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
namespace: null
|
namespace: null
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/faiss_store.db
|
||||||
safety:
|
|
||||||
- provider_id: llama-guard
|
|
||||||
provider_type: inline::llama-guard
|
|
||||||
config:
|
|
||||||
excluded_categories: []
|
|
||||||
agents:
|
|
||||||
- provider_id: meta-reference
|
|
||||||
provider_type: inline::meta-reference
|
|
||||||
config:
|
|
||||||
persistence_store:
|
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/agents_store.db
|
|
||||||
telemetry:
|
telemetry:
|
||||||
- provider_id: meta-reference
|
- provider_id: meta-reference
|
||||||
provider_type: inline::meta-reference
|
provider_type: inline::meta-reference
|
||||||
|
@ -44,14 +27,6 @@ providers:
|
||||||
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
service_name: ${env.OTEL_SERVICE_NAME:llama-stack}
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/ollama/trace_store.db}
|
||||||
eval:
|
|
||||||
- provider_id: meta-reference
|
|
||||||
provider_type: inline::meta-reference
|
|
||||||
config:
|
|
||||||
kvstore:
|
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/meta_reference_eval.db
|
|
||||||
datasetio:
|
datasetio:
|
||||||
- provider_id: huggingface
|
- provider_id: huggingface
|
||||||
provider_type: remote::huggingface
|
provider_type: remote::huggingface
|
||||||
|
@ -67,17 +42,6 @@ providers:
|
||||||
type: sqlite
|
type: sqlite
|
||||||
namespace: null
|
namespace: null
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/ollama}/localfs_datasetio.db
|
||||||
scoring:
|
|
||||||
- provider_id: basic
|
|
||||||
provider_type: inline::basic
|
|
||||||
config: {}
|
|
||||||
- provider_id: llm-as-judge
|
|
||||||
provider_type: inline::llm-as-judge
|
|
||||||
config: {}
|
|
||||||
- provider_id: braintrust
|
|
||||||
provider_type: inline::braintrust
|
|
||||||
config:
|
|
||||||
openai_api_key: ${env.OPENAI_API_KEY:}
|
|
||||||
tool_runtime:
|
tool_runtime:
|
||||||
- provider_id: brave-search
|
- provider_id: brave-search
|
||||||
provider_type: remote::brave-search
|
provider_type: remote::brave-search
|
||||||
|
|
|
@ -115,6 +115,70 @@ def test_agent_simple(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
assert "I can't" in logs_str
|
assert "I can't" in logs_str
|
||||||
|
|
||||||
|
|
||||||
|
def test_agent_name(llama_stack_client, text_model_id):
|
||||||
|
agent_name = f"test-agent-{uuid4()}"
|
||||||
|
|
||||||
|
try:
|
||||||
|
agent = Agent(
|
||||||
|
llama_stack_client,
|
||||||
|
model=text_model_id,
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
name=agent_name,
|
||||||
|
)
|
||||||
|
except TypeError:
|
||||||
|
agent = Agent(
|
||||||
|
llama_stack_client,
|
||||||
|
model=text_model_id,
|
||||||
|
instructions="You are a helpful assistant",
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
session_id = agent.create_session(f"test-session-{uuid4()}")
|
||||||
|
|
||||||
|
agent.create_turn(
|
||||||
|
messages=[
|
||||||
|
{
|
||||||
|
"role": "user",
|
||||||
|
"content": "Give me a sentence that contains the word: hello",
|
||||||
|
}
|
||||||
|
],
|
||||||
|
session_id=session_id,
|
||||||
|
stream=False,
|
||||||
|
)
|
||||||
|
|
||||||
|
all_spans = []
|
||||||
|
for span in llama_stack_client.telemetry.query_spans(
|
||||||
|
attribute_filters=[
|
||||||
|
{"key": "session_id", "op": "eq", "value": session_id},
|
||||||
|
],
|
||||||
|
attributes_to_return=["input", "output", "agent_name", "agent_id", "session_id"],
|
||||||
|
):
|
||||||
|
all_spans.append(span.attributes)
|
||||||
|
|
||||||
|
agent_name_spans = []
|
||||||
|
for span in llama_stack_client.telemetry.query_spans(
|
||||||
|
attribute_filters=[],
|
||||||
|
attributes_to_return=["agent_name"],
|
||||||
|
):
|
||||||
|
if "agent_name" in span.attributes:
|
||||||
|
agent_name_spans.append(span.attributes)
|
||||||
|
|
||||||
|
agent_logs = []
|
||||||
|
for span in llama_stack_client.telemetry.query_spans(
|
||||||
|
attribute_filters=[
|
||||||
|
{"key": "agent_name", "op": "eq", "value": agent_name},
|
||||||
|
],
|
||||||
|
attributes_to_return=["input", "output", "agent_name"],
|
||||||
|
):
|
||||||
|
if "output" in span.attributes and span.attributes["output"] != "no shields":
|
||||||
|
agent_logs.append(span.attributes)
|
||||||
|
|
||||||
|
assert len(agent_logs) == 1
|
||||||
|
assert agent_logs[0]["agent_name"] == agent_name
|
||||||
|
assert "Give me a sentence that contains the word: hello" in agent_logs[0]["input"]
|
||||||
|
assert "hello" in agent_logs[0]["output"].lower()
|
||||||
|
|
||||||
|
|
||||||
def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
def test_tool_config(llama_stack_client_with_mocked_inference, agent_config):
|
||||||
common_params = dict(
|
common_params = dict(
|
||||||
model="meta-llama/Llama-3.2-3B-Instruct",
|
model="meta-llama/Llama-3.2-3B-Instruct",
|
||||||
|
|
|
@ -10,6 +10,7 @@ import platform
|
||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -19,10 +20,29 @@ from .report import Report
|
||||||
logger = get_logger(__name__, category="tests")
|
logger = get_logger(__name__, category="tests")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.hookimpl(hookwrapper=True)
|
||||||
|
def pytest_runtest_makereport(item, call):
|
||||||
|
outcome = yield
|
||||||
|
report = outcome.get_result()
|
||||||
|
if report.when == "call":
|
||||||
|
item.execution_outcome = report.outcome
|
||||||
|
item.was_xfail = getattr(report, "wasxfail", False)
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_teardown(item):
|
def pytest_runtest_teardown(item):
|
||||||
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
|
# Check if the test actually ran and passed or failed, but was not skipped or an expected failure (xfail)
|
||||||
if interval_seconds:
|
outcome = getattr(item, "execution_outcome", None)
|
||||||
time.sleep(float(interval_seconds))
|
was_xfail = getattr(item, "was_xfail", False)
|
||||||
|
|
||||||
|
name = item.nodeid
|
||||||
|
if not any(x in name for x in ("inference/", "safety/", "agents/")):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Test '{item.nodeid}' outcome was '{outcome}' (xfail={was_xfail})")
|
||||||
|
if outcome in ("passed", "failed") and not was_xfail:
|
||||||
|
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
|
||||||
|
if interval_seconds:
|
||||||
|
time.sleep(float(interval_seconds))
|
||||||
|
|
||||||
|
|
||||||
def pytest_configure(config):
|
def pytest_configure(config):
|
||||||
|
|
|
@ -31,6 +31,7 @@ def data_url_from_file(file_path: str) -> str:
|
||||||
return data_url
|
return data_url
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.skip(reason="flaky. Couldn't find 'llamastack/simpleqa' on the Hugging Face Hub")
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"purpose, source, provider_id, limit",
|
"purpose, source, provider_id, limit",
|
||||||
[
|
[
|
||||||
|
|
|
@ -75,19 +75,24 @@ def openai_client(client_with_models):
|
||||||
return OpenAI(base_url=base_url, api_key="bar")
|
return OpenAI(base_url=base_url, api_key="bar")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||||
|
def compat_client(request):
|
||||||
|
return request.getfixturevalue(request.param)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
"inference:completion:sanity",
|
"inference:completion:sanity",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_completion_non_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
# ollama needs more verbose prompting for some reason here...
|
# ollama needs more verbose prompting for some reason here...
|
||||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -103,13 +108,13 @@ def test_openai_completion_non_streaming(openai_client, client_with_models, text
|
||||||
"inference:completion:sanity",
|
"inference:completion:sanity",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_completion_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
# ollama needs more verbose prompting for some reason here...
|
# ollama needs more verbose prompting for some reason here...
|
||||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -127,11 +132,11 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
|
||||||
0,
|
0,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs):
|
def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_models, text_model_id, prompt_logprobs):
|
||||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||||
|
|
||||||
prompt = "Hello, world!"
|
prompt = "Hello, world!"
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -144,11 +149,11 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te
|
||||||
assert len(choice.prompt_logprobs) > 0
|
assert len(choice.prompt_logprobs) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
|
def test_openai_completion_guided_choice(llama_stack_client, client_with_models, text_model_id):
|
||||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||||
|
|
||||||
prompt = "I am feeling really sad today."
|
prompt = "I am feeling really sad today."
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -161,6 +166,9 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
||||||
assert choice.text in ["joy", "sadness"]
|
assert choice.text in ["joy", "sadness"]
|
||||||
|
|
||||||
|
|
||||||
|
# Run the chat-completion tests with both the OpenAI client and the LlamaStack client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
|
@ -168,13 +176,13 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
||||||
"inference:chat_completion:non_streaming_02",
|
"inference:chat_completion:non_streaming_02",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_chat_completion_non_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
question = tc["question"]
|
question = tc["question"]
|
||||||
expected = tc["expected"]
|
expected = tc["expected"]
|
||||||
|
|
||||||
response = openai_client.chat.completions.create(
|
response = compat_client.chat.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
@ -196,13 +204,13 @@ def test_openai_chat_completion_non_streaming(openai_client, client_with_models,
|
||||||
"inference:chat_completion:streaming_02",
|
"inference:chat_completion:streaming_02",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_chat_completion_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
question = tc["question"]
|
question = tc["question"]
|
||||||
expected = tc["expected"]
|
expected = tc["expected"]
|
||||||
|
|
||||||
response = openai_client.chat.completions.create(
|
response = compat_client.chat.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
messages=[{"role": "user", "content": question}],
|
messages=[{"role": "user", "content": question}],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -114,7 +114,7 @@ def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
|
||||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
# Verify it is unregistered
|
# Verify it is unregistered
|
||||||
with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"):
|
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||||
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
# Verify tools are also unregistered
|
# Verify tools are also unregistered
|
||||||
|
|
40
tests/unit/distribution/test_build_path.py
Normal file
40
tests/unit/distribution/test_build_path.py
Normal file
|
@ -0,0 +1,40 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.cli.stack._build import (
|
||||||
|
_run_stack_build_command_from_build_config,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.datatypes import BuildConfig, DistributionSpec
|
||||||
|
from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||||
|
|
||||||
|
|
||||||
|
def test_container_build_passes_path(monkeypatch, tmp_path):
|
||||||
|
called_with = {}
|
||||||
|
|
||||||
|
def spy_build_image(cfg, build_file_path, image_name, template_or_config, run_config=None):
|
||||||
|
called_with["path"] = template_or_config
|
||||||
|
called_with["run_config"] = run_config
|
||||||
|
return 0
|
||||||
|
|
||||||
|
monkeypatch.setattr(
|
||||||
|
"llama_stack.cli.stack._build.build_image",
|
||||||
|
spy_build_image,
|
||||||
|
raising=True,
|
||||||
|
)
|
||||||
|
|
||||||
|
cfg = BuildConfig(
|
||||||
|
image_type=LlamaStackImageType.CONTAINER.value,
|
||||||
|
distribution_spec=DistributionSpec(providers={}, description=""),
|
||||||
|
)
|
||||||
|
|
||||||
|
_run_stack_build_command_from_build_config(cfg, image_name="dummy")
|
||||||
|
|
||||||
|
assert "path" in called_with
|
||||||
|
assert isinstance(called_with["path"], str)
|
||||||
|
assert Path(called_with["path"]).exists()
|
||||||
|
assert called_with["run_config"] is None
|
|
@ -26,9 +26,17 @@ from openai.types.chat.chat_completion_chunk import (
|
||||||
)
|
)
|
||||||
from openai.types.model import Model as OpenAIModel
|
from openai.types.model import Model as OpenAIModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import ToolChoice, ToolConfig
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
CompletionMessage,
|
||||||
|
SystemMessage,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolResponseMessage,
|
||||||
|
UserMessage,
|
||||||
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.models.llama.datatypes import StopReason
|
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||||
from llama_stack.providers.remote.inference.vllm.vllm import (
|
from llama_stack.providers.remote.inference.vllm.vllm import (
|
||||||
VLLMInferenceAdapter,
|
VLLMInferenceAdapter,
|
||||||
|
@ -130,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
||||||
assert request.tool_config.tool_choice == ToolChoice.none
|
assert request.tool_config.tool_choice == ToolChoice.none
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_response(vllm_inference_adapter):
|
||||||
|
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||||
|
into the expected JSON format."""
|
||||||
|
|
||||||
|
# Patch the call to vllm so we can inspect the arguments sent were correct
|
||||||
|
with patch.object(
|
||||||
|
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||||||
|
) as mock_nonstream_completion:
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
|
UserMessage(content="How many?"),
|
||||||
|
CompletionMessage(
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="foo",
|
||||||
|
tool_name="knowledge_search",
|
||||||
|
arguments={"query": "How many?"},
|
||||||
|
arguments_json='{"query": "How many?"}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
|
||||||
|
]
|
||||||
|
await vllm_inference_adapter.chat_completion(
|
||||||
|
"mock-model",
|
||||||
|
messages,
|
||||||
|
stream=False,
|
||||||
|
tools=[],
|
||||||
|
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
|
||||||
|
{
|
||||||
|
"id": "foo",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_call_delta_empty_tool_call_buf():
|
async def test_tool_call_delta_empty_tool_call_buf():
|
||||||
"""
|
"""
|
||||||
|
@ -232,3 +283,14 @@ def test_chat_completion_doesnt_block_event_loop(caplog):
|
||||||
# above.
|
# above.
|
||||||
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
|
asyncio_warnings = [record.message for record in caplog.records if record.name == "asyncio"]
|
||||||
assert not asyncio_warnings
|
assert not asyncio_warnings
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_get_params_empty_tools(vllm_inference_adapter):
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
tools=[],
|
||||||
|
model="test_model",
|
||||||
|
messages=[UserMessage(content="test")],
|
||||||
|
)
|
||||||
|
params = await vllm_inference_adapter._get_params(request)
|
||||||
|
assert "tools" not in params
|
||||||
|
|
|
@ -13,6 +13,7 @@ import pytest
|
||||||
from llama_stack.apis.benchmarks import Benchmark
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
from llama_stack.apis.common.job_types import Job, JobStatus
|
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||||
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
|
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
|
||||||
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
|
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
|
||||||
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
|
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
|
||||||
|
|
||||||
|
@ -32,7 +33,7 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||||
self.agents_api = MagicMock()
|
self.agents_api = MagicMock()
|
||||||
|
|
||||||
self.config = NVIDIAEvalConfig(
|
self.config = NVIDIAEvalConfig(
|
||||||
evaluator_service_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||||
)
|
)
|
||||||
|
|
||||||
self.eval_impl = NVIDIAEvalImpl(
|
self.eval_impl = NVIDIAEvalImpl(
|
||||||
|
@ -118,7 +119,7 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||||
benchmark_config = BenchmarkConfig(
|
benchmark_config = BenchmarkConfig(
|
||||||
eval_candidate=ModelCandidate(
|
eval_candidate=ModelCandidate(
|
||||||
type="model",
|
type="model",
|
||||||
model="meta/llama-3.1-8b-instruct",
|
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||||
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
@ -137,7 +138,7 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||||
self._assert_request_body(
|
self._assert_request_body(
|
||||||
{
|
{
|
||||||
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||||
"target": {"type": "model", "model": benchmark_config.eval_candidate.model},
|
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
|
||||||
}
|
}
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
|
@ -10,14 +10,17 @@ import warnings
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
|
||||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|
||||||
TrainingConfig,
|
|
||||||
TrainingConfigDataConfig,
|
|
||||||
TrainingConfigEfficiencyConfig,
|
|
||||||
TrainingConfigOptimizerConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from llama_stack.apis.post_training.post_training import (
|
||||||
|
DataConfig,
|
||||||
|
DatasetFormat,
|
||||||
|
EfficiencyConfig,
|
||||||
|
LoraFinetuningConfig,
|
||||||
|
OptimizerConfig,
|
||||||
|
OptimizerType,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
NvidiaPostTrainingAdapter,
|
NvidiaPostTrainingAdapter,
|
||||||
NvidiaPostTrainingConfig,
|
NvidiaPostTrainingConfig,
|
||||||
|
@ -66,11 +69,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
|
|
||||||
def test_customizer_parameters_passed(self):
|
def test_customizer_parameters_passed(self):
|
||||||
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||||
custom_adapter_dim = 32 # Different from default of 8
|
|
||||||
algorithm_config = LoraFinetuningConfig(
|
algorithm_config = LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
adapter_dim=custom_adapter_dim,
|
|
||||||
adapter_dropout=0.2,
|
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
apply_lora_to_output=True,
|
apply_lora_to_output=True,
|
||||||
alpha=16,
|
alpha=16,
|
||||||
|
@ -78,8 +78,15 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
)
|
)
|
||||||
|
|
||||||
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
|
data_config = DataConfig(
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
|
dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
|
)
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
|
lr=0.0002,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
|
)
|
||||||
training_config = TrainingConfig(
|
training_config = TrainingConfig(
|
||||||
n_epochs=3,
|
n_epochs=3,
|
||||||
data_config=data_config,
|
data_config=data_config,
|
||||||
|
@ -95,7 +102,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={},
|
logger_config={},
|
||||||
hyperparam_search_config={},
|
hyperparam_search_config={},
|
||||||
)
|
)
|
||||||
|
@ -114,7 +121,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
self._assert_request_params(
|
self._assert_request_params(
|
||||||
{
|
{
|
||||||
"hyperparameters": {
|
"hyperparameters": {
|
||||||
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
|
"lora": {"alpha": 16},
|
||||||
"epochs": 3,
|
"epochs": 3,
|
||||||
"learning_rate": 0.0002,
|
"learning_rate": 0.0002,
|
||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
|
@ -130,8 +137,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
|
|
||||||
algorithm_config = LoraFinetuningConfig(
|
algorithm_config = LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
adapter_dim=16,
|
|
||||||
adapter_dropout=0.1,
|
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
apply_lora_to_output=True,
|
apply_lora_to_output=True,
|
||||||
alpha=16,
|
alpha=16,
|
||||||
|
@ -139,12 +144,16 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
)
|
)
|
||||||
|
|
||||||
data_config = TrainingConfigDataConfig(
|
data_config = DataConfig(
|
||||||
dataset_id=required_dataset_id, # Required parameter
|
dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
batch_size=8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
|
lr=0.0001,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
|
)
|
||||||
|
|
||||||
training_config = TrainingConfig(
|
training_config = TrainingConfig(
|
||||||
n_epochs=1,
|
n_epochs=1,
|
||||||
|
@ -161,7 +170,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
model=required_model, # Required parameter
|
model=required_model, # Required parameter
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={},
|
logger_config={},
|
||||||
hyperparam_search_config={},
|
hyperparam_search_config={},
|
||||||
)
|
)
|
||||||
|
@ -186,24 +195,24 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
|
|
||||||
def test_unsupported_parameters_warning(self):
|
def test_unsupported_parameters_warning(self):
|
||||||
"""Test that warnings are raised for unsupported parameters."""
|
"""Test that warnings are raised for unsupported parameters."""
|
||||||
data_config = TrainingConfigDataConfig(
|
data_config = DataConfig(
|
||||||
dataset_id="test-dataset",
|
dataset_id="test-dataset",
|
||||||
batch_size=8,
|
batch_size=8,
|
||||||
# Unsupported parameters
|
# Unsupported parameters
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
data_format="instruct",
|
data_format=DatasetFormat.instruct,
|
||||||
validation_dataset_id="val-dataset",
|
validation_dataset_id="val-dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(
|
optimizer_config = OptimizerConfig(
|
||||||
lr=0.0001,
|
lr=0.0001,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
# Unsupported parameters
|
# Unsupported parameters
|
||||||
optimizer_type="adam",
|
optimizer_type=OptimizerType.adam,
|
||||||
num_warmup_steps=100,
|
num_warmup_steps=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
efficiency_config = TrainingConfigEfficiencyConfig(
|
efficiency_config = EfficiencyConfig(
|
||||||
enable_activation_checkpointing=True # Unsupported parameter
|
enable_activation_checkpointing=True # Unsupported parameter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -230,15 +239,13 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
checkpoint_dir="test-dir", # Unsupported parameter
|
checkpoint_dir="test-dir", # Unsupported parameter
|
||||||
algorithm_config=LoraFinetuningConfig(
|
algorithm_config=LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
adapter_dim=16,
|
|
||||||
adapter_dropout=0.1,
|
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
apply_lora_to_output=True,
|
apply_lora_to_output=True,
|
||||||
alpha=16,
|
alpha=16,
|
||||||
rank=16,
|
rank=16,
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
),
|
),
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={"test": "value"}, # Unsupported parameter
|
logger_config={"test": "value"}, # Unsupported parameter
|
||||||
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,14 +10,18 @@ import warnings
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
|
||||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|
||||||
TrainingConfig,
|
|
||||||
TrainingConfigDataConfig,
|
|
||||||
TrainingConfigOptimizerConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
from llama_stack.apis.models import Model, ModelType
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.apis.post_training.post_training import (
|
||||||
|
DataConfig,
|
||||||
|
DatasetFormat,
|
||||||
|
LoraFinetuningConfig,
|
||||||
|
OptimizerConfig,
|
||||||
|
OptimizerType,
|
||||||
|
QATFinetuningConfig,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||||
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
ListNvidiaPostTrainingJobs,
|
ListNvidiaPostTrainingJobs,
|
||||||
|
@ -121,7 +125,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
"epochs": 2,
|
"epochs": 2,
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
"lora": {"alpha": 16},
|
||||||
},
|
},
|
||||||
"output_model": "default/job-1234",
|
"output_model": "default/job-1234",
|
||||||
"status": "created",
|
"status": "created",
|
||||||
|
@ -132,8 +136,6 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
|
|
||||||
algorithm_config = LoraFinetuningConfig(
|
algorithm_config = LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
adapter_dim=16,
|
|
||||||
adapter_dropout=0.1,
|
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
apply_lora_to_output=True,
|
apply_lora_to_output=True,
|
||||||
alpha=16,
|
alpha=16,
|
||||||
|
@ -141,10 +143,15 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
)
|
)
|
||||||
|
|
||||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
data_config = DataConfig(
|
||||||
|
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
lr=0.0001,
|
lr=0.0001,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
training_config = TrainingConfig(
|
training_config = TrainingConfig(
|
||||||
|
@ -161,7 +168,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={},
|
logger_config={},
|
||||||
hyperparam_search_config={},
|
hyperparam_search_config={},
|
||||||
)
|
)
|
||||||
|
@ -185,16 +192,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
"epochs": 2,
|
"epochs": 2,
|
||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
|
"weight_decay": 0.01,
|
||||||
|
"lora": {"alpha": 16},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_supervised_fine_tune_with_qat(self):
|
def test_supervised_fine_tune_with_qat(self):
|
||||||
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
data_config = DataConfig(
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(
|
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
|
)
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
lr=0.0001,
|
lr=0.0001,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
)
|
)
|
||||||
training_config = TrainingConfig(
|
training_config = TrainingConfig(
|
||||||
n_epochs=2,
|
n_epochs=2,
|
||||||
|
@ -209,42 +222,55 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={},
|
logger_config={},
|
||||||
hyperparam_search_config={},
|
hyperparam_search_config={},
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_get_training_job_status(self):
|
def test_get_training_job_status(self):
|
||||||
self.mock_make_request.return_value = {
|
customizer_status_to_job_status = [
|
||||||
"created_at": "2024-12-09T04:06:28.580220",
|
("running", "in_progress"),
|
||||||
"updated_at": "2024-12-09T04:21:19.852832",
|
("completed", "completed"),
|
||||||
"status": "completed",
|
("failed", "failed"),
|
||||||
"steps_completed": 1210,
|
("cancelled", "cancelled"),
|
||||||
"epochs_completed": 2,
|
("pending", "scheduled"),
|
||||||
"percentage_done": 100.0,
|
("unknown", "scheduled"),
|
||||||
"best_epoch": 2,
|
]
|
||||||
"train_loss": 1.718016266822815,
|
|
||||||
"val_loss": 1.8661999702453613,
|
|
||||||
}
|
|
||||||
|
|
||||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
for customizer_status, expected_status in customizer_status_to_job_status:
|
||||||
|
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
|
||||||
|
self.mock_make_request.return_value = {
|
||||||
|
"created_at": "2024-12-09T04:06:28.580220",
|
||||||
|
"updated_at": "2024-12-09T04:21:19.852832",
|
||||||
|
"status": customizer_status,
|
||||||
|
"steps_completed": 1210,
|
||||||
|
"epochs_completed": 2,
|
||||||
|
"percentage_done": 100.0,
|
||||||
|
"best_epoch": 2,
|
||||||
|
"train_loss": 1.718016266822815,
|
||||||
|
"val_loss": 1.8661999702453613,
|
||||||
|
}
|
||||||
|
|
||||||
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
|
||||||
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
|
||||||
assert status.status.value == "completed"
|
|
||||||
assert status.steps_completed == 1210
|
|
||||||
assert status.epochs_completed == 2
|
|
||||||
assert status.percentage_done == 100.0
|
|
||||||
assert status.best_epoch == 2
|
|
||||||
assert status.train_loss == 1.718016266822815
|
|
||||||
assert status.val_loss == 1.8661999702453613
|
|
||||||
|
|
||||||
self.mock_make_request.assert_called_once()
|
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
|
||||||
self._assert_request(
|
assert status.status.value == expected_status
|
||||||
self.mock_make_request, "GET", f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}
|
assert status.steps_completed == 1210
|
||||||
)
|
assert status.epochs_completed == 2
|
||||||
|
assert status.percentage_done == 100.0
|
||||||
|
assert status.best_epoch == 2
|
||||||
|
assert status.train_loss == 1.718016266822815
|
||||||
|
assert status.val_loss == 1.8661999702453613
|
||||||
|
|
||||||
|
self._assert_request(
|
||||||
|
self.mock_make_request,
|
||||||
|
"GET",
|
||||||
|
f"/v1/customization/jobs/{job_id}/status",
|
||||||
|
expected_params={"job_id": job_id},
|
||||||
|
)
|
||||||
|
|
||||||
def test_get_training_jobs(self):
|
def test_get_training_jobs(self):
|
||||||
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
|
||||||
|
|
Some files were not shown because too many files have changed in this diff Show more
Loading…
Add table
Add a link
Reference in a new issue