Merge branch 'main' into test-modelregistryhelper

This commit is contained in:
Matthew Farrellee 2025-04-27 10:56:30 -04:00
commit 7fd8a61b4d
80 changed files with 2918 additions and 386 deletions

6
.coveragerc Normal file
View file

@ -0,0 +1,6 @@
[run]
omit =
*/tests/*
*/llama_stack/providers/*
*/llama_stack/templates/*
.venv/*

View file

@ -6,7 +6,6 @@ on:
pull_request:
branches: [ main ]
paths:
- 'distributions/**'
- 'llama_stack/**'
- 'tests/integration/**'
- 'uv.lock'

View file

@ -107,3 +107,41 @@ jobs:
- name: Build a single provider
run: |
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama
build-custom-container-distribution:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Set up Python
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
with:
python-version: '3.10'
- name: Install uv
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
with:
python-version: "3.10"
- name: Install LlamaStack
run: |
uv venv
source .venv/bin/activate
uv pip install -e .
- name: Build a single provider
run: |
yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml
yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml
- name: Inspect the container image entrypoint
run: |
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
echo "Entrypoint: $entrypoint"
if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then
echo "Entrypoint is not correct"
exit 1
fi

View file

@ -5,6 +5,13 @@ on:
branches: [ main ]
pull_request:
branches: [ main ]
paths:
- 'llama_stack/**'
- 'tests/integration/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/test-external-providers.yml' # This workflow
jobs:
test-external-providers:

View file

@ -6,7 +6,6 @@ on:
pull_request:
branches: [ main ]
paths:
- 'distributions/**'
- 'llama_stack/**'
- 'tests/unit/**'
- 'uv.lock'

View file

@ -119,6 +119,7 @@ Here is a list of the various API providers and available distributions that can
| OpenAI | Hosted | | ✅ | | | |
| Anthropic | Hosted | | ✅ | | | |
| Gemini | Hosted | | ✅ | | | |
| watsonx | Hosted | | ✅ | | | |
### Distributions
@ -128,7 +129,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) |
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |

View file

@ -68,7 +68,8 @@ chunks_response = client.vector_io.query(
### Using the RAG Tool
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
and automatically chunks them into smaller pieces.
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
[appendix](#more-ragdocument-examples).
```python
from llama_stack_client import RAGDocument
@ -178,3 +179,38 @@ for vector_db_id in client.vector_dbs.list():
print(f"Unregistering vector database: {vector_db_id.identifier}")
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
```
### Appendix
#### More RAGDocument Examples
```python
from llama_stack_client import RAGDocument
import base64
RAGDocument(document_id="num-0", content={"uri": "file://path/to/file"})
RAGDocument(document_id="num-1", content="plain text")
RAGDocument(
document_id="num-2",
content={
"type": "text",
"text": "plain text input",
}, # for inputs that should be treated as text explicitly
)
RAGDocument(
document_id="num-3",
content={
"type": "image",
"image": {"url": {"uri": "https://mywebsite.com/image.jpg"}},
},
)
B64_ENCODED_IMAGE = base64.b64encode(
requests.get(
"https://raw.githubusercontent.com/meta-llama/llama-stack/refs/heads/main/docs/_static/llama-stack.png"
).content
)
RAGDocuemnt(
document_id="num-4",
content={"type": "image", "image": {"data": B64_ENCODED_IMAGE}},
)
```
for more strongly typed interaction use the typed dicts found [here](https://github.com/meta-llama/llama-stack-client-python/blob/38cd91c9e396f2be0bec1ee96a19771582ba6f17/src/llama_stack_client/types/shared_params/document.py).

View file

@ -109,8 +109,6 @@ llama stack build --list-templates
+------------------------------+-----------------------------------------------------------------------------+
| nvidia | Use NVIDIA NIM for running LLM inference |
+------------------------------+-----------------------------------------------------------------------------+
| meta-reference-quantized-gpu | Use Meta Reference with fp8, int4 quantization for running LLM inference |
+------------------------------+-----------------------------------------------------------------------------+
| cerebras | Use Cerebras for running LLM inference |
+------------------------------+-----------------------------------------------------------------------------+
| ollama | Use (an external) Ollama server for running LLM inference |

View file

@ -0,0 +1,88 @@
---
orphan: true
---
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# watsonx Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-watsonx` distribution consists of the following provider configurations.
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `remote::watsonx` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss` |
### Environment Variables
The following environment variables can be configured:
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
- `WATSONX_API_KEY`: watsonx API Key (default: ``)
- `WATSONX_PROJECT_ID`: watsonx Project ID (default: ``)
### Models
The following models are available by default:
- `meta-llama/llama-3-3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
- `meta-llama/llama-2-13b-chat (aliases: meta-llama/Llama-2-13b)`
- `meta-llama/llama-3-1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
- `meta-llama/llama-3-1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
- `meta-llama/llama-3-2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
- `meta-llama/llama-3-2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
- `meta-llama/llama-3-2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
- `meta-llama/llama-3-2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
- `meta-llama/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
### Prerequisite: API Keys
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
## Running Llama Stack with watsonx
You can do this via Conda (build code), venv or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-watsonx \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env WATSONX_API_KEY=$WATSONX_API_KEY \
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
```
### Via Conda
```bash
llama stack build --template watsonx --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env WATSONX_API_KEY=$WATSONX_API_KEY \
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
```

View file

@ -81,6 +81,7 @@ LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-gpu \
@ -94,6 +95,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
docker run \
-it \
--pull always \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-gpu \

View file

@ -1,123 +0,0 @@
---
orphan: true
---
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
# Meta Reference Quantized Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations:
| API | Provider(s) |
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `remote::huggingface`, `inline::localfs` |
| eval | `inline::meta-reference` |
| inference | `inline::meta-reference-quantized` |
| safety | `inline::llama-guard` |
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
| telemetry | `inline::meta-reference` |
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
### Environment Variables
The following environment variables can be configured:
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
## Prerequisite: Downloading Models
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
```
$ llama model list --downloaded
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
┃ Model ┃ Size ┃ Modified Time ┃
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
├─────────────────────────────────────────┼──────────┼─────────────────────┤
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
└─────────────────────────────────────────┴──────────┴─────────────────────┘
```
## Running the Distribution
You can do this via Conda (build code) or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-quantized-gpu \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
docker run \
-it \
--pull always \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-meta-reference-quantized-gpu \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
```
### Via Conda
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
```bash
llama stack build --template meta-reference-quantized-gpu --image-type conda
llama stack run distributions/meta-reference-quantized-gpu/run.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
```
If you are using Llama Stack Safety / Shield APIs, use:
```bash
llama stack run distributions/meta-reference-quantized-gpu/run-with-safety.yaml \
--port $LLAMA_STACK_PORT \
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
```

View file

@ -7,7 +7,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|-----|-------------|
| agents | `inline::meta-reference` |
| datasetio | `inline::localfs` |
| eval | `inline::meta-reference` |
| eval | `remote::nvidia` |
| inference | `remote::nvidia` |
| post_training | `remote::nvidia` |
| safety | `remote::nvidia` |
@ -22,13 +22,13 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
The following environment variables can be configured:
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)

View file

@ -50,9 +50,10 @@ Llama Stack supports two types of external providers:
Here's a list of known external providers that you can use with Llama Stack:
| Type | Name | Description | Repository |
|------|------|-------------|------------|
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
| Name | Description | API | Type | Repository |
|------|-------------|-----|------|------------|
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
### Remote Provider Specification

View file

@ -389,5 +389,7 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -256,5 +256,7 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -301,5 +301,7 @@
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -200,5 +200,7 @@
"pygments_lexer": "ipython3",
"version": "3.12.2"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -355,5 +355,7 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -398,5 +398,7 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -132,5 +132,7 @@
"pygments_lexer": "ipython3",
"version": "3.11.10"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -188,5 +188,7 @@
"pygments_lexer": "ipython3",
"version": "3.10.15"
}
}
},
"nbformat": 4,
"nbformat_minor": 5
}

View file

@ -136,12 +136,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
)
image_type = prompt(
f"> Enter the image type you want your Llama Stack to be built as ({' or '.join(e.value for e in ImageType)}): ",
"> Enter the image type you want your Llama Stack to be built as (use <TAB> to see options): ",
completer=WordCompleter([e.value for e in ImageType]),
complete_while_typing=True,
validator=Validator.from_callable(
lambda x: x in [e.value for e in ImageType],
error_message=f"Invalid image type, please enter {' or '.join(e.value for e in ImageType)}",
error_message="Invalid image type. Use <TAB> to see options",
),
default=ImageType.CONDA.value,
)
if image_type == ImageType.CONDA.value:
@ -317,11 +318,15 @@ def _generate_run_config(
to_write = json.loads(run_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))
# this path is only invoked when no template is provided
cprint(
f"You can now run your stack with `llama stack run {run_config_file}`",
color="green",
)
# Only print this message for non-container builds since it will be displayed before the
# container is built
# For non-container builds, the run.yaml is generated at the very end of the build process so it
# makes sense to display this message
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
cprint(
f"You can now run your stack with `llama stack run {run_config_file}`",
color="green",
)
return run_config_file
@ -355,6 +360,13 @@ def _run_stack_build_command_from_build_config(
build_file_path = build_dir / f"{image_name}-build.yaml"
os.makedirs(build_dir, exist_ok=True)
run_config_file = None
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
# Only do this if we're building a container image and we're not using a template
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
cprint("Generating run.yaml file", color="green")
run_config_file = _generate_run_config(build_config, build_dir, image_name)
with open(build_file_path, "w") as f:
to_write = json.loads(build_config.model_dump_json())
f.write(yaml.dump(to_write, sort_keys=False))
@ -364,6 +376,7 @@ def _run_stack_build_command_from_build_config(
build_file_path,
image_name,
template_or_config=template_name or config_path or str(build_file_path),
run_config=run_config_file,
)
if return_code != 0:
raise RuntimeError(f"Failed to build image {image_name}")

View file

@ -93,6 +93,7 @@ def build_image(
build_file_path: Path,
image_name: str,
template_or_config: str,
run_config: str | None = None,
):
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
@ -108,6 +109,11 @@ def build_image(
container_base,
" ".join(normal_deps),
]
# When building from a config file (not a template), include the run config path in the
# build arguments
if run_config is not None:
args.append(run_config)
elif build_config.image_type == LlamaStackImageType.CONDA.value:
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
args = [

View file

@ -19,12 +19,16 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
# mounting is not supported by docker buildx, so we use COPY instead
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
# Path to the run.yaml file in the container
RUN_CONFIG_PATH=/app/run.yaml
BUILD_CONTEXT_DIR=$(pwd)
if [ "$#" -lt 4 ]; then
# This only works for templates
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
exit 1
fi
set -euo pipefail
template_or_config="$1"
@ -35,8 +39,27 @@ container_base="$1"
shift
pip_dependencies="$1"
shift
special_pip_deps="${1:-}"
# Handle optional arguments
run_config=""
special_pip_deps=""
# Check if there are more arguments
# The logics is becoming cumbersom, we should refactor it if we can do better
if [ $# -gt 0 ]; then
# Check if the argument ends with .yaml
if [[ "$1" == *.yaml ]]; then
run_config="$1"
shift
# If there's another argument after .yaml, it must be special_pip_deps
if [ $# -gt 0 ]; then
special_pip_deps="$1"
fi
else
# If it's not .yaml, it must be special_pip_deps
special_pip_deps="$1"
fi
fi
# Define color codes
RED='\033[0;31m'
@ -75,7 +98,7 @@ WORKDIR /app
# We install the Python 3.11 dev headers and build tools so that any
# Cextension wheels (e.g. polyleven, faisscpu) can compile successfully.
RUN dnf -y update && dnf install -y iputils net-tools wget \
RUN dnf -y update && dnf install -y iputils git net-tools wget \
vim-minimal python3.11 python3.11-pip python3.11-wheel \
python3.11-setuptools python3.11-devel gcc make && \
ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
@ -119,6 +142,45 @@ EOF
done
fi
# Function to get Python command
get_python_cmd() {
if is_command_available python; then
echo "python"
elif is_command_available python3; then
echo "python3"
else
echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2
exit 1
fi
}
if [ -n "$run_config" ]; then
# Copy the run config to the build context since it's an absolute path
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
add_to_container << EOF
COPY run.yaml $RUN_CONFIG_PATH
EOF
# Parse the run.yaml configuration to identify external provider directories
# If external providers are specified, copy their directory to the container
# and update the configuration to reference the new container path
python_cmd=$(get_python_cmd)
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
if [ -n "$external_providers_dir" ]; then
echo "Copying external providers directory: $external_providers_dir"
add_to_container << EOF
COPY $external_providers_dir /app/providers.d
EOF
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
if [ "$(uname)" = "Darwin" ]; then
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
else
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
fi
fi
fi
stack_mount="/app/llama-stack-source"
client_mount="/app/llama-stack-client-source"
@ -178,15 +240,16 @@ fi
RUN pip uninstall -y uv
EOF
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
if [[ "$template_or_config" != *.yaml ]]; then
# If a run config is provided, we use the --config flag
if [[ -n "$run_config" ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"]
EOF
# If a template is provided (not a yaml file), we use the --template flag
elif [[ "$template_or_config" != *.yaml ]]; then
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
EOF
else
add_to_container << EOF
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
EOF
fi
# Add other require item commands genearic to all containers
@ -258,9 +321,10 @@ $CONTAINER_BINARY build \
"${CLI_ARGS[@]}" \
-t "$image_tag" \
-f "$TEMP_DIR/Containerfile" \
"."
"$BUILD_CONTEXT_DIR"
# clean up tmp/configs
rm -f "$BUILD_CONTEXT_DIR/run.yaml"
set +x
echo "Success!"

View file

@ -8,6 +8,11 @@ import asyncio
import time
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
from pydantic import Field, TypeAdapter
from typing_extensions import Annotated
from llama_stack.apis.common.content_types import (
URL,
InterleavedContent,
@ -526,7 +531,7 @@ class InferenceRouter(Inference):
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
@ -558,6 +563,16 @@ class InferenceRouter(Inference):
if model_obj.model_type == ModelType.embedding:
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
# Use the OpenAI client for a bit of extra input validation without
# exposing the OpenAI client itself as part of our API surface
if tool_choice:
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
if tools is None:
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
if tools:
for tool in tools:
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
params = dict(
model=model_obj.identifier,
messages=messages,

View file

@ -22,6 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, Request
from fastapi import Path as FastapiPath
from fastapi.exceptions import RequestValidationError
from fastapi.responses import JSONResponse, StreamingResponse
from openai import BadRequestError
from pydantic import BaseModel, ValidationError
from typing_extensions import Annotated
@ -110,6 +111,8 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
)
elif isinstance(exc, ValueError):
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
elif isinstance(exc, BadRequestError):
return HTTPException(status_code=400, detail=str(exc))
elif isinstance(exc, PermissionError):
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
elif isinstance(exc, TimeoutError):
@ -162,14 +165,17 @@ async def maybe_await(value):
return value
async def sse_generator(event_gen):
async def sse_generator(event_gen_coroutine):
event_gen = None
try:
async for item in await event_gen:
event_gen = await event_gen_coroutine
async for item in event_gen:
yield create_sse_event(item)
await asyncio.sleep(0.01)
except asyncio.CancelledError:
logger.info("Generator cancelled")
await event_gen.aclose()
if event_gen:
await event_gen.aclose()
except Exception as e:
logger.exception("Error in sse_generator")
yield create_sse_event(
@ -455,6 +461,7 @@ def main(args: Optional[argparse.Namespace] = None):
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
impl_method = getattr(impl, endpoint.name)
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
with warnings.catch_warnings():
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")

View file

@ -24,6 +24,13 @@ def rag_chat_page():
def should_disable_input():
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
def log_message(message):
with st.chat_message(message["role"]):
if "tool_output" in message and message["tool_output"]:
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
st.write(message["tool_output"])
st.markdown(message["content"])
with st.sidebar:
# File/Directory Upload Section
st.subheader("Upload Documents", divider=True)
@ -146,8 +153,7 @@ def rag_chat_page():
# Display chat history
for message in st.session_state.displayed_messages:
with st.chat_message(message["role"]):
st.markdown(message["content"])
log_message(message)
if temperature > 0.0:
strategy = {
@ -201,7 +207,7 @@ def rag_chat_page():
# Display assistant response
with st.chat_message("assistant"):
retrieval_message_placeholder = st.empty()
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
@ -209,14 +215,16 @@ def rag_chat_page():
log.print()
if log.role == "tool_execution":
retrieval_response += log.content.replace("====", "").strip()
retrieval_message_placeholder.info(retrieval_response)
retrieval_message_placeholder.write(retrieval_response)
else:
full_response += log.content
message_placeholder.markdown(full_response + "")
message_placeholder.markdown(full_response)
st.session_state.messages.append({"role": "assistant", "content": full_response})
st.session_state.displayed_messages.append({"role": "assistant", "content": full_response})
st.session_state.displayed_messages.append(
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
)
def direct_process_prompt(prompt):
# Add the system prompt in the beginning of the conversation
@ -230,15 +238,14 @@ def rag_chat_page():
prompt_context = rag_response.content
with st.chat_message("assistant"):
with st.expander(label="Retrieval Output", expanded=False):
st.write(prompt_context)
retrieval_message_placeholder = st.empty()
message_placeholder = st.empty()
full_response = ""
retrieval_response = ""
# Display the retrieved content
retrieval_response += str(prompt_context)
retrieval_message_placeholder.info(retrieval_response)
# Construct the extended prompt
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"

View file

@ -4,14 +4,23 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import enum
import json
import uuid
import streamlit as st
from llama_stack_client import Agent
from llama_stack_client.lib.agents.react.agent import ReActAgent
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
from llama_stack.distribution.ui.modules.api import llama_stack_api
class AgentType(enum.Enum):
REGULAR = "Regular"
REACT = "ReAct"
def tool_chat_page():
st.title("🛠 Tools")
@ -23,6 +32,7 @@ def tool_chat_page():
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
selected_vector_dbs = []
def reset_agent():
st.session_state.clear()
@ -66,25 +76,36 @@ def tool_chat_page():
toolgroup_selection.extend(mcp_selection)
active_tool_list = []
for toolgroup_id in toolgroup_selection:
active_tool_list.extend(
[
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
for t in client.tools.list(toolgroup_id=toolgroup_id)
]
)
grouped_tools = {}
total_tools = 0
st.markdown(f"Active Tools: 🛠 {len(active_tool_list)}", help="List of currently active tools.")
st.json(active_tool_list)
for toolgroup_id in toolgroup_selection:
tools = client.tools.list(toolgroup_id=toolgroup_id)
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
total_tools += len(tools)
st.markdown(f"Active Tools: 🛠 {total_tools}")
for group_id, tools in grouped_tools.items():
with st.expander(f"🔧 Tools from `{group_id}`"):
for idx, tool in enumerate(tools, start=1):
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
st.subheader("Agent Configurations")
st.subheader("Agent Type")
agent_type = st.radio(
"Select Agent Type",
[AgentType.REGULAR, AgentType.REACT],
format_func=lambda x: x.value,
on_change=reset_agent,
)
max_tokens = st.slider(
"Max Tokens",
min_value=0,
max_value=4096,
value=512,
step=1,
step=64,
help="The maximum number of tokens to generate",
on_change=reset_agent,
)
@ -101,13 +122,27 @@ def tool_chat_page():
@st.cache_resource
def create_agent():
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
return ReActAgent(
client=client,
model=model,
tools=toolgroup_selection,
response_format={
"type": "json_schema",
"json_schema": ReActOutput.model_json_schema(),
},
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
else:
return Agent(
client,
model=model,
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
tools=toolgroup_selection,
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
)
st.session_state.agent_type = agent_type
agent = create_agent()
@ -136,6 +171,158 @@ def tool_chat_page():
)
def response_generator(turn_response):
if st.session_state.get("agent_type") == AgentType.REACT:
return _handle_react_response(turn_response)
else:
return _handle_regular_response(turn_response)
def _handle_react_response(turn_response):
current_step_content = ""
final_answer = None
tool_results = []
for response in turn_response:
if not hasattr(response.event, "payload"):
yield (
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
"The response received is missing an expected `payload` attribute.\n"
"This could indicate a malformed response or an internal issue within the server.\n\n"
f"Error details: {response}"
)
return
payload = response.event.payload
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
current_step_content += payload.delta.text
continue
if payload.event_type == "step_complete":
step_details = payload.step_details
if step_details.step_type == "inference":
yield from _process_inference_step(current_step_content, tool_results, final_answer)
current_step_content = ""
elif step_details.step_type == "tool_execution":
tool_results = _process_tool_execution(step_details, tool_results)
current_step_content = ""
else:
current_step_content = ""
if not final_answer and tool_results:
yield from _format_tool_results_summary(tool_results)
def _process_inference_step(current_step_content, tool_results, final_answer):
try:
react_output_data = json.loads(current_step_content)
thought = react_output_data.get("thought")
action = react_output_data.get("action")
answer = react_output_data.get("answer")
if answer and answer != "null" and answer is not None:
final_answer = answer
if thought:
with st.expander("🤔 Thinking...", expanded=False):
st.markdown(f":grey[__{thought}__]")
if action and isinstance(action, dict):
tool_name = action.get("tool_name")
tool_params = action.get("tool_params")
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
st.json(tool_params)
if answer and answer != "null" and answer is not None:
yield f"\n\n✅ **Final Answer:**\n{answer}"
except json.JSONDecodeError:
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
except Exception as e:
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
return final_answer
def _process_tool_execution(step_details, tool_results):
try:
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
for tool_response in step_details.tool_responses:
tool_name = tool_response.tool_name
content = tool_response.content
tool_results.append((tool_name, content))
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
try:
parsed_content = json.loads(content)
st.json(parsed_content)
except json.JSONDecodeError:
st.code(content, language=None)
else:
with st.expander("⚙️ Observation", expanded=False):
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
except Exception as e:
with st.expander("⚙️ Error in Tool Execution", expanded=False):
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
return tool_results
def _format_tool_results_summary(tool_results):
yield "\n\n**Here's what I found:**\n"
for tool_name, content in tool_results:
try:
parsed_content = json.loads(content)
if tool_name == "web_search" and "top_k" in parsed_content:
yield from _format_web_search_results(parsed_content)
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
yield from _format_results_list(parsed_content["results"])
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
yield from _format_dict_results(parsed_content)
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
yield from _format_list_results(parsed_content)
except json.JSONDecodeError:
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
except (TypeError, AttributeError, KeyError, IndexError) as e:
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
def _format_web_search_results(parsed_content):
for i, result in enumerate(parsed_content["top_k"], 1):
if i <= 3:
title = result.get("title", "Untitled")
url = result.get("url", "")
content_text = result.get("content", "").strip()
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
def _format_results_list(results):
for i, result in enumerate(results, 1):
if i <= 3:
if isinstance(result, dict):
name = result.get("name", result.get("title", "Result " + str(i)))
description = result.get("description", result.get("content", result.get("summary", "")))
yield f"\n- **{name}**\n {description}\n"
else:
yield f"\n- {result}\n"
def _format_dict_results(parsed_content):
yield "\n```\n"
for key, value in list(parsed_content.items())[:5]:
if isinstance(value, str) and len(value) < 100:
yield f"{key}: {value}\n"
else:
yield f"{key}: [Complex data]\n"
yield "```\n"
def _format_list_results(parsed_content):
yield "\n"
for _, item in enumerate(parsed_content[:3], 1):
if isinstance(item, str):
yield f"- {item}\n"
elif isinstance(item, dict) and "text" in item:
yield f"- {item['text']}\n"
elif isinstance(item, dict) and len(item) > 0:
first_value = next(iter(item.values()))
if isinstance(first_value, str) and len(first_value) < 100:
yield f"- {first_value}\n"
def _handle_regular_response(turn_response):
for response in turn_response:
if hasattr(response.event, "payload"):
print(response.event.payload)
@ -144,14 +331,18 @@ def tool_chat_page():
yield response.event.payload.delta.text
if response.event.payload.event_type == "step_complete":
if response.event.payload.step_details.step_type == "tool_execution":
yield " 🛠 "
if response.event.payload.step_details.tool_calls:
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
else:
yield "No tool_calls present in step_details"
else:
yield f"Error occurred in the Llama Stack Cluster: {response}"
with st.chat_message("assistant"):
response = st.write_stream(response_generator(turn_response))
response_content = st.write_stream(response_generator(turn_response))
st.session_state.messages.append({"role": "assistant", "content": response})
st.session_state.messages.append({"role": "assistant", "content": response_content})
tool_chat_page()

View file

@ -303,6 +303,7 @@ class ChatFormat:
arguments_json=json.dumps(tool_arguments),
)
)
content = ""
return RawMessage(
role="assistant",

View file

@ -64,7 +64,7 @@ This example passes an image that is smaller than the tile size, to show the til
##### Model Response Format
```
The image depicts a dog standing on a skateboard, with its front paws positioned on the board and its back paws hanging off the back. The dog has a distinctive coat pattern, featuring a white face, brown and black fur, and white paws, and is standing on a skateboard with red wheels, set against a blurred background of a street or alleyway with a teal door and beige wall.<|eot|>
The image depicts a dog standing on a skateboard, positioned centrally and facing the camera directly. The dog has a distinctive coat pattern featuring white, black, and brown fur, with floppy ears and a black nose, and is standing on a skateboard with red wheels.<|eot|>
```
@ -91,7 +91,7 @@ Here is an example of how to pass an image to the model
##### Model Response Format
```
This image shows a dog standing on a skateboard, with its front paws positioned near the front of the board and its back paws near the back. The dog has a white, black, and orange coat, and is standing on a gray skateboard with red wheels, in front of a blurred background that appears to be a street or alleyway.<|eot|>
The image depicts a dog standing on a skateboard, with the dog positioned centrally and facing forward. The dog has a distinctive coat featuring a mix of white, brown, and black fur, and is wearing a collar as it stands on the skateboard, which has red wheels.<|eot|>
```
@ -117,7 +117,7 @@ Here is an example of how to pass an image to the model
##### Model Response Format
```
The first image shows a dog standing on a skateboard, while the second image shows a plate of spaghetti with tomato sauce, parmesan cheese, and parsley. The two images are unrelated, with the first image featuring a dog and the second image featuring a food dish, and they do not share any common elements or themes.<|eot|>
The first image features a dog standing on a skateboard, while the second image showcases a plate of spaghetti with tomato sauce and cheese. The two images appear to be unrelated, with one depicting a playful scene of a dog on a skateboard and the other presenting a classic Italian dish.<|eom|>
```
@ -135,13 +135,44 @@ We are continuing the format for zero shot function calling used in previous ver
```
<|begin_of_text|><|header_start|>system<|header_end|>
You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
1. FUNCTION CALLS:
- ONLY use functions that are EXPLICITLY listed in the function list below
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
Examples:
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
INCORRECT: get_weather(location="New York")
INCORRECT: Let me check the weather: [get_weather(location="New York")]
INCORRECT: [get_events(location="Singapore")] <- If function not in list
2. RESPONSE RULES:
- For pure function requests matching a listed function: ONLY output the function call(s)
- For knowledge questions: ONLY output text
- For missing parameters: ONLY request the specific missing parameters
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
- NEVER combine text and function calls in the same response
- NEVER suggest alternative functions when the requested service is unavailable
- NEVER create or invent new functions not listed below
3. STRICT BOUNDARIES:
- ONLY use functions from the list below - no exceptions
- NEVER use a function as an alternative to unavailable information
- NEVER call functions not present in the function list
- NEVER add explanatory text to function calls
- NEVER respond with empty brackets
- Use proper Python/JSON syntax for function calls
- Check the function list carefully before responding
4. TOOL RESPONSE HANDLING:
- When receiving tool responses: provide concise, natural language responses
- Don't repeat tool response verbatim
- Don't add supplementary information
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
@ -151,9 +182,7 @@ Here is a list of functions in JSON format that you can invoke.
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": [
"city"
],
"required": ["city"],
"properties": {
"city": {
"type": "string",
@ -167,7 +196,10 @@ Here is a list of functions in JSON format that you can invoke.
}
}
}
<|eot|><|header_start|>user<|header_end|>
]
You can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|>
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>
@ -176,7 +208,7 @@ What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_e
##### Model Response Format
```
[get_weather(city='SF'), get_weather(city='Seattle')]<|eot|>
[get_weather(city="San Francisco"), get_weather(city="Seattle")]<|eot|>
```
@ -273,5 +305,5 @@ Use tools to get latest trending songs<|eot|><|header_start|>assistant<|header_e
##### Model Response Format
```
<function=trending_songs>{"n": "10"}</function><|eot|>
<function=trending_songs>{"n": 10}</function><|eot|>
```

View file

@ -0,0 +1,144 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# top-level folder for each specific model found within the models/ directory at
# the top-level of this source tree.
import textwrap
from typing import List, Optional
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
from llama_stack.models.llama.llama3.prompt_templates.base import (
PromptTemplate,
PromptTemplateGeneratorBase,
)
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
DEFAULT_PROMPT = textwrap.dedent(
"""
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
1. FUNCTION CALLS:
- ONLY use functions that are EXPLICITLY listed in the function list below
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
Examples:
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
INCORRECT: get_weather(location="New York")
INCORRECT: Let me check the weather: [get_weather(location="New York")]
INCORRECT: [get_events(location="Singapore")] <- If function not in list
2. RESPONSE RULES:
- For pure function requests matching a listed function: ONLY output the function call(s)
- For knowledge questions: ONLY output text
- For missing parameters: ONLY request the specific missing parameters
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
- NEVER combine text and function calls in the same response
- NEVER suggest alternative functions when the requested service is unavailable
- NEVER create or invent new functions not listed below
3. STRICT BOUNDARIES:
- ONLY use functions from the list below - no exceptions
- NEVER use a function as an alternative to unavailable information
- NEVER call functions not present in the function list
- NEVER add explanatory text to function calls
- NEVER respond with empty brackets
- Use proper Python/JSON syntax for function calls
- Check the function list carefully before responding
4. TOOL RESPONSE HANDLING:
- When receiving tool responses: provide concise, natural language responses
- Don't repeat tool response verbatim
- Don't add supplementary information
{{ function_description }}
""".strip("\n")
)
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
system_prompt = system_prompt or self.DEFAULT_PROMPT
return PromptTemplate(
system_prompt,
{"function_description": self._gen_function_description(custom_tools)},
)
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
template_str = textwrap.dedent(
"""
Here is a list of functions in JSON format that you can invoke.
[
{% for t in tools -%}
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
{%- set tname = t.tool_name -%}
{%- set tdesc = t.description -%}
{%- set tparams = t.parameters -%}
{%- set required_params = [] -%}
{%- for name, param in tparams.items() if param.required == true -%}
{%- set _ = required_params.append(name) -%}
{%- endfor -%}
{
"name": "{{tname}}",
"description": "{{tdesc}}",
"parameters": {
"type": "dict",
"required": {{ required_params | tojson }},
"properties": {
{%- for name, param in tparams.items() %}
"{{name}}": {
"type": "{{param.param_type}}",
"description": "{{param.description}}"{% if param.default %},
"default": "{{param.default}}"{% endif %}
}{% if not loop.last %},{% endif %}
{%- endfor %}
}
}
}{% if not loop.last %},
{% endif -%}
{%- endfor %}
]
You can answer general questions or invoke tools when necessary.
In addition to tool calls, you should also augment your responses by using the tool outputs.
"""
)
return PromptTemplate(
template_str.strip("\n"),
{"tools": [t.model_dump() for t in custom_tools]},
).render()
def data_examples(self) -> List[List[ToolDefinition]]:
return [
[
ToolDefinition(
tool_name="get_weather",
description="Get weather info for places",
parameters={
"city": ToolParamDefinition(
param_type="string",
description="The name of the city to get the weather for",
required=True,
),
"metric": ToolParamDefinition(
param_type="string",
description="The metric for weather. Options are: celsius, fahrenheit",
required=False,
default="celsius",
),
},
),
]
]

View file

@ -9,6 +9,10 @@ from io import BytesIO
from pathlib import Path
from typing import List
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator,
)
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
from ..prompt_format import (
Llama4UseCase,
@ -177,39 +181,9 @@ def usecases(base_model: bool = False) -> List[UseCase | str]:
[
RawMessage(
role="system",
content="""You are an expert in composing functions. You are given a question and a set of possible functions.
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
also point it out. You should only return the function call in tools call sections.
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
You SHOULD NOT include any other text in the response.
Here is a list of functions in JSON format that you can invoke.
[
{
"name": "get_weather",
"description": "Get weather info for places",
"parameters": {
"type": "dict",
"required": [
"city"
],
"properties": {
"city": {
"type": "string",
"description": "The name of the city to get the weather for"
},
"metric": {
"type": "string",
"description": "The metric for weather. Options are: celsius, fahrenheit",
"default": "celsius"
}
}
}
}
""",
content=PythonListCustomToolGenerator()
.gen(PythonListCustomToolGenerator().data_examples()[0])
.render(),
),
RawMessage(
role="user",

View file

@ -253,7 +253,8 @@ class MetaReferenceInferenceImpl(
def impl():
stop_reason = None
for token_result in self.generator.completion(request):
for token_results in self.generator.completion([request]):
token_result = token_results[0]
if token_result.token == tokenizer.eot_id:
stop_reason = StopReason.end_of_turn
text = ""

View file

@ -69,7 +69,10 @@ class CancelSentinel(BaseModel):
class TaskRequest(BaseModel):
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
task: Tuple[
str,
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
]
class TaskResponse(BaseModel):
@ -231,10 +234,10 @@ def worker_process_entrypoint(
while True:
try:
task = req_gen.send(result)
if isinstance(task, str) and task == EndSentinel():
if isinstance(task, EndSentinel):
break
assert isinstance(task, TaskRequest)
assert isinstance(task, TaskRequest), task
result = model(task.task)
except StopIteration:
break
@ -331,7 +334,10 @@ class ModelParallelProcessGroup:
def run_inference(
self,
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
req: Tuple[
str,
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
],
) -> Generator:
assert not self.running, "inference already running"

View file

@ -33,6 +33,7 @@ from llama_stack.apis.tools import (
)
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
from llama_stack.providers.datatypes import ToolsProtocolPrivate
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
from llama_stack.providers.utils.memory.vector_store import (
content_from_doc,
make_overlapped_chunks,
@ -153,6 +154,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
)
)
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
picked.append(
TextContentItem(
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
)
)
return RAGQueryResult(
content=picked,

View file

@ -6,7 +6,7 @@
from typing import List
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
def available_providers() -> List[ProviderSpec]:
@ -25,4 +25,22 @@ def available_providers() -> List[ProviderSpec]:
Api.agents,
],
),
remote_provider_spec(
api=Api.eval,
adapter=AdapterSpec(
adapter_type="nvidia",
pip_packages=[
"requests",
],
module="llama_stack.providers.remote.eval.nvidia",
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
),
api_dependencies=[
Api.datasetio,
Api.datasets,
Api.scoring,
Api.inference,
Api.agents,
],
),
]

View file

@ -288,4 +288,14 @@ def available_providers() -> List[ProviderSpec]:
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
),
),
remote_provider_spec(
api=Api.inference,
adapter=AdapterSpec(
adapter_type="watsonx",
pip_packages=["ibm_watson_machine_learning"],
module="llama_stack.providers.remote.inference.watsonx",
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
),
),
]

View file

@ -0,0 +1,5 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.

View file

@ -0,0 +1,134 @@
# NVIDIA NeMo Evaluator Eval Provider
## Overview
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
### Example for register an academic benchmark
```
POST /eval/benchmarks
```
```json
{
"benchmark_id": "mmlu",
"dataset_id": "",
"scoring_functions": [],
"metadata": {
"type": "mmlu"
}
}
```
### Example for register a custom evaluation
```
POST /eval/benchmarks
```
```json
{
"benchmark_id": "my-custom-benchmark",
"dataset_id": "",
"scoring_functions": [],
"metadata": {
"type": "custom",
"params": {
"parallelism": 8
},
"tasks": {
"qa": {
"type": "completion",
"params": {
"template": {
"prompt": "{{prompt}}",
"max_tokens": 200
}
},
"dataset": {
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
},
"metrics": {
"bleu": {
"type": "bleu",
"params": {
"references": [
"{{ideal_response}}"
]
}
}
}
}
}
}
}
```
### Example for triggering a benchmark/custom evaluation
```
POST /eval/benchmarks/{benchmark_id}/jobs
```
```json
{
"benchmark_id": "my-custom-benchmark",
"benchmark_config": {
"eval_candidate": {
"type": "model",
"model": "meta-llama/Llama3.1-8B-Instruct",
"sampling_params": {
"max_tokens": 100,
"temperature": 0.7
}
},
"scoring_params": {}
}
}
```
Response example:
```json
{
"job_id": "eval-1234",
"status": "in_progress"
}
```
### Example for getting the status of a job
```
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
```
Response example:
```json
{
"job_id": "eval-1234",
"status": "in_progress"
}
```
### Example for cancelling a job
```
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
```
### Example for getting the results
```
GET /eval/benchmarks/{benchmark_id}/results
```
```json
{
"generations": [],
"scores": {
"{benchmark_id}": {
"score_rows": [],
"aggregated_results": {
"tasks": {},
"groups": {}
}
}
}
}
```

View file

@ -0,0 +1,31 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict
from llama_stack.distribution.datatypes import Api
from .config import NVIDIAEvalConfig
async def get_adapter_impl(
config: NVIDIAEvalConfig,
deps: Dict[Api, Any],
):
from .eval import NVIDIAEvalImpl
impl = NVIDIAEvalImpl(
config,
deps[Api.datasetio],
deps[Api.datasets],
deps[Api.scoring],
deps[Api.inference],
deps[Api.agents],
)
await impl.initialize()
return impl
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]

View file

@ -0,0 +1,29 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict
from pydantic import BaseModel, Field
class NVIDIAEvalConfig(BaseModel):
"""
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
Attributes:
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
"""
evaluator_url: str = Field(
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
description="The url for accessing the evaluator service",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
}

View file

@ -0,0 +1,154 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, Dict, List
import requests
from llama_stack.apis.agents import Agents
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.datasetio import DatasetIO
from llama_stack.apis.datasets import Datasets
from llama_stack.apis.inference import Inference
from llama_stack.apis.scoring import Scoring, ScoringResult
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from .....apis.common.job_types import Job, JobStatus
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
from .config import NVIDIAEvalConfig
DEFAULT_NAMESPACE = "nvidia"
class NVIDIAEvalImpl(
Eval,
BenchmarksProtocolPrivate,
ModelRegistryHelper,
):
def __init__(
self,
config: NVIDIAEvalConfig,
datasetio_api: DatasetIO,
datasets_api: Datasets,
scoring_api: Scoring,
inference_api: Inference,
agents_api: Agents,
) -> None:
self.config = config
self.datasetio_api = datasetio_api
self.datasets_api = datasets_api
self.scoring_api = scoring_api
self.inference_api = inference_api
self.agents_api = agents_api
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
async def initialize(self) -> None: ...
async def shutdown(self) -> None: ...
async def _evaluator_get(self, path):
"""Helper for making GET requests to the evaluator service."""
response = requests.get(url=f"{self.config.evaluator_url}{path}")
response.raise_for_status()
return response.json()
async def _evaluator_post(self, path, data):
"""Helper for making POST requests to the evaluator service."""
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
response.raise_for_status()
return response.json()
async def register_benchmark(self, task_def: Benchmark) -> None:
"""Register a benchmark as an evaluation configuration."""
await self._evaluator_post(
"/v1/evaluation/configs",
{
"namespace": DEFAULT_NAMESPACE,
"name": task_def.benchmark_id,
# metadata is copied to request body as-is
**task_def.metadata,
},
)
async def run_eval(
self,
benchmark_id: str,
benchmark_config: BenchmarkConfig,
) -> Job:
"""Run an evaluation job for a benchmark."""
model = (
benchmark_config.eval_candidate.model
if benchmark_config.eval_candidate.type == "model"
else benchmark_config.eval_candidate.config.model
)
nvidia_model = self.get_provider_model_id(model) or model
result = await self._evaluator_post(
"/v1/evaluation/jobs",
{
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
"target": {"type": "model", "model": nvidia_model},
},
)
return Job(job_id=result["id"], status=JobStatus.in_progress)
async def evaluate_rows(
self,
benchmark_id: str,
input_rows: List[Dict[str, Any]],
scoring_functions: List[str],
benchmark_config: BenchmarkConfig,
) -> EvaluateResponse:
raise NotImplementedError()
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
"""Get the status of an evaluation job.
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
"""
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
result_status = result["status"]
job_status = JobStatus.failed
if result_status in ["created", "pending"]:
job_status = JobStatus.scheduled
elif result_status in ["running"]:
job_status = JobStatus.in_progress
elif result_status in ["completed"]:
job_status = JobStatus.completed
elif result_status in ["cancelled"]:
job_status = JobStatus.cancelled
return Job(job_id=job_id, status=job_status)
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
"""Cancel the evaluation job."""
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
"""Returns the results of the evaluation job."""
job = await self.job_status(benchmark_id, job_id)
status = job.status
if not status or status != JobStatus.completed:
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
return EvaluateResponse(
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
generations=[],
scores={
benchmark_id: ScoringResult(
score_rows=[],
aggregated_results=result,
)
},
)

View file

@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
default=60,
description="Timeout for the HTTP requests",
)
append_api_version: bool = Field(
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
"api_key": "${env.NVIDIA_API_KEY:}",
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
}

View file

@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
)
from llama_stack.apis.inference.inference import (
OpenAIChatCompletion,
@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import (
OpenAIMessageParam,
OpenAIResponseFormatParam,
)
from llama_stack.models.llama.datatypes import ToolPromptFormat
from llama_stack.apis.models import Model, ModelType
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
from llama_stack.providers.utils.inference import (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
)
from llama_stack.providers.utils.inference.model_registry import (
ModelRegistryHelper,
)
@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
}
base_url = f"{self._config.url}/v1"
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
base_url = special_model_urls[provider_model_id]
return _get_client_for_base_url(base_url)
async def _get_provider_model_id(self, model_id: str) -> str:
@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
return await self._get_client(provider_model_id).chat.completions.create(**params)
except APIConnectionError as e:
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
async def register_model(self, model: Model) -> Model:
"""
Allow non-llama model registration.
Non-llama model registration: API Catalogue models, post-training models, etc.
client = LlamaStackAsLibraryClient("nvidia")
client.models.register(
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
model_type=ModelType.llm,
provider_id="nvidia",
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
)
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
"""
if model.model_type == ModelType.embedding:
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
provider_resource_id = model.provider_resource_id
else:
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
if provider_resource_id:
model.provider_resource_id = provider_resource_id
else:
llama_model = model.metadata.get("llama_model")
existing_llama_model = self.get_llama_model(model.provider_resource_id)
if existing_llama_model:
if existing_llama_model != llama_model:
raise ValueError(
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
)
else:
# not llama model
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
)
else:
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
return model

View file

@ -8,7 +8,6 @@
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
import httpx
from ollama import AsyncClient
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import (
@ -73,6 +72,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
interleaved_content_as_str,
request_has_media,
)
from ollama import AsyncClient # type: ignore[attr-defined]
from .models import model_entries

View file

@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
async def shutdown(self) -> None:
if self._client:
await self._client.close()
# Together client has no close method, so just set to None
self._client = None
if self._openai_client:
await self._openai_client.close()
self._openai_client = None
async def completion(
self,
@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
top_p=top_p,
user=user,
)
if params.get("stream", True):
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore

View file

@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
self.client = None
async def initialize(self) -> None:
log.info(f"Initializing VLLM client with base_url={self.config.url}")
self.client = AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
pass
async def shutdown(self) -> None:
pass
@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
raise ValueError("Model store not set")
return await self.model_store.get_model(model_id)
def _lazy_initialize_client(self):
if self.client is not None:
return
log.info(f"Initializing vLLM client with base_url={self.config.url}")
self.client = self._create_client()
def _create_client(self):
return AsyncOpenAI(
base_url=self.config.url,
api_key=self.config.api_token,
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
)
async def completion(
self,
model_id: str,
@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
self._lazy_initialize_client()
if sampling_params is None:
sampling_params = SamplingParams()
model = await self._get_model(model_id)
@ -357,12 +368,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
yield chunk
async def register_model(self, model: Model) -> Model:
assert self.client is not None
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
# Changing this may lead to unpredictable behavior.
client = self._create_client() if self.client is None else self.client
try:
model = await self.register_helper.register_model(model)
except ValueError:
pass # Ignore statically unknown model, will check live listing
res = await self.client.models.list()
res = await client.models.list()
available_models = [m.id async for m in res]
if model.provider_resource_id not in available_models:
raise ValueError(
@ -413,6 +427,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
self._lazy_initialize_client()
assert self.client is not None
model = await self._get_model(model_id)
@ -452,6 +467,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
extra_body: Dict[str, Any] = {}
@ -508,6 +524,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
self._lazy_initialize_client()
model_obj = await self._get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,

View file

@ -0,0 +1,22 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.apis.inference import Inference
from .config import WatsonXConfig
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
# import dynamically so `llama stack build` does not fail due to missing dependencies
from .watsonx import WatsonXInferenceAdapter
if not isinstance(config, WatsonXConfig):
raise RuntimeError(f"Unexpected config type: {type(config)}")
adapter = WatsonXInferenceAdapter(config)
return adapter
__all__ = ["get_adapter_impl", "WatsonXConfig"]

View file

@ -0,0 +1,46 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
from typing import Any, Dict, Optional
from pydantic import BaseModel, Field, SecretStr
from llama_stack.schema_utils import json_schema_type
class WatsonXProviderDataValidator(BaseModel):
url: str
api_key: str
project_id: str
@json_schema_type
class WatsonXConfig(BaseModel):
url: str = Field(
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
description="A base url for accessing the watsonx.ai",
)
api_key: Optional[SecretStr] = Field(
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
description="The watsonx API key, only needed of using the hosted service",
)
project_id: Optional[str] = Field(
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
description="The Project ID key, only needed of using the hosted service",
)
timeout: int = Field(
default=60,
description="Timeout for the HTTP requests",
)
@classmethod
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
return {
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
"api_key": "${env.WATSONX_API_KEY:}",
"project_id": "${env.WATSONX_PROJECT_ID:}",
}

View file

@ -0,0 +1,47 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
MODEL_ENTRIES = [
build_hf_repo_model_entry(
"meta-llama/llama-3-3-70b-instruct",
CoreModelId.llama3_3_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-2-13b-chat",
CoreModelId.llama2_13b.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-70b-instruct",
CoreModelId.llama3_1_70b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-1-8b-instruct",
CoreModelId.llama3_1_8b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-11b-vision-instruct",
CoreModelId.llama3_2_11b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-1b-instruct",
CoreModelId.llama3_2_1b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-3b-instruct",
CoreModelId.llama3_2_3b_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-3-2-90b-vision-instruct",
CoreModelId.llama3_2_90b_vision_instruct.value,
),
build_hf_repo_model_entry(
"meta-llama/llama-guard-3-11b-vision",
CoreModelId.llama_guard_3_11b_vision.value,
),
]

View file

@ -0,0 +1,378 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
from ibm_watson_machine_learning.foundation_models import Model
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
from openai import AsyncOpenAI
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
from llama_stack.apis.inference import (
ChatCompletionRequest,
ChatCompletionResponse,
CompletionRequest,
EmbeddingsResponse,
EmbeddingTaskType,
Inference,
LogProbConfig,
Message,
ResponseFormat,
SamplingParams,
TextTruncation,
ToolChoice,
ToolConfig,
ToolDefinition,
ToolPromptFormat,
)
from llama_stack.apis.inference.inference import (
GreedySamplingStrategy,
OpenAIChatCompletion,
OpenAIChatCompletionChunk,
OpenAICompletion,
OpenAIMessageParam,
OpenAIResponseFormatParam,
TopKSamplingStrategy,
TopPSamplingStrategy,
)
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
from llama_stack.providers.utils.inference.openai_compat import (
OpenAICompatCompletionChoice,
OpenAICompatCompletionResponse,
prepare_openai_completion_params,
process_chat_completion_response,
process_chat_completion_stream_response,
process_completion_response,
process_completion_stream_response,
)
from llama_stack.providers.utils.inference.prompt_adapter import (
chat_completion_request_to_prompt,
completion_request_to_prompt,
request_has_media,
)
from . import WatsonXConfig
from .models import MODEL_ENTRIES
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
def __init__(self, config: WatsonXConfig) -> None:
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
print(f"Initializing watsonx InferenceAdapter({config.url})...")
self._config = config
self._project_id = self._config.project_id
async def initialize(self) -> None:
pass
async def shutdown(self) -> None:
pass
async def completion(
self,
model_id: str,
content: InterleavedContent,
sampling_params: Optional[SamplingParams] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = CompletionRequest(
model=model.provider_resource_id,
content=content,
sampling_params=sampling_params,
response_format=response_format,
stream=stream,
logprobs=logprobs,
)
if stream:
return self._stream_completion(request)
else:
return await self._nonstream_completion(request)
def _get_client(self, model_id) -> Model:
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
config_url = self._config.url
project_id = self._config.project_id
credentials = {"url": config_url, "apikey": config_api_key}
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
def _get_openai_client(self) -> AsyncOpenAI:
if not self._openai_client:
self._openai_client = AsyncOpenAI(
base_url=f"{self._config.url}/openai/v1",
api_key=self._config.api_key,
)
return self._openai_client
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
choices = []
if "results" in r:
for result in r["results"]:
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
text=result["generated_text"],
)
choices.append(choice)
response = OpenAICompatCompletionResponse(
choices=choices,
)
return process_completion_response(response)
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
async def _generate_and_convert_to_openai_compat():
s = self._get_client(request.model).generate_text_stream(**params)
for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=None,
text=chunk,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _generate_and_convert_to_openai_compat()
async for chunk in process_completion_stream_response(stream):
yield chunk
async def chat_completion(
self,
model_id: str,
messages: List[Message],
sampling_params: Optional[SamplingParams] = None,
tools: Optional[List[ToolDefinition]] = None,
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
tool_prompt_format: Optional[ToolPromptFormat] = None,
response_format: Optional[ResponseFormat] = None,
stream: Optional[bool] = False,
logprobs: Optional[LogProbConfig] = None,
tool_config: Optional[ToolConfig] = None,
) -> AsyncGenerator:
if sampling_params is None:
sampling_params = SamplingParams()
model = await self.model_store.get_model(model_id)
request = ChatCompletionRequest(
model=model.provider_resource_id,
messages=messages,
sampling_params=sampling_params,
tools=tools or [],
response_format=response_format,
stream=stream,
logprobs=logprobs,
tool_config=tool_config,
)
if stream:
return self._stream_chat_completion(request)
else:
return await self._nonstream_chat_completion(request)
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
params = await self._get_params(request)
r = self._get_client(request.model).generate(**params)
choices = []
if "results" in r:
for result in r["results"]:
choice = OpenAICompatCompletionChoice(
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
text=result["generated_text"],
)
choices.append(choice)
response = OpenAICompatCompletionResponse(
choices=choices,
)
return process_chat_completion_response(response, request)
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
params = await self._get_params(request)
model_id = request.model
# if we shift to TogetherAsyncClient, we won't need this wrapper
async def _to_async_generator():
s = self._get_client(model_id).generate_text_stream(**params)
for chunk in s:
choice = OpenAICompatCompletionChoice(
finish_reason=None,
text=chunk,
)
yield OpenAICompatCompletionResponse(
choices=[choice],
)
stream = _to_async_generator()
async for chunk in process_chat_completion_stream_response(stream, request):
yield chunk
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
input_dict = {"params": {}}
media_present = request_has_media(request)
llama_model = self.get_llama_model(request.model)
if isinstance(request, ChatCompletionRequest):
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
else:
assert not media_present, "Together does not support media for Completion requests"
input_dict["prompt"] = await completion_request_to_prompt(request)
if request.sampling_params:
if request.sampling_params.strategy:
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
if request.sampling_params.max_tokens:
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
if request.sampling_params.repetition_penalty:
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
input_dict["params"][GenParams.TEMPERATURE] = 0.0
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
params = {
**input_dict,
}
return params
async def embeddings(
self,
model_id: str,
contents: List[str] | List[InterleavedContentItem],
text_truncation: Optional[TextTruncation] = TextTruncation.none,
output_dimension: Optional[int] = None,
task_type: Optional[EmbeddingTaskType] = None,
) -> EmbeddingsResponse:
raise NotImplementedError("embedding is not supported for watsonx")
async def openai_completion(
self,
model: str,
prompt: Union[str, List[str], List[int], List[List[int]]],
best_of: Optional[int] = None,
echo: Optional[bool] = None,
frequency_penalty: Optional[float] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
presence_penalty: Optional[float] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
guided_choice: Optional[List[str]] = None,
prompt_logprobs: Optional[int] = None,
) -> OpenAICompletion:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
prompt=prompt,
best_of=best_of,
echo=echo,
frequency_penalty=frequency_penalty,
logit_bias=logit_bias,
logprobs=logprobs,
max_tokens=max_tokens,
n=n,
presence_penalty=presence_penalty,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
top_p=top_p,
user=user,
)
return await self._get_openai_client().completions.create(**params) # type: ignore
async def openai_chat_completion(
self,
model: str,
messages: List[OpenAIMessageParam],
frequency_penalty: Optional[float] = None,
function_call: Optional[Union[str, Dict[str, Any]]] = None,
functions: Optional[List[Dict[str, Any]]] = None,
logit_bias: Optional[Dict[str, float]] = None,
logprobs: Optional[bool] = None,
max_completion_tokens: Optional[int] = None,
max_tokens: Optional[int] = None,
n: Optional[int] = None,
parallel_tool_calls: Optional[bool] = None,
presence_penalty: Optional[float] = None,
response_format: Optional[OpenAIResponseFormatParam] = None,
seed: Optional[int] = None,
stop: Optional[Union[str, List[str]]] = None,
stream: Optional[bool] = None,
stream_options: Optional[Dict[str, Any]] = None,
temperature: Optional[float] = None,
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
tools: Optional[List[Dict[str, Any]]] = None,
top_logprobs: Optional[int] = None,
top_p: Optional[float] = None,
user: Optional[str] = None,
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
model_obj = await self.model_store.get_model(model)
params = await prepare_openai_completion_params(
model=model_obj.provider_resource_id,
messages=messages,
frequency_penalty=frequency_penalty,
function_call=function_call,
functions=functions,
logit_bias=logit_bias,
logprobs=logprobs,
max_completion_tokens=max_completion_tokens,
max_tokens=max_tokens,
n=n,
parallel_tool_calls=parallel_tool_calls,
presence_penalty=presence_penalty,
response_format=response_format,
seed=seed,
stop=stop,
stream=stream,
stream_options=stream_options,
temperature=temperature,
tool_choice=tool_choice,
tools=tools,
top_logprobs=top_logprobs,
top_p=top_p,
user=user,
)
if params.get("stream", False):
return self._stream_openai_chat_completion(params)
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
# watsonx.ai sometimes adds usage data to the stream
include_usage = False
if params.get("stream_options", None):
include_usage = params["stream_options"].get("include_usage", False)
stream = await self._get_openai_client().chat.completions.create(**params)
seen_finish_reason = False
async for chunk in stream:
# Final usage chunk with no choices that the user didn't request, so discard
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
break
yield chunk
for choice in chunk.choices:
if choice.finish_reason:
seen_finish_reason = True
break

View file

@ -36,7 +36,6 @@ import os
os.environ["NVIDIA_API_KEY"] = "your-api-key"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id")
### Inference with the fine-tuned model
#### 1. Register the model
```python
from llama_stack.apis.models import Model, ModelType
client.models.register(
model_id="test-example-model@v1",
provider_id="nvidia",
provider_model_id="test-example-model@v1",
model_type=ModelType.llm,
)
```
#### 2. Inference with the fine-tuned model
```python
response = client.inference.completion(
content="Complete the sentence using one word: Roses are red, violets are ",

View file

@ -67,13 +67,18 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
# TODO: filter by available models based on /config endpoint
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
self.customizer_url = config.customizer_url
self.session = None
self.customizer_url = config.customizer_url
if not self.customizer_url:
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
self.customizer_url = "http://nemo.test"
async def _get_session(self) -> aiohttp.ClientSession:
if self.session is None or self.session.closed:
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
return self.session
async def _make_request(
self,
method: str,
@ -94,8 +99,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
if json and "Content-Type" not in request_headers:
request_headers["Content-Type"] = "application/json"
session = await self._get_session()
for _ in range(self.config.max_retries):
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
async with session.request(method, url, params=params, json=json, **kwargs) as response:
if response.status >= 400:
error_data = await response.json()
raise Exception(f"API request failed: {error_data}")
@ -122,8 +128,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
jobs = []
for job in response.get("data", []):
job_id = job.pop("id")
job_status = job.pop("status", "unknown").lower()
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
job_status = job.pop("status", "scheduled").lower()
mapped_status = STATUS_MAPPING.get(job_status, "scheduled")
# Convert string timestamps to datetime objects
created_at = (
@ -177,7 +183,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
)
api_status = response.pop("status").lower()
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
mapped_status = STATUS_MAPPING.get(api_status, "scheduled")
return NvidiaPostTrainingJobStatusResponse(
status=JobStatus(mapped_status),
@ -239,6 +245,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
Supported models:
- meta/llama-3.1-8b-instruct
- meta/llama-3.2-1b-instruct
Supported algorithm configs:
- LoRA, SFT
@ -284,10 +291,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
- LoRA config:
## NeMo customizer specific LoRA parameters
- adapter_dim: int - Adapter dimension
Default: 8 (supports powers of 2)
- adapter_dropout: float - Adapter dropout
Default: None (0.0-1.0)
- alpha: int - Scaling factor for the LoRA update
Default: 16
Note:
@ -297,7 +300,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
User is informed about unsupported parameters via warnings.
"""
# Map model to nvidia model name
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
# See `_MODEL_ENTRIES` for supported models
nvidia_model = self.get_provider_model_id(model)
# Check for unsupported method parameters
@ -330,7 +333,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
},
"data_config": {"dataset_id", "batch_size"},
"optimizer_config": {"lr", "weight_decay"},
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
"lora_config": {"type", "alpha"},
}
# Validate all parameters at once
@ -389,16 +392,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
# Handle LoRA-specific configuration
if algorithm_config:
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
if algorithm_config.type == "LoRA":
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
job_config["hyperparameters"]["lora"] = {
k: v
for k, v in {
"adapter_dim": algorithm_config.get("adapter_dim"),
"alpha": algorithm_config.get("alpha"),
"adapter_dropout": algorithm_config.get("adapter_dropout"),
}.items()
if v is not None
k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
}
else:
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")

View file

@ -524,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
else:
content = [await _convert_content(message.content)]
return {
result = {
"role": message.role,
"content": content,
}
if hasattr(message, "tool_calls") and message.tool_calls:
result["tool_calls"] = []
for tc in message.tool_calls:
result["tool_calls"].append(
{
"id": tc.call_id,
"type": "function",
"function": {
"name": tc.tool_name,
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
},
}
)
return result
class UnparseableToolCall(BaseModel):
"""

View file

@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
SystemDefaultGenerator,
)
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
)
from llama_stack.models.llama.sku_list import resolve_model
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
from llama_stack.providers.utils.inference import supported_inference_models
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
elif model.model_family in (
ModelFamily.llama3_2,
ModelFamily.llama3_3,
ModelFamily.llama4,
):
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
messages = augment_messages_for_tools_llama_3_2(request)
# llama3.2, llama3.3 follow the same tool prompt format
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
elif model.model_family == ModelFamily.llama4:
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
else:
messages = request.messages
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
return messages
def augment_messages_for_tools_llama_3_2(
def augment_messages_for_tools_llama(
request: ChatCompletionRequest,
custom_tool_prompt_generator,
) -> List[Message]:
existing_messages = request.messages
existing_system_message = None
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
system_prompt = existing_system_message.content
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
sys_content += tool_template.render()
sys_content += "\n"

View file

@ -394,12 +394,10 @@
"aiosqlite",
"blobfile",
"chardet",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"langdetect",
"matplotlib",
"nltk",
"numpy",
@ -411,7 +409,6 @@
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
@ -419,7 +416,6 @@
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
],
"ollama": [
@ -759,5 +755,41 @@
"vllm",
"sentence-transformers --no-deps",
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
],
"watsonx": [
"aiosqlite",
"autoevals",
"blobfile",
"chardet",
"datasets",
"emoji",
"faiss-cpu",
"fastapi",
"fire",
"httpx",
"ibm_watson_machine_learning",
"langdetect",
"matplotlib",
"mcp",
"nltk",
"numpy",
"openai",
"opentelemetry-exporter-otlp-proto-http",
"opentelemetry-sdk",
"pandas",
"pillow",
"psycopg2-binary",
"pymongo",
"pypdf",
"pythainlp",
"redis",
"requests",
"scikit-learn",
"scipy",
"sentencepiece",
"tqdm",
"transformers",
"tree_sitter",
"uvicorn"
]
}

View file

@ -69,6 +69,7 @@ LLAMA_STACK_PORT=8321
docker run \
-it \
--pull always \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \
@ -82,6 +83,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
docker run \
-it \
--pull always \
--gpu all \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ~/.llama:/root/.llama \
llamastack/distribution-{{ name }} \

View file

@ -1,6 +1,6 @@
version: '2'
distribution_spec:
description: Use NVIDIA NIM for running LLM inference and safety
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
providers:
inference:
- remote::nvidia
@ -13,7 +13,7 @@ distribution_spec:
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
- remote::nvidia
post_training:
- remote::nvidia
datasetio:

View file

@ -7,6 +7,7 @@
from pathlib import Path
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
@ -20,7 +21,7 @@ def get_distribution_template() -> DistributionTemplate:
"safety": ["remote::nvidia"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"eval": ["remote::nvidia"],
"post_training": ["remote::nvidia"],
"datasetio": ["inline::localfs"],
"scoring": ["inline::basic"],
@ -37,6 +38,11 @@ def get_distribution_template() -> DistributionTemplate:
provider_type="remote::nvidia",
config=NVIDIASafetyConfig.sample_run_config(),
)
eval_provider = Provider(
provider_id="nvidia",
provider_type="remote::nvidia",
config=NVIDIAEvalConfig.sample_run_config(),
)
inference_model = ModelInput(
model_id="${env.INFERENCE_MODEL}",
provider_id="nvidia",
@ -60,7 +66,7 @@ def get_distribution_template() -> DistributionTemplate:
return DistributionTemplate(
name="nvidia",
distro_type="self_hosted",
description="Use NVIDIA NIM for running LLM inference and safety",
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
@ -69,6 +75,7 @@ def get_distribution_template() -> DistributionTemplate:
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
"eval": [eval_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
@ -78,7 +85,8 @@ def get_distribution_template() -> DistributionTemplate:
"inference": [
inference_provider,
safety_provider,
]
],
"eval": [eval_provider],
},
default_models=[inference_model, safety_model],
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
@ -90,19 +98,15 @@ def get_distribution_template() -> DistributionTemplate:
"",
"NVIDIA API Key",
),
## Nemo Customizer related variables
"NVIDIA_USER_ID": (
"llama-stack-user",
"NVIDIA User ID",
"NVIDIA_APPEND_API_VERSION": (
"True",
"Whether to append the API version to the base_url",
),
## Nemo Customizer related variables
"NVIDIA_DATASET_NAMESPACE": (
"default",
"NVIDIA Dataset Namespace",
),
"NVIDIA_ACCESS_POLICIES": (
"{}",
"NVIDIA Access Policies",
),
"NVIDIA_PROJECT_ID": (
"test-project",
"NVIDIA Project ID",
@ -119,6 +123,10 @@ def get_distribution_template() -> DistributionTemplate:
"http://0.0.0.0:7331",
"URL for the NeMo Guardrails Service",
),
"NVIDIA_EVALUATOR_URL": (
"http://0.0.0.0:7331",
"URL for the NeMo Evaluator Service",
),
"INFERENCE_MODEL": (
"Llama3.1-8B-Instruct",
"Inference model",

View file

@ -18,6 +18,7 @@ providers:
config:
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
- provider_id: nvidia
provider_type: remote::nvidia
config:
@ -53,13 +54,10 @@ providers:
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_id: nvidia
provider_type: remote::nvidia
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
post_training:
- provider_id: nvidia
provider_type: remote::nvidia

View file

@ -18,6 +18,7 @@ providers:
config:
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
api_key: ${env.NVIDIA_API_KEY:}
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
@ -48,13 +49,10 @@ providers:
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
- provider_id: nvidia
provider_type: remote::nvidia
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
post_training:
- provider_id: nvidia
provider_type: remote::nvidia

View file

@ -0,0 +1,7 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from .watsonx import get_distribution_template # noqa: F401

View file

@ -0,0 +1,30 @@
version: '2'
distribution_spec:
description: Use watsonx for running LLM inference
providers:
inference:
- remote::watsonx
vector_io:
- inline::faiss
safety:
- inline::llama-guard
agents:
- inline::meta-reference
telemetry:
- inline::meta-reference
eval:
- inline::meta-reference
datasetio:
- remote::huggingface
- inline::localfs
scoring:
- inline::basic
- inline::llm-as-judge
- inline::braintrust
tool_runtime:
- remote::brave-search
- remote::tavily-search
- inline::code-interpreter
- inline::rag-runtime
- remote::model-context-protocol
image_type: conda

View file

@ -0,0 +1,74 @@
---
orphan: true
---
# watsonx Distribution
```{toctree}
:maxdepth: 2
:hidden:
self
```
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
{{ providers_table }}
{% if run_config_env_vars %}
### Environment Variables
The following environment variables can be configured:
{% for var, (default_value, description) in run_config_env_vars.items() %}
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
{% endfor %}
{% endif %}
{% if default_models %}
### Models
The following models are available by default:
{% for model in default_models %}
- `{{ model.model_id }} {{ model.doc_string }}`
{% endfor %}
{% endif %}
### Prerequisite: API Keys
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
## Running Llama Stack with watsonx
You can do this via Conda (build code), venv or Docker which has a pre-built image.
### Via Docker
This method allows you to get started quickly without having to build the distribution code.
```bash
LLAMA_STACK_PORT=5001
docker run \
-it \
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
-v ./run.yaml:/root/my-run.yaml \
llamastack/distribution-{{ name }} \
--yaml-config /root/my-run.yaml \
--port $LLAMA_STACK_PORT \
--env WATSONX_API_KEY=$WATSONX_API_KEY \
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
```
### Via Conda
```bash
llama stack build --template watsonx --image-type conda
llama stack run ./run.yaml \
--port $LLAMA_STACK_PORT \
--env WATSONX_API_KEY=$WATSONX_API_KEY \
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
```

View file

@ -0,0 +1,210 @@
version: '2'
image_name: watsonx
apis:
- agents
- datasetio
- eval
- inference
- safety
- scoring
- telemetry
- tool_runtime
- vector_io
providers:
inference:
- provider_id: watsonx
provider_type: remote::watsonx
config:
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
api_key: ${env.WATSONX_API_KEY:}
project_id: ${env.WATSONX_PROJECT_ID:}
vector_io:
- provider_id: faiss
provider_type: inline::faiss
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/faiss_store.db
safety:
- provider_id: llama-guard
provider_type: inline::llama-guard
config:
excluded_categories: []
agents:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
persistence_store:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db
telemetry:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/watsonx/trace_store.db}
eval:
- provider_id: meta-reference
provider_type: inline::meta-reference
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/meta_reference_eval.db
datasetio:
- provider_id: huggingface
provider_type: remote::huggingface
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/huggingface_datasetio.db
- provider_id: localfs
provider_type: inline::localfs
config:
kvstore:
type: sqlite
namespace: null
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/localfs_datasetio.db
scoring:
- provider_id: basic
provider_type: inline::basic
config: {}
- provider_id: llm-as-judge
provider_type: inline::llm-as-judge
config: {}
- provider_id: braintrust
provider_type: inline::braintrust
config:
openai_api_key: ${env.OPENAI_API_KEY:}
tool_runtime:
- provider_id: brave-search
provider_type: remote::brave-search
config:
api_key: ${env.BRAVE_SEARCH_API_KEY:}
max_results: 3
- provider_id: tavily-search
provider_type: remote::tavily-search
config:
api_key: ${env.TAVILY_SEARCH_API_KEY:}
max_results: 3
- provider_id: code-interpreter
provider_type: inline::code-interpreter
config: {}
- provider_id: rag-runtime
provider_type: inline::rag-runtime
config: {}
- provider_id: model-context-protocol
provider_type: remote::model-context-protocol
config: {}
metadata_store:
type: sqlite
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db
models:
- metadata: {}
model_id: meta-llama/llama-3-3-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.3-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-3-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-2-13b-chat
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-2-13b
provider_id: watsonx
provider_model_id: meta-llama/llama-2-13b-chat
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-70b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-70B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-70b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-1-8b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.1-8B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-1-8b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-11b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-1b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-1B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-1b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-3b-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-3B-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-3b-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-3-2-90b-vision-instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
provider_id: watsonx
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
model_type: llm
- metadata: {}
model_id: meta-llama/llama-guard-3-11b-vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
- metadata: {}
model_id: meta-llama/Llama-Guard-3-11B-Vision
provider_id: watsonx
provider_model_id: meta-llama/llama-guard-3-11b-vision
model_type: llm
shields: []
vector_dbs: []
datasets: []
scoring_fns: []
benchmarks: []
tool_groups:
- toolgroup_id: builtin::websearch
provider_id: tavily-search
- toolgroup_id: builtin::rag
provider_id: rag-runtime
- toolgroup_id: builtin::code_interpreter
provider_id: code-interpreter
server:
port: 8321

View file

@ -0,0 +1,90 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
from pathlib import Path
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
def get_distribution_template() -> DistributionTemplate:
providers = {
"inference": ["remote::watsonx"],
"vector_io": ["inline::faiss"],
"safety": ["inline::llama-guard"],
"agents": ["inline::meta-reference"],
"telemetry": ["inline::meta-reference"],
"eval": ["inline::meta-reference"],
"datasetio": ["remote::huggingface", "inline::localfs"],
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
"tool_runtime": [
"remote::brave-search",
"remote::tavily-search",
"inline::code-interpreter",
"inline::rag-runtime",
"remote::model-context-protocol",
],
}
inference_provider = Provider(
provider_id="watsonx",
provider_type="remote::watsonx",
config=WatsonXConfig.sample_run_config(),
)
available_models = {
"watsonx": MODEL_ENTRIES,
}
default_tool_groups = [
ToolGroupInput(
toolgroup_id="builtin::websearch",
provider_id="tavily-search",
),
ToolGroupInput(
toolgroup_id="builtin::rag",
provider_id="rag-runtime",
),
ToolGroupInput(
toolgroup_id="builtin::code_interpreter",
provider_id="code-interpreter",
),
]
default_models = get_model_registry(available_models)
return DistributionTemplate(
name="watsonx",
distro_type="remote_hosted",
description="Use watsonx for running LLM inference",
container_image=None,
template_path=Path(__file__).parent / "doc_template.md",
providers=providers,
available_models_by_provider=available_models,
run_configs={
"run.yaml": RunConfigSettings(
provider_overrides={
"inference": [inference_provider],
},
default_models=default_models,
default_tool_groups=default_tool_groups,
),
},
run_config_env_vars={
"LLAMASTACK_PORT": (
"5001",
"Port for the Llama Stack distribution server",
),
"WATSONX_API_KEY": (
"",
"watsonx API Key",
),
"WATSONX_PROJECT_ID": (
"",
"watsonx Project ID",
),
},
)

View file

@ -58,7 +58,16 @@ dev = [
"ruamel.yaml", # needed for openapi generator
]
# These are the dependencies required for running unit tests.
unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"]
unit = [
"sqlite-vec",
"openai",
"aiosqlite",
"aiohttp",
"pypdf",
"chardet",
"qdrant-client",
"opentelemetry-exporter-otlp-proto-http"
]
# These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
@ -265,6 +274,7 @@ exclude = [
"^llama_stack/providers/remote/inference/sample/",
"^llama_stack/providers/remote/inference/tgi/",
"^llama_stack/providers/remote/inference/together/",
"^llama_stack/providers/remote/inference/watsonx/",
"^llama_stack/providers/remote/safety/bedrock/",
"^llama_stack/providers/remote/safety/nvidia/",
"^llama_stack/providers/remote/safety/sample/",

View file

@ -10,6 +10,7 @@ import platform
import textwrap
import time
import pytest
from dotenv import load_dotenv
from llama_stack.log import get_logger
@ -19,10 +20,29 @@ from .report import Report
logger = get_logger(__name__, category="tests")
@pytest.hookimpl(hookwrapper=True)
def pytest_runtest_makereport(item, call):
outcome = yield
report = outcome.get_result()
if report.when == "call":
item.execution_outcome = report.outcome
item.was_xfail = getattr(report, "wasxfail", False)
def pytest_runtest_teardown(item):
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
if interval_seconds:
time.sleep(float(interval_seconds))
# Check if the test actually ran and passed or failed, but was not skipped or an expected failure (xfail)
outcome = getattr(item, "execution_outcome", None)
was_xfail = getattr(item, "was_xfail", False)
name = item.nodeid
if not any(x in name for x in ("inference/", "safety/", "agents/")):
return
logger.debug(f"Test '{item.nodeid}' outcome was '{outcome}' (xfail={was_xfail})")
if outcome in ("passed", "failed") and not was_xfail:
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
if interval_seconds:
time.sleep(float(interval_seconds))
def pytest_configure(config):

View file

@ -75,19 +75,24 @@ def openai_client(client_with_models):
return OpenAI(base_url=base_url, api_key="bar")
@pytest.fixture(params=["openai_client", "llama_stack_client"])
def compat_client(request):
return request.getfixturevalue(request.param)
@pytest.mark.parametrize(
"test_case",
[
"inference:completion:sanity",
],
)
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
def test_openai_completion_non_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
# ollama needs more verbose prompting for some reason here...
prompt = "Respond to this question and explain your answer. " + tc["content"]
response = openai_client.completions.create(
response = llama_stack_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
@ -103,13 +108,13 @@ def test_openai_completion_non_streaming(openai_client, client_with_models, text
"inference:completion:sanity",
],
)
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
def test_openai_completion_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
# ollama needs more verbose prompting for some reason here...
prompt = "Respond to this question and explain your answer. " + tc["content"]
response = openai_client.completions.create(
response = llama_stack_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=True,
@ -127,11 +132,11 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
0,
],
)
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs):
def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_models, text_model_id, prompt_logprobs):
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
prompt = "Hello, world!"
response = openai_client.completions.create(
response = llama_stack_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
@ -144,11 +149,11 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te
assert len(choice.prompt_logprobs) > 0
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
def test_openai_completion_guided_choice(llama_stack_client, client_with_models, text_model_id):
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
prompt = "I am feeling really sad today."
response = openai_client.completions.create(
response = llama_stack_client.completions.create(
model=text_model_id,
prompt=prompt,
stream=False,
@ -161,6 +166,9 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
assert choice.text in ["joy", "sadness"]
# Run the chat-completion tests with both the OpenAI client and the LlamaStack client
@pytest.mark.parametrize(
"test_case",
[
@ -168,13 +176,13 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
"inference:chat_completion:non_streaming_02",
],
)
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
def test_openai_chat_completion_non_streaming(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = openai_client.chat.completions.create(
response = compat_client.chat.completions.create(
model=text_model_id,
messages=[
{
@ -196,13 +204,13 @@ def test_openai_chat_completion_non_streaming(openai_client, client_with_models,
"inference:chat_completion:streaming_02",
],
)
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
def test_openai_chat_completion_streaming(compat_client, client_with_models, text_model_id, test_case):
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
tc = TestCase(test_case)
question = tc["question"]
expected = tc["expected"]
response = openai_client.chat.completions.create(
response = compat_client.chat.completions.create(
model=text_model_id,
messages=[{"role": "user", "content": question}],
stream=True,

View file

@ -114,7 +114,7 @@ def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
# Verify it is unregistered
with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"):
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
# Verify tools are also unregistered

View file

@ -16,8 +16,9 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
def test_container_build_passes_path(monkeypatch, tmp_path):
called_with = {}
def spy_build_image(cfg, build_file_path, image_name, template_or_config):
def spy_build_image(cfg, build_file_path, image_name, template_or_config, run_config=None):
called_with["path"] = template_or_config
called_with["run_config"] = run_config
return 0
monkeypatch.setattr(
@ -36,3 +37,4 @@ def test_container_build_passes_path(monkeypatch, tmp_path):
assert "path" in called_with
assert isinstance(called_with["path"], str)
assert Path(called_with["path"]).exists()
assert called_with["run_config"] is None

View file

@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel
from llama_stack.apis.inference import (
ChatCompletionRequest,
CompletionMessage,
SystemMessage,
ToolChoice,
ToolConfig,
ToolResponseMessage,
UserMessage,
)
from llama_stack.apis.models import Model
from llama_stack.models.llama.datatypes import StopReason
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
from llama_stack.providers.remote.inference.vllm.vllm import (
VLLMInferenceAdapter,
@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
assert request.tool_config.tool_choice == ToolChoice.none
@pytest.mark.asyncio
async def test_tool_call_response(vllm_inference_adapter):
"""Verify that tool call arguments from a CompletionMessage are correctly converted
into the expected JSON format."""
# Patch the call to vllm so we can inspect the arguments sent were correct
with patch.object(
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
) as mock_nonstream_completion:
messages = [
SystemMessage(content="You are a helpful assistant"),
UserMessage(content="How many?"),
CompletionMessage(
content="",
stop_reason=StopReason.end_of_turn,
tool_calls=[
ToolCall(
call_id="foo",
tool_name="knowledge_search",
arguments={"query": "How many?"},
arguments_json='{"query": "How many?"}',
)
],
),
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
]
await vllm_inference_adapter.chat_completion(
"mock-model",
messages,
stream=False,
tools=[],
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
)
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
{
"id": "foo",
"type": "function",
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
}
]
@pytest.mark.asyncio
async def test_tool_call_delta_empty_tool_call_buf():
"""

View file

@ -0,0 +1,201 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import os
import unittest
from unittest.mock import MagicMock, patch
import pytest
from llama_stack.apis.benchmarks import Benchmark
from llama_stack.apis.common.job_types import Job, JobStatus
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
from llama_stack.models.llama.sku_types import CoreModelId
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
MOCK_DATASET_ID = "default/test-dataset"
MOCK_BENCHMARK_ID = "test-benchmark"
class TestNVIDIAEvalImpl(unittest.TestCase):
def setUp(self):
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
# Create mock APIs
self.datasetio_api = MagicMock()
self.datasets_api = MagicMock()
self.scoring_api = MagicMock()
self.inference_api = MagicMock()
self.agents_api = MagicMock()
self.config = NVIDIAEvalConfig(
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
)
self.eval_impl = NVIDIAEvalImpl(
config=self.config,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
scoring_api=self.scoring_api,
inference_api=self.inference_api,
agents_api=self.agents_api,
)
# Mock the HTTP request methods
self.evaluator_get_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
)
self.evaluator_post_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
)
self.mock_evaluator_get = self.evaluator_get_patcher.start()
self.mock_evaluator_post = self.evaluator_post_patcher.start()
def tearDown(self):
"""Clean up after each test."""
self.evaluator_get_patcher.stop()
self.evaluator_post_patcher.stop()
def _assert_request_body(self, expected_json):
"""Helper method to verify request body in Evaluator POST request is correct"""
call_args = self.mock_evaluator_post.call_args
actual_json = call_args[0][1]
# Check that all expected keys contain the expected values in the actual JSON
for key, value in expected_json.items():
assert key in actual_json, f"Key '{key}' missing in actual JSON"
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
else:
assert actual_json[key] == value, f"Value mismatch for '{key}'"
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_register_benchmark(self):
eval_config = {
"type": "custom",
"params": {"parallelism": 8},
"tasks": {
"qa": {
"type": "completion",
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
}
},
}
benchmark = Benchmark(
provider_id="nvidia",
type="benchmark",
identifier=MOCK_BENCHMARK_ID,
dataset_id=MOCK_DATASET_ID,
scoring_functions=["basic::equality"],
metadata=eval_config,
)
# Mock Evaluator API response
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Register the benchmark
self.run_async(self.eval_impl.register_benchmark(benchmark))
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
def test_run_eval(self):
benchmark_config = BenchmarkConfig(
eval_candidate=ModelCandidate(
type="model",
model=CoreModelId.llama3_1_8b_instruct.value,
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
)
)
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Run the Evaluation job
result = self.run_async(
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
)
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body(
{
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
}
)
# Verify the result
assert isinstance(result, Job)
assert result.job_id == "job-123"
assert result.status == JobStatus.in_progress
def test_job_status(self):
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "completed"}
self.mock_evaluator_get.return_value = mock_evaluator_response
# Get the Evaluation job
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
# Verify the result
assert isinstance(result, Job)
assert result.job_id == "job-123"
assert result.status == JobStatus.completed
# Verify the API was called correctly
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
def test_job_cancel(self):
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Cancel the Evaluation job
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
# Verify the API was called correctly
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
def test_job_result(self):
# Mock Evaluator API responses
mock_job_status_response = {"id": "job-123", "status": "completed"}
mock_job_results_response = {
"id": "job-123",
"status": "completed",
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
}
self.mock_evaluator_get.side_effect = [
mock_job_status_response, # First call to retrieve job
mock_job_results_response, # Second call to retrieve job results
]
# Get the Evaluation job results
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
# Verify the result
assert isinstance(result, EvaluateResponse)
assert MOCK_BENCHMARK_ID in result.scores
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
# Verify the API was called correctly
assert self.mock_evaluator_get.call_count == 2
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")

View file

@ -10,14 +10,17 @@ import warnings
from unittest.mock import patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigEfficiencyConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.apis.post_training.post_training import (
DataConfig,
DatasetFormat,
EfficiencyConfig,
LoraFinetuningConfig,
OptimizerConfig,
OptimizerType,
TrainingConfig,
)
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.post_training.nvidia.post_training import (
NvidiaPostTrainingAdapter,
NvidiaPostTrainingConfig,
@ -66,11 +69,8 @@ class TestNvidiaParameters(unittest.TestCase):
def test_customizer_parameters_passed(self):
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
custom_adapter_dim = 32 # Different from default of 8
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=custom_adapter_dim,
adapter_dropout=0.2,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
@ -78,8 +78,15 @@ class TestNvidiaParameters(unittest.TestCase):
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
data_config = DataConfig(
dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0002,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=3,
data_config=data_config,
@ -95,7 +102,7 @@ class TestNvidiaParameters(unittest.TestCase):
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
@ -114,7 +121,7 @@ class TestNvidiaParameters(unittest.TestCase):
self._assert_request_params(
{
"hyperparameters": {
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
"lora": {"alpha": 16},
"epochs": 3,
"learning_rate": 0.0002,
"batch_size": 16,
@ -130,8 +137,6 @@ class TestNvidiaParameters(unittest.TestCase):
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
@ -139,12 +144,16 @@ class TestNvidiaParameters(unittest.TestCase):
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(
dataset_id=required_dataset_id, # Required parameter
batch_size=8,
data_config = DataConfig(
dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=1,
@ -161,7 +170,7 @@ class TestNvidiaParameters(unittest.TestCase):
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
@ -186,24 +195,24 @@ class TestNvidiaParameters(unittest.TestCase):
def test_unsupported_parameters_warning(self):
"""Test that warnings are raised for unsupported parameters."""
data_config = TrainingConfigDataConfig(
data_config = DataConfig(
dataset_id="test-dataset",
batch_size=8,
# Unsupported parameters
shuffle=True,
data_format="instruct",
data_format=DatasetFormat.instruct,
validation_dataset_id="val-dataset",
)
optimizer_config = TrainingConfigOptimizerConfig(
optimizer_config = OptimizerConfig(
lr=0.0001,
weight_decay=0.01,
# Unsupported parameters
optimizer_type="adam",
optimizer_type=OptimizerType.adam,
num_warmup_steps=100,
)
efficiency_config = TrainingConfigEfficiencyConfig(
efficiency_config = EfficiencyConfig(
enable_activation_checkpointing=True # Unsupported parameter
)
@ -230,15 +239,13 @@ class TestNvidiaParameters(unittest.TestCase):
checkpoint_dir="test-dir", # Unsupported parameter
algorithm_config=LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
),
training_config=training_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
)

View file

@ -10,13 +10,19 @@ import warnings
from unittest.mock import patch
import pytest
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
TrainingConfig,
TrainingConfigDataConfig,
TrainingConfigOptimizerConfig,
)
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.post_training.post_training import (
DataConfig,
DatasetFormat,
LoraFinetuningConfig,
OptimizerConfig,
OptimizerType,
QATFinetuningConfig,
TrainingConfig,
)
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter,
@ -40,8 +46,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
)
self.mock_make_request = self.make_request_patcher.start()
# Mock the inference client
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
self.mock_client = unittest.mock.MagicMock()
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
self.inference_mock_make_request = self.mock_client.chat.completions.create
self.inference_make_request_patcher = patch(
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
return_value=self.mock_client,
)
self.inference_make_request_patcher.start()
def tearDown(self):
self.make_request_patcher.stop()
self.inference_make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
@ -105,7 +125,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
"lora": {"alpha": 16},
},
"output_model": "default/job-1234",
"status": "created",
@ -116,8 +136,6 @@ class TestNvidiaPostTraining(unittest.TestCase):
algorithm_config = LoraFinetuningConfig(
type="LoRA",
adapter_dim=16,
adapter_dropout=0.1,
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
@ -125,10 +143,15 @@ class TestNvidiaPostTraining(unittest.TestCase):
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = TrainingConfigOptimizerConfig(
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
@ -145,7 +168,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
@ -169,16 +192,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
"epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001,
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
"weight_decay": 0.01,
"lora": {"alpha": 16},
},
},
)
def test_supervised_fine_tune_with_qat(self):
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
optimizer_config = TrainingConfigOptimizerConfig(
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
@ -193,7 +222,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=training_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
@ -303,6 +332,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
expected_params={"job_id": job_id},
)
def test_inference_register_model(self):
model_id = "default/job-1234"
model_type = ModelType.llm
model = Model(
identifier=model_id,
provider_id="nvidia",
provider_model_id=model_id,
provider_resource_id=model_id,
model_type=model_type,
)
result = self.run_async(self.inference_adapter.register_model(model))
assert result == model
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
self.run_async(
self.inference_adapter.chat_completion(
model_id=model_id,
messages=[{"role": "user", "content": "Hello, model"}],
)
)
mock_chat_completion.assert_called()
if __name__ == "__main__":
unittest.main()

View file

@ -0,0 +1,43 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import pytest
from llama_stack.apis.common.content_types import TextContentItem
from llama_stack.apis.inference.inference import CompletionMessage, UserMessage
from llama_stack.models.llama.datatypes import StopReason, ToolCall
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict():
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
assert await convert_message_to_openai_dict(message) == {
"role": "user",
"content": [{"type": "text", "text": "Hello, world!"}],
}
# Test convert_message_to_openai_dict with a tool call
@pytest.mark.asyncio
async def test_convert_message_to_openai_dict_with_tool_call():
message = CompletionMessage(
content="",
tool_calls=[
ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"})
],
stop_reason=StopReason.end_of_turn,
)
openai_dict = await convert_message_to_openai_dict(message)
assert openai_dict == {
"role": "assistant",
"content": [{"type": "text", "text": ""}],
"tool_calls": [
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
],
}

View file

@ -0,0 +1,91 @@
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import pytest
from llama_stack.distribution.server.server import create_sse_event, sse_generator
@pytest.mark.asyncio
async def test_sse_generator_basic():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
async def event_gen():
yield "Test event 1"
yield "Test event 2"
return event_gen()
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
# Test that the events are streamed correctly
seen_events = []
async for event in sse_gen:
seen_events.append(event)
assert len(seen_events) == 2
assert seen_events[0] == create_sse_event("Test event 1")
assert seen_events[1] == create_sse_event("Test event 2")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected():
# An AsyncIterator wrapped in an Awaitable, just like our web methods
async def async_event_gen():
async def event_gen():
yield "Test event 1"
# Simulate a client disconnect before emitting event 2
raise asyncio.CancelledError()
return event_gen()
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
seen_events = []
async for event in sse_gen:
seen_events.append(event)
# We should see 1 event before the client disconnected
assert len(seen_events) == 1
assert seen_events[0] == create_sse_event("Test event 1")
@pytest.mark.asyncio
async def test_sse_generator_client_disconnected_before_response_starts():
# Disconnect before the response starts
async def async_event_gen():
raise asyncio.CancelledError()
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
seen_events = []
async for event in sse_gen:
seen_events.append(event)
# No events should be seen since the client disconnected immediately
assert len(seen_events) == 0
@pytest.mark.asyncio
async def test_sse_generator_error_before_response_starts():
# Raise an error before the response starts
async def async_event_gen():
raise Exception("Test error")
sse_gen = sse_generator(async_event_gen())
assert sse_gen is not None
seen_events = []
async for event in sse_gen:
seen_events.append(event)
# We should have 1 error event
assert len(seen_events) == 1
assert 'data: {"error":' in seen_events[0]

View file

@ -15,6 +15,52 @@ test_chat_basic:
S?
role: user
output: Saturn
test_chat_input_validation:
test_name: test_chat_input_validation
test_params:
case:
- case_id: "messages_missing"
input:
messages: []
output:
error:
status_code: 400
- case_id: "messages_role_invalid"
input:
messages:
- content: Which planet do humans live on?
role: fake_role
output:
error:
status_code: 400
- case_id: "tool_choice_invalid"
input:
messages:
- content: Which planet do humans live on?
role: user
tool_choice: invalid
output:
error:
status_code: 400
- case_id: "tool_choice_no_tools"
input:
messages:
- content: Which planet do humans live on?
role: user
tool_choice: required
output:
error:
status_code: 400
- case_id: "tools_type_invalid"
input:
messages:
- content: Which planet do humans live on?
role: user
tools:
- type: invalid
output:
error:
status_code: 400
test_chat_image:
test_name: test_chat_image
test_params:

View file

@ -12,6 +12,7 @@ from pathlib import Path
from typing import Any
import pytest
from openai import APIError
from pydantic import BaseModel
from tests.verifications.openai_api.fixtures.fixtures import (
@ -136,6 +137,50 @@ def test_chat_streaming_basic(request, openai_client, model, provider, verificat
assert case["output"].lower() in content.lower()
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with pytest.raises(APIError) as e:
openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=False,
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
tools=case["input"]["tools"] if "tools" in case["input"] else None,
)
assert case["output"]["error"]["status_code"] == e.value.status_code
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
ids=case_id_generator,
)
def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
test_name_base = get_base_test_name(request)
if should_skip_test(verification_config, provider, model, test_name_base):
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
with pytest.raises(APIError) as e:
response = openai_client.chat.completions.create(
model=model,
messages=case["input"]["messages"],
stream=True,
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
tools=case["input"]["tools"] if "tools" in case["input"] else None,
)
for _chunk in response:
pass
assert str(case["output"]["error"]["status_code"]) in e.value.message
@pytest.mark.parametrize(
"case",
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],

2
uv.lock generated
View file

@ -1458,6 +1458,7 @@ unit = [
{ name = "aiosqlite" },
{ name = "chardet" },
{ name = "openai" },
{ name = "opentelemetry-exporter-otlp-proto-http" },
{ name = "pypdf" },
{ name = "qdrant-client" },
{ name = "sqlite-vec" },
@ -1491,6 +1492,7 @@ requires-dist = [
{ name = "openai", marker = "extra == 'test'" },
{ name = "openai", marker = "extra == 'unit'" },
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" },
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'unit'" },
{ name = "opentelemetry-sdk", marker = "extra == 'test'" },
{ name = "pandas", marker = "extra == 'ui'" },
{ name = "pillow" },