mirror of
https://github.com/meta-llama/llama-stack.git
synced 2025-08-03 01:03:59 +00:00
Merge branch 'main' into fix/nvidia-launch-customization
This commit is contained in:
commit
6659ed995a
53 changed files with 2203 additions and 217 deletions
6
.coveragerc
Normal file
6
.coveragerc
Normal file
|
@ -0,0 +1,6 @@
|
||||||
|
[run]
|
||||||
|
omit =
|
||||||
|
*/tests/*
|
||||||
|
*/llama_stack/providers/*
|
||||||
|
*/llama_stack/templates/*
|
||||||
|
.venv/*
|
|
@ -119,6 +119,7 @@ Here is a list of the various API providers and available distributions that can
|
||||||
| OpenAI | Hosted | | ✅ | | | |
|
| OpenAI | Hosted | | ✅ | | | |
|
||||||
| Anthropic | Hosted | | ✅ | | | |
|
| Anthropic | Hosted | | ✅ | | | |
|
||||||
| Gemini | Hosted | | ✅ | | | |
|
| Gemini | Hosted | | ✅ | | | |
|
||||||
|
| watsonx | Hosted | | ✅ | | | |
|
||||||
|
|
||||||
|
|
||||||
### Distributions
|
### Distributions
|
||||||
|
@ -128,7 +129,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider
|
||||||
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
| **Distribution** | **Llama Stack Docker** | Start This Distribution |
|
||||||
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
|:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:|
|
||||||
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
| Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) |
|
||||||
| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) |
|
|
||||||
| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) |
|
| SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) |
|
||||||
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
|
| Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) |
|
||||||
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
| Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) |
|
||||||
|
|
|
@ -109,8 +109,6 @@ llama stack build --list-templates
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| nvidia | Use NVIDIA NIM for running LLM inference |
|
| nvidia | Use NVIDIA NIM for running LLM inference |
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| meta-reference-quantized-gpu | Use Meta Reference with fp8, int4 quantization for running LLM inference |
|
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
|
||||||
| cerebras | Use Cerebras for running LLM inference |
|
| cerebras | Use Cerebras for running LLM inference |
|
||||||
+------------------------------+-----------------------------------------------------------------------------+
|
+------------------------------+-----------------------------------------------------------------------------+
|
||||||
| ollama | Use (an external) Ollama server for running LLM inference |
|
| ollama | Use (an external) Ollama server for running LLM inference |
|
||||||
|
|
88
docs/source/distributions/remote_hosted_distro/watsonx.md
Normal file
88
docs/source/distributions/remote_hosted_distro/watsonx.md
Normal file
|
@ -0,0 +1,88 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
||||||
|
# watsonx Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
|
The `llamastack/distribution-watsonx` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
| API | Provider(s) |
|
||||||
|
|-----|-------------|
|
||||||
|
| agents | `inline::meta-reference` |
|
||||||
|
| datasetio | `remote::huggingface`, `inline::localfs` |
|
||||||
|
| eval | `inline::meta-reference` |
|
||||||
|
| inference | `remote::watsonx` |
|
||||||
|
| safety | `inline::llama-guard` |
|
||||||
|
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
||||||
|
| telemetry | `inline::meta-reference` |
|
||||||
|
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
||||||
|
| vector_io | `inline::faiss` |
|
||||||
|
|
||||||
|
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`)
|
||||||
|
- `WATSONX_API_KEY`: watsonx API Key (default: ``)
|
||||||
|
- `WATSONX_PROJECT_ID`: watsonx Project ID (default: ``)
|
||||||
|
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
- `meta-llama/llama-3-3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)`
|
||||||
|
- `meta-llama/llama-2-13b-chat (aliases: meta-llama/Llama-2-13b)`
|
||||||
|
- `meta-llama/llama-3-1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)`
|
||||||
|
- `meta-llama/llama-3-2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)`
|
||||||
|
- `meta-llama/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)`
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
|
||||||
|
|
||||||
|
|
||||||
|
## Running Llama Stack with watsonx
|
||||||
|
|
||||||
|
You can do this via Conda (build code), venv or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-watsonx \
|
||||||
|
--yaml-config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template watsonx --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
|
||||||
|
```
|
|
@ -81,6 +81,7 @@ LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
@ -94,6 +95,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-meta-reference-gpu \
|
llamastack/distribution-meta-reference-gpu \
|
||||||
|
|
|
@ -1,123 +0,0 @@
|
||||||
---
|
|
||||||
orphan: true
|
|
||||||
---
|
|
||||||
<!-- This file was auto-generated by distro_codegen.py, please edit source -->
|
|
||||||
# Meta Reference Quantized Distribution
|
|
||||||
|
|
||||||
```{toctree}
|
|
||||||
:maxdepth: 2
|
|
||||||
:hidden:
|
|
||||||
|
|
||||||
self
|
|
||||||
```
|
|
||||||
|
|
||||||
The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations:
|
|
||||||
|
|
||||||
| API | Provider(s) |
|
|
||||||
|-----|-------------|
|
|
||||||
| agents | `inline::meta-reference` |
|
|
||||||
| datasetio | `remote::huggingface`, `inline::localfs` |
|
|
||||||
| eval | `inline::meta-reference` |
|
|
||||||
| inference | `inline::meta-reference-quantized` |
|
|
||||||
| safety | `inline::llama-guard` |
|
|
||||||
| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` |
|
|
||||||
| telemetry | `inline::meta-reference` |
|
|
||||||
| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` |
|
|
||||||
| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` |
|
|
||||||
|
|
||||||
|
|
||||||
The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc.
|
|
||||||
|
|
||||||
Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs.
|
|
||||||
|
|
||||||
### Environment Variables
|
|
||||||
|
|
||||||
The following environment variables can be configured:
|
|
||||||
|
|
||||||
- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`)
|
|
||||||
- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`)
|
|
||||||
- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`)
|
|
||||||
|
|
||||||
|
|
||||||
## Prerequisite: Downloading Models
|
|
||||||
|
|
||||||
Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints.
|
|
||||||
|
|
||||||
```
|
|
||||||
$ llama model list --downloaded
|
|
||||||
┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓
|
|
||||||
┃ Model ┃ Size ┃ Modified Time ┃
|
|
||||||
┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩
|
|
||||||
│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │
|
|
||||||
├─────────────────────────────────────────┼──────────┼─────────────────────┤
|
|
||||||
│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │
|
|
||||||
└─────────────────────────────────────────┴──────────┴─────────────────────┘
|
|
||||||
```
|
|
||||||
|
|
||||||
## Running the Distribution
|
|
||||||
|
|
||||||
You can do this via Conda (build code) or Docker which has a pre-built image.
|
|
||||||
|
|
||||||
### Via Docker
|
|
||||||
|
|
||||||
This method allows you to get started quickly without having to build the distribution code.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
LLAMA_STACK_PORT=8321
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
docker run \
|
|
||||||
-it \
|
|
||||||
--pull always \
|
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
|
||||||
-v ~/.llama:/root/.llama \
|
|
||||||
llamastack/distribution-meta-reference-quantized-gpu \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
|
||||||
|
|
||||||
### Via Conda
|
|
||||||
|
|
||||||
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack build --template meta-reference-quantized-gpu --image-type conda
|
|
||||||
llama stack run distributions/meta-reference-quantized-gpu/run.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct
|
|
||||||
```
|
|
||||||
|
|
||||||
If you are using Llama Stack Safety / Shield APIs, use:
|
|
||||||
|
|
||||||
```bash
|
|
||||||
llama stack run distributions/meta-reference-quantized-gpu/run-with-safety.yaml \
|
|
||||||
--port $LLAMA_STACK_PORT \
|
|
||||||
--env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \
|
|
||||||
--env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B
|
|
||||||
```
|
|
|
@ -7,7 +7,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
|-----|-------------|
|
|-----|-------------|
|
||||||
| agents | `inline::meta-reference` |
|
| agents | `inline::meta-reference` |
|
||||||
| datasetio | `inline::localfs` |
|
| datasetio | `inline::localfs` |
|
||||||
| eval | `inline::meta-reference` |
|
| eval | `remote::nvidia` |
|
||||||
| inference | `remote::nvidia` |
|
| inference | `remote::nvidia` |
|
||||||
| post_training | `remote::nvidia` |
|
| post_training | `remote::nvidia` |
|
||||||
| safety | `remote::nvidia` |
|
| safety | `remote::nvidia` |
|
||||||
|
@ -22,13 +22,13 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov
|
||||||
The following environment variables can be configured:
|
The following environment variables can be configured:
|
||||||
|
|
||||||
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
- `NVIDIA_API_KEY`: NVIDIA API Key (default: ``)
|
||||||
- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`)
|
- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`)
|
||||||
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
- `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`)
|
||||||
- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`)
|
|
||||||
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
- `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`)
|
||||||
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
- `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`)
|
||||||
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
- `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`)
|
||||||
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
- `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`)
|
||||||
|
- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`)
|
||||||
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
- `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`)
|
||||||
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
- `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`)
|
||||||
|
|
||||||
|
|
|
@ -389,5 +389,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -256,5 +256,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -301,5 +301,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -200,5 +200,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.12.2"
|
"version": "3.12.2"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -355,5 +355,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -398,5 +398,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -132,5 +132,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.11.10"
|
"version": "3.11.10"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -188,5 +188,7 @@
|
||||||
"pygments_lexer": "ipython3",
|
"pygments_lexer": "ipython3",
|
||||||
"version": "3.10.15"
|
"version": "3.10.15"
|
||||||
}
|
}
|
||||||
}
|
},
|
||||||
|
"nbformat": 4,
|
||||||
|
"nbformat_minor": 5
|
||||||
}
|
}
|
||||||
|
|
|
@ -136,12 +136,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None:
|
||||||
)
|
)
|
||||||
|
|
||||||
image_type = prompt(
|
image_type = prompt(
|
||||||
f"> Enter the image type you want your Llama Stack to be built as ({' or '.join(e.value for e in ImageType)}): ",
|
"> Enter the image type you want your Llama Stack to be built as (use <TAB> to see options): ",
|
||||||
|
completer=WordCompleter([e.value for e in ImageType]),
|
||||||
|
complete_while_typing=True,
|
||||||
validator=Validator.from_callable(
|
validator=Validator.from_callable(
|
||||||
lambda x: x in [e.value for e in ImageType],
|
lambda x: x in [e.value for e in ImageType],
|
||||||
error_message=f"Invalid image type, please enter {' or '.join(e.value for e in ImageType)}",
|
error_message="Invalid image type. Use <TAB> to see options",
|
||||||
),
|
),
|
||||||
default=ImageType.CONDA.value,
|
|
||||||
)
|
)
|
||||||
|
|
||||||
if image_type == ImageType.CONDA.value:
|
if image_type == ImageType.CONDA.value:
|
||||||
|
|
|
@ -166,14 +166,16 @@ async def maybe_await(value):
|
||||||
|
|
||||||
|
|
||||||
async def sse_generator(event_gen_coroutine):
|
async def sse_generator(event_gen_coroutine):
|
||||||
event_gen = await event_gen_coroutine
|
event_gen = None
|
||||||
try:
|
try:
|
||||||
|
event_gen = await event_gen_coroutine
|
||||||
async for item in event_gen:
|
async for item in event_gen:
|
||||||
yield create_sse_event(item)
|
yield create_sse_event(item)
|
||||||
await asyncio.sleep(0.01)
|
await asyncio.sleep(0.01)
|
||||||
except asyncio.CancelledError:
|
except asyncio.CancelledError:
|
||||||
logger.info("Generator cancelled")
|
logger.info("Generator cancelled")
|
||||||
await event_gen.aclose()
|
if event_gen:
|
||||||
|
await event_gen.aclose()
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
logger.exception("Error in sse_generator")
|
logger.exception("Error in sse_generator")
|
||||||
yield create_sse_event(
|
yield create_sse_event(
|
||||||
|
@ -459,6 +461,7 @@ def main(args: Optional[argparse.Namespace] = None):
|
||||||
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
raise ValueError(f"Could not find method {endpoint.name} on {impl}!!")
|
||||||
|
|
||||||
impl_method = getattr(impl, endpoint.name)
|
impl_method = getattr(impl, endpoint.name)
|
||||||
|
logger.debug(f"{endpoint.method.upper()} {endpoint.route}")
|
||||||
|
|
||||||
with warnings.catch_warnings():
|
with warnings.catch_warnings():
|
||||||
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields")
|
||||||
|
|
|
@ -4,14 +4,23 @@
|
||||||
# This source code is licensed under the terms described in the LICENSE file in
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
# the root directory of this source tree.
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import enum
|
||||||
|
import json
|
||||||
import uuid
|
import uuid
|
||||||
|
|
||||||
import streamlit as st
|
import streamlit as st
|
||||||
from llama_stack_client import Agent
|
from llama_stack_client import Agent
|
||||||
|
from llama_stack_client.lib.agents.react.agent import ReActAgent
|
||||||
|
from llama_stack_client.lib.agents.react.tool_parser import ReActOutput
|
||||||
|
|
||||||
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
from llama_stack.distribution.ui.modules.api import llama_stack_api
|
||||||
|
|
||||||
|
|
||||||
|
class AgentType(enum.Enum):
|
||||||
|
REGULAR = "Regular"
|
||||||
|
REACT = "ReAct"
|
||||||
|
|
||||||
|
|
||||||
def tool_chat_page():
|
def tool_chat_page():
|
||||||
st.title("🛠 Tools")
|
st.title("🛠 Tools")
|
||||||
|
|
||||||
|
@ -23,6 +32,7 @@ def tool_chat_page():
|
||||||
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
tool_groups_list = [tool_group.identifier for tool_group in tool_groups]
|
||||||
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")]
|
||||||
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")]
|
||||||
|
selected_vector_dbs = []
|
||||||
|
|
||||||
def reset_agent():
|
def reset_agent():
|
||||||
st.session_state.clear()
|
st.session_state.clear()
|
||||||
|
@ -66,25 +76,36 @@ def tool_chat_page():
|
||||||
|
|
||||||
toolgroup_selection.extend(mcp_selection)
|
toolgroup_selection.extend(mcp_selection)
|
||||||
|
|
||||||
active_tool_list = []
|
grouped_tools = {}
|
||||||
for toolgroup_id in toolgroup_selection:
|
total_tools = 0
|
||||||
active_tool_list.extend(
|
|
||||||
[
|
|
||||||
f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}"
|
|
||||||
for t in client.tools.list(toolgroup_id=toolgroup_id)
|
|
||||||
]
|
|
||||||
)
|
|
||||||
|
|
||||||
st.markdown(f"Active Tools: 🛠 {len(active_tool_list)}", help="List of currently active tools.")
|
for toolgroup_id in toolgroup_selection:
|
||||||
st.json(active_tool_list)
|
tools = client.tools.list(toolgroup_id=toolgroup_id)
|
||||||
|
grouped_tools[toolgroup_id] = [tool.identifier for tool in tools]
|
||||||
|
total_tools += len(tools)
|
||||||
|
|
||||||
|
st.markdown(f"Active Tools: 🛠 {total_tools}")
|
||||||
|
|
||||||
|
for group_id, tools in grouped_tools.items():
|
||||||
|
with st.expander(f"🔧 Tools from `{group_id}`"):
|
||||||
|
for idx, tool in enumerate(tools, start=1):
|
||||||
|
st.markdown(f"{idx}. `{tool.split(':')[-1]}`")
|
||||||
|
|
||||||
st.subheader("Agent Configurations")
|
st.subheader("Agent Configurations")
|
||||||
|
st.subheader("Agent Type")
|
||||||
|
agent_type = st.radio(
|
||||||
|
"Select Agent Type",
|
||||||
|
[AgentType.REGULAR, AgentType.REACT],
|
||||||
|
format_func=lambda x: x.value,
|
||||||
|
on_change=reset_agent,
|
||||||
|
)
|
||||||
|
|
||||||
max_tokens = st.slider(
|
max_tokens = st.slider(
|
||||||
"Max Tokens",
|
"Max Tokens",
|
||||||
min_value=0,
|
min_value=0,
|
||||||
max_value=4096,
|
max_value=4096,
|
||||||
value=512,
|
value=512,
|
||||||
step=1,
|
step=64,
|
||||||
help="The maximum number of tokens to generate",
|
help="The maximum number of tokens to generate",
|
||||||
on_change=reset_agent,
|
on_change=reset_agent,
|
||||||
)
|
)
|
||||||
|
@ -101,13 +122,27 @@ def tool_chat_page():
|
||||||
|
|
||||||
@st.cache_resource
|
@st.cache_resource
|
||||||
def create_agent():
|
def create_agent():
|
||||||
return Agent(
|
if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT:
|
||||||
client,
|
return ReActAgent(
|
||||||
model=model,
|
client=client,
|
||||||
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
model=model,
|
||||||
tools=toolgroup_selection,
|
tools=toolgroup_selection,
|
||||||
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
response_format={
|
||||||
)
|
"type": "json_schema",
|
||||||
|
"json_schema": ReActOutput.model_json_schema(),
|
||||||
|
},
|
||||||
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
return Agent(
|
||||||
|
client,
|
||||||
|
model=model,
|
||||||
|
instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.",
|
||||||
|
tools=toolgroup_selection,
|
||||||
|
sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens},
|
||||||
|
)
|
||||||
|
|
||||||
|
st.session_state.agent_type = agent_type
|
||||||
|
|
||||||
agent = create_agent()
|
agent = create_agent()
|
||||||
|
|
||||||
|
@ -136,6 +171,158 @@ def tool_chat_page():
|
||||||
)
|
)
|
||||||
|
|
||||||
def response_generator(turn_response):
|
def response_generator(turn_response):
|
||||||
|
if st.session_state.get("agent_type") == AgentType.REACT:
|
||||||
|
return _handle_react_response(turn_response)
|
||||||
|
else:
|
||||||
|
return _handle_regular_response(turn_response)
|
||||||
|
|
||||||
|
def _handle_react_response(turn_response):
|
||||||
|
current_step_content = ""
|
||||||
|
final_answer = None
|
||||||
|
tool_results = []
|
||||||
|
|
||||||
|
for response in turn_response:
|
||||||
|
if not hasattr(response.event, "payload"):
|
||||||
|
yield (
|
||||||
|
"\n\n🚨 :red[_Llama Stack server Error:_]\n"
|
||||||
|
"The response received is missing an expected `payload` attribute.\n"
|
||||||
|
"This could indicate a malformed response or an internal issue within the server.\n\n"
|
||||||
|
f"Error details: {response}"
|
||||||
|
)
|
||||||
|
return
|
||||||
|
|
||||||
|
payload = response.event.payload
|
||||||
|
|
||||||
|
if payload.event_type == "step_progress" and hasattr(payload.delta, "text"):
|
||||||
|
current_step_content += payload.delta.text
|
||||||
|
continue
|
||||||
|
|
||||||
|
if payload.event_type == "step_complete":
|
||||||
|
step_details = payload.step_details
|
||||||
|
|
||||||
|
if step_details.step_type == "inference":
|
||||||
|
yield from _process_inference_step(current_step_content, tool_results, final_answer)
|
||||||
|
current_step_content = ""
|
||||||
|
elif step_details.step_type == "tool_execution":
|
||||||
|
tool_results = _process_tool_execution(step_details, tool_results)
|
||||||
|
current_step_content = ""
|
||||||
|
else:
|
||||||
|
current_step_content = ""
|
||||||
|
|
||||||
|
if not final_answer and tool_results:
|
||||||
|
yield from _format_tool_results_summary(tool_results)
|
||||||
|
|
||||||
|
def _process_inference_step(current_step_content, tool_results, final_answer):
|
||||||
|
try:
|
||||||
|
react_output_data = json.loads(current_step_content)
|
||||||
|
thought = react_output_data.get("thought")
|
||||||
|
action = react_output_data.get("action")
|
||||||
|
answer = react_output_data.get("answer")
|
||||||
|
|
||||||
|
if answer and answer != "null" and answer is not None:
|
||||||
|
final_answer = answer
|
||||||
|
|
||||||
|
if thought:
|
||||||
|
with st.expander("🤔 Thinking...", expanded=False):
|
||||||
|
st.markdown(f":grey[__{thought}__]")
|
||||||
|
|
||||||
|
if action and isinstance(action, dict):
|
||||||
|
tool_name = action.get("tool_name")
|
||||||
|
tool_params = action.get("tool_params")
|
||||||
|
with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False):
|
||||||
|
st.json(tool_params)
|
||||||
|
|
||||||
|
if answer and answer != "null" and answer is not None:
|
||||||
|
yield f"\n\n✅ **Final Answer:**\n{answer}"
|
||||||
|
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```"
|
||||||
|
except Exception as e:
|
||||||
|
yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```"
|
||||||
|
|
||||||
|
return final_answer
|
||||||
|
|
||||||
|
def _process_tool_execution(step_details, tool_results):
|
||||||
|
try:
|
||||||
|
if hasattr(step_details, "tool_responses") and step_details.tool_responses:
|
||||||
|
for tool_response in step_details.tool_responses:
|
||||||
|
tool_name = tool_response.tool_name
|
||||||
|
content = tool_response.content
|
||||||
|
tool_results.append((tool_name, content))
|
||||||
|
with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False):
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
st.json(parsed_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
st.code(content, language=None)
|
||||||
|
else:
|
||||||
|
with st.expander("⚙️ Observation", expanded=False):
|
||||||
|
st.markdown(":grey[_Tool execution step completed, but no response data found._]")
|
||||||
|
except Exception as e:
|
||||||
|
with st.expander("⚙️ Error in Tool Execution", expanded=False):
|
||||||
|
st.markdown(f":red[_Error processing tool execution: {str(e)}_]")
|
||||||
|
|
||||||
|
return tool_results
|
||||||
|
|
||||||
|
def _format_tool_results_summary(tool_results):
|
||||||
|
yield "\n\n**Here's what I found:**\n"
|
||||||
|
for tool_name, content in tool_results:
|
||||||
|
try:
|
||||||
|
parsed_content = json.loads(content)
|
||||||
|
|
||||||
|
if tool_name == "web_search" and "top_k" in parsed_content:
|
||||||
|
yield from _format_web_search_results(parsed_content)
|
||||||
|
elif "results" in parsed_content and isinstance(parsed_content["results"], list):
|
||||||
|
yield from _format_results_list(parsed_content["results"])
|
||||||
|
elif isinstance(parsed_content, dict) and len(parsed_content) > 0:
|
||||||
|
yield from _format_dict_results(parsed_content)
|
||||||
|
elif isinstance(parsed_content, list) and len(parsed_content) > 0:
|
||||||
|
yield from _format_list_results(parsed_content)
|
||||||
|
except json.JSONDecodeError:
|
||||||
|
yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n"
|
||||||
|
except (TypeError, AttributeError, KeyError, IndexError) as e:
|
||||||
|
print(f"Error processing {tool_name} result: {type(e).__name__}: {e}")
|
||||||
|
|
||||||
|
def _format_web_search_results(parsed_content):
|
||||||
|
for i, result in enumerate(parsed_content["top_k"], 1):
|
||||||
|
if i <= 3:
|
||||||
|
title = result.get("title", "Untitled")
|
||||||
|
url = result.get("url", "")
|
||||||
|
content_text = result.get("content", "").strip()
|
||||||
|
yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n"
|
||||||
|
|
||||||
|
def _format_results_list(results):
|
||||||
|
for i, result in enumerate(results, 1):
|
||||||
|
if i <= 3:
|
||||||
|
if isinstance(result, dict):
|
||||||
|
name = result.get("name", result.get("title", "Result " + str(i)))
|
||||||
|
description = result.get("description", result.get("content", result.get("summary", "")))
|
||||||
|
yield f"\n- **{name}**\n {description}\n"
|
||||||
|
else:
|
||||||
|
yield f"\n- {result}\n"
|
||||||
|
|
||||||
|
def _format_dict_results(parsed_content):
|
||||||
|
yield "\n```\n"
|
||||||
|
for key, value in list(parsed_content.items())[:5]:
|
||||||
|
if isinstance(value, str) and len(value) < 100:
|
||||||
|
yield f"{key}: {value}\n"
|
||||||
|
else:
|
||||||
|
yield f"{key}: [Complex data]\n"
|
||||||
|
yield "```\n"
|
||||||
|
|
||||||
|
def _format_list_results(parsed_content):
|
||||||
|
yield "\n"
|
||||||
|
for _, item in enumerate(parsed_content[:3], 1):
|
||||||
|
if isinstance(item, str):
|
||||||
|
yield f"- {item}\n"
|
||||||
|
elif isinstance(item, dict) and "text" in item:
|
||||||
|
yield f"- {item['text']}\n"
|
||||||
|
elif isinstance(item, dict) and len(item) > 0:
|
||||||
|
first_value = next(iter(item.values()))
|
||||||
|
if isinstance(first_value, str) and len(first_value) < 100:
|
||||||
|
yield f"- {first_value}\n"
|
||||||
|
|
||||||
|
def _handle_regular_response(turn_response):
|
||||||
for response in turn_response:
|
for response in turn_response:
|
||||||
if hasattr(response.event, "payload"):
|
if hasattr(response.event, "payload"):
|
||||||
print(response.event.payload)
|
print(response.event.payload)
|
||||||
|
@ -153,9 +340,9 @@ def tool_chat_page():
|
||||||
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
yield f"Error occurred in the Llama Stack Cluster: {response}"
|
||||||
|
|
||||||
with st.chat_message("assistant"):
|
with st.chat_message("assistant"):
|
||||||
response = st.write_stream(response_generator(turn_response))
|
response_content = st.write_stream(response_generator(turn_response))
|
||||||
|
|
||||||
st.session_state.messages.append({"role": "assistant", "content": response})
|
st.session_state.messages.append({"role": "assistant", "content": response_content})
|
||||||
|
|
||||||
|
|
||||||
tool_chat_page()
|
tool_chat_page()
|
||||||
|
|
|
@ -0,0 +1,144 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# top-level folder for each specific model found within the models/ directory at
|
||||||
|
# the top-level of this source tree.
|
||||||
|
|
||||||
|
import textwrap
|
||||||
|
from typing import List, Optional
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition
|
||||||
|
from llama_stack.models.llama.llama3.prompt_templates.base import (
|
||||||
|
PromptTemplate,
|
||||||
|
PromptTemplateGeneratorBase,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801
|
||||||
|
DEFAULT_PROMPT = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines:
|
||||||
|
|
||||||
|
1. FUNCTION CALLS:
|
||||||
|
- ONLY use functions that are EXPLICITLY listed in the function list below
|
||||||
|
- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information"
|
||||||
|
- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s)
|
||||||
|
- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)]
|
||||||
|
Examples:
|
||||||
|
CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list
|
||||||
|
INCORRECT: get_weather(location="New York")
|
||||||
|
INCORRECT: Let me check the weather: [get_weather(location="New York")]
|
||||||
|
INCORRECT: [get_events(location="Singapore")] <- If function not in list
|
||||||
|
|
||||||
|
2. RESPONSE RULES:
|
||||||
|
- For pure function requests matching a listed function: ONLY output the function call(s)
|
||||||
|
- For knowledge questions: ONLY output text
|
||||||
|
- For missing parameters: ONLY request the specific missing parameters
|
||||||
|
- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call.
|
||||||
|
- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations
|
||||||
|
- NEVER combine text and function calls in the same response
|
||||||
|
- NEVER suggest alternative functions when the requested service is unavailable
|
||||||
|
- NEVER create or invent new functions not listed below
|
||||||
|
|
||||||
|
3. STRICT BOUNDARIES:
|
||||||
|
- ONLY use functions from the list below - no exceptions
|
||||||
|
- NEVER use a function as an alternative to unavailable information
|
||||||
|
- NEVER call functions not present in the function list
|
||||||
|
- NEVER add explanatory text to function calls
|
||||||
|
- NEVER respond with empty brackets
|
||||||
|
- Use proper Python/JSON syntax for function calls
|
||||||
|
- Check the function list carefully before responding
|
||||||
|
|
||||||
|
4. TOOL RESPONSE HANDLING:
|
||||||
|
- When receiving tool responses: provide concise, natural language responses
|
||||||
|
- Don't repeat tool response verbatim
|
||||||
|
- Don't add supplementary information
|
||||||
|
|
||||||
|
|
||||||
|
{{ function_description }}
|
||||||
|
""".strip("\n")
|
||||||
|
)
|
||||||
|
|
||||||
|
def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate:
|
||||||
|
system_prompt = system_prompt or self.DEFAULT_PROMPT
|
||||||
|
return PromptTemplate(
|
||||||
|
system_prompt,
|
||||||
|
{"function_description": self._gen_function_description(custom_tools)},
|
||||||
|
)
|
||||||
|
|
||||||
|
def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate:
|
||||||
|
template_str = textwrap.dedent(
|
||||||
|
"""
|
||||||
|
Here is a list of functions in JSON format that you can invoke.
|
||||||
|
|
||||||
|
[
|
||||||
|
{% for t in tools -%}
|
||||||
|
{# manually setting up JSON because jinja sorts keys in unexpected ways -#}
|
||||||
|
{%- set tname = t.tool_name -%}
|
||||||
|
{%- set tdesc = t.description -%}
|
||||||
|
{%- set tparams = t.parameters -%}
|
||||||
|
{%- set required_params = [] -%}
|
||||||
|
{%- for name, param in tparams.items() if param.required == true -%}
|
||||||
|
{%- set _ = required_params.append(name) -%}
|
||||||
|
{%- endfor -%}
|
||||||
|
{
|
||||||
|
"name": "{{tname}}",
|
||||||
|
"description": "{{tdesc}}",
|
||||||
|
"parameters": {
|
||||||
|
"type": "dict",
|
||||||
|
"required": {{ required_params | tojson }},
|
||||||
|
"properties": {
|
||||||
|
{%- for name, param in tparams.items() %}
|
||||||
|
"{{name}}": {
|
||||||
|
"type": "{{param.param_type}}",
|
||||||
|
"description": "{{param.description}}"{% if param.default %},
|
||||||
|
"default": "{{param.default}}"{% endif %}
|
||||||
|
}{% if not loop.last %},{% endif %}
|
||||||
|
{%- endfor %}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}{% if not loop.last %},
|
||||||
|
{% endif -%}
|
||||||
|
{%- endfor %}
|
||||||
|
]
|
||||||
|
|
||||||
|
You can answer general questions or invoke tools when necessary.
|
||||||
|
In addition to tool calls, you should also augment your responses by using the tool outputs.
|
||||||
|
|
||||||
|
"""
|
||||||
|
)
|
||||||
|
return PromptTemplate(
|
||||||
|
template_str.strip("\n"),
|
||||||
|
{"tools": [t.model_dump() for t in custom_tools]},
|
||||||
|
).render()
|
||||||
|
|
||||||
|
def data_examples(self) -> List[List[ToolDefinition]]:
|
||||||
|
return [
|
||||||
|
[
|
||||||
|
ToolDefinition(
|
||||||
|
tool_name="get_weather",
|
||||||
|
description="Get weather info for places",
|
||||||
|
parameters={
|
||||||
|
"city": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The name of the city to get the weather for",
|
||||||
|
required=True,
|
||||||
|
),
|
||||||
|
"metric": ToolParamDefinition(
|
||||||
|
param_type="string",
|
||||||
|
description="The metric for weather. Options are: celsius, fahrenheit",
|
||||||
|
required=False,
|
||||||
|
default="celsius",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
),
|
||||||
|
]
|
||||||
|
]
|
|
@ -6,7 +6,7 @@
|
||||||
|
|
||||||
from typing import List
|
from typing import List
|
||||||
|
|
||||||
from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec
|
from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec
|
||||||
|
|
||||||
|
|
||||||
def available_providers() -> List[ProviderSpec]:
|
def available_providers() -> List[ProviderSpec]:
|
||||||
|
@ -25,4 +25,22 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
Api.agents,
|
Api.agents,
|
||||||
],
|
],
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.eval,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="nvidia",
|
||||||
|
pip_packages=[
|
||||||
|
"requests",
|
||||||
|
],
|
||||||
|
module="llama_stack.providers.remote.eval.nvidia",
|
||||||
|
config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig",
|
||||||
|
),
|
||||||
|
api_dependencies=[
|
||||||
|
Api.datasetio,
|
||||||
|
Api.datasets,
|
||||||
|
Api.scoring,
|
||||||
|
Api.inference,
|
||||||
|
Api.agents,
|
||||||
|
],
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
|
@ -288,4 +288,14 @@ def available_providers() -> List[ProviderSpec]:
|
||||||
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator",
|
||||||
),
|
),
|
||||||
),
|
),
|
||||||
|
remote_provider_spec(
|
||||||
|
api=Api.inference,
|
||||||
|
adapter=AdapterSpec(
|
||||||
|
adapter_type="watsonx",
|
||||||
|
pip_packages=["ibm_watson_machine_learning"],
|
||||||
|
module="llama_stack.providers.remote.inference.watsonx",
|
||||||
|
config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig",
|
||||||
|
provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator",
|
||||||
|
),
|
||||||
|
),
|
||||||
]
|
]
|
||||||
|
|
5
llama_stack/providers/remote/eval/__init__.py
Normal file
5
llama_stack/providers/remote/eval/__init__.py
Normal file
|
@ -0,0 +1,5 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
134
llama_stack/providers/remote/eval/nvidia/README.md
Normal file
|
@ -0,0 +1,134 @@
|
||||||
|
# NVIDIA NeMo Evaluator Eval Provider
|
||||||
|
|
||||||
|
|
||||||
|
## Overview
|
||||||
|
|
||||||
|
For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used.
|
||||||
|
|
||||||
|
Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation.
|
||||||
|
|
||||||
|
### Example for register an academic benchmark
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /eval/benchmarks
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"benchmark_id": "mmlu",
|
||||||
|
"dataset_id": "",
|
||||||
|
"scoring_functions": [],
|
||||||
|
"metadata": {
|
||||||
|
"type": "mmlu"
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for register a custom evaluation
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /eval/benchmarks
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"benchmark_id": "my-custom-benchmark",
|
||||||
|
"dataset_id": "",
|
||||||
|
"scoring_functions": [],
|
||||||
|
"metadata": {
|
||||||
|
"type": "custom",
|
||||||
|
"params": {
|
||||||
|
"parallelism": 8
|
||||||
|
},
|
||||||
|
"tasks": {
|
||||||
|
"qa": {
|
||||||
|
"type": "completion",
|
||||||
|
"params": {
|
||||||
|
"template": {
|
||||||
|
"prompt": "{{prompt}}",
|
||||||
|
"max_tokens": 200
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"dataset": {
|
||||||
|
"files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl"
|
||||||
|
},
|
||||||
|
"metrics": {
|
||||||
|
"bleu": {
|
||||||
|
"type": "bleu",
|
||||||
|
"params": {
|
||||||
|
"references": [
|
||||||
|
"{{ideal_response}}"
|
||||||
|
]
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for triggering a benchmark/custom evaluation
|
||||||
|
|
||||||
|
```
|
||||||
|
POST /eval/benchmarks/{benchmark_id}/jobs
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"benchmark_id": "my-custom-benchmark",
|
||||||
|
"benchmark_config": {
|
||||||
|
"eval_candidate": {
|
||||||
|
"type": "model",
|
||||||
|
"model": "meta-llama/Llama3.1-8B-Instruct",
|
||||||
|
"sampling_params": {
|
||||||
|
"max_tokens": 100,
|
||||||
|
"temperature": 0.7
|
||||||
|
}
|
||||||
|
},
|
||||||
|
"scoring_params": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
Response example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"job_id": "eval-1234",
|
||||||
|
"status": "in_progress"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for getting the status of a job
|
||||||
|
```
|
||||||
|
GET /eval/benchmarks/{benchmark_id}/jobs/{job_id}
|
||||||
|
```
|
||||||
|
|
||||||
|
Response example:
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"job_id": "eval-1234",
|
||||||
|
"status": "in_progress"
|
||||||
|
}
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for cancelling a job
|
||||||
|
```
|
||||||
|
POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel
|
||||||
|
```
|
||||||
|
|
||||||
|
### Example for getting the results
|
||||||
|
```
|
||||||
|
GET /eval/benchmarks/{benchmark_id}/results
|
||||||
|
```
|
||||||
|
```json
|
||||||
|
{
|
||||||
|
"generations": [],
|
||||||
|
"scores": {
|
||||||
|
"{benchmark_id}": {
|
||||||
|
"score_rows": [],
|
||||||
|
"aggregated_results": {
|
||||||
|
"tasks": {},
|
||||||
|
"groups": {}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
}
|
||||||
|
```
|
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
31
llama_stack/providers/remote/eval/nvidia/__init__.py
Normal file
|
@ -0,0 +1,31 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Api
|
||||||
|
|
||||||
|
from .config import NVIDIAEvalConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(
|
||||||
|
config: NVIDIAEvalConfig,
|
||||||
|
deps: Dict[Api, Any],
|
||||||
|
):
|
||||||
|
from .eval import NVIDIAEvalImpl
|
||||||
|
|
||||||
|
impl = NVIDIAEvalImpl(
|
||||||
|
config,
|
||||||
|
deps[Api.datasetio],
|
||||||
|
deps[Api.datasets],
|
||||||
|
deps[Api.scoring],
|
||||||
|
deps[Api.inference],
|
||||||
|
deps[Api.agents],
|
||||||
|
)
|
||||||
|
await impl.initialize()
|
||||||
|
return impl
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"]
|
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
29
llama_stack/providers/remote/eval/nvidia/config.py
Normal file
|
@ -0,0 +1,29 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field
|
||||||
|
|
||||||
|
|
||||||
|
class NVIDIAEvalConfig(BaseModel):
|
||||||
|
"""
|
||||||
|
Configuration for the NVIDIA NeMo Evaluator microservice endpoint.
|
||||||
|
|
||||||
|
Attributes:
|
||||||
|
evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000.
|
||||||
|
"""
|
||||||
|
|
||||||
|
evaluator_url: str = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"),
|
||||||
|
description="The url for accessing the evaluator service",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}",
|
||||||
|
}
|
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
154
llama_stack/providers/remote/eval/nvidia/eval.py
Normal file
|
@ -0,0 +1,154 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
from typing import Any, Dict, List
|
||||||
|
|
||||||
|
import requests
|
||||||
|
|
||||||
|
from llama_stack.apis.agents import Agents
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
|
from llama_stack.apis.datasetio import DatasetIO
|
||||||
|
from llama_stack.apis.datasets import Datasets
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
from llama_stack.apis.scoring import Scoring, ScoringResult
|
||||||
|
from llama_stack.providers.datatypes import BenchmarksProtocolPrivate
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
|
||||||
|
from .....apis.common.job_types import Job, JobStatus
|
||||||
|
from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse
|
||||||
|
from .config import NVIDIAEvalConfig
|
||||||
|
|
||||||
|
DEFAULT_NAMESPACE = "nvidia"
|
||||||
|
|
||||||
|
|
||||||
|
class NVIDIAEvalImpl(
|
||||||
|
Eval,
|
||||||
|
BenchmarksProtocolPrivate,
|
||||||
|
ModelRegistryHelper,
|
||||||
|
):
|
||||||
|
def __init__(
|
||||||
|
self,
|
||||||
|
config: NVIDIAEvalConfig,
|
||||||
|
datasetio_api: DatasetIO,
|
||||||
|
datasets_api: Datasets,
|
||||||
|
scoring_api: Scoring,
|
||||||
|
inference_api: Inference,
|
||||||
|
agents_api: Agents,
|
||||||
|
) -> None:
|
||||||
|
self.config = config
|
||||||
|
self.datasetio_api = datasetio_api
|
||||||
|
self.datasets_api = datasets_api
|
||||||
|
self.scoring_api = scoring_api
|
||||||
|
self.inference_api = inference_api
|
||||||
|
self.agents_api = agents_api
|
||||||
|
|
||||||
|
ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES)
|
||||||
|
|
||||||
|
async def initialize(self) -> None: ...
|
||||||
|
|
||||||
|
async def shutdown(self) -> None: ...
|
||||||
|
|
||||||
|
async def _evaluator_get(self, path):
|
||||||
|
"""Helper for making GET requests to the evaluator service."""
|
||||||
|
response = requests.get(url=f"{self.config.evaluator_url}{path}")
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def _evaluator_post(self, path, data):
|
||||||
|
"""Helper for making POST requests to the evaluator service."""
|
||||||
|
response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data)
|
||||||
|
response.raise_for_status()
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
async def register_benchmark(self, task_def: Benchmark) -> None:
|
||||||
|
"""Register a benchmark as an evaluation configuration."""
|
||||||
|
await self._evaluator_post(
|
||||||
|
"/v1/evaluation/configs",
|
||||||
|
{
|
||||||
|
"namespace": DEFAULT_NAMESPACE,
|
||||||
|
"name": task_def.benchmark_id,
|
||||||
|
# metadata is copied to request body as-is
|
||||||
|
**task_def.metadata,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
async def run_eval(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> Job:
|
||||||
|
"""Run an evaluation job for a benchmark."""
|
||||||
|
model = (
|
||||||
|
benchmark_config.eval_candidate.model
|
||||||
|
if benchmark_config.eval_candidate.type == "model"
|
||||||
|
else benchmark_config.eval_candidate.config.model
|
||||||
|
)
|
||||||
|
nvidia_model = self.get_provider_model_id(model) or model
|
||||||
|
|
||||||
|
result = await self._evaluator_post(
|
||||||
|
"/v1/evaluation/jobs",
|
||||||
|
{
|
||||||
|
"config": f"{DEFAULT_NAMESPACE}/{benchmark_id}",
|
||||||
|
"target": {"type": "model", "model": nvidia_model},
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
return Job(job_id=result["id"], status=JobStatus.in_progress)
|
||||||
|
|
||||||
|
async def evaluate_rows(
|
||||||
|
self,
|
||||||
|
benchmark_id: str,
|
||||||
|
input_rows: List[Dict[str, Any]],
|
||||||
|
scoring_functions: List[str],
|
||||||
|
benchmark_config: BenchmarkConfig,
|
||||||
|
) -> EvaluateResponse:
|
||||||
|
raise NotImplementedError()
|
||||||
|
|
||||||
|
async def job_status(self, benchmark_id: str, job_id: str) -> Job:
|
||||||
|
"""Get the status of an evaluation job.
|
||||||
|
|
||||||
|
EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed".
|
||||||
|
JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed"
|
||||||
|
"""
|
||||||
|
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}")
|
||||||
|
result_status = result["status"]
|
||||||
|
|
||||||
|
job_status = JobStatus.failed
|
||||||
|
if result_status in ["created", "pending"]:
|
||||||
|
job_status = JobStatus.scheduled
|
||||||
|
elif result_status in ["running"]:
|
||||||
|
job_status = JobStatus.in_progress
|
||||||
|
elif result_status in ["completed"]:
|
||||||
|
job_status = JobStatus.completed
|
||||||
|
elif result_status in ["cancelled"]:
|
||||||
|
job_status = JobStatus.cancelled
|
||||||
|
|
||||||
|
return Job(job_id=job_id, status=job_status)
|
||||||
|
|
||||||
|
async def job_cancel(self, benchmark_id: str, job_id: str) -> None:
|
||||||
|
"""Cancel the evaluation job."""
|
||||||
|
await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {})
|
||||||
|
|
||||||
|
async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse:
|
||||||
|
"""Returns the results of the evaluation job."""
|
||||||
|
|
||||||
|
job = await self.job_status(benchmark_id, job_id)
|
||||||
|
status = job.status
|
||||||
|
if not status or status != JobStatus.completed:
|
||||||
|
raise ValueError(f"Job {job_id} not completed. Status: {status.value}")
|
||||||
|
|
||||||
|
result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results")
|
||||||
|
|
||||||
|
return EvaluateResponse(
|
||||||
|
# TODO: these are stored in detailed results on NeMo Evaluator side; can be added
|
||||||
|
generations=[],
|
||||||
|
scores={
|
||||||
|
benchmark_id: ScoringResult(
|
||||||
|
score_rows=[],
|
||||||
|
aggregated_results=result,
|
||||||
|
)
|
||||||
|
},
|
||||||
|
)
|
|
@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel):
|
||||||
default=60,
|
default=60,
|
||||||
description="Timeout for the HTTP requests",
|
description="Timeout for the HTTP requests",
|
||||||
)
|
)
|
||||||
|
append_api_version: bool = Field(
|
||||||
|
default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false",
|
||||||
|
description="When set to false, the API version will not be appended to the base_url. By default, it is true.",
|
||||||
|
)
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
"url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}",
|
||||||
"api_key": "${env.NVIDIA_API_KEY:}",
|
"api_key": "${env.NVIDIA_API_KEY:}",
|
||||||
|
"append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}",
|
||||||
}
|
}
|
||||||
|
|
|
@ -33,7 +33,6 @@ from llama_stack.apis.inference import (
|
||||||
TextTruncation,
|
TextTruncation,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
ToolDefinition,
|
|
||||||
)
|
)
|
||||||
from llama_stack.apis.inference.inference import (
|
from llama_stack.apis.inference.inference import (
|
||||||
OpenAIChatCompletion,
|
OpenAIChatCompletion,
|
||||||
|
@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import (
|
||||||
OpenAIMessageParam,
|
OpenAIMessageParam,
|
||||||
OpenAIResponseFormatParam,
|
OpenAIResponseFormatParam,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.datatypes import ToolPromptFormat
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat
|
||||||
|
from llama_stack.providers.utils.inference import (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR,
|
||||||
|
)
|
||||||
from llama_stack.providers.utils.inference.model_registry import (
|
from llama_stack.providers.utils.inference.model_registry import (
|
||||||
ModelRegistryHelper,
|
ModelRegistryHelper,
|
||||||
)
|
)
|
||||||
|
@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
"meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct",
|
||||||
}
|
}
|
||||||
|
|
||||||
base_url = f"{self._config.url}/v1"
|
base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url
|
||||||
|
|
||||||
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls:
|
||||||
base_url = special_model_urls[provider_model_id]
|
base_url = special_model_urls[provider_model_id]
|
||||||
|
|
||||||
return _get_client_for_base_url(base_url)
|
return _get_client_for_base_url(base_url)
|
||||||
|
|
||||||
async def _get_provider_model_id(self, model_id: str) -> str:
|
async def _get_provider_model_id(self, model_id: str) -> str:
|
||||||
|
@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
return await self._get_client(provider_model_id).chat.completions.create(**params)
|
||||||
except APIConnectionError as e:
|
except APIConnectionError as e:
|
||||||
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e
|
||||||
|
|
||||||
|
async def register_model(self, model: Model) -> Model:
|
||||||
|
"""
|
||||||
|
Allow non-llama model registration.
|
||||||
|
|
||||||
|
Non-llama model registration: API Catalogue models, post-training models, etc.
|
||||||
|
client = LlamaStackAsLibraryClient("nvidia")
|
||||||
|
client.models.register(
|
||||||
|
model_id="mistralai/mixtral-8x7b-instruct-v0.1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1"
|
||||||
|
)
|
||||||
|
|
||||||
|
NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format.
|
||||||
|
"""
|
||||||
|
if model.model_type == ModelType.embedding:
|
||||||
|
# embedding models are always registered by their provider model id and does not need to be mapped to a llama model
|
||||||
|
provider_resource_id = model.provider_resource_id
|
||||||
|
else:
|
||||||
|
provider_resource_id = self.get_provider_model_id(model.provider_resource_id)
|
||||||
|
|
||||||
|
if provider_resource_id:
|
||||||
|
model.provider_resource_id = provider_resource_id
|
||||||
|
else:
|
||||||
|
llama_model = model.metadata.get("llama_model")
|
||||||
|
existing_llama_model = self.get_llama_model(model.provider_resource_id)
|
||||||
|
if existing_llama_model:
|
||||||
|
if existing_llama_model != llama_model:
|
||||||
|
raise ValueError(
|
||||||
|
f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
# not llama model
|
||||||
|
if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR:
|
||||||
|
self.provider_id_to_llama_model_map[model.provider_resource_id] = (
|
||||||
|
ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model]
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id
|
||||||
|
return model
|
||||||
|
|
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
22
llama_stack/providers/remote/inference/watsonx/__init__.py
Normal file
|
@ -0,0 +1,22 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.apis.inference import Inference
|
||||||
|
|
||||||
|
from .config import WatsonXConfig
|
||||||
|
|
||||||
|
|
||||||
|
async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference:
|
||||||
|
# import dynamically so `llama stack build` does not fail due to missing dependencies
|
||||||
|
from .watsonx import WatsonXInferenceAdapter
|
||||||
|
|
||||||
|
if not isinstance(config, WatsonXConfig):
|
||||||
|
raise RuntimeError(f"Unexpected config type: {type(config)}")
|
||||||
|
adapter = WatsonXInferenceAdapter(config)
|
||||||
|
return adapter
|
||||||
|
|
||||||
|
|
||||||
|
__all__ = ["get_adapter_impl", "WatsonXConfig"]
|
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
46
llama_stack/providers/remote/inference/watsonx/config.py
Normal file
|
@ -0,0 +1,46 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
from pydantic import BaseModel, Field, SecretStr
|
||||||
|
|
||||||
|
from llama_stack.schema_utils import json_schema_type
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXProviderDataValidator(BaseModel):
|
||||||
|
url: str
|
||||||
|
api_key: str
|
||||||
|
project_id: str
|
||||||
|
|
||||||
|
|
||||||
|
@json_schema_type
|
||||||
|
class WatsonXConfig(BaseModel):
|
||||||
|
url: str = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"),
|
||||||
|
description="A base url for accessing the watsonx.ai",
|
||||||
|
)
|
||||||
|
api_key: Optional[SecretStr] = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_API_KEY"),
|
||||||
|
description="The watsonx API key, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
project_id: Optional[str] = Field(
|
||||||
|
default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"),
|
||||||
|
description="The Project ID key, only needed of using the hosted service",
|
||||||
|
)
|
||||||
|
timeout: int = Field(
|
||||||
|
default=60,
|
||||||
|
description="Timeout for the HTTP requests",
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def sample_run_config(cls, **kwargs) -> Dict[str, Any]:
|
||||||
|
return {
|
||||||
|
"url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}",
|
||||||
|
"api_key": "${env.WATSONX_API_KEY:}",
|
||||||
|
"project_id": "${env.WATSONX_PROJECT_ID:}",
|
||||||
|
}
|
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
47
llama_stack/providers/remote/inference/watsonx/models.py
Normal file
|
@ -0,0 +1,47 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry
|
||||||
|
|
||||||
|
MODEL_ENTRIES = [
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-3-70b-instruct",
|
||||||
|
CoreModelId.llama3_3_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-2-13b-chat",
|
||||||
|
CoreModelId.llama2_13b.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-1-70b-instruct",
|
||||||
|
CoreModelId.llama3_1_70b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-1-8b-instruct",
|
||||||
|
CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-11b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_11b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-1b-instruct",
|
||||||
|
CoreModelId.llama3_2_1b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-3b-instruct",
|
||||||
|
CoreModelId.llama3_2_3b_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-3-2-90b-vision-instruct",
|
||||||
|
CoreModelId.llama3_2_90b_vision_instruct.value,
|
||||||
|
),
|
||||||
|
build_hf_repo_model_entry(
|
||||||
|
"meta-llama/llama-guard-3-11b-vision",
|
||||||
|
CoreModelId.llama_guard_3_11b_vision.value,
|
||||||
|
),
|
||||||
|
]
|
260
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
260
llama_stack/providers/remote/inference/watsonx/watsonx.py
Normal file
|
@ -0,0 +1,260 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from typing import AsyncGenerator, List, Optional, Union
|
||||||
|
|
||||||
|
from ibm_watson_machine_learning.foundation_models import Model
|
||||||
|
from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem
|
||||||
|
from llama_stack.apis.inference import (
|
||||||
|
ChatCompletionRequest,
|
||||||
|
ChatCompletionResponse,
|
||||||
|
CompletionRequest,
|
||||||
|
EmbeddingsResponse,
|
||||||
|
EmbeddingTaskType,
|
||||||
|
Inference,
|
||||||
|
LogProbConfig,
|
||||||
|
Message,
|
||||||
|
ResponseFormat,
|
||||||
|
SamplingParams,
|
||||||
|
TextTruncation,
|
||||||
|
ToolChoice,
|
||||||
|
ToolConfig,
|
||||||
|
ToolDefinition,
|
||||||
|
ToolPromptFormat,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import (
|
||||||
|
OpenAICompatCompletionChoice,
|
||||||
|
OpenAICompatCompletionResponse,
|
||||||
|
process_chat_completion_response,
|
||||||
|
process_chat_completion_stream_response,
|
||||||
|
process_completion_response,
|
||||||
|
process_completion_stream_response,
|
||||||
|
)
|
||||||
|
from llama_stack.providers.utils.inference.prompt_adapter import (
|
||||||
|
chat_completion_request_to_prompt,
|
||||||
|
completion_request_to_prompt,
|
||||||
|
request_has_media,
|
||||||
|
)
|
||||||
|
|
||||||
|
from . import WatsonXConfig
|
||||||
|
from .models import MODEL_ENTRIES
|
||||||
|
|
||||||
|
|
||||||
|
class WatsonXInferenceAdapter(Inference, ModelRegistryHelper):
|
||||||
|
def __init__(self, config: WatsonXConfig) -> None:
|
||||||
|
ModelRegistryHelper.__init__(self, MODEL_ENTRIES)
|
||||||
|
|
||||||
|
print(f"Initializing watsonx InferenceAdapter({config.url})...")
|
||||||
|
|
||||||
|
self._config = config
|
||||||
|
|
||||||
|
self._project_id = self._config.project_id
|
||||||
|
|
||||||
|
async def initialize(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def shutdown(self) -> None:
|
||||||
|
pass
|
||||||
|
|
||||||
|
async def completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
content: InterleavedContent,
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = CompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
content=content,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
)
|
||||||
|
if stream:
|
||||||
|
return self._stream_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_completion(request)
|
||||||
|
|
||||||
|
def _get_client(self, model_id) -> Model:
|
||||||
|
config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None
|
||||||
|
config_url = self._config.url
|
||||||
|
project_id = self._config.project_id
|
||||||
|
credentials = {"url": config_url, "apikey": config_api_key}
|
||||||
|
|
||||||
|
return Model(model_id=model_id, credentials=credentials, project_id=project_id)
|
||||||
|
|
||||||
|
async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
r = self._get_client(request.model).generate(**params)
|
||||||
|
choices = []
|
||||||
|
if "results" in r:
|
||||||
|
for result in r["results"]:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||||
|
text=result["generated_text"],
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=choices,
|
||||||
|
)
|
||||||
|
return process_completion_response(response)
|
||||||
|
|
||||||
|
async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
|
||||||
|
async def _generate_and_convert_to_openai_compat():
|
||||||
|
s = self._get_client(request.model).generate_text_stream(**params)
|
||||||
|
for chunk in s:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
text=chunk,
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _generate_and_convert_to_openai_compat()
|
||||||
|
async for chunk in process_completion_stream_response(stream):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def chat_completion(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
messages: List[Message],
|
||||||
|
sampling_params: Optional[SamplingParams] = None,
|
||||||
|
tools: Optional[List[ToolDefinition]] = None,
|
||||||
|
tool_choice: Optional[ToolChoice] = ToolChoice.auto,
|
||||||
|
tool_prompt_format: Optional[ToolPromptFormat] = None,
|
||||||
|
response_format: Optional[ResponseFormat] = None,
|
||||||
|
stream: Optional[bool] = False,
|
||||||
|
logprobs: Optional[LogProbConfig] = None,
|
||||||
|
tool_config: Optional[ToolConfig] = None,
|
||||||
|
) -> AsyncGenerator:
|
||||||
|
if sampling_params is None:
|
||||||
|
sampling_params = SamplingParams()
|
||||||
|
model = await self.model_store.get_model(model_id)
|
||||||
|
request = ChatCompletionRequest(
|
||||||
|
model=model.provider_resource_id,
|
||||||
|
messages=messages,
|
||||||
|
sampling_params=sampling_params,
|
||||||
|
tools=tools or [],
|
||||||
|
response_format=response_format,
|
||||||
|
stream=stream,
|
||||||
|
logprobs=logprobs,
|
||||||
|
tool_config=tool_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
if stream:
|
||||||
|
return self._stream_chat_completion(request)
|
||||||
|
else:
|
||||||
|
return await self._nonstream_chat_completion(request)
|
||||||
|
|
||||||
|
async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
r = self._get_client(request.model).generate(**params)
|
||||||
|
choices = []
|
||||||
|
if "results" in r:
|
||||||
|
for result in r["results"]:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=result["stop_reason"] if result["stop_reason"] else None,
|
||||||
|
text=result["generated_text"],
|
||||||
|
)
|
||||||
|
choices.append(choice)
|
||||||
|
response = OpenAICompatCompletionResponse(
|
||||||
|
choices=choices,
|
||||||
|
)
|
||||||
|
return process_chat_completion_response(response, request)
|
||||||
|
|
||||||
|
async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator:
|
||||||
|
params = await self._get_params(request)
|
||||||
|
model_id = request.model
|
||||||
|
|
||||||
|
# if we shift to TogetherAsyncClient, we won't need this wrapper
|
||||||
|
async def _to_async_generator():
|
||||||
|
s = self._get_client(model_id).generate_text_stream(**params)
|
||||||
|
for chunk in s:
|
||||||
|
choice = OpenAICompatCompletionChoice(
|
||||||
|
finish_reason=None,
|
||||||
|
text=chunk,
|
||||||
|
)
|
||||||
|
yield OpenAICompatCompletionResponse(
|
||||||
|
choices=[choice],
|
||||||
|
)
|
||||||
|
|
||||||
|
stream = _to_async_generator()
|
||||||
|
async for chunk in process_chat_completion_stream_response(stream, request):
|
||||||
|
yield chunk
|
||||||
|
|
||||||
|
async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict:
|
||||||
|
input_dict = {"params": {}}
|
||||||
|
media_present = request_has_media(request)
|
||||||
|
llama_model = self.get_llama_model(request.model)
|
||||||
|
if isinstance(request, ChatCompletionRequest):
|
||||||
|
input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model)
|
||||||
|
else:
|
||||||
|
assert not media_present, "Together does not support media for Completion requests"
|
||||||
|
input_dict["prompt"] = await completion_request_to_prompt(request)
|
||||||
|
if request.sampling_params:
|
||||||
|
if request.sampling_params.strategy:
|
||||||
|
input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type
|
||||||
|
if request.sampling_params.max_tokens:
|
||||||
|
input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens
|
||||||
|
if request.sampling_params.repetition_penalty:
|
||||||
|
input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty
|
||||||
|
if request.sampling_params.additional_params.get("top_p"):
|
||||||
|
input_dict["params"][GenParams.TOP_P] = request.sampling_params.additional_params["top_p"]
|
||||||
|
if request.sampling_params.additional_params.get("top_k"):
|
||||||
|
input_dict["params"][GenParams.TOP_K] = request.sampling_params.additional_params["top_k"]
|
||||||
|
if request.sampling_params.additional_params.get("temperature"):
|
||||||
|
input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.additional_params["temperature"]
|
||||||
|
if request.sampling_params.additional_params.get("length_penalty"):
|
||||||
|
input_dict["params"][GenParams.LENGTH_PENALTY] = request.sampling_params.additional_params[
|
||||||
|
"length_penalty"
|
||||||
|
]
|
||||||
|
if request.sampling_params.additional_params.get("random_seed"):
|
||||||
|
input_dict["params"][GenParams.RANDOM_SEED] = request.sampling_params.additional_params["random_seed"]
|
||||||
|
if request.sampling_params.additional_params.get("min_new_tokens"):
|
||||||
|
input_dict["params"][GenParams.MIN_NEW_TOKENS] = request.sampling_params.additional_params[
|
||||||
|
"min_new_tokens"
|
||||||
|
]
|
||||||
|
if request.sampling_params.additional_params.get("stop_sequences"):
|
||||||
|
input_dict["params"][GenParams.STOP_SEQUENCES] = request.sampling_params.additional_params[
|
||||||
|
"stop_sequences"
|
||||||
|
]
|
||||||
|
if request.sampling_params.additional_params.get("time_limit"):
|
||||||
|
input_dict["params"][GenParams.TIME_LIMIT] = request.sampling_params.additional_params["time_limit"]
|
||||||
|
if request.sampling_params.additional_params.get("truncate_input_tokens"):
|
||||||
|
input_dict["params"][GenParams.TRUNCATE_INPUT_TOKENS] = request.sampling_params.additional_params[
|
||||||
|
"truncate_input_tokens"
|
||||||
|
]
|
||||||
|
if request.sampling_params.additional_params.get("return_options"):
|
||||||
|
input_dict["params"][GenParams.RETURN_OPTIONS] = request.sampling_params.additional_params[
|
||||||
|
"return_options"
|
||||||
|
]
|
||||||
|
|
||||||
|
params = {
|
||||||
|
**input_dict,
|
||||||
|
}
|
||||||
|
return params
|
||||||
|
|
||||||
|
async def embeddings(
|
||||||
|
self,
|
||||||
|
model_id: str,
|
||||||
|
contents: List[str] | List[InterleavedContentItem],
|
||||||
|
text_truncation: Optional[TextTruncation] = TextTruncation.none,
|
||||||
|
output_dimension: Optional[int] = None,
|
||||||
|
task_type: Optional[EmbeddingTaskType] = None,
|
||||||
|
) -> EmbeddingsResponse:
|
||||||
|
pass
|
|
@ -36,7 +36,6 @@ import os
|
||||||
|
|
||||||
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
os.environ["NVIDIA_API_KEY"] = "your-api-key"
|
||||||
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
|
||||||
os.environ["NVIDIA_USER_ID"] = "llama-stack-user"
|
|
||||||
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
os.environ["NVIDIA_DATASET_NAMESPACE"] = "default"
|
||||||
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
os.environ["NVIDIA_PROJECT_ID"] = "test-project"
|
||||||
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1"
|
||||||
|
@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id")
|
||||||
|
|
||||||
### Inference with the fine-tuned model
|
### Inference with the fine-tuned model
|
||||||
|
|
||||||
|
#### 1. Register the model
|
||||||
|
|
||||||
|
```python
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
|
|
||||||
|
client.models.register(
|
||||||
|
model_id="test-example-model@v1",
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id="test-example-model@v1",
|
||||||
|
model_type=ModelType.llm,
|
||||||
|
)
|
||||||
|
```
|
||||||
|
|
||||||
|
#### 2. Inference with the fine-tuned model
|
||||||
|
|
||||||
```python
|
```python
|
||||||
response = client.inference.completion(
|
response = client.inference.completion(
|
||||||
content="Complete the sentence using one word: Roses are red, violets are ",
|
content="Complete the sentence using one word: Roses are red, violets are ",
|
||||||
|
|
|
@ -524,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals
|
||||||
else:
|
else:
|
||||||
content = [await _convert_content(message.content)]
|
content = [await _convert_content(message.content)]
|
||||||
|
|
||||||
return {
|
result = {
|
||||||
"role": message.role,
|
"role": message.role,
|
||||||
"content": content,
|
"content": content,
|
||||||
}
|
}
|
||||||
|
|
||||||
|
if hasattr(message, "tool_calls") and message.tool_calls:
|
||||||
|
result["tool_calls"] = []
|
||||||
|
for tc in message.tool_calls:
|
||||||
|
result["tool_calls"].append(
|
||||||
|
{
|
||||||
|
"id": tc.call_id,
|
||||||
|
"type": "function",
|
||||||
|
"function": {
|
||||||
|
"name": tc.tool_name,
|
||||||
|
"arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments),
|
||||||
|
},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
return result
|
||||||
|
|
||||||
|
|
||||||
class UnparseableToolCall(BaseModel):
|
class UnparseableToolCall(BaseModel):
|
||||||
"""
|
"""
|
||||||
|
|
|
@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import (
|
||||||
SystemDefaultGenerator,
|
SystemDefaultGenerator,
|
||||||
)
|
)
|
||||||
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
from llama_stack.models.llama.llama3.tokenizer import Tokenizer
|
||||||
|
from llama_stack.models.llama.llama4.prompt_templates.system_prompts import (
|
||||||
|
PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4,
|
||||||
|
)
|
||||||
from llama_stack.models.llama.sku_list import resolve_model
|
from llama_stack.models.llama.sku_list import resolve_model
|
||||||
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal
|
||||||
from llama_stack.providers.utils.inference import supported_inference_models
|
from llama_stack.providers.utils.inference import supported_inference_models
|
||||||
|
@ -306,10 +309,11 @@ def chat_completion_request_to_messages(
|
||||||
elif model.model_family in (
|
elif model.model_family in (
|
||||||
ModelFamily.llama3_2,
|
ModelFamily.llama3_2,
|
||||||
ModelFamily.llama3_3,
|
ModelFamily.llama3_3,
|
||||||
ModelFamily.llama4,
|
|
||||||
):
|
):
|
||||||
# llama3.2, llama3.3 and llama4 models follow the same tool prompt format
|
# llama3.2, llama3.3 follow the same tool prompt format
|
||||||
messages = augment_messages_for_tools_llama_3_2(request)
|
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator)
|
||||||
|
elif model.model_family == ModelFamily.llama4:
|
||||||
|
messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4)
|
||||||
else:
|
else:
|
||||||
messages = request.messages
|
messages = request.messages
|
||||||
|
|
||||||
|
@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1(
|
||||||
return messages
|
return messages
|
||||||
|
|
||||||
|
|
||||||
def augment_messages_for_tools_llama_3_2(
|
def augment_messages_for_tools_llama(
|
||||||
request: ChatCompletionRequest,
|
request: ChatCompletionRequest,
|
||||||
|
custom_tool_prompt_generator,
|
||||||
) -> List[Message]:
|
) -> List[Message]:
|
||||||
existing_messages = request.messages
|
existing_messages = request.messages
|
||||||
existing_system_message = None
|
existing_system_message = None
|
||||||
|
@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2(
|
||||||
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace:
|
||||||
system_prompt = existing_system_message.content
|
system_prompt = existing_system_message.content
|
||||||
|
|
||||||
tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt)
|
tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt)
|
||||||
|
|
||||||
sys_content += tool_template.render()
|
sys_content += tool_template.render()
|
||||||
sys_content += "\n"
|
sys_content += "\n"
|
||||||
|
|
|
@ -394,12 +394,10 @@
|
||||||
"aiosqlite",
|
"aiosqlite",
|
||||||
"blobfile",
|
"blobfile",
|
||||||
"chardet",
|
"chardet",
|
||||||
"emoji",
|
|
||||||
"faiss-cpu",
|
"faiss-cpu",
|
||||||
"fastapi",
|
"fastapi",
|
||||||
"fire",
|
"fire",
|
||||||
"httpx",
|
"httpx",
|
||||||
"langdetect",
|
|
||||||
"matplotlib",
|
"matplotlib",
|
||||||
"nltk",
|
"nltk",
|
||||||
"numpy",
|
"numpy",
|
||||||
|
@ -411,7 +409,6 @@
|
||||||
"psycopg2-binary",
|
"psycopg2-binary",
|
||||||
"pymongo",
|
"pymongo",
|
||||||
"pypdf",
|
"pypdf",
|
||||||
"pythainlp",
|
|
||||||
"redis",
|
"redis",
|
||||||
"requests",
|
"requests",
|
||||||
"scikit-learn",
|
"scikit-learn",
|
||||||
|
@ -419,7 +416,6 @@
|
||||||
"sentencepiece",
|
"sentencepiece",
|
||||||
"tqdm",
|
"tqdm",
|
||||||
"transformers",
|
"transformers",
|
||||||
"tree_sitter",
|
|
||||||
"uvicorn"
|
"uvicorn"
|
||||||
],
|
],
|
||||||
"ollama": [
|
"ollama": [
|
||||||
|
@ -759,5 +755,41 @@
|
||||||
"vllm",
|
"vllm",
|
||||||
"sentence-transformers --no-deps",
|
"sentence-transformers --no-deps",
|
||||||
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
"torch torchvision --index-url https://download.pytorch.org/whl/cpu"
|
||||||
|
],
|
||||||
|
"watsonx": [
|
||||||
|
"aiosqlite",
|
||||||
|
"autoevals",
|
||||||
|
"blobfile",
|
||||||
|
"chardet",
|
||||||
|
"datasets",
|
||||||
|
"emoji",
|
||||||
|
"faiss-cpu",
|
||||||
|
"fastapi",
|
||||||
|
"fire",
|
||||||
|
"httpx",
|
||||||
|
"ibm_watson_machine_learning",
|
||||||
|
"langdetect",
|
||||||
|
"matplotlib",
|
||||||
|
"mcp",
|
||||||
|
"nltk",
|
||||||
|
"numpy",
|
||||||
|
"openai",
|
||||||
|
"opentelemetry-exporter-otlp-proto-http",
|
||||||
|
"opentelemetry-sdk",
|
||||||
|
"pandas",
|
||||||
|
"pillow",
|
||||||
|
"psycopg2-binary",
|
||||||
|
"pymongo",
|
||||||
|
"pypdf",
|
||||||
|
"pythainlp",
|
||||||
|
"redis",
|
||||||
|
"requests",
|
||||||
|
"scikit-learn",
|
||||||
|
"scipy",
|
||||||
|
"sentencepiece",
|
||||||
|
"tqdm",
|
||||||
|
"transformers",
|
||||||
|
"tree_sitter",
|
||||||
|
"uvicorn"
|
||||||
]
|
]
|
||||||
}
|
}
|
||||||
|
|
|
@ -69,6 +69,7 @@ LLAMA_STACK_PORT=8321
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
|
@ -82,6 +83,7 @@ If you are using Llama Stack Safety / Shield APIs, use:
|
||||||
docker run \
|
docker run \
|
||||||
-it \
|
-it \
|
||||||
--pull always \
|
--pull always \
|
||||||
|
--gpu all \
|
||||||
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
-v ~/.llama:/root/.llama \
|
-v ~/.llama:/root/.llama \
|
||||||
llamastack/distribution-{{ name }} \
|
llamastack/distribution-{{ name }} \
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
version: '2'
|
version: '2'
|
||||||
distribution_spec:
|
distribution_spec:
|
||||||
description: Use NVIDIA NIM for running LLM inference and safety
|
description: Use NVIDIA NIM for running LLM inference, evaluation and safety
|
||||||
providers:
|
providers:
|
||||||
inference:
|
inference:
|
||||||
- remote::nvidia
|
- remote::nvidia
|
||||||
|
@ -13,7 +13,7 @@ distribution_spec:
|
||||||
telemetry:
|
telemetry:
|
||||||
- inline::meta-reference
|
- inline::meta-reference
|
||||||
eval:
|
eval:
|
||||||
- inline::meta-reference
|
- remote::nvidia
|
||||||
post_training:
|
post_training:
|
||||||
- remote::nvidia
|
- remote::nvidia
|
||||||
datasetio:
|
datasetio:
|
||||||
|
|
|
@ -7,6 +7,7 @@
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
|
|
||||||
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
|
from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput
|
||||||
|
from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig
|
||||||
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES
|
||||||
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig
|
||||||
|
@ -20,7 +21,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"safety": ["remote::nvidia"],
|
"safety": ["remote::nvidia"],
|
||||||
"agents": ["inline::meta-reference"],
|
"agents": ["inline::meta-reference"],
|
||||||
"telemetry": ["inline::meta-reference"],
|
"telemetry": ["inline::meta-reference"],
|
||||||
"eval": ["inline::meta-reference"],
|
"eval": ["remote::nvidia"],
|
||||||
"post_training": ["remote::nvidia"],
|
"post_training": ["remote::nvidia"],
|
||||||
"datasetio": ["inline::localfs"],
|
"datasetio": ["inline::localfs"],
|
||||||
"scoring": ["inline::basic"],
|
"scoring": ["inline::basic"],
|
||||||
|
@ -37,6 +38,11 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
provider_type="remote::nvidia",
|
provider_type="remote::nvidia",
|
||||||
config=NVIDIASafetyConfig.sample_run_config(),
|
config=NVIDIASafetyConfig.sample_run_config(),
|
||||||
)
|
)
|
||||||
|
eval_provider = Provider(
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_type="remote::nvidia",
|
||||||
|
config=NVIDIAEvalConfig.sample_run_config(),
|
||||||
|
)
|
||||||
inference_model = ModelInput(
|
inference_model = ModelInput(
|
||||||
model_id="${env.INFERENCE_MODEL}",
|
model_id="${env.INFERENCE_MODEL}",
|
||||||
provider_id="nvidia",
|
provider_id="nvidia",
|
||||||
|
@ -60,7 +66,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
return DistributionTemplate(
|
return DistributionTemplate(
|
||||||
name="nvidia",
|
name="nvidia",
|
||||||
distro_type="self_hosted",
|
distro_type="self_hosted",
|
||||||
description="Use NVIDIA NIM for running LLM inference and safety",
|
description="Use NVIDIA NIM for running LLM inference, evaluation and safety",
|
||||||
container_image=None,
|
container_image=None,
|
||||||
template_path=Path(__file__).parent / "doc_template.md",
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
providers=providers,
|
providers=providers,
|
||||||
|
@ -69,6 +75,7 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"run.yaml": RunConfigSettings(
|
"run.yaml": RunConfigSettings(
|
||||||
provider_overrides={
|
provider_overrides={
|
||||||
"inference": [inference_provider],
|
"inference": [inference_provider],
|
||||||
|
"eval": [eval_provider],
|
||||||
},
|
},
|
||||||
default_models=default_models,
|
default_models=default_models,
|
||||||
default_tool_groups=default_tool_groups,
|
default_tool_groups=default_tool_groups,
|
||||||
|
@ -78,7 +85,8 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"inference": [
|
"inference": [
|
||||||
inference_provider,
|
inference_provider,
|
||||||
safety_provider,
|
safety_provider,
|
||||||
]
|
],
|
||||||
|
"eval": [eval_provider],
|
||||||
},
|
},
|
||||||
default_models=[inference_model, safety_model],
|
default_models=[inference_model, safety_model],
|
||||||
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")],
|
||||||
|
@ -90,19 +98,15 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"",
|
"",
|
||||||
"NVIDIA API Key",
|
"NVIDIA API Key",
|
||||||
),
|
),
|
||||||
## Nemo Customizer related variables
|
"NVIDIA_APPEND_API_VERSION": (
|
||||||
"NVIDIA_USER_ID": (
|
"True",
|
||||||
"llama-stack-user",
|
"Whether to append the API version to the base_url",
|
||||||
"NVIDIA User ID",
|
|
||||||
),
|
),
|
||||||
|
## Nemo Customizer related variables
|
||||||
"NVIDIA_DATASET_NAMESPACE": (
|
"NVIDIA_DATASET_NAMESPACE": (
|
||||||
"default",
|
"default",
|
||||||
"NVIDIA Dataset Namespace",
|
"NVIDIA Dataset Namespace",
|
||||||
),
|
),
|
||||||
"NVIDIA_ACCESS_POLICIES": (
|
|
||||||
"{}",
|
|
||||||
"NVIDIA Access Policies",
|
|
||||||
),
|
|
||||||
"NVIDIA_PROJECT_ID": (
|
"NVIDIA_PROJECT_ID": (
|
||||||
"test-project",
|
"test-project",
|
||||||
"NVIDIA Project ID",
|
"NVIDIA Project ID",
|
||||||
|
@ -119,6 +123,10 @@ def get_distribution_template() -> DistributionTemplate:
|
||||||
"http://0.0.0.0:7331",
|
"http://0.0.0.0:7331",
|
||||||
"URL for the NeMo Guardrails Service",
|
"URL for the NeMo Guardrails Service",
|
||||||
),
|
),
|
||||||
|
"NVIDIA_EVALUATOR_URL": (
|
||||||
|
"http://0.0.0.0:7331",
|
||||||
|
"URL for the NeMo Evaluator Service",
|
||||||
|
),
|
||||||
"INFERENCE_MODEL": (
|
"INFERENCE_MODEL": (
|
||||||
"Llama3.1-8B-Instruct",
|
"Llama3.1-8B-Instruct",
|
||||||
"Inference model",
|
"Inference model",
|
||||||
|
|
|
@ -18,6 +18,7 @@ providers:
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||||
api_key: ${env.NVIDIA_API_KEY:}
|
api_key: ${env.NVIDIA_API_KEY:}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
|
@ -53,13 +54,10 @@ providers:
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||||
eval:
|
eval:
|
||||||
- provider_id: meta-reference
|
- provider_id: nvidia
|
||||||
provider_type: inline::meta-reference
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
kvstore:
|
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
|
|
|
@ -18,6 +18,7 @@ providers:
|
||||||
config:
|
config:
|
||||||
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}
|
||||||
api_key: ${env.NVIDIA_API_KEY:}
|
api_key: ${env.NVIDIA_API_KEY:}
|
||||||
|
append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True}
|
||||||
vector_io:
|
vector_io:
|
||||||
- provider_id: faiss
|
- provider_id: faiss
|
||||||
provider_type: inline::faiss
|
provider_type: inline::faiss
|
||||||
|
@ -48,13 +49,10 @@ providers:
|
||||||
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db}
|
||||||
eval:
|
eval:
|
||||||
- provider_id: meta-reference
|
- provider_id: nvidia
|
||||||
provider_type: inline::meta-reference
|
provider_type: remote::nvidia
|
||||||
config:
|
config:
|
||||||
kvstore:
|
evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}
|
||||||
type: sqlite
|
|
||||||
namespace: null
|
|
||||||
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db
|
|
||||||
post_training:
|
post_training:
|
||||||
- provider_id: nvidia
|
- provider_id: nvidia
|
||||||
provider_type: remote::nvidia
|
provider_type: remote::nvidia
|
||||||
|
|
7
llama_stack/templates/watsonx/__init__.py
Normal file
7
llama_stack/templates/watsonx/__init__.py
Normal file
|
@ -0,0 +1,7 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from .watsonx import get_distribution_template # noqa: F401
|
30
llama_stack/templates/watsonx/build.yaml
Normal file
30
llama_stack/templates/watsonx/build.yaml
Normal file
|
@ -0,0 +1,30 @@
|
||||||
|
version: '2'
|
||||||
|
distribution_spec:
|
||||||
|
description: Use watsonx for running LLM inference
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- remote::watsonx
|
||||||
|
vector_io:
|
||||||
|
- inline::faiss
|
||||||
|
safety:
|
||||||
|
- inline::llama-guard
|
||||||
|
agents:
|
||||||
|
- inline::meta-reference
|
||||||
|
telemetry:
|
||||||
|
- inline::meta-reference
|
||||||
|
eval:
|
||||||
|
- inline::meta-reference
|
||||||
|
datasetio:
|
||||||
|
- remote::huggingface
|
||||||
|
- inline::localfs
|
||||||
|
scoring:
|
||||||
|
- inline::basic
|
||||||
|
- inline::llm-as-judge
|
||||||
|
- inline::braintrust
|
||||||
|
tool_runtime:
|
||||||
|
- remote::brave-search
|
||||||
|
- remote::tavily-search
|
||||||
|
- inline::code-interpreter
|
||||||
|
- inline::rag-runtime
|
||||||
|
- remote::model-context-protocol
|
||||||
|
image_type: conda
|
74
llama_stack/templates/watsonx/doc_template.md
Normal file
74
llama_stack/templates/watsonx/doc_template.md
Normal file
|
@ -0,0 +1,74 @@
|
||||||
|
---
|
||||||
|
orphan: true
|
||||||
|
---
|
||||||
|
# watsonx Distribution
|
||||||
|
|
||||||
|
```{toctree}
|
||||||
|
:maxdepth: 2
|
||||||
|
:hidden:
|
||||||
|
|
||||||
|
self
|
||||||
|
```
|
||||||
|
|
||||||
|
The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations.
|
||||||
|
|
||||||
|
{{ providers_table }}
|
||||||
|
|
||||||
|
{% if run_config_env_vars %}
|
||||||
|
|
||||||
|
### Environment Variables
|
||||||
|
|
||||||
|
The following environment variables can be configured:
|
||||||
|
|
||||||
|
{% for var, (default_value, description) in run_config_env_vars.items() %}
|
||||||
|
- `{{ var }}`: {{ description }} (default: `{{ default_value }}`)
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
{% if default_models %}
|
||||||
|
### Models
|
||||||
|
|
||||||
|
The following models are available by default:
|
||||||
|
|
||||||
|
{% for model in default_models %}
|
||||||
|
- `{{ model.model_id }} {{ model.doc_string }}`
|
||||||
|
{% endfor %}
|
||||||
|
{% endif %}
|
||||||
|
|
||||||
|
|
||||||
|
### Prerequisite: API Keys
|
||||||
|
|
||||||
|
Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key).
|
||||||
|
|
||||||
|
|
||||||
|
## Running Llama Stack with watsonx
|
||||||
|
|
||||||
|
You can do this via Conda (build code), venv or Docker which has a pre-built image.
|
||||||
|
|
||||||
|
### Via Docker
|
||||||
|
|
||||||
|
This method allows you to get started quickly without having to build the distribution code.
|
||||||
|
|
||||||
|
```bash
|
||||||
|
LLAMA_STACK_PORT=5001
|
||||||
|
docker run \
|
||||||
|
-it \
|
||||||
|
-p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \
|
||||||
|
-v ./run.yaml:/root/my-run.yaml \
|
||||||
|
llamastack/distribution-{{ name }} \
|
||||||
|
--yaml-config /root/my-run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \
|
||||||
|
--env WATSONX_BASE_URL=$WATSONX_BASE_URL
|
||||||
|
```
|
||||||
|
|
||||||
|
### Via Conda
|
||||||
|
|
||||||
|
```bash
|
||||||
|
llama stack build --template watsonx --image-type conda
|
||||||
|
llama stack run ./run.yaml \
|
||||||
|
--port $LLAMA_STACK_PORT \
|
||||||
|
--env WATSONX_API_KEY=$WATSONX_API_KEY \
|
||||||
|
--env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID
|
||||||
|
```
|
210
llama_stack/templates/watsonx/run.yaml
Normal file
210
llama_stack/templates/watsonx/run.yaml
Normal file
|
@ -0,0 +1,210 @@
|
||||||
|
version: '2'
|
||||||
|
image_name: watsonx
|
||||||
|
apis:
|
||||||
|
- agents
|
||||||
|
- datasetio
|
||||||
|
- eval
|
||||||
|
- inference
|
||||||
|
- safety
|
||||||
|
- scoring
|
||||||
|
- telemetry
|
||||||
|
- tool_runtime
|
||||||
|
- vector_io
|
||||||
|
providers:
|
||||||
|
inference:
|
||||||
|
- provider_id: watsonx
|
||||||
|
provider_type: remote::watsonx
|
||||||
|
config:
|
||||||
|
url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}
|
||||||
|
api_key: ${env.WATSONX_API_KEY:}
|
||||||
|
project_id: ${env.WATSONX_PROJECT_ID:}
|
||||||
|
vector_io:
|
||||||
|
- provider_id: faiss
|
||||||
|
provider_type: inline::faiss
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/faiss_store.db
|
||||||
|
safety:
|
||||||
|
- provider_id: llama-guard
|
||||||
|
provider_type: inline::llama-guard
|
||||||
|
config:
|
||||||
|
excluded_categories: []
|
||||||
|
agents:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
persistence_store:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db
|
||||||
|
telemetry:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
service_name: "${env.OTEL_SERVICE_NAME:\u200B}"
|
||||||
|
sinks: ${env.TELEMETRY_SINKS:console,sqlite}
|
||||||
|
sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/watsonx/trace_store.db}
|
||||||
|
eval:
|
||||||
|
- provider_id: meta-reference
|
||||||
|
provider_type: inline::meta-reference
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/meta_reference_eval.db
|
||||||
|
datasetio:
|
||||||
|
- provider_id: huggingface
|
||||||
|
provider_type: remote::huggingface
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/huggingface_datasetio.db
|
||||||
|
- provider_id: localfs
|
||||||
|
provider_type: inline::localfs
|
||||||
|
config:
|
||||||
|
kvstore:
|
||||||
|
type: sqlite
|
||||||
|
namespace: null
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/localfs_datasetio.db
|
||||||
|
scoring:
|
||||||
|
- provider_id: basic
|
||||||
|
provider_type: inline::basic
|
||||||
|
config: {}
|
||||||
|
- provider_id: llm-as-judge
|
||||||
|
provider_type: inline::llm-as-judge
|
||||||
|
config: {}
|
||||||
|
- provider_id: braintrust
|
||||||
|
provider_type: inline::braintrust
|
||||||
|
config:
|
||||||
|
openai_api_key: ${env.OPENAI_API_KEY:}
|
||||||
|
tool_runtime:
|
||||||
|
- provider_id: brave-search
|
||||||
|
provider_type: remote::brave-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.BRAVE_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: tavily-search
|
||||||
|
provider_type: remote::tavily-search
|
||||||
|
config:
|
||||||
|
api_key: ${env.TAVILY_SEARCH_API_KEY:}
|
||||||
|
max_results: 3
|
||||||
|
- provider_id: code-interpreter
|
||||||
|
provider_type: inline::code-interpreter
|
||||||
|
config: {}
|
||||||
|
- provider_id: rag-runtime
|
||||||
|
provider_type: inline::rag-runtime
|
||||||
|
config: {}
|
||||||
|
- provider_id: model-context-protocol
|
||||||
|
provider_type: remote::model-context-protocol
|
||||||
|
config: {}
|
||||||
|
metadata_store:
|
||||||
|
type: sqlite
|
||||||
|
db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db
|
||||||
|
models:
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.3-70B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-3-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-2-13b-chat
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-2-13b-chat
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-2-13b
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-2-13b-chat
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-70B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-70b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.1-8B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-1-8b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-11B-Vision-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-11b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-1b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-1B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-1b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-3b-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-3B-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-3b-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-3.2-90B-Vision-Instruct
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-3-2-90b-vision-instruct
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
- metadata: {}
|
||||||
|
model_id: meta-llama/Llama-Guard-3-11B-Vision
|
||||||
|
provider_id: watsonx
|
||||||
|
provider_model_id: meta-llama/llama-guard-3-11b-vision
|
||||||
|
model_type: llm
|
||||||
|
shields: []
|
||||||
|
vector_dbs: []
|
||||||
|
datasets: []
|
||||||
|
scoring_fns: []
|
||||||
|
benchmarks: []
|
||||||
|
tool_groups:
|
||||||
|
- toolgroup_id: builtin::websearch
|
||||||
|
provider_id: tavily-search
|
||||||
|
- toolgroup_id: builtin::rag
|
||||||
|
provider_id: rag-runtime
|
||||||
|
- toolgroup_id: builtin::code_interpreter
|
||||||
|
provider_id: code-interpreter
|
||||||
|
server:
|
||||||
|
port: 8321
|
90
llama_stack/templates/watsonx/watsonx.py
Normal file
90
llama_stack/templates/watsonx/watsonx.py
Normal file
|
@ -0,0 +1,90 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
from pathlib import Path
|
||||||
|
|
||||||
|
from llama_stack.distribution.datatypes import Provider, ToolGroupInput
|
||||||
|
from llama_stack.providers.remote.inference.watsonx import WatsonXConfig
|
||||||
|
from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES
|
||||||
|
from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry
|
||||||
|
|
||||||
|
|
||||||
|
def get_distribution_template() -> DistributionTemplate:
|
||||||
|
providers = {
|
||||||
|
"inference": ["remote::watsonx"],
|
||||||
|
"vector_io": ["inline::faiss"],
|
||||||
|
"safety": ["inline::llama-guard"],
|
||||||
|
"agents": ["inline::meta-reference"],
|
||||||
|
"telemetry": ["inline::meta-reference"],
|
||||||
|
"eval": ["inline::meta-reference"],
|
||||||
|
"datasetio": ["remote::huggingface", "inline::localfs"],
|
||||||
|
"scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"],
|
||||||
|
"tool_runtime": [
|
||||||
|
"remote::brave-search",
|
||||||
|
"remote::tavily-search",
|
||||||
|
"inline::code-interpreter",
|
||||||
|
"inline::rag-runtime",
|
||||||
|
"remote::model-context-protocol",
|
||||||
|
],
|
||||||
|
}
|
||||||
|
|
||||||
|
inference_provider = Provider(
|
||||||
|
provider_id="watsonx",
|
||||||
|
provider_type="remote::watsonx",
|
||||||
|
config=WatsonXConfig.sample_run_config(),
|
||||||
|
)
|
||||||
|
|
||||||
|
available_models = {
|
||||||
|
"watsonx": MODEL_ENTRIES,
|
||||||
|
}
|
||||||
|
default_tool_groups = [
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::websearch",
|
||||||
|
provider_id="tavily-search",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::rag",
|
||||||
|
provider_id="rag-runtime",
|
||||||
|
),
|
||||||
|
ToolGroupInput(
|
||||||
|
toolgroup_id="builtin::code_interpreter",
|
||||||
|
provider_id="code-interpreter",
|
||||||
|
),
|
||||||
|
]
|
||||||
|
|
||||||
|
default_models = get_model_registry(available_models)
|
||||||
|
return DistributionTemplate(
|
||||||
|
name="watsonx",
|
||||||
|
distro_type="remote_hosted",
|
||||||
|
description="Use watsonx for running LLM inference",
|
||||||
|
container_image=None,
|
||||||
|
template_path=Path(__file__).parent / "doc_template.md",
|
||||||
|
providers=providers,
|
||||||
|
available_models_by_provider=available_models,
|
||||||
|
run_configs={
|
||||||
|
"run.yaml": RunConfigSettings(
|
||||||
|
provider_overrides={
|
||||||
|
"inference": [inference_provider],
|
||||||
|
},
|
||||||
|
default_models=default_models,
|
||||||
|
default_tool_groups=default_tool_groups,
|
||||||
|
),
|
||||||
|
},
|
||||||
|
run_config_env_vars={
|
||||||
|
"LLAMASTACK_PORT": (
|
||||||
|
"5001",
|
||||||
|
"Port for the Llama Stack distribution server",
|
||||||
|
),
|
||||||
|
"WATSONX_API_KEY": (
|
||||||
|
"",
|
||||||
|
"watsonx API Key",
|
||||||
|
),
|
||||||
|
"WATSONX_PROJECT_ID": (
|
||||||
|
"",
|
||||||
|
"watsonx Project ID",
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
|
@ -274,6 +274,7 @@ exclude = [
|
||||||
"^llama_stack/providers/remote/inference/sample/",
|
"^llama_stack/providers/remote/inference/sample/",
|
||||||
"^llama_stack/providers/remote/inference/tgi/",
|
"^llama_stack/providers/remote/inference/tgi/",
|
||||||
"^llama_stack/providers/remote/inference/together/",
|
"^llama_stack/providers/remote/inference/together/",
|
||||||
|
"^llama_stack/providers/remote/inference/watsonx/",
|
||||||
"^llama_stack/providers/remote/safety/bedrock/",
|
"^llama_stack/providers/remote/safety/bedrock/",
|
||||||
"^llama_stack/providers/remote/safety/nvidia/",
|
"^llama_stack/providers/remote/safety/nvidia/",
|
||||||
"^llama_stack/providers/remote/safety/sample/",
|
"^llama_stack/providers/remote/safety/sample/",
|
||||||
|
|
|
@ -75,19 +75,24 @@ def openai_client(client_with_models):
|
||||||
return OpenAI(base_url=base_url, api_key="bar")
|
return OpenAI(base_url=base_url, api_key="bar")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.fixture(params=["openai_client", "llama_stack_client"])
|
||||||
|
def compat_client(request):
|
||||||
|
return request.getfixturevalue(request.param)
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
"inference:completion:sanity",
|
"inference:completion:sanity",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_completion_non_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
# ollama needs more verbose prompting for some reason here...
|
# ollama needs more verbose prompting for some reason here...
|
||||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -103,13 +108,13 @@ def test_openai_completion_non_streaming(openai_client, client_with_models, text
|
||||||
"inference:completion:sanity",
|
"inference:completion:sanity",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_completion_streaming(llama_stack_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
|
|
||||||
# ollama needs more verbose prompting for some reason here...
|
# ollama needs more verbose prompting for some reason here...
|
||||||
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
prompt = "Respond to this question and explain your answer. " + tc["content"]
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=True,
|
stream=True,
|
||||||
|
@ -127,11 +132,11 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod
|
||||||
0,
|
0,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs):
|
def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_models, text_model_id, prompt_logprobs):
|
||||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||||
|
|
||||||
prompt = "Hello, world!"
|
prompt = "Hello, world!"
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -144,11 +149,11 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te
|
||||||
assert len(choice.prompt_logprobs) > 0
|
assert len(choice.prompt_logprobs) > 0
|
||||||
|
|
||||||
|
|
||||||
def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id):
|
def test_openai_completion_guided_choice(llama_stack_client, client_with_models, text_model_id):
|
||||||
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
skip_if_provider_isnt_vllm(client_with_models, text_model_id)
|
||||||
|
|
||||||
prompt = "I am feeling really sad today."
|
prompt = "I am feeling really sad today."
|
||||||
response = openai_client.completions.create(
|
response = llama_stack_client.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
prompt=prompt,
|
prompt=prompt,
|
||||||
stream=False,
|
stream=False,
|
||||||
|
@ -161,6 +166,9 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
||||||
assert choice.text in ["joy", "sadness"]
|
assert choice.text in ["joy", "sadness"]
|
||||||
|
|
||||||
|
|
||||||
|
# Run the chat-completion tests with both the OpenAI client and the LlamaStack client
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.parametrize(
|
@pytest.mark.parametrize(
|
||||||
"test_case",
|
"test_case",
|
||||||
[
|
[
|
||||||
|
@ -168,13 +176,13 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text
|
||||||
"inference:chat_completion:non_streaming_02",
|
"inference:chat_completion:non_streaming_02",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_chat_completion_non_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
question = tc["question"]
|
question = tc["question"]
|
||||||
expected = tc["expected"]
|
expected = tc["expected"]
|
||||||
|
|
||||||
response = openai_client.chat.completions.create(
|
response = compat_client.chat.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
messages=[
|
messages=[
|
||||||
{
|
{
|
||||||
|
@ -196,13 +204,13 @@ def test_openai_chat_completion_non_streaming(openai_client, client_with_models,
|
||||||
"inference:chat_completion:streaming_02",
|
"inference:chat_completion:streaming_02",
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case):
|
def test_openai_chat_completion_streaming(compat_client, client_with_models, text_model_id, test_case):
|
||||||
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id)
|
||||||
tc = TestCase(test_case)
|
tc = TestCase(test_case)
|
||||||
question = tc["question"]
|
question = tc["question"]
|
||||||
expected = tc["expected"]
|
expected = tc["expected"]
|
||||||
|
|
||||||
response = openai_client.chat.completions.create(
|
response = compat_client.chat.completions.create(
|
||||||
model=text_model_id,
|
model=text_model_id,
|
||||||
messages=[{"role": "user", "content": question}],
|
messages=[{"role": "user", "content": question}],
|
||||||
stream=True,
|
stream=True,
|
||||||
|
|
|
@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel
|
||||||
|
|
||||||
from llama_stack.apis.inference import (
|
from llama_stack.apis.inference import (
|
||||||
ChatCompletionRequest,
|
ChatCompletionRequest,
|
||||||
|
CompletionMessage,
|
||||||
|
SystemMessage,
|
||||||
ToolChoice,
|
ToolChoice,
|
||||||
ToolConfig,
|
ToolConfig,
|
||||||
|
ToolResponseMessage,
|
||||||
UserMessage,
|
UserMessage,
|
||||||
)
|
)
|
||||||
from llama_stack.apis.models import Model
|
from llama_stack.apis.models import Model
|
||||||
from llama_stack.models.llama.datatypes import StopReason
|
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||||
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig
|
||||||
from llama_stack.providers.remote.inference.vllm.vllm import (
|
from llama_stack.providers.remote.inference.vllm.vllm import (
|
||||||
VLLMInferenceAdapter,
|
VLLMInferenceAdapter,
|
||||||
|
@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter):
|
||||||
assert request.tool_config.tool_choice == ToolChoice.none
|
assert request.tool_config.tool_choice == ToolChoice.none
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_tool_call_response(vllm_inference_adapter):
|
||||||
|
"""Verify that tool call arguments from a CompletionMessage are correctly converted
|
||||||
|
into the expected JSON format."""
|
||||||
|
|
||||||
|
# Patch the call to vllm so we can inspect the arguments sent were correct
|
||||||
|
with patch.object(
|
||||||
|
vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock
|
||||||
|
) as mock_nonstream_completion:
|
||||||
|
messages = [
|
||||||
|
SystemMessage(content="You are a helpful assistant"),
|
||||||
|
UserMessage(content="How many?"),
|
||||||
|
CompletionMessage(
|
||||||
|
content="",
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(
|
||||||
|
call_id="foo",
|
||||||
|
tool_name="knowledge_search",
|
||||||
|
arguments={"query": "How many?"},
|
||||||
|
arguments_json='{"query": "How many?"}',
|
||||||
|
)
|
||||||
|
],
|
||||||
|
),
|
||||||
|
ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."),
|
||||||
|
]
|
||||||
|
await vllm_inference_adapter.chat_completion(
|
||||||
|
"mock-model",
|
||||||
|
messages,
|
||||||
|
stream=False,
|
||||||
|
tools=[],
|
||||||
|
tool_config=ToolConfig(tool_choice=ToolChoice.auto),
|
||||||
|
)
|
||||||
|
|
||||||
|
assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [
|
||||||
|
{
|
||||||
|
"id": "foo",
|
||||||
|
"type": "function",
|
||||||
|
"function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'},
|
||||||
|
}
|
||||||
|
]
|
||||||
|
|
||||||
|
|
||||||
@pytest.mark.asyncio
|
@pytest.mark.asyncio
|
||||||
async def test_tool_call_delta_empty_tool_call_buf():
|
async def test_tool_call_delta_empty_tool_call_buf():
|
||||||
"""
|
"""
|
||||||
|
|
201
tests/unit/providers/nvidia/test_eval.py
Normal file
201
tests/unit/providers/nvidia/test_eval.py
Normal file
|
@ -0,0 +1,201 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import os
|
||||||
|
import unittest
|
||||||
|
from unittest.mock import MagicMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.benchmarks import Benchmark
|
||||||
|
from llama_stack.apis.common.job_types import Job, JobStatus
|
||||||
|
from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams
|
||||||
|
from llama_stack.models.llama.sku_types import CoreModelId
|
||||||
|
from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig
|
||||||
|
from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl
|
||||||
|
|
||||||
|
MOCK_DATASET_ID = "default/test-dataset"
|
||||||
|
MOCK_BENCHMARK_ID = "test-benchmark"
|
||||||
|
|
||||||
|
|
||||||
|
class TestNVIDIAEvalImpl(unittest.TestCase):
|
||||||
|
def setUp(self):
|
||||||
|
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
|
||||||
|
|
||||||
|
# Create mock APIs
|
||||||
|
self.datasetio_api = MagicMock()
|
||||||
|
self.datasets_api = MagicMock()
|
||||||
|
self.scoring_api = MagicMock()
|
||||||
|
self.inference_api = MagicMock()
|
||||||
|
self.agents_api = MagicMock()
|
||||||
|
|
||||||
|
self.config = NVIDIAEvalConfig(
|
||||||
|
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
|
||||||
|
)
|
||||||
|
|
||||||
|
self.eval_impl = NVIDIAEvalImpl(
|
||||||
|
config=self.config,
|
||||||
|
datasetio_api=self.datasetio_api,
|
||||||
|
datasets_api=self.datasets_api,
|
||||||
|
scoring_api=self.scoring_api,
|
||||||
|
inference_api=self.inference_api,
|
||||||
|
agents_api=self.agents_api,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock the HTTP request methods
|
||||||
|
self.evaluator_get_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
|
||||||
|
)
|
||||||
|
self.evaluator_post_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
|
||||||
|
)
|
||||||
|
|
||||||
|
self.mock_evaluator_get = self.evaluator_get_patcher.start()
|
||||||
|
self.mock_evaluator_post = self.evaluator_post_patcher.start()
|
||||||
|
|
||||||
|
def tearDown(self):
|
||||||
|
"""Clean up after each test."""
|
||||||
|
self.evaluator_get_patcher.stop()
|
||||||
|
self.evaluator_post_patcher.stop()
|
||||||
|
|
||||||
|
def _assert_request_body(self, expected_json):
|
||||||
|
"""Helper method to verify request body in Evaluator POST request is correct"""
|
||||||
|
call_args = self.mock_evaluator_post.call_args
|
||||||
|
actual_json = call_args[0][1]
|
||||||
|
|
||||||
|
# Check that all expected keys contain the expected values in the actual JSON
|
||||||
|
for key, value in expected_json.items():
|
||||||
|
assert key in actual_json, f"Key '{key}' missing in actual JSON"
|
||||||
|
|
||||||
|
if isinstance(value, dict):
|
||||||
|
for nested_key, nested_value in value.items():
|
||||||
|
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
|
||||||
|
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
|
||||||
|
else:
|
||||||
|
assert actual_json[key] == value, f"Value mismatch for '{key}'"
|
||||||
|
|
||||||
|
@pytest.fixture(autouse=True)
|
||||||
|
def inject_fixtures(self, run_async):
|
||||||
|
self.run_async = run_async
|
||||||
|
|
||||||
|
def test_register_benchmark(self):
|
||||||
|
eval_config = {
|
||||||
|
"type": "custom",
|
||||||
|
"params": {"parallelism": 8},
|
||||||
|
"tasks": {
|
||||||
|
"qa": {
|
||||||
|
"type": "completion",
|
||||||
|
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
|
||||||
|
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
|
||||||
|
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
|
||||||
|
}
|
||||||
|
},
|
||||||
|
}
|
||||||
|
|
||||||
|
benchmark = Benchmark(
|
||||||
|
provider_id="nvidia",
|
||||||
|
type="benchmark",
|
||||||
|
identifier=MOCK_BENCHMARK_ID,
|
||||||
|
dataset_id=MOCK_DATASET_ID,
|
||||||
|
scoring_functions=["basic::equality"],
|
||||||
|
metadata=eval_config,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock Evaluator API response
|
||||||
|
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
|
||||||
|
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||||
|
|
||||||
|
# Register the benchmark
|
||||||
|
self.run_async(self.eval_impl.register_benchmark(benchmark))
|
||||||
|
|
||||||
|
# Verify the Evaluator API was called correctly
|
||||||
|
self.mock_evaluator_post.assert_called_once()
|
||||||
|
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
|
||||||
|
|
||||||
|
def test_run_eval(self):
|
||||||
|
benchmark_config = BenchmarkConfig(
|
||||||
|
eval_candidate=ModelCandidate(
|
||||||
|
type="model",
|
||||||
|
model=CoreModelId.llama3_1_8b_instruct.value,
|
||||||
|
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Mock Evaluator API response
|
||||||
|
mock_evaluator_response = {"id": "job-123", "status": "created"}
|
||||||
|
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||||
|
|
||||||
|
# Run the Evaluation job
|
||||||
|
result = self.run_async(
|
||||||
|
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the Evaluator API was called correctly
|
||||||
|
self.mock_evaluator_post.assert_called_once()
|
||||||
|
self._assert_request_body(
|
||||||
|
{
|
||||||
|
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
|
||||||
|
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
|
||||||
|
}
|
||||||
|
)
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert isinstance(result, Job)
|
||||||
|
assert result.job_id == "job-123"
|
||||||
|
assert result.status == JobStatus.in_progress
|
||||||
|
|
||||||
|
def test_job_status(self):
|
||||||
|
# Mock Evaluator API response
|
||||||
|
mock_evaluator_response = {"id": "job-123", "status": "completed"}
|
||||||
|
self.mock_evaluator_get.return_value = mock_evaluator_response
|
||||||
|
|
||||||
|
# Get the Evaluation job
|
||||||
|
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert isinstance(result, Job)
|
||||||
|
assert result.job_id == "job-123"
|
||||||
|
assert result.status == JobStatus.completed
|
||||||
|
|
||||||
|
# Verify the API was called correctly
|
||||||
|
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
|
||||||
|
|
||||||
|
def test_job_cancel(self):
|
||||||
|
# Mock Evaluator API response
|
||||||
|
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
|
||||||
|
self.mock_evaluator_post.return_value = mock_evaluator_response
|
||||||
|
|
||||||
|
# Cancel the Evaluation job
|
||||||
|
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||||
|
|
||||||
|
# Verify the API was called correctly
|
||||||
|
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
|
||||||
|
|
||||||
|
def test_job_result(self):
|
||||||
|
# Mock Evaluator API responses
|
||||||
|
mock_job_status_response = {"id": "job-123", "status": "completed"}
|
||||||
|
mock_job_results_response = {
|
||||||
|
"id": "job-123",
|
||||||
|
"status": "completed",
|
||||||
|
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
|
||||||
|
}
|
||||||
|
self.mock_evaluator_get.side_effect = [
|
||||||
|
mock_job_status_response, # First call to retrieve job
|
||||||
|
mock_job_results_response, # Second call to retrieve job results
|
||||||
|
]
|
||||||
|
|
||||||
|
# Get the Evaluation job results
|
||||||
|
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
|
||||||
|
|
||||||
|
# Verify the result
|
||||||
|
assert isinstance(result, EvaluateResponse)
|
||||||
|
assert MOCK_BENCHMARK_ID in result.scores
|
||||||
|
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
|
||||||
|
|
||||||
|
# Verify the API was called correctly
|
||||||
|
assert self.mock_evaluator_get.call_count == 2
|
||||||
|
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
|
||||||
|
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
|
|
@ -11,6 +11,7 @@ from unittest.mock import patch
|
||||||
|
|
||||||
import pytest
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.models import Model, ModelType
|
||||||
from llama_stack.apis.post_training.post_training import (
|
from llama_stack.apis.post_training.post_training import (
|
||||||
DataConfig,
|
DataConfig,
|
||||||
DatasetFormat,
|
DatasetFormat,
|
||||||
|
@ -21,6 +22,7 @@ from llama_stack.apis.post_training.post_training import (
|
||||||
TrainingConfig,
|
TrainingConfig,
|
||||||
)
|
)
|
||||||
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
|
||||||
|
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
|
||||||
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
from llama_stack.providers.remote.post_training.nvidia.post_training import (
|
||||||
ListNvidiaPostTrainingJobs,
|
ListNvidiaPostTrainingJobs,
|
||||||
NvidiaPostTrainingAdapter,
|
NvidiaPostTrainingAdapter,
|
||||||
|
@ -44,8 +46,22 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
)
|
)
|
||||||
self.mock_make_request = self.make_request_patcher.start()
|
self.mock_make_request = self.make_request_patcher.start()
|
||||||
|
|
||||||
|
# Mock the inference client
|
||||||
|
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
|
||||||
|
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
|
||||||
|
|
||||||
|
self.mock_client = unittest.mock.MagicMock()
|
||||||
|
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
|
||||||
|
self.inference_mock_make_request = self.mock_client.chat.completions.create
|
||||||
|
self.inference_make_request_patcher = patch(
|
||||||
|
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
|
||||||
|
return_value=self.mock_client,
|
||||||
|
)
|
||||||
|
self.inference_make_request_patcher.start()
|
||||||
|
|
||||||
def tearDown(self):
|
def tearDown(self):
|
||||||
self.make_request_patcher.stop()
|
self.make_request_patcher.stop()
|
||||||
|
self.inference_make_request_patcher.stop()
|
||||||
|
|
||||||
@pytest.fixture(autouse=True)
|
@pytest.fixture(autouse=True)
|
||||||
def inject_fixtures(self, run_async):
|
def inject_fixtures(self, run_async):
|
||||||
|
@ -316,6 +332,31 @@ class TestNvidiaPostTraining(unittest.TestCase):
|
||||||
expected_params={"job_id": job_id},
|
expected_params={"job_id": job_id},
|
||||||
)
|
)
|
||||||
|
|
||||||
|
def test_inference_register_model(self):
|
||||||
|
model_id = "default/job-1234"
|
||||||
|
model_type = ModelType.llm
|
||||||
|
model = Model(
|
||||||
|
identifier=model_id,
|
||||||
|
provider_id="nvidia",
|
||||||
|
provider_model_id=model_id,
|
||||||
|
provider_resource_id=model_id,
|
||||||
|
model_type=model_type,
|
||||||
|
)
|
||||||
|
result = self.run_async(self.inference_adapter.register_model(model))
|
||||||
|
assert result == model
|
||||||
|
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
|
||||||
|
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
|
||||||
|
|
||||||
|
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
|
||||||
|
self.run_async(
|
||||||
|
self.inference_adapter.chat_completion(
|
||||||
|
model_id=model_id,
|
||||||
|
messages=[{"role": "user", "content": "Hello, model"}],
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_chat_completion.assert_called()
|
||||||
|
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
unittest.main()
|
unittest.main()
|
||||||
|
|
43
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
43
tests/unit/providers/utils/inference/test_openai_compat.py
Normal file
|
@ -0,0 +1,43 @@
|
||||||
|
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
||||||
|
# All rights reserved.
|
||||||
|
#
|
||||||
|
# This source code is licensed under the terms described in the LICENSE file in
|
||||||
|
# the root directory of this source tree.
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from llama_stack.apis.common.content_types import TextContentItem
|
||||||
|
from llama_stack.apis.inference.inference import CompletionMessage, UserMessage
|
||||||
|
from llama_stack.models.llama.datatypes import StopReason, ToolCall
|
||||||
|
from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_message_to_openai_dict():
|
||||||
|
message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user")
|
||||||
|
assert await convert_message_to_openai_dict(message) == {
|
||||||
|
"role": "user",
|
||||||
|
"content": [{"type": "text", "text": "Hello, world!"}],
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
# Test convert_message_to_openai_dict with a tool call
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_convert_message_to_openai_dict_with_tool_call():
|
||||||
|
message = CompletionMessage(
|
||||||
|
content="",
|
||||||
|
tool_calls=[
|
||||||
|
ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"})
|
||||||
|
],
|
||||||
|
stop_reason=StopReason.end_of_turn,
|
||||||
|
)
|
||||||
|
|
||||||
|
openai_dict = await convert_message_to_openai_dict(message)
|
||||||
|
|
||||||
|
assert openai_dict == {
|
||||||
|
"role": "assistant",
|
||||||
|
"content": [{"type": "text", "text": ""}],
|
||||||
|
"tool_calls": [
|
||||||
|
{"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}}
|
||||||
|
],
|
||||||
|
}
|
|
@ -47,9 +47,45 @@ async def test_sse_generator_client_disconnected():
|
||||||
sse_gen = sse_generator(async_event_gen())
|
sse_gen = sse_generator(async_event_gen())
|
||||||
assert sse_gen is not None
|
assert sse_gen is not None
|
||||||
|
|
||||||
# Start reading the events, ensuring this doesn't raise an exception
|
|
||||||
seen_events = []
|
seen_events = []
|
||||||
async for event in sse_gen:
|
async for event in sse_gen:
|
||||||
seen_events.append(event)
|
seen_events.append(event)
|
||||||
|
|
||||||
|
# We should see 1 event before the client disconnected
|
||||||
assert len(seen_events) == 1
|
assert len(seen_events) == 1
|
||||||
assert seen_events[0] == create_sse_event("Test event 1")
|
assert seen_events[0] == create_sse_event("Test event 1")
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_generator_client_disconnected_before_response_starts():
|
||||||
|
# Disconnect before the response starts
|
||||||
|
async def async_event_gen():
|
||||||
|
raise asyncio.CancelledError()
|
||||||
|
|
||||||
|
sse_gen = sse_generator(async_event_gen())
|
||||||
|
assert sse_gen is not None
|
||||||
|
|
||||||
|
seen_events = []
|
||||||
|
async for event in sse_gen:
|
||||||
|
seen_events.append(event)
|
||||||
|
|
||||||
|
# No events should be seen since the client disconnected immediately
|
||||||
|
assert len(seen_events) == 0
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_sse_generator_error_before_response_starts():
|
||||||
|
# Raise an error before the response starts
|
||||||
|
async def async_event_gen():
|
||||||
|
raise Exception("Test error")
|
||||||
|
|
||||||
|
sse_gen = sse_generator(async_event_gen())
|
||||||
|
assert sse_gen is not None
|
||||||
|
|
||||||
|
seen_events = []
|
||||||
|
async for event in sse_gen:
|
||||||
|
seen_events.append(event)
|
||||||
|
|
||||||
|
# We should have 1 error event
|
||||||
|
assert len(seen_events) == 1
|
||||||
|
assert 'data: {"error":' in seen_events[0]
|
||||||
|
|
Loading…
Add table
Add a link
Reference in a new issue