mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 09:21:45 +00:00
Merge branch 'main' into test-modelregistryhelper
This commit is contained in:
commit
7fd8a61b4d
80 changed files with 2918 additions and 386 deletions
6
.coveragerc
Normal file
6
.coveragerc
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
[run]
|
||||||
|
omit =
|
||||||
|
*/tests/*
|
||||||
|
*/llama_stack/providers/*
|
||||||
|
*/llama_stack/templates/*
|
||||||
|
.venv/*
|
1
.github/workflows/integration-tests.yml
vendored
1
.github/workflows/integration-tests.yml
vendored
|
@ -6,7 +6,6 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- 'distributions/**'
|
|
||||||
- 'llama_stack/**'
|
- 'llama_stack/**'
|
||||||
- 'tests/integration/**'
|
- 'tests/integration/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
|
|
38
.github/workflows/providers-build.yml
vendored
38
.github/workflows/providers-build.yml
vendored
|
@ -107,3 +107,41 @@ jobs:
|
||||||
- name: Build a single provider
|
- name: Build a single provider
|
||||||
run: |
|
run: |
|
||||||
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama
|
||||||
|
|
||||||
|
build-custom-container-distribution:
|
||||||
|
runs-on: ubuntu-latest
|
||||||
|
steps:
|
||||||
|
- name: Checkout repository
|
||||||
|
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
|
||||||
|
|
||||||
|
- name: Set up Python
|
||||||
|
uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0
|
||||||
|
with:
|
||||||
|
python-version: '3.10'
|
||||||
|
|
||||||
|
- name: Install uv
|
||||||
|
uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1
|
||||||
|
with:
|
||||||
|
python-version: "3.10"
|
||||||
|
|
||||||
|
- name: Install LlamaStack
|
||||||
|
run: |
|
||||||
|
uv venv
|
||||||
|
source .venv/bin/activate
|
||||||
|
uv pip install -e .
|
||||||
|
|
||||||
|
- name: Build a single provider
|
||||||
|
run: |
|
||||||
|
yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml
|
||||||
|
yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml
|
||||||
|
USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml
|
||||||
|
|
||||||
|
- name: Inspect the container image entrypoint
|
||||||
|
run: |
|
||||||
|
IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1)
|
||||||
|
entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID)
|
||||||
|
echo "Entrypoint: $entrypoint"
|
||||||
|
if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then
|
||||||
|
echo "Entrypoint is not correct"
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
|
|
@ -5,6 +5,13 @@ on:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
|
paths:
|
||||||
|
- 'llama_stack/**'
|
||||||
|
- 'tests/integration/**'
|
||||||
|
- 'uv.lock'
|
||||||
|
- 'pyproject.toml'
|
||||||
|
- 'requirements.txt'
|
||||||
|
- '.github/workflows/test-external-providers.yml' # This workflow
|
||||||
|
|
||||||
jobs:
|
jobs:
|
||||||
test-external-providers:
|
test-external-providers:
|
||||||
|
|
1
.github/workflows/unit-tests.yml
vendored
1
.github/workflows/unit-tests.yml
vendored
|
@ -6,7 +6,6 @@ on:
|
||||||
pull_request:
|
pull_request:
|
||||||
branches: [ main ]
|
branches: [ main ]
|
||||||
paths:
|
paths:
|
||||||
- 'distributions/**'
|
|
||||||
- 'llama_stack/**'
|
- 'llama_stack/**'
|
||||||
- 'tests/unit/**'
|
- 'tests/unit/**'
|
||||||
- 'uv.lock'
|
- 'uv.lock'
|
||||||
|
|
|
@ -119,6 +119,7 @@ Here is a list of the various API providers and available distributions that can
|
||||||
| OpenAI | Hosted | | ✅ | | | |
|
| OpenAI | Hosted | | ✅ | | | |
|
||||||
| Anthropic | Hosted | | ✅ | | | |
|
| Anthropic | Hosted | | ✅ | | | |
|
||||||
| Gemini | Hosted | | ✅ | | | |
|
| Gemini | Hosted | | ✅ | | | |
|
||||||
|
| watsonx | Hosted | | ✅ | | | |
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
@ -128,7 +129,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
|
||||||
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
||||||
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
||||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
|
||||||
| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) |
|
| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) |
|
||||||
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
|
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
|
||||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
||||||
|
|
|
@ -68,7 +68,8 @@ chunks_response = client.vector_io.query(
|
||||||
### Using the RAG Tool
|
### Using the RAG Tool
|
||||||
|
|
||||||
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc.
|
||||||
and automatically chunks them into smaller pieces.
|
and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the
|
||||||
|
[appendix](#more-ragdocument-examples).
|
||||||
|
|
||||||
```python
|
```python
|
||||||
from llama_stack_client import RAGDocument
|
from llama_stack_client import RAGDocument
|
||||||
|
@ -178,3 +179,38 @@ for vector_db_id in client.vector_dbs.list():
|
||||||
print(f"Unregistering vector database: {vector_db_id.identifier}")
|
print(f"Unregistering vector database: {vector_db_id.identifier}")
|
||||||
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
|
client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier)
|
||||||
```
|
```
|
||||||
|
|
||||||
|
### Appendix
|
||||||
|
|
||||||
|
#### More RAGDocument Examples
|
||||||
|
```python
|
||||||
|
from llama_stack_client import RAGDocument
|
||||||
|
import base64
|
||||||
|
|
||||||
|
RAGDocument(document_id="num-0", content={"uri": "file://path/to/file"})
|
||||||
|
RAGDocument(document_id="num-1", content="plain text")
|
||||||
|
RAGDocument(
|
||||||
|
document_id="num-2",
|
||||||
|
content={
|
||||||
|
"type": "text",
|
||||||
|
"text": "plain text input",
|
||||||
|
}, # for inputs that should be treated as text explicitly
|
||||||
|
)
|
||||||
|
RAGDocument(
|
||||||
|
document_id="num-3",
|
||||||
|
content={
|
||||||
|
"type": "image",
|
||||||
|
"image": {"url": {"uri": "https://mywebsite.com/image.jpg"}},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
B64_ENCODED_IMAGE = base64.b64encode(
|
||||||
|
requests.get(
|
||||||
|
"https://raw.githubusercontent.com/meta-llama/llama-stack/refs/heads/main/docs/_static/llama-stack.png"
|
||||||
|
).content
|
||||||
|
)
|
||||||
|
RAGDocuemnt(
|
||||||
|
document_id="num-4",
|
||||||
|
content={"type": "image", "image": {"data": B64_ENCODED_IMAGE}},
|
||||||
|
)
|
||||||
|
```
|
||||||
|
for more strongly typed interaction use the typed dicts found [here](https://github.com/meta-llama/llama-stack-client-python/blob/38cd91c9e396f2be0bec1ee96a19771582ba6f17/src/llama_stack_client/types/shared_params/document.py).
|
||||||
|
|
|
@ -109,8 +109,6 @@ llama stack build --list-templates
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| nvidia | Use NVIDIA NIM for running LLM inference |
|
| nvidia | Use NVIDIA NIM for running LLM inference |
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| meta-reference-quantized-gpu | Use Meta Reference with fp8, int4 quantization for running LLM inference |
|
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
|
||||||
| cerebras | Use Cerebras for running LLM inference |
|
| cerebras | Use Cerebras for running LLM inference |
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| ollama | Use (an external) Ollama server for running LLM inference |
|
| ollama | Use (an external) Ollama server for running LLM inference |
|
||||||
|
|
88
docs/source/distributions/remote_hosted_distro/watsonx.md
Normal file
88
docs/source/distributions/remote_hosted_distro/watsonx.md
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||||
|
# watsonx Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
|
The `llamastack/distribution-watsonx` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
| API | Provider(s) |
|
||||||
|
|-----|-------------|
|
||||||
|
| agents | `inline::meta-reference` |
|
||||||
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
|
| eval | `inline::meta-reference` |
|
||||||
|
| inference | `remote::watsonx` |
|
||||||
|
| safety | `inline::llama-guard` |
|
||||||
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
|
- `WATSONX_API_KEY`: watsonx API Key (default: ``)
|
||||||
|
- `WATSONX_PROJECT_ID`: watsonx Project ID (default: ``)
|
||||||
|
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
- `meta-llama/llama-3-3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
|
- `meta-llama/llama-2-13b-chat (aliases: meta-llama/Llama-2-13b)`
|
||||||
|
- `meta-llama/llama-3-1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
|
- `meta-llama/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
|
||||||
|
|
||||||
|
|
||||||
|
## Running Llama Stack with watsonx
|
||||||
|
|
||||||
|
You can do this via Conda (build code), venv or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-watsonx \
|
||||||
|
--yaml-config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template watsonx --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
|
||||||
|
```
|
|
@ -81,6 +81,7 @@ LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
@ -94,6 +95,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
|
|
@ -1,123 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
|
||||||
# Meta Reference Quantized Distribution
|
|
||||||
|
|
||||||
```{toctree}
|
|
||||||
:maxdepth: 2
|
|
||||||
:hidden:
|
|
||||||
|
|
||||||
self
|
|
||||||
```
|
|
||||||
|
|
||||||
The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations:
|
|
||||||
|
|
||||||
| API | Provider(s) |
|
|
||||||
|-----|-------------|
|
|
||||||
| agents | `inline::meta-reference` |
|
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
|
||||||
| eval | `inline::meta-reference` |
|
|
||||||
| inference | `inline::meta-reference-quantized` |
|
|
||||||
| safety | `inline::llama-guard` |
|
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
|
||||||
| telemetry | `inline::meta-reference` |
|
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
|
||||||
|
|
||||||
|
|
||||||
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
|
||||||
|
|
||||||
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
The following environment variables can be configured:
|
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
|
||||||
|
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
|
||||||
|
|
||||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
|
||||||
|
|
||||||
```
|
|
||||||
$ llama model list --downloaded
|
|
||||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
|
||||||
┃ Model ┃ Size ┃ Modified Time ┃
|
|
||||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
|
||||||
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
|
||||||
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running the Distribution
|
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
|
||||||
|
|
||||||
### Via Docker
|
|
||||||
|
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
LLAMA_STACK_PORT=8321
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
|
||||||
|
|
||||||
### Via Conda
|
|
||||||
|
|
||||||
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template meta-reference-quantized-gpu --image-type conda
|
|
||||||
llama stack run distributions/meta-reference-quantized-gpu/run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack run distributions/meta-reference-quantized-gpu/run-with-safety.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
|
|
@ -7,7 +7,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `inline::localfs` |
|
| datasetio | `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `remote::nvidia` |
|
||||||
| inference | `remote::nvidia` |
|
| inference | `remote::nvidia` |
|
||||||
| post_training | `remote::nvidia` |
|
| post_training | `remote::nvidia` |
|
||||||
| safety | `remote::nvidia` |
|
| safety | `remote::nvidia` |
|
||||||
|
@ -22,13 +22,13 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||||
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
|
||||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||||
|
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
|
||||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||||
|
|
||||||
|
|
|
@ -50,9 +50,10 @@ Llama Stack supports two types of external providers:
|
||||||
|
|
||||||
Here's a list of known external providers that you can use with Llama Stack:
|
Here's a list of known external providers that you can use with Llama Stack:
|
||||||
|
|
||||||
| Type | Name | Description | Repository |
|
| Name | Description | API | Type | Repository |
|
||||||
|------|------|-------------|------------|
|
|------|-------------|-----|------|------------|
|
||||||
| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) |
|
||||||
|
| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) |
|
||||||
|
|
||||||
### Remote Provider Specification
|
### Remote Provider Specification
|
||||||
|
|
||||||
|
|
|
@ -389,5 +389,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -256,5 +256,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -301,5 +301,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -200,5 +200,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -355,5 +355,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -398,5 +398,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,5 +132,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.10"
|
"version": "3.11.10"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,5 +188,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,12 +136,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
image_type = prompt(
|
image_type = prompt(
|
||||||
f"> Enter the image type you want your Llama Stack to be built as ({' or '.join(e.value for e in ImageType)}): ",
|
"> Enter the image type you want your Llama Stack to be built as (use <TAB> to see options): ",
|
||||||
|
completer=WordCompleter([e.value for e in ImageType]),
|
||||||
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
lambda x: x in [e.value for e in ImageType],
|
lambda x: x in [e.value for e in ImageType],
|
||||||
error_message=f"Invalid image type, please enter {' or '.join(e.value for e in ImageType)}",
|
error_message="Invalid image type. Use <TAB> to see options",
|
||||||
),
|
),
|
||||||
default=ImageType.CONDA.value,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_type == ImageType.CONDA.value:
|
if image_type == ImageType.CONDA.value:
|
||||||
|
@ -317,7 +318,11 @@ def _generate_run_config(
|
||||||
to_write = json.loads(run_config.model_dump_json())
|
to_write = json.loads(run_config.model_dump_json())
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
|
||||||
# this path is only invoked when no template is provided
|
# Only print this message for non-container builds since it will be displayed before the
|
||||||
|
# container is built
|
||||||
|
# For non-container builds, the run.yaml is generated at the very end of the build process so it
|
||||||
|
# makes sense to display this message
|
||||||
|
if build_config.image_type != LlamaStackImageType.CONTAINER.value:
|
||||||
cprint(
|
cprint(
|
||||||
f"You can now run your stack with `llama stack run {run_config_file}`",
|
f"You can now run your stack with `llama stack run {run_config_file}`",
|
||||||
color="green",
|
color="green",
|
||||||
|
@ -355,6 +360,13 @@ def _run_stack_build_command_from_build_config(
|
||||||
build_file_path = build_dir / f"{image_name}-build.yaml"
|
build_file_path = build_dir / f"{image_name}-build.yaml"
|
||||||
|
|
||||||
os.makedirs(build_dir, exist_ok=True)
|
os.makedirs(build_dir, exist_ok=True)
|
||||||
|
run_config_file = None
|
||||||
|
# Generate the run.yaml so it can be included in the container image with the proper entrypoint
|
||||||
|
# Only do this if we're building a container image and we're not using a template
|
||||||
|
if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path:
|
||||||
|
cprint("Generating run.yaml file", color="green")
|
||||||
|
run_config_file = _generate_run_config(build_config, build_dir, image_name)
|
||||||
|
|
||||||
with open(build_file_path, "w") as f:
|
with open(build_file_path, "w") as f:
|
||||||
to_write = json.loads(build_config.model_dump_json())
|
to_write = json.loads(build_config.model_dump_json())
|
||||||
f.write(yaml.dump(to_write, sort_keys=False))
|
f.write(yaml.dump(to_write, sort_keys=False))
|
||||||
|
@ -364,6 +376,7 @@ def _run_stack_build_command_from_build_config(
|
||||||
build_file_path,
|
build_file_path,
|
||||||
image_name,
|
image_name,
|
||||||
template_or_config=template_name or config_path or str(build_file_path),
|
template_or_config=template_name or config_path or str(build_file_path),
|
||||||
|
run_config=run_config_file,
|
||||||
)
|
)
|
||||||
if return_code != 0:
|
if return_code != 0:
|
||||||
raise RuntimeError(f"Failed to build image {image_name}")
|
raise RuntimeError(f"Failed to build image {image_name}")
|
||||||
|
|
|
@ -93,6 +93,7 @@ def build_image(
|
||||||
build_file_path: Path,
|
build_file_path: Path,
|
||||||
image_name: str,
|
image_name: str,
|
||||||
template_or_config: str,
|
template_or_config: str,
|
||||||
|
run_config: str | None = None,
|
||||||
):
|
):
|
||||||
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
container_base = build_config.distribution_spec.container_image or "python:3.10-slim"
|
||||||
|
|
||||||
|
@ -108,6 +109,11 @@ def build_image(
|
||||||
container_base,
|
container_base,
|
||||||
" ".join(normal_deps),
|
" ".join(normal_deps),
|
||||||
]
|
]
|
||||||
|
|
||||||
|
# When building from a config file (not a template), include the run config path in the
|
||||||
|
# build arguments
|
||||||
|
if run_config is not None:
|
||||||
|
args.append(run_config)
|
||||||
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
elif build_config.image_type == LlamaStackImageType.CONDA.value:
|
||||||
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh")
|
||||||
args = [
|
args = [
|
||||||
|
|
|
@ -19,12 +19,16 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500}
|
||||||
# mounting is not supported by docker buildx, so we use COPY instead
|
# mounting is not supported by docker buildx, so we use COPY instead
|
||||||
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-}
|
||||||
|
|
||||||
|
# Path to the run.yaml file in the container
|
||||||
|
RUN_CONFIG_PATH=/app/run.yaml
|
||||||
|
|
||||||
|
BUILD_CONTEXT_DIR=$(pwd)
|
||||||
|
|
||||||
if [ "$#" -lt 4 ]; then
|
if [ "$#" -lt 4 ]; then
|
||||||
# This only works for templates
|
# This only works for templates
|
||||||
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<special_pip_deps>]" >&2
|
echo "Usage: $0 <template_or_config> <image_name> <container_base> <pip_dependencies> [<run_config>] [<special_pip_deps>]" >&2
|
||||||
exit 1
|
exit 1
|
||||||
fi
|
fi
|
||||||
|
|
||||||
set -euo pipefail
|
set -euo pipefail
|
||||||
|
|
||||||
template_or_config="$1"
|
template_or_config="$1"
|
||||||
|
@ -35,8 +39,27 @@ container_base="$1"
|
||||||
shift
|
shift
|
||||||
pip_dependencies="$1"
|
pip_dependencies="$1"
|
||||||
shift
|
shift
|
||||||
special_pip_deps="${1:-}"
|
|
||||||
|
|
||||||
|
# Handle optional arguments
|
||||||
|
run_config=""
|
||||||
|
special_pip_deps=""
|
||||||
|
|
||||||
|
# Check if there are more arguments
|
||||||
|
# The logics is becoming cumbersom, we should refactor it if we can do better
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
# Check if the argument ends with .yaml
|
||||||
|
if [[ "$1" == *.yaml ]]; then
|
||||||
|
run_config="$1"
|
||||||
|
shift
|
||||||
|
# If there's another argument after .yaml, it must be special_pip_deps
|
||||||
|
if [ $# -gt 0 ]; then
|
||||||
|
special_pip_deps="$1"
|
||||||
|
fi
|
||||||
|
else
|
||||||
|
# If it's not .yaml, it must be special_pip_deps
|
||||||
|
special_pip_deps="$1"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
# Define color codes
|
# Define color codes
|
||||||
RED='\033[0;31m'
|
RED='\033[0;31m'
|
||||||
|
@ -75,7 +98,7 @@ WORKDIR /app
|
||||||
# We install the Python 3.11 dev headers and build tools so that any
|
# We install the Python 3.11 dev headers and build tools so that any
|
||||||
# C‑extension wheels (e.g. polyleven, faiss‑cpu) can compile successfully.
|
# C‑extension wheels (e.g. polyleven, faiss‑cpu) can compile successfully.
|
||||||
|
|
||||||
RUN dnf -y update && dnf install -y iputils net-tools wget \
|
RUN dnf -y update && dnf install -y iputils git net-tools wget \
|
||||||
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
vim-minimal python3.11 python3.11-pip python3.11-wheel \
|
||||||
python3.11-setuptools python3.11-devel gcc make && \
|
python3.11-setuptools python3.11-devel gcc make && \
|
||||||
ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all
|
||||||
|
@ -119,6 +142,45 @@ EOF
|
||||||
done
|
done
|
||||||
fi
|
fi
|
||||||
|
|
||||||
|
# Function to get Python command
|
||||||
|
get_python_cmd() {
|
||||||
|
if is_command_available python; then
|
||||||
|
echo "python"
|
||||||
|
elif is_command_available python3; then
|
||||||
|
echo "python3"
|
||||||
|
else
|
||||||
|
echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2
|
||||||
|
exit 1
|
||||||
|
fi
|
||||||
|
}
|
||||||
|
|
||||||
|
if [ -n "$run_config" ]; then
|
||||||
|
# Copy the run config to the build context since it's an absolute path
|
||||||
|
cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
add_to_container << EOF
|
||||||
|
COPY run.yaml $RUN_CONFIG_PATH
|
||||||
|
EOF
|
||||||
|
|
||||||
|
# Parse the run.yaml configuration to identify external provider directories
|
||||||
|
# If external providers are specified, copy their directory to the container
|
||||||
|
# and update the configuration to reference the new container path
|
||||||
|
python_cmd=$(get_python_cmd)
|
||||||
|
external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')")
|
||||||
|
if [ -n "$external_providers_dir" ]; then
|
||||||
|
echo "Copying external providers directory: $external_providers_dir"
|
||||||
|
add_to_container << EOF
|
||||||
|
COPY $external_providers_dir /app/providers.d
|
||||||
|
EOF
|
||||||
|
# Edit the run.yaml file to change the external_providers_dir to /app/providers.d
|
||||||
|
if [ "$(uname)" = "Darwin" ]; then
|
||||||
|
sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak"
|
||||||
|
else
|
||||||
|
sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
fi
|
||||||
|
|
||||||
stack_mount="/app/llama-stack-source"
|
stack_mount="/app/llama-stack-source"
|
||||||
client_mount="/app/llama-stack-client-source"
|
client_mount="/app/llama-stack-client-source"
|
||||||
|
|
||||||
|
@ -178,15 +240,16 @@ fi
|
||||||
RUN pip uninstall -y uv
|
RUN pip uninstall -y uv
|
||||||
EOF
|
EOF
|
||||||
|
|
||||||
# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag
|
# If a run config is provided, we use the --config flag
|
||||||
if [[ "$template_or_config" != *.yaml ]]; then
|
if [[ -n "$run_config" ]]; then
|
||||||
|
add_to_container << EOF
|
||||||
|
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"]
|
||||||
|
EOF
|
||||||
|
# If a template is provided (not a yaml file), we use the --template flag
|
||||||
|
elif [[ "$template_or_config" != *.yaml ]]; then
|
||||||
add_to_container << EOF
|
add_to_container << EOF
|
||||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
|
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"]
|
||||||
EOF
|
EOF
|
||||||
else
|
|
||||||
add_to_container << EOF
|
|
||||||
ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"]
|
|
||||||
EOF
|
|
||||||
fi
|
fi
|
||||||
|
|
||||||
# Add other require item commands genearic to all containers
|
# Add other require item commands genearic to all containers
|
||||||
|
@ -258,9 +321,10 @@ $CONTAINER_BINARY build \
|
||||||
"${CLI_ARGS[@]}" \
|
"${CLI_ARGS[@]}" \
|
||||||
-t "$image_tag" \
|
-t "$image_tag" \
|
||||||
-f "$TEMP_DIR/Containerfile" \
|
-f "$TEMP_DIR/Containerfile" \
|
||||||
"."
|
"$BUILD_CONTEXT_DIR"
|
||||||
|
|
||||||
# clean up tmp/configs
|
# clean up tmp/configs
|
||||||
|
rm -f "$BUILD_CONTEXT_DIR/run.yaml"
|
||||||
set +x
|
set +x
|
||||||
|
|
||||||
echo "Success!"
|
echo "Success!"
|
||||||
|
|
|
@ -8,6 +8,11 @@ import asyncio
|
||||||
import time
|
import time
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam
|
||||||
|
from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam
|
||||||
|
from pydantic import Field, TypeAdapter
|
||||||
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
URL,
|
URL,
|
||||||
InterleavedContent,
|
InterleavedContent,
|
||||||
|
@ -526,7 +531,7 @@ class InferenceRouter(Inference):
|
||||||
async def openai_chat_completion(
|
async def openai_chat_completion(
|
||||||
self,
|
self,
|
||||||
model: str,
|
model: str,
|
||||||
messages: List[OpenAIMessageParam],
|
messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)],
|
||||||
frequency_penalty: Optional[float] = None,
|
frequency_penalty: Optional[float] = None,
|
||||||
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
functions: Optional[List[Dict[str, Any]]] = None,
|
functions: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
@ -558,6 +563,16 @@ class InferenceRouter(Inference):
|
||||||
if model_obj.model_type == ModelType.embedding:
|
if model_obj.model_type == ModelType.embedding:
|
||||||
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions")
|
||||||
|
|
||||||
|
# Use the OpenAI client for a bit of extra input validation without
|
||||||
|
# exposing the OpenAI client itself as part of our API surface
|
||||||
|
if tool_choice:
|
||||||
|
TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice)
|
||||||
|
if tools is None:
|
||||||
|
raise ValueError("'tool_choice' is only allowed when 'tools' is also provided")
|
||||||
|
if tools:
|
||||||
|
for tool in tools:
|
||||||
|
TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool)
|
||||||
|
|
||||||
params = dict(
|
params = dict(
|
||||||
model=model_obj.identifier,
|
model=model_obj.identifier,
|
||||||
messages=messages,
|
messages=messages,
|
||||||
|
|
|
@ -22,6 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, Request
|
||||||
from fastapi import Path as FastapiPath
|
from fastapi import Path as FastapiPath
|
||||||
from fastapi.exceptions import RequestValidationError
|
from fastapi.exceptions import RequestValidationError
|
||||||
from fastapi.responses import JSONResponse, StreamingResponse
|
from fastapi.responses import JSONResponse, StreamingResponse
|
||||||
|
from openai import BadRequestError
|
||||||
from pydantic import BaseModel, ValidationError
|
from pydantic import BaseModel, ValidationError
|
||||||
from typing_extensions import Annotated
|
from typing_extensions import Annotated
|
||||||
|
|
||||||
|
@ -110,6 +111,8 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio
|
||||||
)
|
)
|
||||||
elif isinstance(exc, ValueError):
|
elif isinstance(exc, ValueError):
|
||||||
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}")
|
||||||
|
elif isinstance(exc, BadRequestError):
|
||||||
|
return HTTPException(status_code=400, detail=str(exc))
|
||||||
elif isinstance(exc, PermissionError):
|
elif isinstance(exc, PermissionError):
|
||||||
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}")
|
||||||
elif isinstance(exc, TimeoutError):
|
elif isinstance(exc, TimeoutError):
|
||||||
|
@ -162,13 +165,16 @@ async def maybe_await(value):
|
||||||
return value
|
return value
|
||||||
|
|
||||||
|
|
||||||
async def sse_generator(event_gen):
|
async def sse_generator(event_gen_coroutine):
|
||||||
|
event_gen = None
|
||||||
try:
|
try:
|
||||||
async for item in await event_gen:
|
event_gen = await event_gen_coroutine
|
||||||
|
async for item in event_gen:
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Generator cancelled")
|
logger.info("Generator cancelled")
|
||||||
|
if event_gen:
|
||||||
await event_gen.aclose()
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error in sse_generator")
|
logger.exception("Error in sse_generator")
|
||||||
|
@ -455,6 +461,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
||||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
|
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
|
|
|
@ -24,6 +24,13 @@ def rag_chat_page():
|
||||||
def should_disable_input():
|
def should_disable_input():
|
||||||
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0
|
||||||
|
|
||||||
|
def log_message(message):
|
||||||
|
with st.chat_message(message["role"]):
|
||||||
|
if "tool_output" in message and message["tool_output"]:
|
||||||
|
with st.expander(label="Tool Output", expanded=False, icon="🛠"):
|
||||||
|
st.write(message["tool_output"])
|
||||||
|
st.markdown(message["content"])
|
||||||
|
|
||||||
with st.sidebar:
|
with st.sidebar:
|
||||||
# File/Directory Upload Section
|
# File/Directory Upload Section
|
||||||
st.subheader("Upload Documents", divider=True)
|
st.subheader("Upload Documents", divider=True)
|
||||||
|
@ -146,8 +153,7 @@ def rag_chat_page():
|
||||||
|
|
||||||
# Display chat history
|
# Display chat history
|
||||||
for message in st.session_state.displayed_messages:
|
for message in st.session_state.displayed_messages:
|
||||||
with st.chat_message(message["role"]):
|
log_message(message)
|
||||||
st.markdown(message["content"])
|
|
||||||
|
|
||||||
if temperature > 0.0:
|
if temperature > 0.0:
|
||||||
strategy = {
|
strategy = {
|
||||||
|
@ -201,7 +207,7 @@ def rag_chat_page():
|
||||||
|
|
||||||
# Display assistant response
|
# Display assistant response
|
||||||
with st.chat_message("assistant"):
|
with st.chat_message("assistant"):
|
||||||
retrieval_message_placeholder = st.empty()
|
retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠")
|
||||||
message_placeholder = st.empty()
|
message_placeholder = st.empty()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
retrieval_response = ""
|
retrieval_response = ""
|
||||||
|
@ -209,14 +215,16 @@ def rag_chat_page():
|
||||||
log.print()
|
log.print()
|
||||||
if log.role == "tool_execution":
|
if log.role == "tool_execution":
|
||||||
retrieval_response += log.content.replace("====", "").strip()
|
retrieval_response += log.content.replace("====", "").strip()
|
||||||
retrieval_message_placeholder.info(retrieval_response)
|
retrieval_message_placeholder.write(retrieval_response)
|
||||||
else:
|
else:
|
||||||
full_response += log.content
|
full_response += log.content
|
||||||
message_placeholder.markdown(full_response + "▌")
|
message_placeholder.markdown(full_response + "▌")
|
||||||
message_placeholder.markdown(full_response)
|
message_placeholder.markdown(full_response)
|
||||||
|
|
||||||
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
st.session_state.messages.append({"role": "assistant", "content": full_response})
|
||||||
st.session_state.displayed_messages.append({"role": "assistant", "content": full_response})
|
st.session_state.displayed_messages.append(
|
||||||
|
{"role": "assistant", "content": full_response, "tool_output": retrieval_response}
|
||||||
|
)
|
||||||
|
|
||||||
def direct_process_prompt(prompt):
|
def direct_process_prompt(prompt):
|
||||||
# Add the system prompt in the beginning of the conversation
|
# Add the system prompt in the beginning of the conversation
|
||||||
|
@ -230,15 +238,14 @@ def rag_chat_page():
|
||||||
prompt_context = rag_response.content
|
prompt_context = rag_response.content
|
||||||
|
|
||||||
with st.chat_message("assistant"):
|
with st.chat_message("assistant"):
|
||||||
|
with st.expander(label="Retrieval Output", expanded=False):
|
||||||
|
st.write(prompt_context)
|
||||||
|
|
||||||
retrieval_message_placeholder = st.empty()
|
retrieval_message_placeholder = st.empty()
|
||||||
message_placeholder = st.empty()
|
message_placeholder = st.empty()
|
||||||
full_response = ""
|
full_response = ""
|
||||||
retrieval_response = ""
|
retrieval_response = ""
|
||||||
|
|
||||||
# Display the retrieved content
|
|
||||||
retrieval_response += str(prompt_context)
|
|
||||||
retrieval_message_placeholder.info(retrieval_response)
|
|
||||||
|
|
||||||
# Construct the extended prompt
|
# Construct the extended prompt
|
||||||
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}"
|
||||||
|
|
||||||
|
|
|
@ -4,14 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from llama_stack_client import Agent
|
from llama_stack_client import Agent
|
||||||
|
from llama_stack_client.lib.agents.react.agent import ReActAgent
|
||||||
|
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
|
||||||
|
|
||||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(enum.Enum):
|
||||||
|
REGULAR = "Regular"
|
||||||
|
REACT = "ReAct"
|
||||||
|
|
||||||
|
|
||||||
def tool_chat_page():
|
def tool_chat_page():
|
||||||
st.title("🛠 Tools")
|
st.title("🛠 Tools")
|
||||||
|
|
||||||
|
@ -23,6 +32,7 @@ def tool_chat_page():
|
||||||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||||
|
selected_vector_dbs = []
|
||||||
|
|
||||||
def reset_agent():
|
def reset_agent():
|
||||||
st.session_state.clear()
|
st.session_state.clear()
|
||||||
|
@ -66,25 +76,36 @@ def tool_chat_page():
|
||||||
|
|
||||||
toolgroup_selection.extend(mcp_selection)
|
toolgroup_selection.extend(mcp_selection)
|
||||||
|
|
||||||
active_tool_list = []
|
grouped_tools = {}
|
||||||
for toolgroup_id in toolgroup_selection:
|
total_tools = 0
|
||||||
active_tool_list.extend(
|
|
||||||
[
|
|
||||||
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
|
||||||
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
st.markdown(f"Active Tools: 🛠 {len(active_tool_list)}", help="List of currently active tools.")
|
for toolgroup_id in toolgroup_selection:
|
||||||
st.json(active_tool_list)
|
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||||
|
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||||
|
total_tools += len(tools)
|
||||||
|
|
||||||
|
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||||
|
|
||||||
|
for group_id, tools in grouped_tools.items():
|
||||||
|
with st.expander(f"🔧 Tools from `{group_id}`"):
|
||||||
|
for idx, tool in enumerate(tools, start=1):
|
||||||
|
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
|
||||||
|
|
||||||
st.subheader("Agent Configurations")
|
st.subheader("Agent Configurations")
|
||||||
|
st.subheader("Agent Type")
|
||||||
|
agent_type = st.radio(
|
||||||
|
"Select Agent Type",
|
||||||
|
[AgentType.REGULAR, AgentType.REACT],
|
||||||
|
format_func=lambda x: x.value,
|
||||||
|
on_change=reset_agent,
|
||||||
|
)
|
||||||
|
|
||||||
max_tokens = st.slider(
|
max_tokens = st.slider(
|
||||||
"Max Tokens",
|
"Max Tokens",
|
||||||
min_value=0,
|
min_value=0,
|
||||||
max_value=4096,
|
max_value=4096,
|
||||||
value=512,
|
value=512,
|
||||||
step=1,
|
step=64,
|
||||||
help="The maximum number of tokens to generate",
|
help="The maximum number of tokens to generate",
|
||||||
on_change=reset_agent,
|
on_change=reset_agent,
|
||||||
)
|
)
|
||||||
|
@ -101,6 +122,18 @@ def tool_chat_page():
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def create_agent():
|
def create_agent():
|
||||||
|
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||||
|
return ReActAgent(
|
||||||
|
client=client,
|
||||||
|
model=model,
|
||||||
|
tools=toolgroup_selection,
|
||||||
|
response_format={
|
||||||
|
"type": "json_schema",
|
||||||
|
"json_schema": ReActOutput.model_json_schema(),
|
||||||
|
},
|
||||||
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
|
)
|
||||||
|
else:
|
||||||
return Agent(
|
return Agent(
|
||||||
client,
|
client,
|
||||||
model=model,
|
model=model,
|
||||||
|
@ -109,6 +142,8 @@ def tool_chat_page():
|
||||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
st.session_state.agent_type = agent_type
|
||||||
|
|
||||||
agent = create_agent()
|
agent = create_agent()
|
||||||
|
|
||||||
if "agent_session_id" not in st.session_state:
|
if "agent_session_id" not in st.session_state:
|
||||||
|
@ -136,6 +171,158 @@ def tool_chat_page():
|
||||||
)
|
)
|
||||||
|
|
||||||
def response_generator(turn_response):
|
def response_generator(turn_response):
|
||||||
|
if st.session_state.get("agent_type") == AgentType.REACT:
|
||||||
|
return _handle_react_response(turn_response)
|
||||||
|
else:
|
||||||
|
return _handle_regular_response(turn_response)
|
||||||
|
|
||||||
|
def _handle_react_response(turn_response):
|
||||||
|
current_step_content = ""
|
||||||
|
final_answer = None
|
||||||
|
tool_results = []
|
||||||
|
|
||||||
|
for response in turn_response:
|
||||||
|
if not hasattr(response.event, "payload"):
|
||||||
|
yield (
|
||||||
|
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
|
||||||
|
"The response received is missing an expected `payload` attribute.\n"
|
||||||
|
"This could indicate a malformed response or an internal issue within the server.\n\n"
|
||||||
|
f"Error details: {response}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = response.event.payload
|
||||||
|
|
||||||
|
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
|
||||||
|
current_step_content += payload.delta.text
|
||||||
|
continue
|
||||||
|
|
||||||
|
if payload.event_type == "step_complete":
|
||||||
|
step_details = payload.step_details
|
||||||
|
|
||||||
|
if step_details.step_type == "inference":
|
||||||
|
yield from _process_inference_step(current_step_content, tool_results, final_answer)
|
||||||
|
current_step_content = ""
|
||||||
|
elif step_details.step_type == "tool_execution":
|
||||||
|
tool_results = _process_tool_execution(step_details, tool_results)
|
||||||
|
current_step_content = ""
|
||||||
|
else:
|
||||||
|
current_step_content = ""
|
||||||
|
|
||||||
|
if not final_answer and tool_results:
|
||||||
|
yield from _format_tool_results_summary(tool_results)
|
||||||
|
|
||||||
|
def _process_inference_step(current_step_content, tool_results, final_answer):
|
||||||
|
try:
|
||||||
|
react_output_data = json.loads(current_step_content)
|
||||||
|
thought = react_output_data.get("thought")
|
||||||
|
action = react_output_data.get("action")
|
||||||
|
answer = react_output_data.get("answer")
|
||||||
|
|
||||||
|
if answer and answer != "null" and answer is not None:
|
||||||
|
final_answer = answer
|
||||||
|
|
||||||
|
if thought:
|
||||||
|
with st.expander("🤔 Thinking...", expanded=False):
|
||||||
|
st.markdown(f":grey[__{thought}__]")
|
||||||
|
|
||||||
|
if action and isinstance(action, dict):
|
||||||
|
tool_name = action.get("tool_name")
|
||||||
|
tool_params = action.get("tool_params")
|
||||||
|
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
|
||||||
|
st.json(tool_params)
|
||||||
|
|
||||||
|
if answer and answer != "null" and answer is not None:
|
||||||
|
yield f"\n\n✅ **Final Answer:**\n{answer}"
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
|
||||||
|
except Exception as e:
|
||||||
|
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
|
||||||
|
|
||||||
|
return final_answer
|
||||||
|
|
||||||
|
def _process_tool_execution(step_details, tool_results):
|
||||||
|
try:
|
||||||
|
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
|
||||||
|
for tool_response in step_details.tool_responses:
|
||||||
|
tool_name = tool_response.tool_name
|
||||||
|
content = tool_response.content
|
||||||
|
tool_results.append((tool_name, content))
|
||||||
|
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
st.json(parsed_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
st.code(content, language=None)
|
||||||
|
else:
|
||||||
|
with st.expander("⚙️ Observation", expanded=False):
|
||||||
|
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
|
||||||
|
except Exception as e:
|
||||||
|
with st.expander("⚙️ Error in Tool Execution", expanded=False):
|
||||||
|
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
|
||||||
|
|
||||||
|
return tool_results
|
||||||
|
|
||||||
|
def _format_tool_results_summary(tool_results):
|
||||||
|
yield "\n\n**Here's what I found:**\n"
|
||||||
|
for tool_name, content in tool_results:
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
|
||||||
|
if tool_name == "web_search" and "top_k" in parsed_content:
|
||||||
|
yield from _format_web_search_results(parsed_content)
|
||||||
|
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
|
||||||
|
yield from _format_results_list(parsed_content["results"])
|
||||||
|
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
|
||||||
|
yield from _format_dict_results(parsed_content)
|
||||||
|
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
|
||||||
|
yield from _format_list_results(parsed_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
|
||||||
|
except (TypeError, AttributeError, KeyError, IndexError) as e:
|
||||||
|
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
|
||||||
|
|
||||||
|
def _format_web_search_results(parsed_content):
|
||||||
|
for i, result in enumerate(parsed_content["top_k"], 1):
|
||||||
|
if i <= 3:
|
||||||
|
title = result.get("title", "Untitled")
|
||||||
|
url = result.get("url", "")
|
||||||
|
content_text = result.get("content", "").strip()
|
||||||
|
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
|
||||||
|
|
||||||
|
def _format_results_list(results):
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
if i <= 3:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
name = result.get("name", result.get("title", "Result " + str(i)))
|
||||||
|
description = result.get("description", result.get("content", result.get("summary", "")))
|
||||||
|
yield f"\n- **{name}**\n {description}\n"
|
||||||
|
else:
|
||||||
|
yield f"\n- {result}\n"
|
||||||
|
|
||||||
|
def _format_dict_results(parsed_content):
|
||||||
|
yield "\n```\n"
|
||||||
|
for key, value in list(parsed_content.items())[:5]:
|
||||||
|
if isinstance(value, str) and len(value) < 100:
|
||||||
|
yield f"{key}: {value}\n"
|
||||||
|
else:
|
||||||
|
yield f"{key}: [Complex data]\n"
|
||||||
|
yield "```\n"
|
||||||
|
|
||||||
|
def _format_list_results(parsed_content):
|
||||||
|
yield "\n"
|
||||||
|
for _, item in enumerate(parsed_content[:3], 1):
|
||||||
|
if isinstance(item, str):
|
||||||
|
yield f"- {item}\n"
|
||||||
|
elif isinstance(item, dict) and "text" in item:
|
||||||
|
yield f"- {item['text']}\n"
|
||||||
|
elif isinstance(item, dict) and len(item) > 0:
|
||||||
|
first_value = next(iter(item.values()))
|
||||||
|
if isinstance(first_value, str) and len(first_value) < 100:
|
||||||
|
yield f"- {first_value}\n"
|
||||||
|
|
||||||
|
def _handle_regular_response(turn_response):
|
||||||
for response in turn_response:
|
for response in turn_response:
|
||||||
if hasattr(response.event, "payload"):
|
if hasattr(response.event, "payload"):
|
||||||
print(response.event.payload)
|
print(response.event.payload)
|
||||||
|
@ -144,14 +331,18 @@ def tool_chat_page():
|
||||||
yield response.event.payload.delta.text
|
yield response.event.payload.delta.text
|
||||||
if response.event.payload.event_type == "step_complete":
|
if response.event.payload.event_type == "step_complete":
|
||||||
if response.event.payload.step_details.step_type == "tool_execution":
|
if response.event.payload.step_details.step_type == "tool_execution":
|
||||||
yield " 🛠 "
|
if response.event.payload.step_details.tool_calls:
|
||||||
|
tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name)
|
||||||
|
yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n'
|
||||||
|
else:
|
||||||
|
yield "No tool_calls present in step_details"
|
||||||
else:
|
else:
|
||||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||||
|
|
||||||
with st.chat_message("assistant"):
|
with st.chat_message("assistant"):
|
||||||
response = st.write_stream(response_generator(turn_response))
|
response_content = st.write_stream(response_generator(turn_response))
|
||||||
|
|
||||||
st.session_state.messages.append({"role": "assistant", "content": response})
|
st.session_state.messages.append({"role": "assistant", "content": response_content})
|
||||||
|
|
||||||
|
|
||||||
tool_chat_page()
|
tool_chat_page()
|
||||||
|
|
|
@ -303,6 +303,7 @@ class ChatFormat:
|
||||||
arguments_json=json.dumps(tool_arguments),
|
arguments_json=json.dumps(tool_arguments),
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
|
content = ""
|
||||||
|
|
||||||
return RawMessage(
|
return RawMessage(
|
||||||
role="assistant",
|
role="assistant",
|
||||||
|
|
|
@ -64,7 +64,7 @@ This example passes an image that is smaller than the tile size, to show the til
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
The image depicts a dog standing on a skateboard, with its front paws positioned on the board and its back paws hanging off the back. The dog has a distinctive coat pattern, featuring a white face, brown and black fur, and white paws, and is standing on a skateboard with red wheels, set against a blurred background of a street or alleyway with a teal door and beige wall.<|eot|>
|
The image depicts a dog standing on a skateboard, positioned centrally and facing the camera directly. The dog has a distinctive coat pattern featuring white, black, and brown fur, with floppy ears and a black nose, and is standing on a skateboard with red wheels.<|eot|>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -91,7 +91,7 @@ Here is an example of how to pass an image to the model
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
This image shows a dog standing on a skateboard, with its front paws positioned near the front of the board and its back paws near the back. The dog has a white, black, and orange coat, and is standing on a gray skateboard with red wheels, in front of a blurred background that appears to be a street or alleyway.<|eot|>
|
The image depicts a dog standing on a skateboard, with the dog positioned centrally and facing forward. The dog has a distinctive coat featuring a mix of white, brown, and black fur, and is wearing a collar as it stands on the skateboard, which has red wheels.<|eot|>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -117,7 +117,7 @@ Here is an example of how to pass an image to the model
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
The first image shows a dog standing on a skateboard, while the second image shows a plate of spaghetti with tomato sauce, parmesan cheese, and parsley. The two images are unrelated, with the first image featuring a dog and the second image featuring a food dish, and they do not share any common elements or themes.<|eot|>
|
The first image features a dog standing on a skateboard, while the second image showcases a plate of spaghetti with tomato sauce and cheese. The two images appear to be unrelated, with one depicting a playful scene of a dog on a skateboard and the other presenting a classic Italian dish.<|eom|>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -135,13 +135,44 @@ We are continuing the format for zero shot function calling used in previous ver
|
||||||
```
|
```
|
||||||
<|begin_of_text|><|header_start|>system<|header_end|>
|
<|begin_of_text|><|header_start|>system<|header_end|>
|
||||||
|
|
||||||
You are an expert in composing functions. You are given a question and a set of possible functions.
|
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
|
||||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
1. FUNCTION CALLS:
|
||||||
also point it out. You should only return the function call in tools call sections.
|
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||||
|
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||||
|
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||||
|
Examples:
|
||||||
|
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||||
|
INCORRECT: get_weather(location="New York")
|
||||||
|
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||||
|
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||||
|
|
||||||
|
2. RESPONSE RULES:
|
||||||
|
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||||
|
- For knowledge questions: ONLY output text
|
||||||
|
- For missing parameters: ONLY request the specific missing parameters
|
||||||
|
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||||
|
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||||
|
- NEVER combine text and function calls in the same response
|
||||||
|
- NEVER suggest alternative functions when the requested service is unavailable
|
||||||
|
- NEVER create or invent new functions not listed below
|
||||||
|
|
||||||
|
3. STRICT BOUNDARIES:
|
||||||
|
- ONLY use functions from the list below - no exceptions
|
||||||
|
- NEVER use a function as an alternative to unavailable information
|
||||||
|
- NEVER call functions not present in the function list
|
||||||
|
- NEVER add explanatory text to function calls
|
||||||
|
- NEVER respond with empty brackets
|
||||||
|
- Use proper Python/JSON syntax for function calls
|
||||||
|
- Check the function list carefully before responding
|
||||||
|
|
||||||
|
4. TOOL RESPONSE HANDLING:
|
||||||
|
- When receiving tool responses: provide concise, natural language responses
|
||||||
|
- Don't repeat tool response verbatim
|
||||||
|
- Don't add supplementary information
|
||||||
|
|
||||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
|
||||||
You SHOULD NOT include any other text in the response.
|
|
||||||
|
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
@ -151,9 +182,7 @@ Here is a list of functions in JSON format that you can invoke.
|
||||||
"description": "Get weather info for places",
|
"description": "Get weather info for places",
|
||||||
"parameters": {
|
"parameters": {
|
||||||
"type": "dict",
|
"type": "dict",
|
||||||
"required": [
|
"required": ["city"],
|
||||||
"city"
|
|
||||||
],
|
|
||||||
"properties": {
|
"properties": {
|
||||||
"city": {
|
"city": {
|
||||||
"type": "string",
|
"type": "string",
|
||||||
|
@ -167,7 +196,10 @@ Here is a list of functions in JSON format that you can invoke.
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
<|eot|><|header_start|>user<|header_end|>
|
]
|
||||||
|
|
||||||
|
You can answer general questions or invoke tools when necessary.
|
||||||
|
In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|>
|
||||||
|
|
||||||
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>
|
What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|>
|
||||||
|
|
||||||
|
@ -176,7 +208,7 @@ What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_e
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
[get_weather(city='SF'), get_weather(city='Seattle')]<|eot|>
|
[get_weather(city="San Francisco"), get_weather(city="Seattle")]<|eot|>
|
||||||
```
|
```
|
||||||
|
|
||||||
|
|
||||||
|
@ -273,5 +305,5 @@ Use tools to get latest trending songs<|eot|><|header_start|>assistant<|header_e
|
||||||
|
|
||||||
##### Model Response Format
|
##### Model Response Format
|
||||||
```
|
```
|
||||||
<function=trending_songs>{"n": "10"}</function><|eot|>
|
<function=trending_songs>{"n": 10}</function><|eot|>
|
||||||
```
|
```
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||||
|
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||||
|
PromptTemplate,
|
||||||
|
PromptTemplateGeneratorBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
|
DEFAULT_PROMPT = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||||
|
|
||||||
|
1. FUNCTION CALLS:
|
||||||
|
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||||
|
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||||
|
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||||
|
Examples:
|
||||||
|
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||||
|
INCORRECT: get_weather(location="New York")
|
||||||
|
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||||
|
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||||
|
|
||||||
|
2. RESPONSE RULES:
|
||||||
|
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||||
|
- For knowledge questions: ONLY output text
|
||||||
|
- For missing parameters: ONLY request the specific missing parameters
|
||||||
|
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||||
|
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||||
|
- NEVER combine text and function calls in the same response
|
||||||
|
- NEVER suggest alternative functions when the requested service is unavailable
|
||||||
|
- NEVER create or invent new functions not listed below
|
||||||
|
|
||||||
|
3. STRICT BOUNDARIES:
|
||||||
|
- ONLY use functions from the list below - no exceptions
|
||||||
|
- NEVER use a function as an alternative to unavailable information
|
||||||
|
- NEVER call functions not present in the function list
|
||||||
|
- NEVER add explanatory text to function calls
|
||||||
|
- NEVER respond with empty brackets
|
||||||
|
- Use proper Python/JSON syntax for function calls
|
||||||
|
- Check the function list carefully before responding
|
||||||
|
|
||||||
|
4. TOOL RESPONSE HANDLING:
|
||||||
|
- When receiving tool responses: provide concise, natural language responses
|
||||||
|
- Don't repeat tool response verbatim
|
||||||
|
- Don't add supplementary information
|
||||||
|
|
||||||
|
|
||||||
|
{{ function_description }}
|
||||||
|
""".strip("\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||||
|
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||||
|
return PromptTemplate(
|
||||||
|
system_prompt,
|
||||||
|
{"function_description": self._gen_function_description(custom_tools)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{% for t in tools -%}
|
||||||
|
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set tparams = t.parameters -%}
|
||||||
|
{%- set required_params = [] -%}
|
||||||
|
{%- for name, param in tparams.items() if param.required == true -%}
|
||||||
|
{%- set _ = required_params.append(name) -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{
|
||||||
|
"name": "{{tname}}",
|
||||||
|
"description": "{{tdesc}}",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": {{ required_params | tojson }},
|
||||||
|
"properties": {
|
||||||
|
{%- for name, param in tparams.items() %}
|
||||||
|
"{{name}}": {
|
||||||
|
"type": "{{param.param_type}}",
|
||||||
|
"description": "{{param.description}}"{% if param.default %},
|
||||||
|
"default": "{{param.default}}"{% endif %}
|
||||||
|
}{% if not loop.last %},{% endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}{% if not loop.last %},
|
||||||
|
{% endif -%}
|
||||||
|
{%- endfor %}
|
||||||
|
]
|
||||||
|
|
||||||
|
You can answer general questions or invoke tools when necessary.
|
||||||
|
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.strip("\n"),
|
||||||
|
{"tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
).render()
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get weather info for places",
|
||||||
|
parameters={
|
||||||
|
"city": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The name of the city to get the weather for",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"metric": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
required=False,
|
||||||
|
default="celsius",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
|
@ -9,6 +9,10 @@ from io import BytesIO
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
|
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||||
|
PythonListCustomToolGenerator,
|
||||||
|
)
|
||||||
|
|
||||||
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
from ..datatypes import RawMediaItem, RawMessage, RawTextItem
|
||||||
from ..prompt_format import (
|
from ..prompt_format import (
|
||||||
Llama4UseCase,
|
Llama4UseCase,
|
||||||
|
@ -177,39 +181,9 @@ def usecases(base_model: bool = False) -> List[UseCase | str]:
|
||||||
[
|
[
|
||||||
RawMessage(
|
RawMessage(
|
||||||
role="system",
|
role="system",
|
||||||
content="""You are an expert in composing functions. You are given a question and a set of possible functions.
|
content=PythonListCustomToolGenerator()
|
||||||
Based on the question, you will need to make one or more function/tool calls to achieve the purpose.
|
.gen(PythonListCustomToolGenerator().data_examples()[0])
|
||||||
If none of the function can be used, point it out. If the given question lacks the parameters required by the function,
|
.render(),
|
||||||
also point it out. You should only return the function call in tools call sections.
|
|
||||||
|
|
||||||
If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)]
|
|
||||||
You SHOULD NOT include any other text in the response.
|
|
||||||
|
|
||||||
Here is a list of functions in JSON format that you can invoke.
|
|
||||||
|
|
||||||
[
|
|
||||||
{
|
|
||||||
"name": "get_weather",
|
|
||||||
"description": "Get weather info for places",
|
|
||||||
"parameters": {
|
|
||||||
"type": "dict",
|
|
||||||
"required": [
|
|
||||||
"city"
|
|
||||||
],
|
|
||||||
"properties": {
|
|
||||||
"city": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The name of the city to get the weather for"
|
|
||||||
},
|
|
||||||
"metric": {
|
|
||||||
"type": "string",
|
|
||||||
"description": "The metric for weather. Options are: celsius, fahrenheit",
|
|
||||||
"default": "celsius"
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
}
|
|
||||||
""",
|
|
||||||
),
|
),
|
||||||
RawMessage(
|
RawMessage(
|
||||||
role="user",
|
role="user",
|
||||||
|
|
|
@ -253,7 +253,8 @@ class MetaReferenceInferenceImpl(
|
||||||
def impl():
|
def impl():
|
||||||
stop_reason = None
|
stop_reason = None
|
||||||
|
|
||||||
for token_result in self.generator.completion(request):
|
for token_results in self.generator.completion([request]):
|
||||||
|
token_result = token_results[0]
|
||||||
if token_result.token == tokenizer.eot_id:
|
if token_result.token == tokenizer.eot_id:
|
||||||
stop_reason = StopReason.end_of_turn
|
stop_reason = StopReason.end_of_turn
|
||||||
text = ""
|
text = ""
|
||||||
|
|
|
@ -69,7 +69,10 @@ class CancelSentinel(BaseModel):
|
||||||
|
|
||||||
class TaskRequest(BaseModel):
|
class TaskRequest(BaseModel):
|
||||||
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request
|
||||||
task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]]
|
task: Tuple[
|
||||||
|
str,
|
||||||
|
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
class TaskResponse(BaseModel):
|
class TaskResponse(BaseModel):
|
||||||
|
@ -231,10 +234,10 @@ def worker_process_entrypoint(
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
task = req_gen.send(result)
|
task = req_gen.send(result)
|
||||||
if isinstance(task, str) and task == EndSentinel():
|
if isinstance(task, EndSentinel):
|
||||||
break
|
break
|
||||||
|
|
||||||
assert isinstance(task, TaskRequest)
|
assert isinstance(task, TaskRequest), task
|
||||||
result = model(task.task)
|
result = model(task.task)
|
||||||
except StopIteration:
|
except StopIteration:
|
||||||
break
|
break
|
||||||
|
@ -331,7 +334,10 @@ class ModelParallelProcessGroup:
|
||||||
|
|
||||||
def run_inference(
|
def run_inference(
|
||||||
self,
|
self,
|
||||||
req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]],
|
req: Tuple[
|
||||||
|
str,
|
||||||
|
List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent],
|
||||||
|
],
|
||||||
) -> Generator:
|
) -> Generator:
|
||||||
assert not self.running, "inference already running"
|
assert not self.running, "inference already running"
|
||||||
|
|
||||||
|
|
|
@ -33,6 +33,7 @@ from llama_stack.apis.tools import (
|
||||||
)
|
)
|
||||||
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO
|
||||||
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
from llama_stack.providers.datatypes import ToolsProtocolPrivate
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str
|
||||||
from llama_stack.providers.utils.memory.vector_store import (
|
from llama_stack.providers.utils.memory.vector_store import (
|
||||||
content_from_doc,
|
content_from_doc,
|
||||||
make_overlapped_chunks,
|
make_overlapped_chunks,
|
||||||
|
@ -153,6 +154,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime):
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
picked.append(TextContentItem(text="END of knowledge_search tool results.\n"))
|
||||||
|
picked.append(
|
||||||
|
TextContentItem(
|
||||||
|
text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n',
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
return RAGQueryResult(
|
return RAGQueryResult(
|
||||||
content=picked,
|
content=picked,
|
||||||
|
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
@ -25,4 +25,22 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.agents,
|
Api.agents,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.eval,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="nvidia",
|
||||||
|
pip_packages=[
|
||||||
|
"requests",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.remote.eval.nvidia",
|
||||||
|
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||||
|
),
|
||||||
|
api_dependencies=[
|
||||||
|
Api.datasetio,
|
||||||
|
Api.datasets,
|
||||||
|
Api.scoring,
|
||||||
|
Api.inference,
|
||||||
|
Api.agents,
|
||||||
|
],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -288,4 +288,14 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="watsonx",
|
||||||
|
pip_packages=["ibm_watson_machine_learning"],
|
||||||
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
5
llama_stack/providers/remote/eval/__init__.py
Normal file
5
llama_stack/providers/remote/eval/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
# NVIDIA NeMo Evaluator Eval Provider
|
||||||
|
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
|
||||||
|
|
||||||
|
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
|
||||||
|
|
||||||
|
### Example for register an academic benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /eval/benchmarks
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"benchmark_id": "mmlu",
|
||||||
|
"dataset_id": "",
|
||||||
|
"scoring_functions": [],
|
||||||
|
"metadata": {
|
||||||
|
"type": "mmlu"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for register a custom evaluation
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /eval/benchmarks
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"benchmark_id": "my-custom-benchmark",
|
||||||
|
"dataset_id": "",
|
||||||
|
"scoring_functions": [],
|
||||||
|
"metadata": {
|
||||||
|
"type": "custom",
|
||||||
|
"params": {
|
||||||
|
"parallelism": 8
|
||||||
|
},
|
||||||
|
"tasks": {
|
||||||
|
"qa": {
|
||||||
|
"type": "completion",
|
||||||
|
"params": {
|
||||||
|
"template": {
|
||||||
|
"prompt": "{{prompt}}",
|
||||||
|
"max_tokens": 200
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"dataset": {
|
||||||
|
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"bleu": {
|
||||||
|
"type": "bleu",
|
||||||
|
"params": {
|
||||||
|
"references": [
|
||||||
|
"{{ideal_response}}"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for triggering a benchmark/custom evaluation
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /eval/benchmarks/{benchmark_id}/jobs
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"benchmark_id": "my-custom-benchmark",
|
||||||
|
"benchmark_config": {
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "model",
|
||||||
|
"model": "meta-llama/Llama3.1-8B-Instruct",
|
||||||
|
"sampling_params": {
|
||||||
|
"max_tokens": 100,
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scoring_params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Response example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"job_id": "eval-1234",
|
||||||
|
"status": "in_progress"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for getting the status of a job
|
||||||
|
```
|
||||||
|
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
Response example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"job_id": "eval-1234",
|
||||||
|
"status": "in_progress"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for cancelling a job
|
||||||
|
```
|
||||||
|
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for getting the results
|
||||||
|
```
|
||||||
|
GET /eval/benchmarks/{benchmark_id}/results
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"generations": [],
|
||||||
|
"scores": {
|
||||||
|
"{benchmark_id}": {
|
||||||
|
"score_rows": [],
|
||||||
|
"aggregated_results": {
|
||||||
|
"tasks": {},
|
||||||
|
"groups": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
from .config import NVIDIAEvalConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(
|
||||||
|
config: NVIDIAEvalConfig,
|
||||||
|
deps: Dict[Api, Any],
|
||||||
|
):
|
||||||
|
from .eval import NVIDIAEvalImpl
|
||||||
|
|
||||||
|
impl = NVIDIAEvalImpl(
|
||||||
|
config,
|
||||||
|
deps[Api.datasetio],
|
||||||
|
deps[Api.datasets],
|
||||||
|
deps[Api.scoring],
|
||||||
|
deps[Api.inference],
|
||||||
|
deps[Api.agents],
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]
|
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class NVIDIAEvalConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
|
||||||
|
"""
|
||||||
|
|
||||||
|
evaluator_url: str = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
|
||||||
|
description="The url for accessing the evaluator service",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
|
||||||
|
}
|
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.scoring import Scoring, ScoringResult
|
||||||
|
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
|
from .....apis.common.job_types import Job, JobStatus
|
||||||
|
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||||
|
from .config import NVIDIAEvalConfig
|
||||||
|
|
||||||
|
DEFAULT_NAMESPACE = "nvidia"
|
||||||
|
|
||||||
|
|
||||||
|
class NVIDIAEvalImpl(
|
||||||
|
Eval,
|
||||||
|
BenchmarksProtocolPrivate,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: NVIDIAEvalConfig,
|
||||||
|
datasetio_api: DatasetIO,
|
||||||
|
datasets_api: Datasets,
|
||||||
|
scoring_api: Scoring,
|
||||||
|
inference_api: Inference,
|
||||||
|
agents_api: Agents,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets_api
|
||||||
|
self.scoring_api = scoring_api
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.agents_api = agents_api
|
||||||
|
|
||||||
|
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def _evaluator_get(self, path):
|
||||||
|
"""Helper for making GET requests to the evaluator service."""
|
||||||
|
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def _evaluator_post(self, path, data):
|
||||||
|
"""Helper for making POST requests to the evaluator service."""
|
||||||
|
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||||
|
"""Register a benchmark as an evaluation configuration."""
|
||||||
|
await self._evaluator_post(
|
||||||
|
"/v1/evaluation/configs",
|
||||||
|
{
|
||||||
|
"namespace": DEFAULT_NAMESPACE,
|
||||||
|
"name": task_def.benchmark_id,
|
||||||
|
# metadata is copied to request body as-is
|
||||||
|
**task_def.metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_eval(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> Job:
|
||||||
|
"""Run an evaluation job for a benchmark."""
|
||||||
|
model = (
|
||||||
|
benchmark_config.eval_candidate.model
|
||||||
|
if benchmark_config.eval_candidate.type == "model"
|
||||||
|
else benchmark_config.eval_candidate.config.model
|
||||||
|
)
|
||||||
|
nvidia_model = self.get_provider_model_id(model) or model
|
||||||
|
|
||||||
|
result = await self._evaluator_post(
|
||||||
|
"/v1/evaluation/jobs",
|
||||||
|
{
|
||||||
|
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
|
||||||
|
"target": {"type": "model", "model": nvidia_model},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return Job(job_id=result["id"], status=JobStatus.in_progress)
|
||||||
|
|
||||||
|
async def evaluate_rows(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
scoring_functions: List[str],
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||||
|
"""Get the status of an evaluation job.
|
||||||
|
|
||||||
|
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
|
||||||
|
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
|
||||||
|
"""
|
||||||
|
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
|
||||||
|
result_status = result["status"]
|
||||||
|
|
||||||
|
job_status = JobStatus.failed
|
||||||
|
if result_status in ["created", "pending"]:
|
||||||
|
job_status = JobStatus.scheduled
|
||||||
|
elif result_status in ["running"]:
|
||||||
|
job_status = JobStatus.in_progress
|
||||||
|
elif result_status in ["completed"]:
|
||||||
|
job_status = JobStatus.completed
|
||||||
|
elif result_status in ["cancelled"]:
|
||||||
|
job_status = JobStatus.cancelled
|
||||||
|
|
||||||
|
return Job(job_id=job_id, status=job_status)
|
||||||
|
|
||||||
|
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||||
|
"""Cancel the evaluation job."""
|
||||||
|
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
|
||||||
|
|
||||||
|
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||||
|
"""Returns the results of the evaluation job."""
|
||||||
|
|
||||||
|
job = await self.job_status(benchmark_id, job_id)
|
||||||
|
status = job.status
|
||||||
|
if not status or status != JobStatus.completed:
|
||||||
|
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
|
||||||
|
|
||||||
|
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
|
||||||
|
|
||||||
|
return EvaluateResponse(
|
||||||
|
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
|
||||||
|
generations=[],
|
||||||
|
scores={
|
||||||
|
benchmark_id: ScoringResult(
|
||||||
|
score_rows=[],
|
||||||
|
aggregated_results=result,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
|
@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
|
||||||
default=60,
|
default=60,
|
||||||
description="Timeout for the HTTP requests",
|
description="Timeout for the HTTP requests",
|
||||||
)
|
)
|
||||||
|
append_api_version: bool = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||||
|
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
||||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||||
|
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
|
from llama_stack.providers.utils.inference import (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
base_url = f"{self._config.url}/v1"
|
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||||
|
|
||||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||||
base_url = special_model_urls[provider_model_id]
|
base_url = special_model_urls[provider_model_id]
|
||||||
|
|
||||||
return _get_client_for_base_url(base_url)
|
return _get_client_for_base_url(base_url)
|
||||||
|
|
||||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||||
|
@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
"""
|
||||||
|
Allow non-llama model registration.
|
||||||
|
|
||||||
|
Non-llama model registration: API Catalogue models, post-training models, etc.
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
client.models.register(
|
||||||
|
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
||||||
|
"""
|
||||||
|
if model.model_type == ModelType.embedding:
|
||||||
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
|
provider_resource_id = model.provider_resource_id
|
||||||
|
else:
|
||||||
|
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
|
if provider_resource_id:
|
||||||
|
model.provider_resource_id = provider_resource_id
|
||||||
|
else:
|
||||||
|
llama_model = model.metadata.get("llama_model")
|
||||||
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
|
if existing_llama_model:
|
||||||
|
if existing_llama_model != llama_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# not llama model
|
||||||
|
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||||
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
||||||
|
return model
|
||||||
|
|
|
@ -8,7 +8,6 @@
|
||||||
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
import httpx
|
import httpx
|
||||||
from ollama import AsyncClient
|
|
||||||
from openai import AsyncOpenAI
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
from llama_stack.apis.common.content_types import (
|
from llama_stack.apis.common.content_types import (
|
||||||
|
@ -73,6 +72,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
interleaved_content_as_str,
|
interleaved_content_as_str,
|
||||||
request_has_media,
|
request_has_media,
|
||||||
)
|
)
|
||||||
|
from ollama import AsyncClient # type: ignore[attr-defined]
|
||||||
|
|
||||||
from .models import model_entries
|
from .models import model_entries
|
||||||
|
|
||||||
|
|
|
@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
if self._client:
|
if self._client:
|
||||||
await self._client.close()
|
# Together client has no close method, so just set to None
|
||||||
self._client = None
|
self._client = None
|
||||||
|
if self._openai_client:
|
||||||
|
await self._openai_client.close()
|
||||||
|
self._openai_client = None
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
|
@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi
|
||||||
top_p=top_p,
|
top_p=top_p,
|
||||||
user=user,
|
user=user,
|
||||||
)
|
)
|
||||||
if params.get("stream", True):
|
if params.get("stream", False):
|
||||||
return self._stream_openai_chat_completion(params)
|
return self._stream_openai_chat_completion(params)
|
||||||
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
|
|
@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
self.client = None
|
self.client = None
|
||||||
|
|
||||||
async def initialize(self) -> None:
|
async def initialize(self) -> None:
|
||||||
log.info(f"Initializing VLLM client with base_url={self.config.url}")
|
pass
|
||||||
self.client = AsyncOpenAI(
|
|
||||||
base_url=self.config.url,
|
|
||||||
api_key=self.config.api_token,
|
|
||||||
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
|
||||||
)
|
|
||||||
|
|
||||||
async def shutdown(self) -> None:
|
async def shutdown(self) -> None:
|
||||||
pass
|
pass
|
||||||
|
@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
raise ValueError("Model store not set")
|
raise ValueError("Model store not set")
|
||||||
return await self.model_store.get_model(model_id)
|
return await self.model_store.get_model(model_id)
|
||||||
|
|
||||||
|
def _lazy_initialize_client(self):
|
||||||
|
if self.client is not None:
|
||||||
|
return
|
||||||
|
|
||||||
|
log.info(f"Initializing vLLM client with base_url={self.config.url}")
|
||||||
|
self.client = self._create_client()
|
||||||
|
|
||||||
|
def _create_client(self):
|
||||||
|
return AsyncOpenAI(
|
||||||
|
base_url=self.config.url,
|
||||||
|
api_key=self.config.api_token,
|
||||||
|
http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False),
|
||||||
|
)
|
||||||
|
|
||||||
async def completion(
|
async def completion(
|
||||||
self,
|
self,
|
||||||
model_id: str,
|
model_id: str,
|
||||||
|
@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
stream: Optional[bool] = False,
|
stream: Optional[bool] = False,
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]:
|
||||||
|
self._lazy_initialize_client()
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
logprobs: Optional[LogProbConfig] = None,
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
tool_config: Optional[ToolConfig] = None,
|
tool_config: Optional[ToolConfig] = None,
|
||||||
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]:
|
||||||
|
self._lazy_initialize_client()
|
||||||
if sampling_params is None:
|
if sampling_params is None:
|
||||||
sampling_params = SamplingParams()
|
sampling_params = SamplingParams()
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
@ -357,12 +368,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
yield chunk
|
yield chunk
|
||||||
|
|
||||||
async def register_model(self, model: Model) -> Model:
|
async def register_model(self, model: Model) -> Model:
|
||||||
assert self.client is not None
|
# register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet.
|
||||||
|
# self.client should only be created after the initialization is complete to avoid asyncio cross-context errors.
|
||||||
|
# Changing this may lead to unpredictable behavior.
|
||||||
|
client = self._create_client() if self.client is None else self.client
|
||||||
try:
|
try:
|
||||||
model = await self.register_helper.register_model(model)
|
model = await self.register_helper.register_model(model)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass # Ignore statically unknown model, will check live listing
|
pass # Ignore statically unknown model, will check live listing
|
||||||
res = await self.client.models.list()
|
res = await client.models.list()
|
||||||
available_models = [m.id async for m in res]
|
available_models = [m.id async for m in res]
|
||||||
if model.provider_resource_id not in available_models:
|
if model.provider_resource_id not in available_models:
|
||||||
raise ValueError(
|
raise ValueError(
|
||||||
|
@ -413,6 +427,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
output_dimension: Optional[int] = None,
|
output_dimension: Optional[int] = None,
|
||||||
task_type: Optional[EmbeddingTaskType] = None,
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
) -> EmbeddingsResponse:
|
) -> EmbeddingsResponse:
|
||||||
|
self._lazy_initialize_client()
|
||||||
assert self.client is not None
|
assert self.client is not None
|
||||||
model = await self._get_model(model_id)
|
model = await self._get_model(model_id)
|
||||||
|
|
||||||
|
@ -452,6 +467,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
guided_choice: Optional[List[str]] = None,
|
guided_choice: Optional[List[str]] = None,
|
||||||
prompt_logprobs: Optional[int] = None,
|
prompt_logprobs: Optional[int] = None,
|
||||||
) -> OpenAICompletion:
|
) -> OpenAICompletion:
|
||||||
|
self._lazy_initialize_client()
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
|
|
||||||
extra_body: Dict[str, Any] = {}
|
extra_body: Dict[str, Any] = {}
|
||||||
|
@ -508,6 +524,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate):
|
||||||
top_p: Optional[float] = None,
|
top_p: Optional[float] = None,
|
||||||
user: Optional[str] = None,
|
user: Optional[str] = None,
|
||||||
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
|
self._lazy_initialize_client()
|
||||||
model_obj = await self._get_model(model)
|
model_obj = await self._get_model(model)
|
||||||
params = await prepare_openai_completion_params(
|
params = await prepare_openai_completion_params(
|
||||||
model=model_obj.provider_resource_id,
|
model=model_obj.provider_resource_id,
|
||||||
|
|
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import WatsonXConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||||
|
from .watsonx import WatsonXInferenceAdapter
|
||||||
|
|
||||||
|
if not isinstance(config, WatsonXConfig):
|
||||||
|
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||||
|
adapter = WatsonXInferenceAdapter(config)
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXProviderDataValidator(BaseModel):
|
||||||
|
url: str
|
||||||
|
api_key: str
|
||||||
|
project_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class WatsonXConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
|
description="A base url for accessing the watsonx.ai",
|
||||||
|
)
|
||||||
|
api_key: Optional[SecretStr] = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||||
|
description="The watsonx API key, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
project_id: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||||
|
description="The Project ID key, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
timeout: int = Field(
|
||||||
|
default=60,
|
||||||
|
description="Timeout for the HTTP requests",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
|
||||||
|
"api_key": "${env.WATSONX_API_KEY:}",
|
||||||
|
"project_id": "${env.WATSONX_PROJECT_ID:}",
|
||||||
|
}
|
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
||||||
|
|
||||||
|
MODEL_ENTRIES = [
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-3-70b-instruct",
|
||||||
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-2-13b-chat",
|
||||||
|
CoreModelId.llama2_13b.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-1-70b-instruct",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-1-8b-instruct",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-1b-instruct",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-3b-instruct",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-guard-3-11b-vision",
|
||||||
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
|
),
|
||||||
|
]
|
378
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
378
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
|
@ -0,0 +1,378 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union
|
||||||
|
|
||||||
|
from ibm_watson_machine_learning.foundation_models import Model
|
||||||
|
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||||
|
from openai import AsyncOpenAI
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.apis.inference.inference import (
|
||||||
|
GreedySamplingStrategy,
|
||||||
|
OpenAIChatCompletion,
|
||||||
|
OpenAIChatCompletionChunk,
|
||||||
|
OpenAICompletion,
|
||||||
|
OpenAIMessageParam,
|
||||||
|
OpenAIResponseFormatParam,
|
||||||
|
TopKSamplingStrategy,
|
||||||
|
TopPSamplingStrategy,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
prepare_openai_completion_params,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
process_completion_response,
|
||||||
|
process_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
completion_request_to_prompt,
|
||||||
|
request_has_media,
|
||||||
|
)
|
||||||
|
|
||||||
|
from . import WatsonXConfig
|
||||||
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
def __init__(self, config: WatsonXConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||||
|
|
||||||
|
print(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||||
|
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
self._project_id = self._config.project_id
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedContent,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
def _get_client(self, model_id) -> Model:
|
||||||
|
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||||
|
config_url = self._config.url
|
||||||
|
project_id = self._config.project_id
|
||||||
|
credentials = {"url": config_url, "apikey": config_api_key}
|
||||||
|
|
||||||
|
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
||||||
|
|
||||||
|
def _get_openai_client(self) -> AsyncOpenAI:
|
||||||
|
if not self._openai_client:
|
||||||
|
self._openai_client = AsyncOpenAI(
|
||||||
|
base_url=f"{self._config.url}/openai/v1",
|
||||||
|
api_key=self._config.api_key,
|
||||||
|
)
|
||||||
|
return self._openai_client
|
||||||
|
|
||||||
|
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
r = self._get_client(request.model).generate(**params)
|
||||||
|
choices = []
|
||||||
|
if "results" in r:
|
||||||
|
for result in r["results"]:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||||
|
text=result["generated_text"],
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=choices,
|
||||||
|
)
|
||||||
|
return process_completion_response(response)
|
||||||
|
|
||||||
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
|
||||||
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
s = self._get_client(request.model).generate_text_stream(**params)
|
||||||
|
for chunk in s:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
text=chunk,
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
|
async for chunk in process_completion_stream_response(stream):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
r = self._get_client(request.model).generate(**params)
|
||||||
|
choices = []
|
||||||
|
if "results" in r:
|
||||||
|
for result in r["results"]:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||||
|
text=result["generated_text"],
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=choices,
|
||||||
|
)
|
||||||
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
model_id = request.model
|
||||||
|
|
||||||
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
|
async def _to_async_generator():
|
||||||
|
s = self._get_client(model_id).generate_text_stream(**params)
|
||||||
|
for chunk in s:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
text=chunk,
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
|
input_dict = {"params": {}}
|
||||||
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
|
else:
|
||||||
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
if request.sampling_params:
|
||||||
|
if request.sampling_params.strategy:
|
||||||
|
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
||||||
|
if request.sampling_params.max_tokens:
|
||||||
|
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||||
|
if request.sampling_params.repetition_penalty:
|
||||||
|
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||||
|
|
||||||
|
if isinstance(request.sampling_params.strategy, TopPSamplingStrategy):
|
||||||
|
input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p
|
||||||
|
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature
|
||||||
|
if isinstance(request.sampling_params.strategy, TopKSamplingStrategy):
|
||||||
|
input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k
|
||||||
|
if isinstance(request.sampling_params.strategy, GreedySamplingStrategy):
|
||||||
|
input_dict["params"][GenParams.TEMPERATURE] = 0.0
|
||||||
|
|
||||||
|
input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"]
|
||||||
|
|
||||||
|
params = {
|
||||||
|
**input_dict,
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
raise NotImplementedError("embedding is not supported for watsonx")
|
||||||
|
|
||||||
|
async def openai_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
prompt: Union[str, List[str], List[int], List[List[int]]],
|
||||||
|
best_of: Optional[int] = None,
|
||||||
|
echo: Optional[bool] = None,
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
guided_choice: Optional[List[str]] = None,
|
||||||
|
prompt_logprobs: Optional[int] = None,
|
||||||
|
) -> OpenAICompletion:
|
||||||
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
params = await prepare_openai_completion_params(
|
||||||
|
model=model_obj.provider_resource_id,
|
||||||
|
prompt=prompt,
|
||||||
|
best_of=best_of,
|
||||||
|
echo=echo,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
return await self._get_openai_client().completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
async def openai_chat_completion(
|
||||||
|
self,
|
||||||
|
model: str,
|
||||||
|
messages: List[OpenAIMessageParam],
|
||||||
|
frequency_penalty: Optional[float] = None,
|
||||||
|
function_call: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
|
functions: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
logit_bias: Optional[Dict[str, float]] = None,
|
||||||
|
logprobs: Optional[bool] = None,
|
||||||
|
max_completion_tokens: Optional[int] = None,
|
||||||
|
max_tokens: Optional[int] = None,
|
||||||
|
n: Optional[int] = None,
|
||||||
|
parallel_tool_calls: Optional[bool] = None,
|
||||||
|
presence_penalty: Optional[float] = None,
|
||||||
|
response_format: Optional[OpenAIResponseFormatParam] = None,
|
||||||
|
seed: Optional[int] = None,
|
||||||
|
stop: Optional[Union[str, List[str]]] = None,
|
||||||
|
stream: Optional[bool] = None,
|
||||||
|
stream_options: Optional[Dict[str, Any]] = None,
|
||||||
|
temperature: Optional[float] = None,
|
||||||
|
tool_choice: Optional[Union[str, Dict[str, Any]]] = None,
|
||||||
|
tools: Optional[List[Dict[str, Any]]] = None,
|
||||||
|
top_logprobs: Optional[int] = None,
|
||||||
|
top_p: Optional[float] = None,
|
||||||
|
user: Optional[str] = None,
|
||||||
|
) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]:
|
||||||
|
model_obj = await self.model_store.get_model(model)
|
||||||
|
params = await prepare_openai_completion_params(
|
||||||
|
model=model_obj.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
frequency_penalty=frequency_penalty,
|
||||||
|
function_call=function_call,
|
||||||
|
functions=functions,
|
||||||
|
logit_bias=logit_bias,
|
||||||
|
logprobs=logprobs,
|
||||||
|
max_completion_tokens=max_completion_tokens,
|
||||||
|
max_tokens=max_tokens,
|
||||||
|
n=n,
|
||||||
|
parallel_tool_calls=parallel_tool_calls,
|
||||||
|
presence_penalty=presence_penalty,
|
||||||
|
response_format=response_format,
|
||||||
|
seed=seed,
|
||||||
|
stop=stop,
|
||||||
|
stream=stream,
|
||||||
|
stream_options=stream_options,
|
||||||
|
temperature=temperature,
|
||||||
|
tool_choice=tool_choice,
|
||||||
|
tools=tools,
|
||||||
|
top_logprobs=top_logprobs,
|
||||||
|
top_p=top_p,
|
||||||
|
user=user,
|
||||||
|
)
|
||||||
|
if params.get("stream", False):
|
||||||
|
return self._stream_openai_chat_completion(params)
|
||||||
|
return await self._get_openai_client().chat.completions.create(**params) # type: ignore
|
||||||
|
|
||||||
|
async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator:
|
||||||
|
# watsonx.ai sometimes adds usage data to the stream
|
||||||
|
include_usage = False
|
||||||
|
if params.get("stream_options", None):
|
||||||
|
include_usage = params["stream_options"].get("include_usage", False)
|
||||||
|
stream = await self._get_openai_client().chat.completions.create(**params)
|
||||||
|
|
||||||
|
seen_finish_reason = False
|
||||||
|
async for chunk in stream:
|
||||||
|
# Final usage chunk with no choices that the user didn't request, so discard
|
||||||
|
if not include_usage and seen_finish_reason and len(chunk.choices) == 0:
|
||||||
|
break
|
||||||
|
yield chunk
|
||||||
|
for choice in chunk.choices:
|
||||||
|
if choice.finish_reason:
|
||||||
|
seen_finish_reason = True
|
||||||
|
break
|
|
@ -36,7 +36,6 @@ import os
|
||||||
|
|
||||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
|
|
||||||
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||||
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||||
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
||||||
|
@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id")
|
||||||
|
|
||||||
### Inference with the fine-tuned model
|
### Inference with the fine-tuned model
|
||||||
|
|
||||||
|
#### 1. Register the model
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
|
||||||
|
client.models.register(
|
||||||
|
model_id="test-example-model@v1",
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id="test-example-model@v1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Inference with the fine-tuned model
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = client.inference.completion(
|
response = client.inference.completion(
|
||||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||||
|
|
|
@ -67,13 +67,18 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
self.timeout = aiohttp.ClientTimeout(total=config.timeout)
|
||||||
# TODO: filter by available models based on /config endpoint
|
# TODO: filter by available models based on /config endpoint
|
||||||
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES)
|
||||||
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
self.session = None
|
||||||
self.customizer_url = config.customizer_url
|
|
||||||
|
|
||||||
|
self.customizer_url = config.customizer_url
|
||||||
if not self.customizer_url:
|
if not self.customizer_url:
|
||||||
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2)
|
||||||
self.customizer_url = "http://nemo.test"
|
self.customizer_url = "http://nemo.test"
|
||||||
|
|
||||||
|
async def _get_session(self) -> aiohttp.ClientSession:
|
||||||
|
if self.session is None or self.session.closed:
|
||||||
|
self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout)
|
||||||
|
return self.session
|
||||||
|
|
||||||
async def _make_request(
|
async def _make_request(
|
||||||
self,
|
self,
|
||||||
method: str,
|
method: str,
|
||||||
|
@ -94,8 +99,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
if json and "Content-Type" not in request_headers:
|
if json and "Content-Type" not in request_headers:
|
||||||
request_headers["Content-Type"] = "application/json"
|
request_headers["Content-Type"] = "application/json"
|
||||||
|
|
||||||
|
session = await self._get_session()
|
||||||
for _ in range(self.config.max_retries):
|
for _ in range(self.config.max_retries):
|
||||||
async with self.session.request(method, url, params=params, json=json, **kwargs) as response:
|
async with session.request(method, url, params=params, json=json, **kwargs) as response:
|
||||||
if response.status >= 400:
|
if response.status >= 400:
|
||||||
error_data = await response.json()
|
error_data = await response.json()
|
||||||
raise Exception(f"API request failed: {error_data}")
|
raise Exception(f"API request failed: {error_data}")
|
||||||
|
@ -122,8 +128,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
jobs = []
|
jobs = []
|
||||||
for job in response.get("data", []):
|
for job in response.get("data", []):
|
||||||
job_id = job.pop("id")
|
job_id = job.pop("id")
|
||||||
job_status = job.pop("status", "unknown").lower()
|
job_status = job.pop("status", "scheduled").lower()
|
||||||
mapped_status = STATUS_MAPPING.get(job_status, "unknown")
|
mapped_status = STATUS_MAPPING.get(job_status, "scheduled")
|
||||||
|
|
||||||
# Convert string timestamps to datetime objects
|
# Convert string timestamps to datetime objects
|
||||||
created_at = (
|
created_at = (
|
||||||
|
@ -177,7 +183,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
)
|
)
|
||||||
|
|
||||||
api_status = response.pop("status").lower()
|
api_status = response.pop("status").lower()
|
||||||
mapped_status = STATUS_MAPPING.get(api_status, "unknown")
|
mapped_status = STATUS_MAPPING.get(api_status, "scheduled")
|
||||||
|
|
||||||
return NvidiaPostTrainingJobStatusResponse(
|
return NvidiaPostTrainingJobStatusResponse(
|
||||||
status=JobStatus(mapped_status),
|
status=JobStatus(mapped_status),
|
||||||
|
@ -239,6 +245,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
Supported models:
|
Supported models:
|
||||||
- meta/llama-3.1-8b-instruct
|
- meta/llama-3.1-8b-instruct
|
||||||
|
- meta/llama-3.2-1b-instruct
|
||||||
|
|
||||||
Supported algorithm configs:
|
Supported algorithm configs:
|
||||||
- LoRA, SFT
|
- LoRA, SFT
|
||||||
|
@ -284,10 +291,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
- LoRA config:
|
- LoRA config:
|
||||||
## NeMo customizer specific LoRA parameters
|
## NeMo customizer specific LoRA parameters
|
||||||
- adapter_dim: int - Adapter dimension
|
|
||||||
Default: 8 (supports powers of 2)
|
|
||||||
- adapter_dropout: float - Adapter dropout
|
|
||||||
Default: None (0.0-1.0)
|
|
||||||
- alpha: int - Scaling factor for the LoRA update
|
- alpha: int - Scaling factor for the LoRA update
|
||||||
Default: 16
|
Default: 16
|
||||||
Note:
|
Note:
|
||||||
|
@ -297,7 +300,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
User is informed about unsupported parameters via warnings.
|
User is informed about unsupported parameters via warnings.
|
||||||
"""
|
"""
|
||||||
# Map model to nvidia model name
|
# Map model to nvidia model name
|
||||||
# ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models
|
# See `_MODEL_ENTRIES` for supported models
|
||||||
nvidia_model = self.get_provider_model_id(model)
|
nvidia_model = self.get_provider_model_id(model)
|
||||||
|
|
||||||
# Check for unsupported method parameters
|
# Check for unsupported method parameters
|
||||||
|
@ -330,7 +333,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
},
|
},
|
||||||
"data_config": {"dataset_id", "batch_size"},
|
"data_config": {"dataset_id", "batch_size"},
|
||||||
"optimizer_config": {"lr", "weight_decay"},
|
"optimizer_config": {"lr", "weight_decay"},
|
||||||
"lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"},
|
"lora_config": {"type", "alpha"},
|
||||||
}
|
}
|
||||||
|
|
||||||
# Validate all parameters at once
|
# Validate all parameters at once
|
||||||
|
@ -389,16 +392,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper):
|
||||||
|
|
||||||
# Handle LoRA-specific configuration
|
# Handle LoRA-specific configuration
|
||||||
if algorithm_config:
|
if algorithm_config:
|
||||||
if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA":
|
if algorithm_config.type == "LoRA":
|
||||||
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config")
|
||||||
job_config["hyperparameters"]["lora"] = {
|
job_config["hyperparameters"]["lora"] = {
|
||||||
k: v
|
k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None
|
||||||
for k, v in {
|
|
||||||
"adapter_dim": algorithm_config.get("adapter_dim"),
|
|
||||||
"alpha": algorithm_config.get("alpha"),
|
|
||||||
"adapter_dropout": algorithm_config.get("adapter_dropout"),
|
|
||||||
}.items()
|
|
||||||
if v is not None
|
|
||||||
}
|
}
|
||||||
else:
|
else:
|
||||||
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}")
|
||||||
|
|
|
@ -524,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
||||||
else:
|
else:
|
||||||
content = [await _convert_content(message.content)]
|
content = [await _convert_content(message.content)]
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
result["tool_calls"] = []
|
||||||
|
for tc in message.tool_calls:
|
||||||
|
result["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"id": tc.call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.tool_name,
|
||||||
|
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class UnparseableToolCall(BaseModel):
|
class UnparseableToolCall(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
SystemDefaultGenerator,
|
SystemDefaultGenerator,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
|
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||||
|
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
|
||||||
elif model.model_family in (
|
elif model.model_family in (
|
||||||
ModelFamily.llama3_2,
|
ModelFamily.llama3_2,
|
||||||
ModelFamily.llama3_3,
|
ModelFamily.llama3_3,
|
||||||
ModelFamily.llama4,
|
|
||||||
):
|
):
|
||||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
# llama3.2, llama3.3 follow the same tool prompt format
|
||||||
messages = augment_messages_for_tools_llama_3_2(request)
|
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
||||||
|
elif model.model_family == ModelFamily.llama4:
|
||||||
|
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
||||||
else:
|
else:
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
|
|
||||||
|
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools_llama_3_2(
|
def augment_messages_for_tools_llama(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
|
custom_tool_prompt_generator,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
existing_messages = request.messages
|
existing_messages = request.messages
|
||||||
existing_system_message = None
|
existing_system_message = None
|
||||||
|
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||||
system_prompt = existing_system_message.content
|
system_prompt = existing_system_message.content
|
||||||
|
|
||||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
sys_content += tool_template.render()
|
||||||
sys_content += "\n"
|
sys_content += "\n"
|
||||||
|
|
|
@ -394,12 +394,10 @@
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
"chardet",
|
"chardet",
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
@ -411,7 +409,6 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -419,7 +416,6 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn"
|
"uvicorn"
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
|
@ -759,5 +755,41 @@
|
||||||
"vllm",
|
"vllm",
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
|
"watsonx": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"datasets",
|
||||||
|
"emoji",
|
||||||
|
"faiss-cpu",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"ibm_watson_machine_learning",
|
||||||
|
"langdetect",
|
||||||
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pymongo",
|
||||||
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"tree_sitter",
|
||||||
|
"uvicorn"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,7 @@ LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
|
@ -82,6 +83,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
version: '2'
|
version: '2'
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
description: Use NVIDIA NIM for running LLM inference and safety
|
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::nvidia
|
- remote::nvidia
|
||||||
|
@ -13,7 +13,7 @@ distribution_spec:
|
||||||
telemetry:
|
telemetry:
|
||||||
- inline::meta-reference
|
- inline::meta-reference
|
||||||
eval:
|
eval:
|
||||||
- inline::meta-reference
|
- remote::nvidia
|
||||||
post_training:
|
post_training:
|
||||||
- remote::nvidia
|
- remote::nvidia
|
||||||
datasetio:
|
datasetio:
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||||
|
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||||
|
@ -20,7 +21,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"safety": ["remote::nvidia"],
|
"safety": ["remote::nvidia"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
"telemetry": ["inline::meta-reference"],
|
"telemetry": ["inline::meta-reference"],
|
||||||
"eval": ["inline::meta-reference"],
|
"eval": ["remote::nvidia"],
|
||||||
"post_training": ["remote::nvidia"],
|
"post_training": ["remote::nvidia"],
|
||||||
"datasetio": ["inline::localfs"],
|
"datasetio": ["inline::localfs"],
|
||||||
"scoring": ["inline::basic"],
|
"scoring": ["inline::basic"],
|
||||||
|
@ -37,6 +38,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::nvidia",
|
provider_type="remote::nvidia",
|
||||||
config=NVIDIASafetyConfig.sample_run_config(),
|
config=NVIDIASafetyConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
eval_provider = Provider(
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
|
config=NVIDIAEvalConfig.sample_run_config(),
|
||||||
|
)
|
||||||
inference_model = ModelInput(
|
inference_model = ModelInput(
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="nvidia",
|
provider_id="nvidia",
|
||||||
|
@ -60,7 +66,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="nvidia",
|
name="nvidia",
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
description="Use NVIDIA NIM for running LLM inference and safety",
|
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
|
||||||
container_image=None,
|
container_image=None,
|
||||||
template_path=Path(__file__).parent / "doc_template.md",
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
providers=providers,
|
providers=providers,
|
||||||
|
@ -69,6 +75,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider],
|
||||||
|
"eval": [eval_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models,
|
default_models=default_models,
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
|
@ -78,7 +85,8 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"inference": [
|
"inference": [
|
||||||
inference_provider,
|
inference_provider,
|
||||||
safety_provider,
|
safety_provider,
|
||||||
]
|
],
|
||||||
|
"eval": [eval_provider],
|
||||||
},
|
},
|
||||||
default_models=[inference_model, safety_model],
|
default_models=[inference_model, safety_model],
|
||||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
||||||
|
@ -90,19 +98,15 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"",
|
"",
|
||||||
"NVIDIA API Key",
|
"NVIDIA API Key",
|
||||||
),
|
),
|
||||||
## Nemo Customizer related variables
|
"NVIDIA_APPEND_API_VERSION": (
|
||||||
"NVIDIA_USER_ID": (
|
"True",
|
||||||
"llama-stack-user",
|
"Whether to append the API version to the base_url",
|
||||||
"NVIDIA User ID",
|
|
||||||
),
|
),
|
||||||
|
## Nemo Customizer related variables
|
||||||
"NVIDIA_DATASET_NAMESPACE": (
|
"NVIDIA_DATASET_NAMESPACE": (
|
||||||
"default",
|
"default",
|
||||||
"NVIDIA Dataset Namespace",
|
"NVIDIA Dataset Namespace",
|
||||||
),
|
),
|
||||||
"NVIDIA_ACCESS_POLICIES": (
|
|
||||||
"{}",
|
|
||||||
"NVIDIA Access Policies",
|
|
||||||
),
|
|
||||||
"NVIDIA_PROJECT_ID": (
|
"NVIDIA_PROJECT_ID": (
|
||||||
"test-project",
|
"test-project",
|
||||||
"NVIDIA Project ID",
|
"NVIDIA Project ID",
|
||||||
|
@ -119,6 +123,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"http://0.0.0.0:7331",
|
"http://0.0.0.0:7331",
|
||||||
"URL for the NeMo Guardrails Service",
|
"URL for the NeMo Guardrails Service",
|
||||||
),
|
),
|
||||||
|
"NVIDIA_EVALUATOR_URL": (
|
||||||
|
"http://0.0.0.0:7331",
|
||||||
|
"URL for the NeMo Evaluator Service",
|
||||||
|
),
|
||||||
"INFERENCE_MODEL": (
|
"INFERENCE_MODEL": (
|
||||||
"Llama3.1-8B-Instruct",
|
"Llama3.1-8B-Instruct",
|
||||||
"Inference model",
|
"Inference model",
|
||||||
|
|
|
@ -18,6 +18,7 @@ providers:
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||||
api_key: ${env.NVIDIA_API_KEY:}
|
api_key: ${env.NVIDIA_API_KEY:}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
|
@ -53,13 +54,10 @@ providers:
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||||
eval:
|
eval:
|
||||||
- provider_id: meta-reference
|
- provider_id: nvidia
|
||||||
provider_type: inline::meta-reference
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
kvstore:
|
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
|
|
|
@ -18,6 +18,7 @@ providers:
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||||
api_key: ${env.NVIDIA_API_KEY:}
|
api_key: ${env.NVIDIA_API_KEY:}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -48,13 +49,10 @@ providers:
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||||
eval:
|
eval:
|
||||||
- provider_id: meta-reference
|
- provider_id: nvidia
|
||||||
provider_type: inline::meta-reference
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
kvstore:
|
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
|
|
7
llama_stack/templates/watsonx/__init__.py
Normal file
7
llama_stack/templates/watsonx/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .watsonx import get_distribution_template # noqa: F401
|
30
llama_stack/templates/watsonx/build.yaml
Normal file
30
llama_stack/templates/watsonx/build.yaml
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Use watsonx for running LLM inference
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::watsonx
|
||||||
|
vector_io:
|
||||||
|
- inline::faiss
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
eval:
|
||||||
|
- inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
- inline::localfs
|
||||||
|
scoring:
|
||||||
|
- inline::basic
|
||||||
|
- inline::llm-as-judge
|
||||||
|
- inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::rag-runtime
|
||||||
|
- remote::model-context-protocol
|
||||||
|
image_type: conda
|
74
llama_stack/templates/watsonx/doc_template.md
Normal file
74
llama_stack/templates/watsonx/doc_template.md
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# watsonx Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
|
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
{{ providers_table }}
|
||||||
|
|
||||||
|
{% if run_config_env_vars %}
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||||
|
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if default_models %}
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
{% for model in default_models %}
|
||||||
|
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
|
||||||
|
|
||||||
|
|
||||||
|
## Running Llama Stack with watsonx
|
||||||
|
|
||||||
|
You can do this via Conda (build code), venv or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-{{ name }} \
|
||||||
|
--yaml-config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template watsonx --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
|
||||||
|
```
|
210
llama_stack/templates/watsonx/run.yaml
Normal file
210
llama_stack/templates/watsonx/run.yaml
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: watsonx
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: watsonx
|
||||||
|
provider_type: remote::watsonx
|
||||||
|
config:
|
||||||
|
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
|
||||||
|
api_key: ${env.WATSONX_API_KEY:}
|
||||||
|
project_id: ${env.WATSONX_PROJECT_ID:}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/faiss_store.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/watsonx/trace_store.db}
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/meta_reference_eval.db
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/huggingface_datasetio.db
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/localfs_datasetio.db
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
- provider_id: llm-as-judge
|
||||||
|
provider_type: inline::llm-as-judge
|
||||||
|
config: {}
|
||||||
|
- provider_id: braintrust
|
||||||
|
provider_type: inline::braintrust
|
||||||
|
config:
|
||||||
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: code-interpreter
|
||||||
|
provider_type: inline::code-interpreter
|
||||||
|
config: {}
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-2-13b-chat
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-2-13b-chat
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-2-13b
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-2-13b-chat
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-1b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-3b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
shields: []
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
- toolgroup_id: builtin::code_interpreter
|
||||||
|
provider_id: code-interpreter
|
||||||
|
server:
|
||||||
|
port: 8321
|
90
llama_stack/templates/watsonx/watsonx.py
Normal file
90
llama_stack/templates/watsonx/watsonx.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||||
|
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||||
|
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::watsonx"],
|
||||||
|
"vector_io": ["inline::faiss"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"eval": ["inline::meta-reference"],
|
||||||
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::rag-runtime",
|
||||||
|
"remote::model-context-protocol",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="watsonx",
|
||||||
|
provider_type="remote::watsonx",
|
||||||
|
config=WatsonXConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
available_models = {
|
||||||
|
"watsonx": MODEL_ENTRIES,
|
||||||
|
}
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
default_models = get_model_registry(available_models)
|
||||||
|
return DistributionTemplate(
|
||||||
|
name="watsonx",
|
||||||
|
distro_type="remote_hosted",
|
||||||
|
description="Use watsonx for running LLM inference",
|
||||||
|
container_image=None,
|
||||||
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
|
providers=providers,
|
||||||
|
available_models_by_provider=available_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMASTACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"WATSONX_API_KEY": (
|
||||||
|
"",
|
||||||
|
"watsonx API Key",
|
||||||
|
),
|
||||||
|
"WATSONX_PROJECT_ID": (
|
||||||
|
"",
|
||||||
|
"watsonx Project ID",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
|
@ -58,7 +58,16 @@ dev = [
|
||||||
"ruamel.yaml", # needed for openapi generator
|
"ruamel.yaml", # needed for openapi generator
|
||||||
]
|
]
|
||||||
# These are the dependencies required for running unit tests.
|
# These are the dependencies required for running unit tests.
|
||||||
unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"]
|
unit = [
|
||||||
|
"sqlite-vec",
|
||||||
|
"openai",
|
||||||
|
"aiosqlite",
|
||||||
|
"aiohttp",
|
||||||
|
"pypdf",
|
||||||
|
"chardet",
|
||||||
|
"qdrant-client",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http"
|
||||||
|
]
|
||||||
# These are the core dependencies required for running integration tests. They are shared across all
|
# These are the core dependencies required for running integration tests. They are shared across all
|
||||||
# providers. If a provider requires additional dependencies, please add them to your environment
|
# providers. If a provider requires additional dependencies, please add them to your environment
|
||||||
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
# separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra
|
||||||
|
@ -265,6 +274,7 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/inference/sample/",
|
"^llama_stack/providers/remote/inference/sample/",
|
||||||
"^llama_stack/providers/remote/inference/tgi/",
|
"^llama_stack/providers/remote/inference/tgi/",
|
||||||
"^llama_stack/providers/remote/inference/together/",
|
"^llama_stack/providers/remote/inference/together/",
|
||||||
|
"^llama_stack/providers/remote/inference/watsonx/",
|
||||||
"^llama_stack/providers/remote/safety/bedrock/",
|
"^llama_stack/providers/remote/safety/bedrock/",
|
||||||
"^llama_stack/providers/remote/safety/nvidia/",
|
"^llama_stack/providers/remote/safety/nvidia/",
|
||||||
"^llama_stack/providers/remote/safety/sample/",
|
"^llama_stack/providers/remote/safety/sample/",
|
||||||
|
|
|
@ -10,6 +10,7 @@ import platform
|
||||||
import textwrap
|
import textwrap
|
||||||
import time
|
import time
|
||||||
|
|
||||||
|
import pytest
|
||||||
from dotenv import load_dotenv
|
from dotenv import load_dotenv
|
||||||
|
|
||||||
from llama_stack.log import get_logger
|
from llama_stack.log import get_logger
|
||||||
|
@ -19,7 +20,26 @@ from .report import Report
|
||||||
logger = get_logger(__name__, category="tests")
|
logger = get_logger(__name__, category="tests")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.hookimpl(hookwrapper=True)
|
||||||
|
def pytest_runtest_makereport(item, call):
|
||||||
|
outcome = yield
|
||||||
|
report = outcome.get_result()
|
||||||
|
if report.when == "call":
|
||||||
|
item.execution_outcome = report.outcome
|
||||||
|
item.was_xfail = getattr(report, "wasxfail", False)
|
||||||
|
|
||||||
|
|
||||||
def pytest_runtest_teardown(item):
|
def pytest_runtest_teardown(item):
|
||||||
|
# Check if the test actually ran and passed or failed, but was not skipped or an expected failure (xfail)
|
||||||
|
outcome = getattr(item, "execution_outcome", None)
|
||||||
|
was_xfail = getattr(item, "was_xfail", False)
|
||||||
|
|
||||||
|
name = item.nodeid
|
||||||
|
if not any(x in name for x in ("inference/", "safety/", "agents/")):
|
||||||
|
return
|
||||||
|
|
||||||
|
logger.debug(f"Test '{item.nodeid}' outcome was '{outcome}' (xfail={was_xfail})")
|
||||||
|
if outcome in ("passed", "failed") and not was_xfail:
|
||||||
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
|
interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS")
|
||||||
if interval_seconds:
|
if interval_seconds:
|
||||||
time.sleep(float(interval_seconds))
|
time.sleep(float(interval_seconds))
|
||||||
|
|
|
@ -75,19 +75,24 @@ def openai_client(client_with_models):
|
||||||
return OpenAI(base_url=base_url, api_key="bar")
|
return OpenAI(base_url=base_url, api_key="bar")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||||
|
def compat_client(request):
|
||||||
|
return request.getfixturevalue(request.param)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
"inference:completion:sanity",
|
"inference:completion:sanity",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_completion_non_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
# ollama needs more verbose prompting for some reason here...
|
# ollama needs more verbose prompting for some reason here...
|
||||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -103,13 +108,13 @@ def test_openai_completion_non_streaming(openai_client, client_with_models, text
|
||||||
"inference:completion:sanity",
|
"inference:completion:sanity",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_completion_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
# ollama needs more verbose prompting for some reason here...
|
# ollama needs more verbose prompting for some reason here...
|
||||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -127,11 +132,11 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
|
||||||
0,
|
0,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs):
|
def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_models, text_model_id, prompt_logprobs):
|
||||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||||
|
|
||||||
prompt = "Hello, world!"
|
prompt = "Hello, world!"
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -144,11 +149,11 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te
|
||||||
assert len(choice.prompt_logprobs) > 0
|
assert len(choice.prompt_logprobs) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
|
def test_openai_completion_guided_choice(llama_stack_client, client_with_models, text_model_id):
|
||||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||||
|
|
||||||
prompt = "I am feeling really sad today."
|
prompt = "I am feeling really sad today."
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -161,6 +166,9 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
||||||
assert choice.text in ["joy", "sadness"]
|
assert choice.text in ["joy", "sadness"]
|
||||||
|
|
||||||
|
|
||||||
|
# Run the chat-completion tests with both the OpenAI client and the LlamaStack client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
|
@ -168,13 +176,13 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
||||||
"inference:chat_completion:non_streaming_02",
|
"inference:chat_completion:non_streaming_02",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_chat_completion_non_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
question = tc["question"]
|
question = tc["question"]
|
||||||
expected = tc["expected"]
|
expected = tc["expected"]
|
||||||
|
|
||||||
response = openai_client.chat.completions.create(
|
response = compat_client.chat.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
@ -196,13 +204,13 @@ def test_openai_chat_completion_non_streaming(openai_client, client_with_models,
|
||||||
"inference:chat_completion:streaming_02",
|
"inference:chat_completion:streaming_02",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_chat_completion_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
question = tc["question"]
|
question = tc["question"]
|
||||||
expected = tc["expected"]
|
expected = tc["expected"]
|
||||||
|
|
||||||
response = openai_client.chat.completions.create(
|
response = compat_client.chat.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
messages=[{"role": "user", "content": question}],
|
messages=[{"role": "user", "content": question}],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -114,7 +114,7 @@ def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server):
|
||||||
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
# Verify it is unregistered
|
# Verify it is unregistered
|
||||||
with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"):
|
with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"):
|
||||||
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id)
|
||||||
|
|
||||||
# Verify tools are also unregistered
|
# Verify tools are also unregistered
|
||||||
|
|
|
@ -16,8 +16,9 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType
|
||||||
def test_container_build_passes_path(monkeypatch, tmp_path):
|
def test_container_build_passes_path(monkeypatch, tmp_path):
|
||||||
called_with = {}
|
called_with = {}
|
||||||
|
|
||||||
def spy_build_image(cfg, build_file_path, image_name, template_or_config):
|
def spy_build_image(cfg, build_file_path, image_name, template_or_config, run_config=None):
|
||||||
called_with["path"] = template_or_config
|
called_with["path"] = template_or_config
|
||||||
|
called_with["run_config"] = run_config
|
||||||
return 0
|
return 0
|
||||||
|
|
||||||
monkeypatch.setattr(
|
monkeypatch.setattr(
|
||||||
|
@ -36,3 +37,4 @@ def test_container_build_passes_path(monkeypatch, tmp_path):
|
||||||
assert "path" in called_with
|
assert "path" in called_with
|
||||||
assert isinstance(called_with["path"], str)
|
assert isinstance(called_with["path"], str)
|
||||||
assert Path(called_with["path"]).exists()
|
assert Path(called_with["path"]).exists()
|
||||||
|
assert called_with["run_config"] is None
|
||||||
|
|
|
@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
CompletionMessage,
|
||||||
|
SystemMessage,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.models.llama.datatypes import StopReason
|
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||||
from llama_stack.providers.remote.inference.vllm.vllm import (
|
from llama_stack.providers.remote.inference.vllm.vllm import (
|
||||||
VLLMInferenceAdapter,
|
VLLMInferenceAdapter,
|
||||||
|
@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
||||||
assert request.tool_config.tool_choice == ToolChoice.none
|
assert request.tool_config.tool_choice == ToolChoice.none
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_response(vllm_inference_adapter):
|
||||||
|
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||||
|
into the expected JSON format."""
|
||||||
|
|
||||||
|
# Patch the call to vllm so we can inspect the arguments sent were correct
|
||||||
|
with patch.object(
|
||||||
|
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||||||
|
) as mock_nonstream_completion:
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
|
UserMessage(content="How many?"),
|
||||||
|
CompletionMessage(
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="foo",
|
||||||
|
tool_name="knowledge_search",
|
||||||
|
arguments={"query": "How many?"},
|
||||||
|
arguments_json='{"query": "How many?"}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
|
||||||
|
]
|
||||||
|
await vllm_inference_adapter.chat_completion(
|
||||||
|
"mock-model",
|
||||||
|
messages,
|
||||||
|
stream=False,
|
||||||
|
tools=[],
|
||||||
|
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
|
||||||
|
{
|
||||||
|
"id": "foo",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_call_delta_empty_tool_call_buf():
|
async def test_tool_call_delta_empty_tool_call_buf():
|
||||||
"""
|
"""
|
||||||
|
|
201
tests/unit/providers/nvidia/test_eval.py
Normal file
201
tests/unit/providers/nvidia/test_eval.py
Normal file
|
@ -0,0 +1,201 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
|
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||||
|
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
|
||||||
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
|
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
|
||||||
|
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
|
||||||
|
|
||||||
|
MOCK_DATASET_ID = "default/test-dataset"
|
||||||
|
MOCK_BENCHMARK_ID = "test-benchmark"
|
||||||
|
|
||||||
|
|
||||||
|
class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
|
||||||
|
|
||||||
|
# Create mock APIs
|
||||||
|
self.datasetio_api = MagicMock()
|
||||||
|
self.datasets_api = MagicMock()
|
||||||
|
self.scoring_api = MagicMock()
|
||||||
|
self.inference_api = MagicMock()
|
||||||
|
self.agents_api = MagicMock()
|
||||||
|
|
||||||
|
self.config = NVIDIAEvalConfig(
|
||||||
|
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.eval_impl = NVIDIAEvalImpl(
|
||||||
|
config=self.config,
|
||||||
|
datasetio_api=self.datasetio_api,
|
||||||
|
datasets_api=self.datasets_api,
|
||||||
|
scoring_api=self.scoring_api,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
agents_api=self.agents_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the HTTP request methods
|
||||||
|
self.evaluator_get_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
|
||||||
|
)
|
||||||
|
self.evaluator_post_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mock_evaluator_get = self.evaluator_get_patcher.start()
|
||||||
|
self.mock_evaluator_post = self.evaluator_post_patcher.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up after each test."""
|
||||||
|
self.evaluator_get_patcher.stop()
|
||||||
|
self.evaluator_post_patcher.stop()
|
||||||
|
|
||||||
|
def _assert_request_body(self, expected_json):
|
||||||
|
"""Helper method to verify request body in Evaluator POST request is correct"""
|
||||||
|
call_args = self.mock_evaluator_post.call_args
|
||||||
|
actual_json = call_args[0][1]
|
||||||
|
|
||||||
|
# Check that all expected keys contain the expected values in the actual JSON
|
||||||
|
for key, value in expected_json.items():
|
||||||
|
assert key in actual_json, f"Key '{key}' missing in actual JSON"
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for nested_key, nested_value in value.items():
|
||||||
|
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
|
||||||
|
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
|
||||||
|
else:
|
||||||
|
assert actual_json[key] == value, f"Value mismatch for '{key}'"
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def inject_fixtures(self, run_async):
|
||||||
|
self.run_async = run_async
|
||||||
|
|
||||||
|
def test_register_benchmark(self):
|
||||||
|
eval_config = {
|
||||||
|
"type": "custom",
|
||||||
|
"params": {"parallelism": 8},
|
||||||
|
"tasks": {
|
||||||
|
"qa": {
|
||||||
|
"type": "completion",
|
||||||
|
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
|
||||||
|
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
|
||||||
|
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmark = Benchmark(
|
||||||
|
provider_id="nvidia",
|
||||||
|
type="benchmark",
|
||||||
|
identifier=MOCK_BENCHMARK_ID,
|
||||||
|
dataset_id=MOCK_DATASET_ID,
|
||||||
|
scoring_functions=["basic::equality"],
|
||||||
|
metadata=eval_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock Evaluator API response
|
||||||
|
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
|
||||||
|
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||||
|
|
||||||
|
# Register the benchmark
|
||||||
|
self.run_async(self.eval_impl.register_benchmark(benchmark))
|
||||||
|
|
||||||
|
# Verify the Evaluator API was called correctly
|
||||||
|
self.mock_evaluator_post.assert_called_once()
|
||||||
|
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
|
||||||
|
|
||||||
|
def test_run_eval(self):
|
||||||
|
benchmark_config = BenchmarkConfig(
|
||||||
|
eval_candidate=ModelCandidate(
|
||||||
|
type="model",
|
||||||
|
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock Evaluator API response
|
||||||
|
mock_evaluator_response = {"id": "job-123", "status": "created"}
|
||||||
|
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||||
|
|
||||||
|
# Run the Evaluation job
|
||||||
|
result = self.run_async(
|
||||||
|
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the Evaluator API was called correctly
|
||||||
|
self.mock_evaluator_post.assert_called_once()
|
||||||
|
self._assert_request_body(
|
||||||
|
{
|
||||||
|
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||||
|
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert isinstance(result, Job)
|
||||||
|
assert result.job_id == "job-123"
|
||||||
|
assert result.status == JobStatus.in_progress
|
||||||
|
|
||||||
|
def test_job_status(self):
|
||||||
|
# Mock Evaluator API response
|
||||||
|
mock_evaluator_response = {"id": "job-123", "status": "completed"}
|
||||||
|
self.mock_evaluator_get.return_value = mock_evaluator_response
|
||||||
|
|
||||||
|
# Get the Evaluation job
|
||||||
|
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert isinstance(result, Job)
|
||||||
|
assert result.job_id == "job-123"
|
||||||
|
assert result.status == JobStatus.completed
|
||||||
|
|
||||||
|
# Verify the API was called correctly
|
||||||
|
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
|
||||||
|
|
||||||
|
def test_job_cancel(self):
|
||||||
|
# Mock Evaluator API response
|
||||||
|
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
|
||||||
|
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||||
|
|
||||||
|
# Cancel the Evaluation job
|
||||||
|
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||||
|
|
||||||
|
# Verify the API was called correctly
|
||||||
|
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
|
||||||
|
|
||||||
|
def test_job_result(self):
|
||||||
|
# Mock Evaluator API responses
|
||||||
|
mock_job_status_response = {"id": "job-123", "status": "completed"}
|
||||||
|
mock_job_results_response = {
|
||||||
|
"id": "job-123",
|
||||||
|
"status": "completed",
|
||||||
|
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
|
||||||
|
}
|
||||||
|
self.mock_evaluator_get.side_effect = [
|
||||||
|
mock_job_status_response, # First call to retrieve job
|
||||||
|
mock_job_results_response, # Second call to retrieve job results
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get the Evaluation job results
|
||||||
|
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert isinstance(result, EvaluateResponse)
|
||||||
|
assert MOCK_BENCHMARK_ID in result.scores
|
||||||
|
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
|
||||||
|
|
||||||
|
# Verify the API was called correctly
|
||||||
|
assert self.mock_evaluator_get.call_count == 2
|
||||||
|
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
|
||||||
|
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
|
|
@ -10,14 +10,17 @@ import warnings
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig
|
|
||||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|
||||||
TrainingConfig,
|
|
||||||
TrainingConfigDataConfig,
|
|
||||||
TrainingConfigEfficiencyConfig,
|
|
||||||
TrainingConfigOptimizerConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from llama_stack.apis.post_training.post_training import (
|
||||||
|
DataConfig,
|
||||||
|
DatasetFormat,
|
||||||
|
EfficiencyConfig,
|
||||||
|
LoraFinetuningConfig,
|
||||||
|
OptimizerConfig,
|
||||||
|
OptimizerType,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
NvidiaPostTrainingAdapter,
|
NvidiaPostTrainingAdapter,
|
||||||
NvidiaPostTrainingConfig,
|
NvidiaPostTrainingConfig,
|
||||||
|
@ -66,11 +69,8 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
|
|
||||||
def test_customizer_parameters_passed(self):
|
def test_customizer_parameters_passed(self):
|
||||||
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
|
||||||
custom_adapter_dim = 32 # Different from default of 8
|
|
||||||
algorithm_config = LoraFinetuningConfig(
|
algorithm_config = LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
adapter_dim=custom_adapter_dim,
|
|
||||||
adapter_dropout=0.2,
|
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
apply_lora_to_output=True,
|
apply_lora_to_output=True,
|
||||||
alpha=16,
|
alpha=16,
|
||||||
|
@ -78,8 +78,15 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
)
|
)
|
||||||
|
|
||||||
data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16)
|
data_config = DataConfig(
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002)
|
dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
|
)
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
|
lr=0.0002,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
|
)
|
||||||
training_config = TrainingConfig(
|
training_config = TrainingConfig(
|
||||||
n_epochs=3,
|
n_epochs=3,
|
||||||
data_config=data_config,
|
data_config=data_config,
|
||||||
|
@ -95,7 +102,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={},
|
logger_config={},
|
||||||
hyperparam_search_config={},
|
hyperparam_search_config={},
|
||||||
)
|
)
|
||||||
|
@ -114,7 +121,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
self._assert_request_params(
|
self._assert_request_params(
|
||||||
{
|
{
|
||||||
"hyperparameters": {
|
"hyperparameters": {
|
||||||
"lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16},
|
"lora": {"alpha": 16},
|
||||||
"epochs": 3,
|
"epochs": 3,
|
||||||
"learning_rate": 0.0002,
|
"learning_rate": 0.0002,
|
||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
|
@ -130,8 +137,6 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
|
|
||||||
algorithm_config = LoraFinetuningConfig(
|
algorithm_config = LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
adapter_dim=16,
|
|
||||||
adapter_dropout=0.1,
|
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
apply_lora_to_output=True,
|
apply_lora_to_output=True,
|
||||||
alpha=16,
|
alpha=16,
|
||||||
|
@ -139,12 +144,16 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
)
|
)
|
||||||
|
|
||||||
data_config = TrainingConfigDataConfig(
|
data_config = DataConfig(
|
||||||
dataset_id=required_dataset_id, # Required parameter
|
dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
batch_size=8,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001)
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
|
lr=0.0001,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
|
)
|
||||||
|
|
||||||
training_config = TrainingConfig(
|
training_config = TrainingConfig(
|
||||||
n_epochs=1,
|
n_epochs=1,
|
||||||
|
@ -161,7 +170,7 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
model=required_model, # Required parameter
|
model=required_model, # Required parameter
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={},
|
logger_config={},
|
||||||
hyperparam_search_config={},
|
hyperparam_search_config={},
|
||||||
)
|
)
|
||||||
|
@ -186,24 +195,24 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
|
|
||||||
def test_unsupported_parameters_warning(self):
|
def test_unsupported_parameters_warning(self):
|
||||||
"""Test that warnings are raised for unsupported parameters."""
|
"""Test that warnings are raised for unsupported parameters."""
|
||||||
data_config = TrainingConfigDataConfig(
|
data_config = DataConfig(
|
||||||
dataset_id="test-dataset",
|
dataset_id="test-dataset",
|
||||||
batch_size=8,
|
batch_size=8,
|
||||||
# Unsupported parameters
|
# Unsupported parameters
|
||||||
shuffle=True,
|
shuffle=True,
|
||||||
data_format="instruct",
|
data_format=DatasetFormat.instruct,
|
||||||
validation_dataset_id="val-dataset",
|
validation_dataset_id="val-dataset",
|
||||||
)
|
)
|
||||||
|
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(
|
optimizer_config = OptimizerConfig(
|
||||||
lr=0.0001,
|
lr=0.0001,
|
||||||
weight_decay=0.01,
|
weight_decay=0.01,
|
||||||
# Unsupported parameters
|
# Unsupported parameters
|
||||||
optimizer_type="adam",
|
optimizer_type=OptimizerType.adam,
|
||||||
num_warmup_steps=100,
|
num_warmup_steps=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
efficiency_config = TrainingConfigEfficiencyConfig(
|
efficiency_config = EfficiencyConfig(
|
||||||
enable_activation_checkpointing=True # Unsupported parameter
|
enable_activation_checkpointing=True # Unsupported parameter
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@ -230,15 +239,13 @@ class TestNvidiaParameters(unittest.TestCase):
|
||||||
checkpoint_dir="test-dir", # Unsupported parameter
|
checkpoint_dir="test-dir", # Unsupported parameter
|
||||||
algorithm_config=LoraFinetuningConfig(
|
algorithm_config=LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
adapter_dim=16,
|
|
||||||
adapter_dropout=0.1,
|
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
apply_lora_to_output=True,
|
apply_lora_to_output=True,
|
||||||
alpha=16,
|
alpha=16,
|
||||||
rank=16,
|
rank=16,
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
),
|
),
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={"test": "value"}, # Unsupported parameter
|
logger_config={"test": "value"}, # Unsupported parameter
|
||||||
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
hyperparam_search_config={"test": "value"}, # Unsupported parameter
|
||||||
)
|
)
|
||||||
|
|
|
@ -10,13 +10,19 @@ import warnings
|
||||||
from unittest.mock import patch
|
from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig
|
|
||||||
from llama_stack_client.types.post_training_supervised_fine_tune_params import (
|
|
||||||
TrainingConfig,
|
|
||||||
TrainingConfigDataConfig,
|
|
||||||
TrainingConfigOptimizerConfig,
|
|
||||||
)
|
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.apis.post_training.post_training import (
|
||||||
|
DataConfig,
|
||||||
|
DatasetFormat,
|
||||||
|
LoraFinetuningConfig,
|
||||||
|
OptimizerConfig,
|
||||||
|
OptimizerType,
|
||||||
|
QATFinetuningConfig,
|
||||||
|
TrainingConfig,
|
||||||
|
)
|
||||||
|
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
ListNvidiaPostTrainingJobs,
|
ListNvidiaPostTrainingJobs,
|
||||||
NvidiaPostTrainingAdapter,
|
NvidiaPostTrainingAdapter,
|
||||||
|
@ -40,8 +46,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.mock_make_request = self.make_request_patcher.start()
|
self.mock_make_request = self.make_request_patcher.start()
|
||||||
|
|
||||||
|
# Mock the inference client
|
||||||
|
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||||
|
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||||
|
|
||||||
|
self.mock_client = unittest.mock.MagicMock()
|
||||||
|
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||||
|
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||||
|
self.inference_make_request_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||||
|
return_value=self.mock_client,
|
||||||
|
)
|
||||||
|
self.inference_make_request_patcher.start()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.make_request_patcher.stop()
|
self.make_request_patcher.stop()
|
||||||
|
self.inference_make_request_patcher.stop()
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_fixtures(self, run_async):
|
def inject_fixtures(self, run_async):
|
||||||
|
@ -105,7 +125,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
"epochs": 2,
|
"epochs": 2,
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
|
"lora": {"alpha": 16},
|
||||||
},
|
},
|
||||||
"output_model": "default/job-1234",
|
"output_model": "default/job-1234",
|
||||||
"status": "created",
|
"status": "created",
|
||||||
|
@ -116,8 +136,6 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
|
|
||||||
algorithm_config = LoraFinetuningConfig(
|
algorithm_config = LoraFinetuningConfig(
|
||||||
type="LoRA",
|
type="LoRA",
|
||||||
adapter_dim=16,
|
|
||||||
adapter_dropout=0.1,
|
|
||||||
apply_lora_to_mlp=True,
|
apply_lora_to_mlp=True,
|
||||||
apply_lora_to_output=True,
|
apply_lora_to_output=True,
|
||||||
alpha=16,
|
alpha=16,
|
||||||
|
@ -125,10 +143,15 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
|
||||||
)
|
)
|
||||||
|
|
||||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
data_config = DataConfig(
|
||||||
|
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
|
)
|
||||||
|
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
lr=0.0001,
|
lr=0.0001,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
)
|
)
|
||||||
|
|
||||||
training_config = TrainingConfig(
|
training_config = TrainingConfig(
|
||||||
|
@ -145,7 +168,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={},
|
logger_config={},
|
||||||
hyperparam_search_config={},
|
hyperparam_search_config={},
|
||||||
)
|
)
|
||||||
|
@ -169,16 +192,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
"epochs": 2,
|
"epochs": 2,
|
||||||
"batch_size": 16,
|
"batch_size": 16,
|
||||||
"learning_rate": 0.0001,
|
"learning_rate": 0.0001,
|
||||||
"lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1},
|
"weight_decay": 0.01,
|
||||||
|
"lora": {"alpha": 16},
|
||||||
},
|
},
|
||||||
},
|
},
|
||||||
)
|
)
|
||||||
|
|
||||||
def test_supervised_fine_tune_with_qat(self):
|
def test_supervised_fine_tune_with_qat(self):
|
||||||
algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
|
||||||
data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16)
|
data_config = DataConfig(
|
||||||
optimizer_config = TrainingConfigOptimizerConfig(
|
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
|
||||||
|
)
|
||||||
|
optimizer_config = OptimizerConfig(
|
||||||
|
optimizer_type=OptimizerType.adam,
|
||||||
lr=0.0001,
|
lr=0.0001,
|
||||||
|
weight_decay=0.01,
|
||||||
|
num_warmup_steps=100,
|
||||||
)
|
)
|
||||||
training_config = TrainingConfig(
|
training_config = TrainingConfig(
|
||||||
n_epochs=2,
|
n_epochs=2,
|
||||||
|
@ -193,7 +222,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
model="meta-llama/Llama-3.1-8B-Instruct",
|
model="meta-llama/Llama-3.1-8B-Instruct",
|
||||||
checkpoint_dir="",
|
checkpoint_dir="",
|
||||||
algorithm_config=algorithm_config,
|
algorithm_config=algorithm_config,
|
||||||
training_config=training_config,
|
training_config=convert_pydantic_to_json_value(training_config),
|
||||||
logger_config={},
|
logger_config={},
|
||||||
hyperparam_search_config={},
|
hyperparam_search_config={},
|
||||||
)
|
)
|
||||||
|
@ -303,6 +332,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
expected_params={"job_id": job_id},
|
expected_params={"job_id": job_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_inference_register_model(self):
|
||||||
|
model_id = "default/job-1234"
|
||||||
|
model_type = ModelType.llm
|
||||||
|
model = Model(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id=model_id,
|
||||||
|
provider_resource_id=model_id,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
result = self.run_async(self.inference_adapter.register_model(model))
|
||||||
|
assert result == model
|
||||||
|
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||||
|
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||||
|
|
||||||
|
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||||
|
self.run_async(
|
||||||
|
self.inference_adapter.chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages=[{"role": "user", "content": "Hello, model"}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chat_completion.assert_called()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
43
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
43
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import TextContentItem
|
||||||
|
from llama_stack.apis.inference.inference import CompletionMessage, UserMessage
|
||||||
|
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_message_to_openai_dict():
|
||||||
|
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||||
|
assert await convert_message_to_openai_dict(message) == {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Hello, world!"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Test convert_message_to_openai_dict with a tool call
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||||
|
message = CompletionMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"})
|
||||||
|
],
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_dict = await convert_message_to_openai_dict(message)
|
||||||
|
|
||||||
|
assert openai_dict == {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": ""}],
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
||||||
|
],
|
||||||
|
}
|
91
tests/unit/server/test_sse.py
Normal file
91
tests/unit/server/test_sse.py
Normal file
|
@ -0,0 +1,91 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.distribution.server.server import create_sse_event, sse_generator
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_generator_basic():
|
||||||
|
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||||
|
async def async_event_gen():
|
||||||
|
async def event_gen():
|
||||||
|
yield "Test event 1"
|
||||||
|
yield "Test event 2"
|
||||||
|
|
||||||
|
return event_gen()
|
||||||
|
|
||||||
|
sse_gen = sse_generator(async_event_gen())
|
||||||
|
assert sse_gen is not None
|
||||||
|
|
||||||
|
# Test that the events are streamed correctly
|
||||||
|
seen_events = []
|
||||||
|
async for event in sse_gen:
|
||||||
|
seen_events.append(event)
|
||||||
|
assert len(seen_events) == 2
|
||||||
|
assert seen_events[0] == create_sse_event("Test event 1")
|
||||||
|
assert seen_events[1] == create_sse_event("Test event 2")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_generator_client_disconnected():
|
||||||
|
# An AsyncIterator wrapped in an Awaitable, just like our web methods
|
||||||
|
async def async_event_gen():
|
||||||
|
async def event_gen():
|
||||||
|
yield "Test event 1"
|
||||||
|
# Simulate a client disconnect before emitting event 2
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
return event_gen()
|
||||||
|
|
||||||
|
sse_gen = sse_generator(async_event_gen())
|
||||||
|
assert sse_gen is not None
|
||||||
|
|
||||||
|
seen_events = []
|
||||||
|
async for event in sse_gen:
|
||||||
|
seen_events.append(event)
|
||||||
|
|
||||||
|
# We should see 1 event before the client disconnected
|
||||||
|
assert len(seen_events) == 1
|
||||||
|
assert seen_events[0] == create_sse_event("Test event 1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_generator_client_disconnected_before_response_starts():
|
||||||
|
# Disconnect before the response starts
|
||||||
|
async def async_event_gen():
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
sse_gen = sse_generator(async_event_gen())
|
||||||
|
assert sse_gen is not None
|
||||||
|
|
||||||
|
seen_events = []
|
||||||
|
async for event in sse_gen:
|
||||||
|
seen_events.append(event)
|
||||||
|
|
||||||
|
# No events should be seen since the client disconnected immediately
|
||||||
|
assert len(seen_events) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_generator_error_before_response_starts():
|
||||||
|
# Raise an error before the response starts
|
||||||
|
async def async_event_gen():
|
||||||
|
raise Exception("Test error")
|
||||||
|
|
||||||
|
sse_gen = sse_generator(async_event_gen())
|
||||||
|
assert sse_gen is not None
|
||||||
|
|
||||||
|
seen_events = []
|
||||||
|
async for event in sse_gen:
|
||||||
|
seen_events.append(event)
|
||||||
|
|
||||||
|
# We should have 1 error event
|
||||||
|
assert len(seen_events) == 1
|
||||||
|
assert 'data: {"error":' in seen_events[0]
|
|
@ -15,6 +15,52 @@ test_chat_basic:
|
||||||
S?
|
S?
|
||||||
role: user
|
role: user
|
||||||
output: Saturn
|
output: Saturn
|
||||||
|
test_chat_input_validation:
|
||||||
|
test_name: test_chat_input_validation
|
||||||
|
test_params:
|
||||||
|
case:
|
||||||
|
- case_id: "messages_missing"
|
||||||
|
input:
|
||||||
|
messages: []
|
||||||
|
output:
|
||||||
|
error:
|
||||||
|
status_code: 400
|
||||||
|
- case_id: "messages_role_invalid"
|
||||||
|
input:
|
||||||
|
messages:
|
||||||
|
- content: Which planet do humans live on?
|
||||||
|
role: fake_role
|
||||||
|
output:
|
||||||
|
error:
|
||||||
|
status_code: 400
|
||||||
|
- case_id: "tool_choice_invalid"
|
||||||
|
input:
|
||||||
|
messages:
|
||||||
|
- content: Which planet do humans live on?
|
||||||
|
role: user
|
||||||
|
tool_choice: invalid
|
||||||
|
output:
|
||||||
|
error:
|
||||||
|
status_code: 400
|
||||||
|
- case_id: "tool_choice_no_tools"
|
||||||
|
input:
|
||||||
|
messages:
|
||||||
|
- content: Which planet do humans live on?
|
||||||
|
role: user
|
||||||
|
tool_choice: required
|
||||||
|
output:
|
||||||
|
error:
|
||||||
|
status_code: 400
|
||||||
|
- case_id: "tools_type_invalid"
|
||||||
|
input:
|
||||||
|
messages:
|
||||||
|
- content: Which planet do humans live on?
|
||||||
|
role: user
|
||||||
|
tools:
|
||||||
|
- type: invalid
|
||||||
|
output:
|
||||||
|
error:
|
||||||
|
status_code: 400
|
||||||
test_chat_image:
|
test_chat_image:
|
||||||
test_name: test_chat_image
|
test_name: test_chat_image
|
||||||
test_params:
|
test_params:
|
||||||
|
|
|
@ -12,6 +12,7 @@ from pathlib import Path
|
||||||
from typing import Any
|
from typing import Any
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
from openai import APIError
|
||||||
from pydantic import BaseModel
|
from pydantic import BaseModel
|
||||||
|
|
||||||
from tests.verifications.openai_api.fixtures.fixtures import (
|
from tests.verifications.openai_api.fixtures.fixtures import (
|
||||||
|
@ -136,6 +137,50 @@ def test_chat_streaming_basic(request, openai_client, model, provider, verificat
|
||||||
assert case["output"].lower() in content.lower()
|
assert case["output"].lower() in content.lower()
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"case",
|
||||||
|
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
|
||||||
|
ids=case_id_generator,
|
||||||
|
)
|
||||||
|
def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
|
||||||
|
test_name_base = get_base_test_name(request)
|
||||||
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||||
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||||
|
|
||||||
|
with pytest.raises(APIError) as e:
|
||||||
|
openai_client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=case["input"]["messages"],
|
||||||
|
stream=False,
|
||||||
|
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
|
||||||
|
tools=case["input"]["tools"] if "tools" in case["input"] else None,
|
||||||
|
)
|
||||||
|
assert case["output"]["error"]["status_code"] == e.value.status_code
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.parametrize(
|
||||||
|
"case",
|
||||||
|
chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"],
|
||||||
|
ids=case_id_generator,
|
||||||
|
)
|
||||||
|
def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case):
|
||||||
|
test_name_base = get_base_test_name(request)
|
||||||
|
if should_skip_test(verification_config, provider, model, test_name_base):
|
||||||
|
pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.")
|
||||||
|
|
||||||
|
with pytest.raises(APIError) as e:
|
||||||
|
response = openai_client.chat.completions.create(
|
||||||
|
model=model,
|
||||||
|
messages=case["input"]["messages"],
|
||||||
|
stream=True,
|
||||||
|
tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None,
|
||||||
|
tools=case["input"]["tools"] if "tools" in case["input"] else None,
|
||||||
|
)
|
||||||
|
for _chunk in response:
|
||||||
|
pass
|
||||||
|
assert str(case["output"]["error"]["status_code"]) in e.value.message
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"case",
|
"case",
|
||||||
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
|
chat_completion_test_cases["test_chat_image"]["test_params"]["case"],
|
||||||
|
|
2
uv.lock
generated
2
uv.lock
generated
|
@ -1458,6 +1458,7 @@ unit = [
|
||||||
{ name = "aiosqlite" },
|
{ name = "aiosqlite" },
|
||||||
{ name = "chardet" },
|
{ name = "chardet" },
|
||||||
{ name = "openai" },
|
{ name = "openai" },
|
||||||
|
{ name = "opentelemetry-exporter-otlp-proto-http" },
|
||||||
{ name = "pypdf" },
|
{ name = "pypdf" },
|
||||||
{ name = "qdrant-client" },
|
{ name = "qdrant-client" },
|
||||||
{ name = "sqlite-vec" },
|
{ name = "sqlite-vec" },
|
||||||
|
@ -1491,6 +1492,7 @@ requires-dist = [
|
||||||
{ name = "openai", marker = "extra == 'test'" },
|
{ name = "openai", marker = "extra == 'test'" },
|
||||||
{ name = "openai", marker = "extra == 'unit'" },
|
{ name = "openai", marker = "extra == 'unit'" },
|
||||||
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" },
|
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" },
|
||||||
|
{ name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'unit'" },
|
||||||
{ name = "opentelemetry-sdk", marker = "extra == 'test'" },
|
{ name = "opentelemetry-sdk", marker = "extra == 'test'" },
|
||||||
{ name = "pandas", marker = "extra == 'ui'" },
|
{ name = "pandas", marker = "extra == 'ui'" },
|
||||||
{ name = "pillow" },
|
{ name = "pillow" },
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue