diff --git a/.coveragerc b/.coveragerc new file mode 100644 index 000000000..e16c2e461 --- /dev/null +++ b/.coveragerc @@ -0,0 +1,6 @@ +[run] +omit = + */tests/* + */llama_stack/providers/* + */llama_stack/templates/* + .venv/* diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 0eb252695..f54bed839 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -6,7 +6,6 @@ on: pull_request: branches: [ main ] paths: - - 'distributions/**' - 'llama_stack/**' - 'tests/integration/**' - 'uv.lock' diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 117c8b6d2..23257d7dc 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -107,3 +107,41 @@ jobs: - name: Build a single provider run: | USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --image-type venv --image-name test --providers inference=remote::ollama + + build-custom-container-distribution: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Set up Python + uses: actions/setup-python@8d9ed9ac5c53483de85588cdf95a591a75ab9f55 # v5.5.0 + with: + python-version: '3.10' + + - name: Install uv + uses: astral-sh/setup-uv@0c5e2b8115b80b4c7c5ddf6ffdd634974642d182 # v5.4.1 + with: + python-version: "3.10" + + - name: Install LlamaStack + run: | + uv venv + source .venv/bin/activate + uv pip install -e . + + - name: Build a single provider + run: | + yq -i '.image_type = "container"' llama_stack/templates/dev/build.yaml + yq -i '.image_name = "test"' llama_stack/templates/dev/build.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/dev/build.yaml + + - name: Inspect the container image entrypoint + run: | + IMAGE_ID=$(docker images --format "{{.Repository}}:{{.Tag}}" | head -n 1) + entrypoint=$(docker inspect --format '{{ .Config.Entrypoint }}' $IMAGE_ID) + echo "Entrypoint: $entrypoint" + if [ "$entrypoint" != "[python -m llama_stack.distribution.server.server --config /app/run.yaml]" ]; then + echo "Entrypoint is not correct" + exit 1 + fi diff --git a/.github/workflows/test-external-providers.yml b/.github/workflows/test-external-providers.yml index f7801c8d3..37f5c45ab 100644 --- a/.github/workflows/test-external-providers.yml +++ b/.github/workflows/test-external-providers.yml @@ -5,6 +5,13 @@ on: branches: [ main ] pull_request: branches: [ main ] + paths: + - 'llama_stack/**' + - 'tests/integration/**' + - 'uv.lock' + - 'pyproject.toml' + - 'requirements.txt' + - '.github/workflows/test-external-providers.yml' # This workflow jobs: test-external-providers: diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index 4b0c58b99..962141744 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -6,7 +6,6 @@ on: pull_request: branches: [ main ] paths: - - 'distributions/**' - 'llama_stack/**' - 'tests/unit/**' - 'uv.lock' diff --git a/README.md b/README.md index 8c201e43d..9a4f1a849 100644 --- a/README.md +++ b/README.md @@ -119,6 +119,7 @@ Here is a list of the various API providers and available distributions that can | OpenAI | Hosted | | ✅ | | | | | Anthropic | Hosted | | ✅ | | | | | Gemini | Hosted | | ✅ | | | | +| watsonx | Hosted | | ✅ | | | | ### Distributions @@ -128,7 +129,6 @@ A Llama Stack Distribution (or "distro") is a pre-configured bundle of provider | **Distribution** | **Llama Stack Docker** | Start This Distribution | |:---------------------------------------------:|:-------------------------------------------------------------------------------------------------------------------------------------------------------------:|:------------------------------------------------------------------------------------------------------------------------:| | Meta Reference | [llamastack/distribution-meta-reference-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-gpu.html) | -| Meta Reference Quantized | [llamastack/distribution-meta-reference-quantized-gpu](https://hub.docker.com/repository/docker/llamastack/distribution-meta-reference-quantized-gpu/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/meta-reference-quantized-gpu.html) | | SambaNova | [llamastack/distribution-sambanova](https://hub.docker.com/repository/docker/llamastack/distribution-sambanova/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/sambanova.html) | | Cerebras | [llamastack/distribution-cerebras](https://hub.docker.com/repository/docker/llamastack/distribution-cerebras/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/cerebras.html) | | Ollama | [llamastack/distribution-ollama](https://hub.docker.com/repository/docker/llamastack/distribution-ollama/general) | [Guide](https://llama-stack.readthedocs.io/en/latest/distributions/self_hosted_distro/ollama.html) | diff --git a/docs/source/building_applications/rag.md b/docs/source/building_applications/rag.md index 39d1ba333..db6303209 100644 --- a/docs/source/building_applications/rag.md +++ b/docs/source/building_applications/rag.md @@ -68,7 +68,8 @@ chunks_response = client.vector_io.query( ### Using the RAG Tool A better way to ingest documents is to use the RAG Tool. This tool allows you to ingest documents from URLs, files, etc. -and automatically chunks them into smaller pieces. +and automatically chunks them into smaller pieces. More examples for how to format a RAGDocument can be found in the +[appendix](#more-ragdocument-examples). ```python from llama_stack_client import RAGDocument @@ -178,3 +179,38 @@ for vector_db_id in client.vector_dbs.list(): print(f"Unregistering vector database: {vector_db_id.identifier}") client.vector_dbs.unregister(vector_db_id=vector_db_id.identifier) ``` + +### Appendix + +#### More RAGDocument Examples +```python +from llama_stack_client import RAGDocument +import base64 + +RAGDocument(document_id="num-0", content={"uri": "file://path/to/file"}) +RAGDocument(document_id="num-1", content="plain text") +RAGDocument( + document_id="num-2", + content={ + "type": "text", + "text": "plain text input", + }, # for inputs that should be treated as text explicitly +) +RAGDocument( + document_id="num-3", + content={ + "type": "image", + "image": {"url": {"uri": "https://mywebsite.com/image.jpg"}}, + }, +) +B64_ENCODED_IMAGE = base64.b64encode( + requests.get( + "https://raw.githubusercontent.com/meta-llama/llama-stack/refs/heads/main/docs/_static/llama-stack.png" + ).content +) +RAGDocuemnt( + document_id="num-4", + content={"type": "image", "image": {"data": B64_ENCODED_IMAGE}}, +) +``` +for more strongly typed interaction use the typed dicts found [here](https://github.com/meta-llama/llama-stack-client-python/blob/38cd91c9e396f2be0bec1ee96a19771582ba6f17/src/llama_stack_client/types/shared_params/document.py). diff --git a/docs/source/distributions/building_distro.md b/docs/source/distributions/building_distro.md index 4c342b14b..56b8d30a8 100644 --- a/docs/source/distributions/building_distro.md +++ b/docs/source/distributions/building_distro.md @@ -109,8 +109,6 @@ llama stack build --list-templates +------------------------------+-----------------------------------------------------------------------------+ | nvidia | Use NVIDIA NIM for running LLM inference | +------------------------------+-----------------------------------------------------------------------------+ -| meta-reference-quantized-gpu | Use Meta Reference with fp8, int4 quantization for running LLM inference | -+------------------------------+-----------------------------------------------------------------------------+ | cerebras | Use Cerebras for running LLM inference | +------------------------------+-----------------------------------------------------------------------------+ | ollama | Use (an external) Ollama server for running LLM inference | diff --git a/docs/source/distributions/remote_hosted_distro/watsonx.md b/docs/source/distributions/remote_hosted_distro/watsonx.md new file mode 100644 index 000000000..018dc2a3c --- /dev/null +++ b/docs/source/distributions/remote_hosted_distro/watsonx.md @@ -0,0 +1,88 @@ +--- +orphan: true +--- + +# watsonx Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-watsonx` distribution consists of the following provider configurations. + +| API | Provider(s) | +|-----|-------------| +| agents | `inline::meta-reference` | +| datasetio | `remote::huggingface`, `inline::localfs` | +| eval | `inline::meta-reference` | +| inference | `remote::watsonx` | +| safety | `inline::llama-guard` | +| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | +| telemetry | `inline::meta-reference` | +| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | +| vector_io | `inline::faiss` | + + + +### Environment Variables + +The following environment variables can be configured: + +- `LLAMASTACK_PORT`: Port for the Llama Stack distribution server (default: `5001`) +- `WATSONX_API_KEY`: watsonx API Key (default: ``) +- `WATSONX_PROJECT_ID`: watsonx Project ID (default: ``) + +### Models + +The following models are available by default: + +- `meta-llama/llama-3-3-70b-instruct (aliases: meta-llama/Llama-3.3-70B-Instruct)` +- `meta-llama/llama-2-13b-chat (aliases: meta-llama/Llama-2-13b)` +- `meta-llama/llama-3-1-70b-instruct (aliases: meta-llama/Llama-3.1-70B-Instruct)` +- `meta-llama/llama-3-1-8b-instruct (aliases: meta-llama/Llama-3.1-8B-Instruct)` +- `meta-llama/llama-3-2-11b-vision-instruct (aliases: meta-llama/Llama-3.2-11B-Vision-Instruct)` +- `meta-llama/llama-3-2-1b-instruct (aliases: meta-llama/Llama-3.2-1B-Instruct)` +- `meta-llama/llama-3-2-3b-instruct (aliases: meta-llama/Llama-3.2-3B-Instruct)` +- `meta-llama/llama-3-2-90b-vision-instruct (aliases: meta-llama/Llama-3.2-90B-Vision-Instruct)` +- `meta-llama/llama-guard-3-11b-vision (aliases: meta-llama/Llama-Guard-3-11B-Vision)` + + +### Prerequisite: API Keys + +Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key). + + +## Running Llama Stack with watsonx + +You can do this via Conda (build code), venv or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-watsonx \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env WATSONX_API_KEY=$WATSONX_API_KEY \ + --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ + --env WATSONX_BASE_URL=$WATSONX_BASE_URL +``` + +### Via Conda + +```bash +llama stack build --template watsonx --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env WATSONX_API_KEY=$WATSONX_API_KEY \ + --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID +``` diff --git a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md index b90f75347..f58d7bbee 100644 --- a/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md +++ b/docs/source/distributions/self_hosted_distro/meta-reference-gpu.md @@ -81,6 +81,7 @@ LLAMA_STACK_PORT=8321 docker run \ -it \ --pull always \ + --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ llamastack/distribution-meta-reference-gpu \ @@ -94,6 +95,7 @@ If you are using Llama Stack Safety / Shield APIs, use: docker run \ -it \ --pull always \ + --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ llamastack/distribution-meta-reference-gpu \ diff --git a/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md b/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md deleted file mode 100644 index c3e2b4f2c..000000000 --- a/docs/source/distributions/self_hosted_distro/meta-reference-quantized-gpu.md +++ /dev/null @@ -1,123 +0,0 @@ ---- -orphan: true ---- - -# Meta Reference Quantized Distribution - -```{toctree} -:maxdepth: 2 -:hidden: - -self -``` - -The `llamastack/distribution-meta-reference-quantized-gpu` distribution consists of the following provider configurations: - -| API | Provider(s) | -|-----|-------------| -| agents | `inline::meta-reference` | -| datasetio | `remote::huggingface`, `inline::localfs` | -| eval | `inline::meta-reference` | -| inference | `inline::meta-reference-quantized` | -| safety | `inline::llama-guard` | -| scoring | `inline::basic`, `inline::llm-as-judge`, `inline::braintrust` | -| telemetry | `inline::meta-reference` | -| tool_runtime | `remote::brave-search`, `remote::tavily-search`, `inline::code-interpreter`, `inline::rag-runtime`, `remote::model-context-protocol` | -| vector_io | `inline::faiss`, `remote::chromadb`, `remote::pgvector` | - - -The only difference vs. the `meta-reference-gpu` distribution is that it has support for more efficient inference -- with fp8, int4 quantization, etc. - -Note that you need access to nvidia GPUs to run this distribution. This distribution is not compatible with CPU-only machines or machines with AMD GPUs. - -### Environment Variables - -The following environment variables can be configured: - -- `LLAMA_STACK_PORT`: Port for the Llama Stack distribution server (default: `8321`) -- `INFERENCE_MODEL`: Inference model loaded into the Meta Reference server (default: `meta-llama/Llama-3.2-3B-Instruct`) -- `INFERENCE_CHECKPOINT_DIR`: Directory containing the Meta Reference model checkpoint (default: `null`) - - -## Prerequisite: Downloading Models - -Please use `llama model list --downloaded` to check that you have llama model checkpoints downloaded in `~/.llama` before proceeding. See [installation guide](https://llama-stack.readthedocs.io/en/latest/references/llama_cli_reference/download_models.html) here to download the models. Run `llama model list` to see the available models to download, and `llama model download` to download the checkpoints. - -``` -$ llama model list --downloaded -┏━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━┳━━━━━━━━━━┳━━━━━━━━━━━━━━━━━━━━━┓ -┃ Model ┃ Size ┃ Modified Time ┃ -┡━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━╇━━━━━━━━━━╇━━━━━━━━━━━━━━━━━━━━━┩ -│ Llama3.2-1B-Instruct:int4-qlora-eo8 │ 1.53 GB │ 2025-02-26 11:22:28 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.2-1B │ 2.31 GB │ 2025-02-18 21:48:52 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Prompt-Guard-86M │ 0.02 GB │ 2025-02-26 11:29:28 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.2-3B-Instruct:int4-spinquant-eo8 │ 3.69 GB │ 2025-02-26 11:37:41 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.2-3B │ 5.99 GB │ 2025-02-18 21:51:26 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.1-8B │ 14.97 GB │ 2025-02-16 10:36:37 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama3.2-1B-Instruct:int4-spinquant-eo8 │ 1.51 GB │ 2025-02-26 11:35:02 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama-Guard-3-1B │ 2.80 GB │ 2025-02-26 11:20:46 │ -├─────────────────────────────────────────┼──────────┼─────────────────────┤ -│ Llama-Guard-3-1B:int4 │ 0.43 GB │ 2025-02-26 11:33:33 │ -└─────────────────────────────────────────┴──────────┴─────────────────────┘ -``` - -## Running the Distribution - -You can do this via Conda (build code) or Docker which has a pre-built image. - -### Via Docker - -This method allows you to get started quickly without having to build the distribution code. - -```bash -LLAMA_STACK_PORT=8321 -docker run \ - -it \ - --pull always \ - -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ~/.llama:/root/.llama \ - llamastack/distribution-meta-reference-quantized-gpu \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct -``` - -If you are using Llama Stack Safety / Shield APIs, use: - -```bash -docker run \ - -it \ - --pull always \ - -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ - -v ~/.llama:/root/.llama \ - llamastack/distribution-meta-reference-quantized-gpu \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B -``` - -### Via Conda - -Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. - -```bash -llama stack build --template meta-reference-quantized-gpu --image-type conda -llama stack run distributions/meta-reference-quantized-gpu/run.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct -``` - -If you are using Llama Stack Safety / Shield APIs, use: - -```bash -llama stack run distributions/meta-reference-quantized-gpu/run-with-safety.yaml \ - --port $LLAMA_STACK_PORT \ - --env INFERENCE_MODEL=meta-llama/Llama-3.2-3B-Instruct \ - --env SAFETY_MODEL=meta-llama/Llama-Guard-3-1B -``` diff --git a/docs/source/distributions/self_hosted_distro/nvidia.md b/docs/source/distributions/self_hosted_distro/nvidia.md index 0922cb512..4407de779 100644 --- a/docs/source/distributions/self_hosted_distro/nvidia.md +++ b/docs/source/distributions/self_hosted_distro/nvidia.md @@ -7,7 +7,7 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov |-----|-------------| | agents | `inline::meta-reference` | | datasetio | `inline::localfs` | -| eval | `inline::meta-reference` | +| eval | `remote::nvidia` | | inference | `remote::nvidia` | | post_training | `remote::nvidia` | | safety | `remote::nvidia` | @@ -22,13 +22,13 @@ The `llamastack/distribution-nvidia` distribution consists of the following prov The following environment variables can be configured: - `NVIDIA_API_KEY`: NVIDIA API Key (default: ``) -- `NVIDIA_USER_ID`: NVIDIA User ID (default: `llama-stack-user`) +- `NVIDIA_APPEND_API_VERSION`: Whether to append the API version to the base_url (default: `True`) - `NVIDIA_DATASET_NAMESPACE`: NVIDIA Dataset Namespace (default: `default`) -- `NVIDIA_ACCESS_POLICIES`: NVIDIA Access Policies (default: `{}`) - `NVIDIA_PROJECT_ID`: NVIDIA Project ID (default: `test-project`) - `NVIDIA_CUSTOMIZER_URL`: NVIDIA Customizer URL (default: `https://customizer.api.nvidia.com`) - `NVIDIA_OUTPUT_MODEL_DIR`: NVIDIA Output Model Directory (default: `test-example-model@v1`) - `GUARDRAILS_SERVICE_URL`: URL for the NeMo Guardrails Service (default: `http://0.0.0.0:7331`) +- `NVIDIA_EVALUATOR_URL`: URL for the NeMo Evaluator Service (default: `http://0.0.0.0:7331`) - `INFERENCE_MODEL`: Inference model (default: `Llama3.1-8B-Instruct`) - `SAFETY_MODEL`: Name of the model to use for safety (default: `meta/llama-3.1-8b-instruct`) diff --git a/docs/source/providers/external.md b/docs/source/providers/external.md index 90fc77979..5aab5ee0f 100644 --- a/docs/source/providers/external.md +++ b/docs/source/providers/external.md @@ -50,9 +50,10 @@ Llama Stack supports two types of external providers: Here's a list of known external providers that you can use with Llama Stack: -| Type | Name | Description | Repository | -|------|------|-------------|------------| -| Remote | KubeFlow Training | Train models with KubeFlow | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | +| Name | Description | API | Type | Repository | +|------|-------------|-----|------|------------| +| KubeFlow Training | Train models with KubeFlow | Post Training | Remote | [llama-stack-provider-kft](https://github.com/opendatahub-io/llama-stack-provider-kft) | +| RamaLama | Inference models with RamaLama | Inference | Remote | [ramalama-stack](https://github.com/containers/ramalama-stack) | ### Remote Provider Specification diff --git a/docs/zero_to_hero_guide/00_Inference101.ipynb b/docs/zero_to_hero_guide/00_Inference101.ipynb index b3b781375..4f71f9f89 100644 --- a/docs/zero_to_hero_guide/00_Inference101.ipynb +++ b/docs/zero_to_hero_guide/00_Inference101.ipynb @@ -389,5 +389,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/01_Local_Cloud_Inference101.ipynb b/docs/zero_to_hero_guide/01_Local_Cloud_Inference101.ipynb index d66e1b4f5..19a7fe3be 100644 --- a/docs/zero_to_hero_guide/01_Local_Cloud_Inference101.ipynb +++ b/docs/zero_to_hero_guide/01_Local_Cloud_Inference101.ipynb @@ -256,5 +256,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/02_Prompt_Engineering101.ipynb b/docs/zero_to_hero_guide/02_Prompt_Engineering101.ipynb index 7fccf8c51..f3566eeb3 100644 --- a/docs/zero_to_hero_guide/02_Prompt_Engineering101.ipynb +++ b/docs/zero_to_hero_guide/02_Prompt_Engineering101.ipynb @@ -301,5 +301,7 @@ "pygments_lexer": "ipython3", "version": "3.12.2" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/03_Image_Chat101.ipynb b/docs/zero_to_hero_guide/03_Image_Chat101.ipynb index 58353e813..ae10d8808 100644 --- a/docs/zero_to_hero_guide/03_Image_Chat101.ipynb +++ b/docs/zero_to_hero_guide/03_Image_Chat101.ipynb @@ -200,5 +200,7 @@ "pygments_lexer": "ipython3", "version": "3.12.2" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb b/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb index c3a383e8c..de3754b21 100644 --- a/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb +++ b/docs/zero_to_hero_guide/04_Tool_Calling101.ipynb @@ -355,5 +355,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/05_Memory101.ipynb b/docs/zero_to_hero_guide/05_Memory101.ipynb index bfeb40adc..66956259f 100644 --- a/docs/zero_to_hero_guide/05_Memory101.ipynb +++ b/docs/zero_to_hero_guide/05_Memory101.ipynb @@ -398,5 +398,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/06_Safety101.ipynb b/docs/zero_to_hero_guide/06_Safety101.ipynb index c8c1fe9c7..5d7763924 100644 --- a/docs/zero_to_hero_guide/06_Safety101.ipynb +++ b/docs/zero_to_hero_guide/06_Safety101.ipynb @@ -132,5 +132,7 @@ "pygments_lexer": "ipython3", "version": "3.11.10" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/docs/zero_to_hero_guide/07_Agents101.ipynb b/docs/zero_to_hero_guide/07_Agents101.ipynb index 8c988e1e3..b6df2a4c8 100644 --- a/docs/zero_to_hero_guide/07_Agents101.ipynb +++ b/docs/zero_to_hero_guide/07_Agents101.ipynb @@ -188,5 +188,7 @@ "pygments_lexer": "ipython3", "version": "3.10.15" } - } + }, + "nbformat": 4, + "nbformat_minor": 5 } diff --git a/llama_stack/cli/stack/_build.py b/llama_stack/cli/stack/_build.py index 26c09af4e..2787a93d5 100644 --- a/llama_stack/cli/stack/_build.py +++ b/llama_stack/cli/stack/_build.py @@ -136,12 +136,13 @@ def run_stack_build_command(args: argparse.Namespace) -> None: ) image_type = prompt( - f"> Enter the image type you want your Llama Stack to be built as ({' or '.join(e.value for e in ImageType)}): ", + "> Enter the image type you want your Llama Stack to be built as (use to see options): ", + completer=WordCompleter([e.value for e in ImageType]), + complete_while_typing=True, validator=Validator.from_callable( lambda x: x in [e.value for e in ImageType], - error_message=f"Invalid image type, please enter {' or '.join(e.value for e in ImageType)}", + error_message="Invalid image type. Use to see options", ), - default=ImageType.CONDA.value, ) if image_type == ImageType.CONDA.value: @@ -317,11 +318,15 @@ def _generate_run_config( to_write = json.loads(run_config.model_dump_json()) f.write(yaml.dump(to_write, sort_keys=False)) - # this path is only invoked when no template is provided - cprint( - f"You can now run your stack with `llama stack run {run_config_file}`", - color="green", - ) + # Only print this message for non-container builds since it will be displayed before the + # container is built + # For non-container builds, the run.yaml is generated at the very end of the build process so it + # makes sense to display this message + if build_config.image_type != LlamaStackImageType.CONTAINER.value: + cprint( + f"You can now run your stack with `llama stack run {run_config_file}`", + color="green", + ) return run_config_file @@ -355,6 +360,13 @@ def _run_stack_build_command_from_build_config( build_file_path = build_dir / f"{image_name}-build.yaml" os.makedirs(build_dir, exist_ok=True) + run_config_file = None + # Generate the run.yaml so it can be included in the container image with the proper entrypoint + # Only do this if we're building a container image and we're not using a template + if build_config.image_type == LlamaStackImageType.CONTAINER.value and not template_name and config_path: + cprint("Generating run.yaml file", color="green") + run_config_file = _generate_run_config(build_config, build_dir, image_name) + with open(build_file_path, "w") as f: to_write = json.loads(build_config.model_dump_json()) f.write(yaml.dump(to_write, sort_keys=False)) @@ -364,6 +376,7 @@ def _run_stack_build_command_from_build_config( build_file_path, image_name, template_or_config=template_name or config_path or str(build_file_path), + run_config=run_config_file, ) if return_code != 0: raise RuntimeError(f"Failed to build image {image_name}") diff --git a/llama_stack/distribution/build.py b/llama_stack/distribution/build.py index 5b61ae081..9664449f3 100644 --- a/llama_stack/distribution/build.py +++ b/llama_stack/distribution/build.py @@ -93,6 +93,7 @@ def build_image( build_file_path: Path, image_name: str, template_or_config: str, + run_config: str | None = None, ): container_base = build_config.distribution_spec.container_image or "python:3.10-slim" @@ -108,6 +109,11 @@ def build_image( container_base, " ".join(normal_deps), ] + + # When building from a config file (not a template), include the run config path in the + # build arguments + if run_config is not None: + args.append(run_config) elif build_config.image_type == LlamaStackImageType.CONDA.value: script = str(importlib.resources.files("llama_stack") / "distribution/build_conda_env.sh") args = [ diff --git a/llama_stack/distribution/build_container.sh b/llama_stack/distribution/build_container.sh index fb4780432..ad316d45e 100755 --- a/llama_stack/distribution/build_container.sh +++ b/llama_stack/distribution/build_container.sh @@ -19,12 +19,16 @@ UV_HTTP_TIMEOUT=${UV_HTTP_TIMEOUT:-500} # mounting is not supported by docker buildx, so we use COPY instead USE_COPY_NOT_MOUNT=${USE_COPY_NOT_MOUNT:-} +# Path to the run.yaml file in the container +RUN_CONFIG_PATH=/app/run.yaml + +BUILD_CONTEXT_DIR=$(pwd) + if [ "$#" -lt 4 ]; then # This only works for templates - echo "Usage: $0 []" >&2 + echo "Usage: $0 [] []" >&2 exit 1 fi - set -euo pipefail template_or_config="$1" @@ -35,8 +39,27 @@ container_base="$1" shift pip_dependencies="$1" shift -special_pip_deps="${1:-}" +# Handle optional arguments +run_config="" +special_pip_deps="" + +# Check if there are more arguments +# The logics is becoming cumbersom, we should refactor it if we can do better +if [ $# -gt 0 ]; then + # Check if the argument ends with .yaml + if [[ "$1" == *.yaml ]]; then + run_config="$1" + shift + # If there's another argument after .yaml, it must be special_pip_deps + if [ $# -gt 0 ]; then + special_pip_deps="$1" + fi + else + # If it's not .yaml, it must be special_pip_deps + special_pip_deps="$1" + fi +fi # Define color codes RED='\033[0;31m' @@ -75,7 +98,7 @@ WORKDIR /app # We install the Python 3.11 dev headers and build tools so that any # C‑extension wheels (e.g. polyleven, faiss‑cpu) can compile successfully. -RUN dnf -y update && dnf install -y iputils net-tools wget \ +RUN dnf -y update && dnf install -y iputils git net-tools wget \ vim-minimal python3.11 python3.11-pip python3.11-wheel \ python3.11-setuptools python3.11-devel gcc make && \ ln -s /bin/pip3.11 /bin/pip && ln -s /bin/python3.11 /bin/python && dnf clean all @@ -119,6 +142,45 @@ EOF done fi +# Function to get Python command +get_python_cmd() { + if is_command_available python; then + echo "python" + elif is_command_available python3; then + echo "python3" + else + echo "Error: Neither python nor python3 is installed. Please install Python to continue." >&2 + exit 1 + fi +} + +if [ -n "$run_config" ]; then + # Copy the run config to the build context since it's an absolute path + cp "$run_config" "$BUILD_CONTEXT_DIR/run.yaml" + add_to_container << EOF +COPY run.yaml $RUN_CONFIG_PATH +EOF + + # Parse the run.yaml configuration to identify external provider directories + # If external providers are specified, copy their directory to the container + # and update the configuration to reference the new container path + python_cmd=$(get_python_cmd) + external_providers_dir=$($python_cmd -c "import yaml; config = yaml.safe_load(open('$run_config')); print(config.get('external_providers_dir') or '')") + if [ -n "$external_providers_dir" ]; then + echo "Copying external providers directory: $external_providers_dir" + add_to_container << EOF +COPY $external_providers_dir /app/providers.d +EOF + # Edit the run.yaml file to change the external_providers_dir to /app/providers.d + if [ "$(uname)" = "Darwin" ]; then + sed -i.bak -e 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml" + rm -f "$BUILD_CONTEXT_DIR/run.yaml.bak" + else + sed -i 's|external_providers_dir:.*|external_providers_dir: /app/providers.d|' "$BUILD_CONTEXT_DIR/run.yaml" + fi + fi +fi + stack_mount="/app/llama-stack-source" client_mount="/app/llama-stack-client-source" @@ -178,15 +240,16 @@ fi RUN pip uninstall -y uv EOF -# if template_or_config ends with .yaml, it is not a template and we should not use the --template flag -if [[ "$template_or_config" != *.yaml ]]; then +# If a run config is provided, we use the --config flag +if [[ -n "$run_config" ]]; then + add_to_container << EOF +ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--config", "$RUN_CONFIG_PATH"] +EOF +# If a template is provided (not a yaml file), we use the --template flag +elif [[ "$template_or_config" != *.yaml ]]; then add_to_container << EOF ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server", "--template", "$template_or_config"] EOF -else - add_to_container << EOF -ENTRYPOINT ["python", "-m", "llama_stack.distribution.server.server"] -EOF fi # Add other require item commands genearic to all containers @@ -258,9 +321,10 @@ $CONTAINER_BINARY build \ "${CLI_ARGS[@]}" \ -t "$image_tag" \ -f "$TEMP_DIR/Containerfile" \ - "." + "$BUILD_CONTEXT_DIR" # clean up tmp/configs +rm -f "$BUILD_CONTEXT_DIR/run.yaml" set +x echo "Success!" diff --git a/llama_stack/distribution/routers/routers.py b/llama_stack/distribution/routers/routers.py index 17aecdaf8..d88df00bd 100644 --- a/llama_stack/distribution/routers/routers.py +++ b/llama_stack/distribution/routers/routers.py @@ -8,6 +8,11 @@ import asyncio import time from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union +from openai.types.chat import ChatCompletionToolChoiceOptionParam as OpenAIChatCompletionToolChoiceOptionParam +from openai.types.chat import ChatCompletionToolParam as OpenAIChatCompletionToolParam +from pydantic import Field, TypeAdapter +from typing_extensions import Annotated + from llama_stack.apis.common.content_types import ( URL, InterleavedContent, @@ -526,7 +531,7 @@ class InferenceRouter(Inference): async def openai_chat_completion( self, model: str, - messages: List[OpenAIMessageParam], + messages: Annotated[List[OpenAIMessageParam], Field(..., min_length=1)], frequency_penalty: Optional[float] = None, function_call: Optional[Union[str, Dict[str, Any]]] = None, functions: Optional[List[Dict[str, Any]]] = None, @@ -558,6 +563,16 @@ class InferenceRouter(Inference): if model_obj.model_type == ModelType.embedding: raise ValueError(f"Model '{model}' is an embedding model and does not support chat completions") + # Use the OpenAI client for a bit of extra input validation without + # exposing the OpenAI client itself as part of our API surface + if tool_choice: + TypeAdapter(OpenAIChatCompletionToolChoiceOptionParam).validate_python(tool_choice) + if tools is None: + raise ValueError("'tool_choice' is only allowed when 'tools' is also provided") + if tools: + for tool in tools: + TypeAdapter(OpenAIChatCompletionToolParam).validate_python(tool) + params = dict( model=model_obj.identifier, messages=messages, diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 6c5e2506c..6e9941d1c 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -22,6 +22,7 @@ from fastapi import Body, FastAPI, HTTPException, Request from fastapi import Path as FastapiPath from fastapi.exceptions import RequestValidationError from fastapi.responses import JSONResponse, StreamingResponse +from openai import BadRequestError from pydantic import BaseModel, ValidationError from typing_extensions import Annotated @@ -110,6 +111,8 @@ def translate_exception(exc: Exception) -> Union[HTTPException, RequestValidatio ) elif isinstance(exc, ValueError): return HTTPException(status_code=400, detail=f"Invalid value: {str(exc)}") + elif isinstance(exc, BadRequestError): + return HTTPException(status_code=400, detail=str(exc)) elif isinstance(exc, PermissionError): return HTTPException(status_code=403, detail=f"Permission denied: {str(exc)}") elif isinstance(exc, TimeoutError): @@ -162,14 +165,17 @@ async def maybe_await(value): return value -async def sse_generator(event_gen): +async def sse_generator(event_gen_coroutine): + event_gen = None try: - async for item in await event_gen: + event_gen = await event_gen_coroutine + async for item in event_gen: yield create_sse_event(item) await asyncio.sleep(0.01) except asyncio.CancelledError: logger.info("Generator cancelled") - await event_gen.aclose() + if event_gen: + await event_gen.aclose() except Exception as e: logger.exception("Error in sse_generator") yield create_sse_event( @@ -455,6 +461,7 @@ def main(args: Optional[argparse.Namespace] = None): raise ValueError(f"Could not find method {endpoint.name} on {impl}!!") impl_method = getattr(impl, endpoint.name) + logger.debug(f"{endpoint.method.upper()} {endpoint.route}") with warnings.catch_warnings(): warnings.filterwarnings("ignore", category=UserWarning, module="pydantic._internal._fields") diff --git a/llama_stack/distribution/ui/page/playground/rag.py b/llama_stack/distribution/ui/page/playground/rag.py index 392c9afe2..696d89bc2 100644 --- a/llama_stack/distribution/ui/page/playground/rag.py +++ b/llama_stack/distribution/ui/page/playground/rag.py @@ -24,6 +24,13 @@ def rag_chat_page(): def should_disable_input(): return "displayed_messages" in st.session_state and len(st.session_state.displayed_messages) > 0 + def log_message(message): + with st.chat_message(message["role"]): + if "tool_output" in message and message["tool_output"]: + with st.expander(label="Tool Output", expanded=False, icon="🛠"): + st.write(message["tool_output"]) + st.markdown(message["content"]) + with st.sidebar: # File/Directory Upload Section st.subheader("Upload Documents", divider=True) @@ -146,8 +153,7 @@ def rag_chat_page(): # Display chat history for message in st.session_state.displayed_messages: - with st.chat_message(message["role"]): - st.markdown(message["content"]) + log_message(message) if temperature > 0.0: strategy = { @@ -201,7 +207,7 @@ def rag_chat_page(): # Display assistant response with st.chat_message("assistant"): - retrieval_message_placeholder = st.empty() + retrieval_message_placeholder = st.expander(label="Tool Output", expanded=False, icon="🛠") message_placeholder = st.empty() full_response = "" retrieval_response = "" @@ -209,14 +215,16 @@ def rag_chat_page(): log.print() if log.role == "tool_execution": retrieval_response += log.content.replace("====", "").strip() - retrieval_message_placeholder.info(retrieval_response) + retrieval_message_placeholder.write(retrieval_response) else: full_response += log.content message_placeholder.markdown(full_response + "▌") message_placeholder.markdown(full_response) st.session_state.messages.append({"role": "assistant", "content": full_response}) - st.session_state.displayed_messages.append({"role": "assistant", "content": full_response}) + st.session_state.displayed_messages.append( + {"role": "assistant", "content": full_response, "tool_output": retrieval_response} + ) def direct_process_prompt(prompt): # Add the system prompt in the beginning of the conversation @@ -230,15 +238,14 @@ def rag_chat_page(): prompt_context = rag_response.content with st.chat_message("assistant"): + with st.expander(label="Retrieval Output", expanded=False): + st.write(prompt_context) + retrieval_message_placeholder = st.empty() message_placeholder = st.empty() full_response = "" retrieval_response = "" - # Display the retrieved content - retrieval_response += str(prompt_context) - retrieval_message_placeholder.info(retrieval_response) - # Construct the extended prompt extended_prompt = f"Please answer the following query using the context below.\n\nCONTEXT:\n{prompt_context}\n\nQUERY:\n{prompt}" diff --git a/llama_stack/distribution/ui/page/playground/tools.py b/llama_stack/distribution/ui/page/playground/tools.py index c5bb2216a..6c6a9fcfd 100644 --- a/llama_stack/distribution/ui/page/playground/tools.py +++ b/llama_stack/distribution/ui/page/playground/tools.py @@ -4,14 +4,23 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import enum +import json import uuid import streamlit as st from llama_stack_client import Agent +from llama_stack_client.lib.agents.react.agent import ReActAgent +from llama_stack_client.lib.agents.react.tool_parser import ReActOutput from llama_stack.distribution.ui.modules.api import llama_stack_api +class AgentType(enum.Enum): + REGULAR = "Regular" + REACT = "ReAct" + + def tool_chat_page(): st.title("🛠 Tools") @@ -23,6 +32,7 @@ def tool_chat_page(): tool_groups_list = [tool_group.identifier for tool_group in tool_groups] mcp_tools_list = [tool for tool in tool_groups_list if tool.startswith("mcp::")] builtin_tools_list = [tool for tool in tool_groups_list if not tool.startswith("mcp::")] + selected_vector_dbs = [] def reset_agent(): st.session_state.clear() @@ -66,25 +76,36 @@ def tool_chat_page(): toolgroup_selection.extend(mcp_selection) - active_tool_list = [] - for toolgroup_id in toolgroup_selection: - active_tool_list.extend( - [ - f"{''.join(toolgroup_id.split('::')[1:])}:{t.identifier}" - for t in client.tools.list(toolgroup_id=toolgroup_id) - ] - ) + grouped_tools = {} + total_tools = 0 - st.markdown(f"Active Tools: 🛠 {len(active_tool_list)}", help="List of currently active tools.") - st.json(active_tool_list) + for toolgroup_id in toolgroup_selection: + tools = client.tools.list(toolgroup_id=toolgroup_id) + grouped_tools[toolgroup_id] = [tool.identifier for tool in tools] + total_tools += len(tools) + + st.markdown(f"Active Tools: 🛠 {total_tools}") + + for group_id, tools in grouped_tools.items(): + with st.expander(f"🔧 Tools from `{group_id}`"): + for idx, tool in enumerate(tools, start=1): + st.markdown(f"{idx}. `{tool.split(':')[-1]}`") st.subheader("Agent Configurations") + st.subheader("Agent Type") + agent_type = st.radio( + "Select Agent Type", + [AgentType.REGULAR, AgentType.REACT], + format_func=lambda x: x.value, + on_change=reset_agent, + ) + max_tokens = st.slider( "Max Tokens", min_value=0, max_value=4096, value=512, - step=1, + step=64, help="The maximum number of tokens to generate", on_change=reset_agent, ) @@ -101,13 +122,27 @@ def tool_chat_page(): @st.cache_resource def create_agent(): - return Agent( - client, - model=model, - instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.", - tools=toolgroup_selection, - sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens}, - ) + if "agent_type" in st.session_state and st.session_state.agent_type == AgentType.REACT: + return ReActAgent( + client=client, + model=model, + tools=toolgroup_selection, + response_format={ + "type": "json_schema", + "json_schema": ReActOutput.model_json_schema(), + }, + sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens}, + ) + else: + return Agent( + client, + model=model, + instructions="You are a helpful assistant. When you use a tool always respond with a summary of the result.", + tools=toolgroup_selection, + sampling_params={"strategy": {"type": "greedy"}, "max_tokens": max_tokens}, + ) + + st.session_state.agent_type = agent_type agent = create_agent() @@ -136,6 +171,158 @@ def tool_chat_page(): ) def response_generator(turn_response): + if st.session_state.get("agent_type") == AgentType.REACT: + return _handle_react_response(turn_response) + else: + return _handle_regular_response(turn_response) + + def _handle_react_response(turn_response): + current_step_content = "" + final_answer = None + tool_results = [] + + for response in turn_response: + if not hasattr(response.event, "payload"): + yield ( + "\n\n🚨 :red[_Llama Stack server Error:_]\n" + "The response received is missing an expected `payload` attribute.\n" + "This could indicate a malformed response or an internal issue within the server.\n\n" + f"Error details: {response}" + ) + return + + payload = response.event.payload + + if payload.event_type == "step_progress" and hasattr(payload.delta, "text"): + current_step_content += payload.delta.text + continue + + if payload.event_type == "step_complete": + step_details = payload.step_details + + if step_details.step_type == "inference": + yield from _process_inference_step(current_step_content, tool_results, final_answer) + current_step_content = "" + elif step_details.step_type == "tool_execution": + tool_results = _process_tool_execution(step_details, tool_results) + current_step_content = "" + else: + current_step_content = "" + + if not final_answer and tool_results: + yield from _format_tool_results_summary(tool_results) + + def _process_inference_step(current_step_content, tool_results, final_answer): + try: + react_output_data = json.loads(current_step_content) + thought = react_output_data.get("thought") + action = react_output_data.get("action") + answer = react_output_data.get("answer") + + if answer and answer != "null" and answer is not None: + final_answer = answer + + if thought: + with st.expander("🤔 Thinking...", expanded=False): + st.markdown(f":grey[__{thought}__]") + + if action and isinstance(action, dict): + tool_name = action.get("tool_name") + tool_params = action.get("tool_params") + with st.expander(f'🛠 Action: Using tool "{tool_name}"', expanded=False): + st.json(tool_params) + + if answer and answer != "null" and answer is not None: + yield f"\n\n✅ **Final Answer:**\n{answer}" + + except json.JSONDecodeError: + yield f"\n\nFailed to parse ReAct step content:\n```json\n{current_step_content}\n```" + except Exception as e: + yield f"\n\nFailed to process ReAct step: {e}\n```json\n{current_step_content}\n```" + + return final_answer + + def _process_tool_execution(step_details, tool_results): + try: + if hasattr(step_details, "tool_responses") and step_details.tool_responses: + for tool_response in step_details.tool_responses: + tool_name = tool_response.tool_name + content = tool_response.content + tool_results.append((tool_name, content)) + with st.expander(f'⚙️ Observation (Result from "{tool_name}")', expanded=False): + try: + parsed_content = json.loads(content) + st.json(parsed_content) + except json.JSONDecodeError: + st.code(content, language=None) + else: + with st.expander("⚙️ Observation", expanded=False): + st.markdown(":grey[_Tool execution step completed, but no response data found._]") + except Exception as e: + with st.expander("⚙️ Error in Tool Execution", expanded=False): + st.markdown(f":red[_Error processing tool execution: {str(e)}_]") + + return tool_results + + def _format_tool_results_summary(tool_results): + yield "\n\n**Here's what I found:**\n" + for tool_name, content in tool_results: + try: + parsed_content = json.loads(content) + + if tool_name == "web_search" and "top_k" in parsed_content: + yield from _format_web_search_results(parsed_content) + elif "results" in parsed_content and isinstance(parsed_content["results"], list): + yield from _format_results_list(parsed_content["results"]) + elif isinstance(parsed_content, dict) and len(parsed_content) > 0: + yield from _format_dict_results(parsed_content) + elif isinstance(parsed_content, list) and len(parsed_content) > 0: + yield from _format_list_results(parsed_content) + except json.JSONDecodeError: + yield f"\n**{tool_name}** was used but returned complex data. Check the observation for details.\n" + except (TypeError, AttributeError, KeyError, IndexError) as e: + print(f"Error processing {tool_name} result: {type(e).__name__}: {e}") + + def _format_web_search_results(parsed_content): + for i, result in enumerate(parsed_content["top_k"], 1): + if i <= 3: + title = result.get("title", "Untitled") + url = result.get("url", "") + content_text = result.get("content", "").strip() + yield f"\n- **{title}**\n {content_text}\n [Source]({url})\n" + + def _format_results_list(results): + for i, result in enumerate(results, 1): + if i <= 3: + if isinstance(result, dict): + name = result.get("name", result.get("title", "Result " + str(i))) + description = result.get("description", result.get("content", result.get("summary", ""))) + yield f"\n- **{name}**\n {description}\n" + else: + yield f"\n- {result}\n" + + def _format_dict_results(parsed_content): + yield "\n```\n" + for key, value in list(parsed_content.items())[:5]: + if isinstance(value, str) and len(value) < 100: + yield f"{key}: {value}\n" + else: + yield f"{key}: [Complex data]\n" + yield "```\n" + + def _format_list_results(parsed_content): + yield "\n" + for _, item in enumerate(parsed_content[:3], 1): + if isinstance(item, str): + yield f"- {item}\n" + elif isinstance(item, dict) and "text" in item: + yield f"- {item['text']}\n" + elif isinstance(item, dict) and len(item) > 0: + first_value = next(iter(item.values())) + if isinstance(first_value, str) and len(first_value) < 100: + yield f"- {first_value}\n" + + def _handle_regular_response(turn_response): for response in turn_response: if hasattr(response.event, "payload"): print(response.event.payload) @@ -144,14 +331,18 @@ def tool_chat_page(): yield response.event.payload.delta.text if response.event.payload.event_type == "step_complete": if response.event.payload.step_details.step_type == "tool_execution": - yield " 🛠 " + if response.event.payload.step_details.tool_calls: + tool_name = str(response.event.payload.step_details.tool_calls[0].tool_name) + yield f'\n\n🛠 :grey[_Using "{tool_name}" tool:_]\n\n' + else: + yield "No tool_calls present in step_details" else: yield f"Error occurred in the Llama Stack Cluster: {response}" with st.chat_message("assistant"): - response = st.write_stream(response_generator(turn_response)) + response_content = st.write_stream(response_generator(turn_response)) - st.session_state.messages.append({"role": "assistant", "content": response}) + st.session_state.messages.append({"role": "assistant", "content": response_content}) tool_chat_page() diff --git a/llama_stack/models/llama/llama4/chat_format.py b/llama_stack/models/llama/llama4/chat_format.py index 1debadcc5..1574eeb5e 100644 --- a/llama_stack/models/llama/llama4/chat_format.py +++ b/llama_stack/models/llama/llama4/chat_format.py @@ -303,6 +303,7 @@ class ChatFormat: arguments_json=json.dumps(tool_arguments), ) ) + content = "" return RawMessage( role="assistant", diff --git a/llama_stack/models/llama/llama4/prompt_format.md b/llama_stack/models/llama/llama4/prompt_format.md index 698571093..350a5517a 100644 --- a/llama_stack/models/llama/llama4/prompt_format.md +++ b/llama_stack/models/llama/llama4/prompt_format.md @@ -64,7 +64,7 @@ This example passes an image that is smaller than the tile size, to show the til ##### Model Response Format ``` -The image depicts a dog standing on a skateboard, with its front paws positioned on the board and its back paws hanging off the back. The dog has a distinctive coat pattern, featuring a white face, brown and black fur, and white paws, and is standing on a skateboard with red wheels, set against a blurred background of a street or alleyway with a teal door and beige wall.<|eot|> +The image depicts a dog standing on a skateboard, positioned centrally and facing the camera directly. The dog has a distinctive coat pattern featuring white, black, and brown fur, with floppy ears and a black nose, and is standing on a skateboard with red wheels.<|eot|> ``` @@ -91,7 +91,7 @@ Here is an example of how to pass an image to the model ##### Model Response Format ``` -This image shows a dog standing on a skateboard, with its front paws positioned near the front of the board and its back paws near the back. The dog has a white, black, and orange coat, and is standing on a gray skateboard with red wheels, in front of a blurred background that appears to be a street or alleyway.<|eot|> +The image depicts a dog standing on a skateboard, with the dog positioned centrally and facing forward. The dog has a distinctive coat featuring a mix of white, brown, and black fur, and is wearing a collar as it stands on the skateboard, which has red wheels.<|eot|> ``` @@ -117,7 +117,7 @@ Here is an example of how to pass an image to the model ##### Model Response Format ``` -The first image shows a dog standing on a skateboard, while the second image shows a plate of spaghetti with tomato sauce, parmesan cheese, and parsley. The two images are unrelated, with the first image featuring a dog and the second image featuring a food dish, and they do not share any common elements or themes.<|eot|> +The first image features a dog standing on a skateboard, while the second image showcases a plate of spaghetti with tomato sauce and cheese. The two images appear to be unrelated, with one depicting a playful scene of a dog on a skateboard and the other presenting a classic Italian dish.<|eom|> ``` @@ -135,13 +135,44 @@ We are continuing the format for zero shot function calling used in previous ver ``` <|begin_of_text|><|header_start|>system<|header_end|> -You are an expert in composing functions. You are given a question and a set of possible functions. -Based on the question, you will need to make one or more function/tool calls to achieve the purpose. -If none of the function can be used, point it out. If the given question lacks the parameters required by the function, -also point it out. You should only return the function call in tools call sections. +You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines: + +1. FUNCTION CALLS: +- ONLY use functions that are EXPLICITLY listed in the function list below +- If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" +- If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" +- If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s) +- Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)] +Examples: +CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list +INCORRECT: get_weather(location="New York") +INCORRECT: Let me check the weather: [get_weather(location="New York")] +INCORRECT: [get_events(location="Singapore")] <- If function not in list + +2. RESPONSE RULES: +- For pure function requests matching a listed function: ONLY output the function call(s) +- For knowledge questions: ONLY output text +- For missing parameters: ONLY request the specific missing parameters +- For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call. +- If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations +- NEVER combine text and function calls in the same response +- NEVER suggest alternative functions when the requested service is unavailable +- NEVER create or invent new functions not listed below + +3. STRICT BOUNDARIES: +- ONLY use functions from the list below - no exceptions +- NEVER use a function as an alternative to unavailable information +- NEVER call functions not present in the function list +- NEVER add explanatory text to function calls +- NEVER respond with empty brackets +- Use proper Python/JSON syntax for function calls +- Check the function list carefully before responding + +4. TOOL RESPONSE HANDLING: +- When receiving tool responses: provide concise, natural language responses +- Don't repeat tool response verbatim +- Don't add supplementary information -If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] -You SHOULD NOT include any other text in the response. Here is a list of functions in JSON format that you can invoke. @@ -151,9 +182,7 @@ Here is a list of functions in JSON format that you can invoke. "description": "Get weather info for places", "parameters": { "type": "dict", - "required": [ - "city" - ], + "required": ["city"], "properties": { "city": { "type": "string", @@ -167,7 +196,10 @@ Here is a list of functions in JSON format that you can invoke. } } } -<|eot|><|header_start|>user<|header_end|> +] + +You can answer general questions or invoke tools when necessary. +In addition to tool calls, you should also augment your responses by using the tool outputs.<|eot|><|header_start|>user<|header_end|> What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_end|> @@ -176,7 +208,7 @@ What is the weather in SF and Seattle?<|eot|><|header_start|>assistant<|header_e ##### Model Response Format ``` -[get_weather(city='SF'), get_weather(city='Seattle')]<|eot|> +[get_weather(city="San Francisco"), get_weather(city="Seattle")]<|eot|> ``` @@ -273,5 +305,5 @@ Use tools to get latest trending songs<|eot|><|header_start|>assistant<|header_e ##### Model Response Format ``` -{"n": "10"}<|eot|> +{"n": 10}<|eot|> ``` diff --git a/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py new file mode 100644 index 000000000..139e204ad --- /dev/null +++ b/llama_stack/models/llama/llama4/prompt_templates/system_prompts.py @@ -0,0 +1,144 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# top-level folder for each specific model found within the models/ directory at +# the top-level of this source tree. + +import textwrap +from typing import List, Optional + +from llama_stack.apis.inference import ToolDefinition, ToolParamDefinition +from llama_stack.models.llama.llama3.prompt_templates.base import ( + PromptTemplate, + PromptTemplateGeneratorBase, +) + + +class PythonListCustomToolGenerator(PromptTemplateGeneratorBase): # noqa: N801 + DEFAULT_PROMPT = textwrap.dedent( + """ + You are a helpful assistant and an expert in function composition. You can answer general questions using your internal knowledge OR invoke functions when necessary. Follow these strict guidelines: + + 1. FUNCTION CALLS: + - ONLY use functions that are EXPLICITLY listed in the function list below + - If NO functions are listed (empty function list []), respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" + - If a function is not in the list, respond ONLY with internal knowledge or "I don't have access to [Unavailable service] information" + - If ALL required parameters are present AND the query EXACTLY matches a listed function's purpose: output ONLY the function call(s) + - Use exact format: [func_name1(param1=value1, param2=value2), func_name2(...)] + Examples: + CORRECT: [get_weather(location="Vancouver"), calculate_route(start="Boston", end="New York")] <- Only if get_weather and calculate_route are in function list + INCORRECT: get_weather(location="New York") + INCORRECT: Let me check the weather: [get_weather(location="New York")] + INCORRECT: [get_events(location="Singapore")] <- If function not in list + + 2. RESPONSE RULES: + - For pure function requests matching a listed function: ONLY output the function call(s) + - For knowledge questions: ONLY output text + - For missing parameters: ONLY request the specific missing parameters + - For unavailable services (not in function list): output ONLY with internal knowledge or "I don't have access to [Unavailable service] information". Do NOT execute a function call. + - If the query asks for information beyond what a listed function provides: output ONLY with internal knowledge about your limitations + - NEVER combine text and function calls in the same response + - NEVER suggest alternative functions when the requested service is unavailable + - NEVER create or invent new functions not listed below + + 3. STRICT BOUNDARIES: + - ONLY use functions from the list below - no exceptions + - NEVER use a function as an alternative to unavailable information + - NEVER call functions not present in the function list + - NEVER add explanatory text to function calls + - NEVER respond with empty brackets + - Use proper Python/JSON syntax for function calls + - Check the function list carefully before responding + + 4. TOOL RESPONSE HANDLING: + - When receiving tool responses: provide concise, natural language responses + - Don't repeat tool response verbatim + - Don't add supplementary information + + + {{ function_description }} + """.strip("\n") + ) + + def gen(self, custom_tools: List[ToolDefinition], system_prompt: Optional[str] = None) -> PromptTemplate: + system_prompt = system_prompt or self.DEFAULT_PROMPT + return PromptTemplate( + system_prompt, + {"function_description": self._gen_function_description(custom_tools)}, + ) + + def _gen_function_description(self, custom_tools: List[ToolDefinition]) -> PromptTemplate: + template_str = textwrap.dedent( + """ + Here is a list of functions in JSON format that you can invoke. + + [ + {% for t in tools -%} + {# manually setting up JSON because jinja sorts keys in unexpected ways -#} + {%- set tname = t.tool_name -%} + {%- set tdesc = t.description -%} + {%- set tparams = t.parameters -%} + {%- set required_params = [] -%} + {%- for name, param in tparams.items() if param.required == true -%} + {%- set _ = required_params.append(name) -%} + {%- endfor -%} + { + "name": "{{tname}}", + "description": "{{tdesc}}", + "parameters": { + "type": "dict", + "required": {{ required_params | tojson }}, + "properties": { + {%- for name, param in tparams.items() %} + "{{name}}": { + "type": "{{param.param_type}}", + "description": "{{param.description}}"{% if param.default %}, + "default": "{{param.default}}"{% endif %} + }{% if not loop.last %},{% endif %} + {%- endfor %} + } + } + }{% if not loop.last %}, + {% endif -%} + {%- endfor %} + ] + + You can answer general questions or invoke tools when necessary. + In addition to tool calls, you should also augment your responses by using the tool outputs. + + """ + ) + return PromptTemplate( + template_str.strip("\n"), + {"tools": [t.model_dump() for t in custom_tools]}, + ).render() + + def data_examples(self) -> List[List[ToolDefinition]]: + return [ + [ + ToolDefinition( + tool_name="get_weather", + description="Get weather info for places", + parameters={ + "city": ToolParamDefinition( + param_type="string", + description="The name of the city to get the weather for", + required=True, + ), + "metric": ToolParamDefinition( + param_type="string", + description="The metric for weather. Options are: celsius, fahrenheit", + required=False, + default="celsius", + ), + }, + ), + ] + ] diff --git a/llama_stack/models/llama/llama4/prompts.py b/llama_stack/models/llama/llama4/prompts.py index 13b96359a..fe9a59130 100644 --- a/llama_stack/models/llama/llama4/prompts.py +++ b/llama_stack/models/llama/llama4/prompts.py @@ -9,6 +9,10 @@ from io import BytesIO from pathlib import Path from typing import List +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator, +) + from ..datatypes import RawMediaItem, RawMessage, RawTextItem from ..prompt_format import ( Llama4UseCase, @@ -177,39 +181,9 @@ def usecases(base_model: bool = False) -> List[UseCase | str]: [ RawMessage( role="system", - content="""You are an expert in composing functions. You are given a question and a set of possible functions. -Based on the question, you will need to make one or more function/tool calls to achieve the purpose. -If none of the function can be used, point it out. If the given question lacks the parameters required by the function, -also point it out. You should only return the function call in tools call sections. - -If you decide to invoke any of the function(s), you MUST put it in the format of [func_name1(params_name1=params_value1, params_name2=params_value2...), func_name2(params)] -You SHOULD NOT include any other text in the response. - -Here is a list of functions in JSON format that you can invoke. - -[ - { - "name": "get_weather", - "description": "Get weather info for places", - "parameters": { - "type": "dict", - "required": [ - "city" - ], - "properties": { - "city": { - "type": "string", - "description": "The name of the city to get the weather for" - }, - "metric": { - "type": "string", - "description": "The metric for weather. Options are: celsius, fahrenheit", - "default": "celsius" - } - } - } - } -""", + content=PythonListCustomToolGenerator() + .gen(PythonListCustomToolGenerator().data_examples()[0]) + .render(), ), RawMessage( role="user", diff --git a/llama_stack/providers/inline/inference/meta_reference/inference.py b/llama_stack/providers/inline/inference/meta_reference/inference.py index 0e69c2e7e..1bc098fab 100644 --- a/llama_stack/providers/inline/inference/meta_reference/inference.py +++ b/llama_stack/providers/inline/inference/meta_reference/inference.py @@ -253,7 +253,8 @@ class MetaReferenceInferenceImpl( def impl(): stop_reason = None - for token_result in self.generator.completion(request): + for token_results in self.generator.completion([request]): + token_result = token_results[0] if token_result.token == tokenizer.eot_id: stop_reason = StopReason.end_of_turn text = "" diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 8752f06f3..8c0ffc632 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -69,7 +69,10 @@ class CancelSentinel(BaseModel): class TaskRequest(BaseModel): type: Literal[ProcessingMessageName.task_request] = ProcessingMessageName.task_request - task: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]] + task: Tuple[ + str, + List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + ] class TaskResponse(BaseModel): @@ -231,10 +234,10 @@ def worker_process_entrypoint( while True: try: task = req_gen.send(result) - if isinstance(task, str) and task == EndSentinel(): + if isinstance(task, EndSentinel): break - assert isinstance(task, TaskRequest) + assert isinstance(task, TaskRequest), task result = model(task.task) except StopIteration: break @@ -331,7 +334,10 @@ class ModelParallelProcessGroup: def run_inference( self, - req: Tuple[str, List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent]], + req: Tuple[ + str, + List[CompletionRequestWithRawContent] | List[ChatCompletionRequestWithRawContent], + ], ) -> Generator: assert not self.running, "inference already running" diff --git a/llama_stack/providers/inline/tool_runtime/rag/memory.py b/llama_stack/providers/inline/tool_runtime/rag/memory.py index 97c53d454..8d4689e5d 100644 --- a/llama_stack/providers/inline/tool_runtime/rag/memory.py +++ b/llama_stack/providers/inline/tool_runtime/rag/memory.py @@ -33,6 +33,7 @@ from llama_stack.apis.tools import ( ) from llama_stack.apis.vector_io import QueryChunksResponse, VectorIO from llama_stack.providers.datatypes import ToolsProtocolPrivate +from llama_stack.providers.utils.inference.prompt_adapter import interleaved_content_as_str from llama_stack.providers.utils.memory.vector_store import ( content_from_doc, make_overlapped_chunks, @@ -153,6 +154,11 @@ class MemoryToolRuntimeImpl(ToolsProtocolPrivate, ToolRuntime, RAGToolRuntime): ) ) picked.append(TextContentItem(text="END of knowledge_search tool results.\n")) + picked.append( + TextContentItem( + text=f'The above results were retrieved to help answer the user\'s query: "{interleaved_content_as_str(content)}". Use them as supporting information only in answering this query.\n', + ) + ) return RAGQueryResult( content=picked, diff --git a/llama_stack/providers/registry/eval.py b/llama_stack/providers/registry/eval.py index f3e42c531..9604d5da4 100644 --- a/llama_stack/providers/registry/eval.py +++ b/llama_stack/providers/registry/eval.py @@ -6,7 +6,7 @@ from typing import List -from llama_stack.providers.datatypes import Api, InlineProviderSpec, ProviderSpec +from llama_stack.providers.datatypes import AdapterSpec, Api, InlineProviderSpec, ProviderSpec, remote_provider_spec def available_providers() -> List[ProviderSpec]: @@ -25,4 +25,22 @@ def available_providers() -> List[ProviderSpec]: Api.agents, ], ), + remote_provider_spec( + api=Api.eval, + adapter=AdapterSpec( + adapter_type="nvidia", + pip_packages=[ + "requests", + ], + module="llama_stack.providers.remote.eval.nvidia", + config_class="llama_stack.providers.remote.eval.nvidia.NVIDIAEvalConfig", + ), + api_dependencies=[ + Api.datasetio, + Api.datasets, + Api.scoring, + Api.inference, + Api.agents, + ], + ), ] diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 3c54cabcf..4040f0d80 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -288,4 +288,14 @@ def available_providers() -> List[ProviderSpec]: provider_data_validator="llama_stack.providers.remote.inference.passthrough.PassthroughProviderDataValidator", ), ), + remote_provider_spec( + api=Api.inference, + adapter=AdapterSpec( + adapter_type="watsonx", + pip_packages=["ibm_watson_machine_learning"], + module="llama_stack.providers.remote.inference.watsonx", + config_class="llama_stack.providers.remote.inference.watsonx.WatsonXConfig", + provider_data_validator="llama_stack.providers.remote.inference.watsonx.WatsonXProviderDataValidator", + ), + ), ] diff --git a/llama_stack/providers/remote/eval/__init__.py b/llama_stack/providers/remote/eval/__init__.py new file mode 100644 index 000000000..756f351d8 --- /dev/null +++ b/llama_stack/providers/remote/eval/__init__.py @@ -0,0 +1,5 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. diff --git a/llama_stack/providers/remote/eval/nvidia/README.md b/llama_stack/providers/remote/eval/nvidia/README.md new file mode 100644 index 000000000..cebc77920 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/README.md @@ -0,0 +1,134 @@ +# NVIDIA NeMo Evaluator Eval Provider + + +## Overview + +For the first integration, Benchmarks are mapped to Evaluation Configs on in the NeMo Evaluator. The full evaluation config object is provided as part of the meta-data. The `dataset_id` and `scoring_functions` are not used. + +Below are a few examples of how to register a benchmark, which in turn will create an evaluation config in NeMo Evaluator and how to trigger an evaluation. + +### Example for register an academic benchmark + +``` +POST /eval/benchmarks +``` +```json +{ + "benchmark_id": "mmlu", + "dataset_id": "", + "scoring_functions": [], + "metadata": { + "type": "mmlu" + } +} +``` + +### Example for register a custom evaluation + +``` +POST /eval/benchmarks +``` +```json +{ + "benchmark_id": "my-custom-benchmark", + "dataset_id": "", + "scoring_functions": [], + "metadata": { + "type": "custom", + "params": { + "parallelism": 8 + }, + "tasks": { + "qa": { + "type": "completion", + "params": { + "template": { + "prompt": "{{prompt}}", + "max_tokens": 200 + } + }, + "dataset": { + "files_url": "hf://datasets/default/sample-basic-test/testing/testing.jsonl" + }, + "metrics": { + "bleu": { + "type": "bleu", + "params": { + "references": [ + "{{ideal_response}}" + ] + } + } + } + } + } + } +} +``` + +### Example for triggering a benchmark/custom evaluation + +``` +POST /eval/benchmarks/{benchmark_id}/jobs +``` +```json +{ + "benchmark_id": "my-custom-benchmark", + "benchmark_config": { + "eval_candidate": { + "type": "model", + "model": "meta-llama/Llama3.1-8B-Instruct", + "sampling_params": { + "max_tokens": 100, + "temperature": 0.7 + } + }, + "scoring_params": {} + } +} +``` + +Response example: +```json +{ + "job_id": "eval-1234", + "status": "in_progress" +} +``` + +### Example for getting the status of a job +``` +GET /eval/benchmarks/{benchmark_id}/jobs/{job_id} +``` + +Response example: +```json +{ + "job_id": "eval-1234", + "status": "in_progress" +} +``` + +### Example for cancelling a job +``` +POST /eval/benchmarks/{benchmark_id}/jobs/{job_id}/cancel +``` + +### Example for getting the results +``` +GET /eval/benchmarks/{benchmark_id}/results +``` +```json +{ + "generations": [], + "scores": { + "{benchmark_id}": { + "score_rows": [], + "aggregated_results": { + "tasks": {}, + "groups": {} + } + } + } +} +``` diff --git a/llama_stack/providers/remote/eval/nvidia/__init__.py b/llama_stack/providers/remote/eval/nvidia/__init__.py new file mode 100644 index 000000000..8abbec9b2 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/__init__.py @@ -0,0 +1,31 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict + +from llama_stack.distribution.datatypes import Api + +from .config import NVIDIAEvalConfig + + +async def get_adapter_impl( + config: NVIDIAEvalConfig, + deps: Dict[Api, Any], +): + from .eval import NVIDIAEvalImpl + + impl = NVIDIAEvalImpl( + config, + deps[Api.datasetio], + deps[Api.datasets], + deps[Api.scoring], + deps[Api.inference], + deps[Api.agents], + ) + await impl.initialize() + return impl + + +__all__ = ["get_adapter_impl", "NVIDIAEvalImpl"] diff --git a/llama_stack/providers/remote/eval/nvidia/config.py b/llama_stack/providers/remote/eval/nvidia/config.py new file mode 100644 index 000000000..b660fcd68 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/config.py @@ -0,0 +1,29 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +import os +from typing import Any, Dict + +from pydantic import BaseModel, Field + + +class NVIDIAEvalConfig(BaseModel): + """ + Configuration for the NVIDIA NeMo Evaluator microservice endpoint. + + Attributes: + evaluator_url (str): A base url for accessing the NVIDIA evaluation endpoint, e.g. http://localhost:8000. + """ + + evaluator_url: str = Field( + default_factory=lambda: os.getenv("NVIDIA_EVALUATOR_URL", "http://0.0.0.0:7331"), + description="The url for accessing the evaluator service", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "evaluator_url": "${env.NVIDIA_EVALUATOR_URL:http://localhost:7331}", + } diff --git a/llama_stack/providers/remote/eval/nvidia/eval.py b/llama_stack/providers/remote/eval/nvidia/eval.py new file mode 100644 index 000000000..e1a3b5355 --- /dev/null +++ b/llama_stack/providers/remote/eval/nvidia/eval.py @@ -0,0 +1,154 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. +from typing import Any, Dict, List + +import requests + +from llama_stack.apis.agents import Agents +from llama_stack.apis.benchmarks import Benchmark +from llama_stack.apis.datasetio import DatasetIO +from llama_stack.apis.datasets import Datasets +from llama_stack.apis.inference import Inference +from llama_stack.apis.scoring import Scoring, ScoringResult +from llama_stack.providers.datatypes import BenchmarksProtocolPrivate +from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper + +from .....apis.common.job_types import Job, JobStatus +from .....apis.eval.eval import BenchmarkConfig, Eval, EvaluateResponse +from .config import NVIDIAEvalConfig + +DEFAULT_NAMESPACE = "nvidia" + + +class NVIDIAEvalImpl( + Eval, + BenchmarksProtocolPrivate, + ModelRegistryHelper, +): + def __init__( + self, + config: NVIDIAEvalConfig, + datasetio_api: DatasetIO, + datasets_api: Datasets, + scoring_api: Scoring, + inference_api: Inference, + agents_api: Agents, + ) -> None: + self.config = config + self.datasetio_api = datasetio_api + self.datasets_api = datasets_api + self.scoring_api = scoring_api + self.inference_api = inference_api + self.agents_api = agents_api + + ModelRegistryHelper.__init__(self, model_entries=MODEL_ENTRIES) + + async def initialize(self) -> None: ... + + async def shutdown(self) -> None: ... + + async def _evaluator_get(self, path): + """Helper for making GET requests to the evaluator service.""" + response = requests.get(url=f"{self.config.evaluator_url}{path}") + response.raise_for_status() + return response.json() + + async def _evaluator_post(self, path, data): + """Helper for making POST requests to the evaluator service.""" + response = requests.post(url=f"{self.config.evaluator_url}{path}", json=data) + response.raise_for_status() + return response.json() + + async def register_benchmark(self, task_def: Benchmark) -> None: + """Register a benchmark as an evaluation configuration.""" + await self._evaluator_post( + "/v1/evaluation/configs", + { + "namespace": DEFAULT_NAMESPACE, + "name": task_def.benchmark_id, + # metadata is copied to request body as-is + **task_def.metadata, + }, + ) + + async def run_eval( + self, + benchmark_id: str, + benchmark_config: BenchmarkConfig, + ) -> Job: + """Run an evaluation job for a benchmark.""" + model = ( + benchmark_config.eval_candidate.model + if benchmark_config.eval_candidate.type == "model" + else benchmark_config.eval_candidate.config.model + ) + nvidia_model = self.get_provider_model_id(model) or model + + result = await self._evaluator_post( + "/v1/evaluation/jobs", + { + "config": f"{DEFAULT_NAMESPACE}/{benchmark_id}", + "target": {"type": "model", "model": nvidia_model}, + }, + ) + + return Job(job_id=result["id"], status=JobStatus.in_progress) + + async def evaluate_rows( + self, + benchmark_id: str, + input_rows: List[Dict[str, Any]], + scoring_functions: List[str], + benchmark_config: BenchmarkConfig, + ) -> EvaluateResponse: + raise NotImplementedError() + + async def job_status(self, benchmark_id: str, job_id: str) -> Job: + """Get the status of an evaluation job. + + EvaluatorStatus: "created", "pending", "running", "cancelled", "cancelling", "failed", "completed". + JobStatus: "scheduled", "in_progress", "completed", "cancelled", "failed" + """ + result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}") + result_status = result["status"] + + job_status = JobStatus.failed + if result_status in ["created", "pending"]: + job_status = JobStatus.scheduled + elif result_status in ["running"]: + job_status = JobStatus.in_progress + elif result_status in ["completed"]: + job_status = JobStatus.completed + elif result_status in ["cancelled"]: + job_status = JobStatus.cancelled + + return Job(job_id=job_id, status=job_status) + + async def job_cancel(self, benchmark_id: str, job_id: str) -> None: + """Cancel the evaluation job.""" + await self._evaluator_post(f"/v1/evaluation/jobs/{job_id}/cancel", {}) + + async def job_result(self, benchmark_id: str, job_id: str) -> EvaluateResponse: + """Returns the results of the evaluation job.""" + + job = await self.job_status(benchmark_id, job_id) + status = job.status + if not status or status != JobStatus.completed: + raise ValueError(f"Job {job_id} not completed. Status: {status.value}") + + result = await self._evaluator_get(f"/v1/evaluation/jobs/{job_id}/results") + + return EvaluateResponse( + # TODO: these are stored in detailed results on NeMo Evaluator side; can be added + generations=[], + scores={ + benchmark_id: ScoringResult( + score_rows=[], + aggregated_results=result, + ) + }, + ) diff --git a/llama_stack/providers/remote/inference/nvidia/config.py b/llama_stack/providers/remote/inference/nvidia/config.py index abd34b498..8f80408d4 100644 --- a/llama_stack/providers/remote/inference/nvidia/config.py +++ b/llama_stack/providers/remote/inference/nvidia/config.py @@ -47,10 +47,15 @@ class NVIDIAConfig(BaseModel): default=60, description="Timeout for the HTTP requests", ) + append_api_version: bool = Field( + default_factory=lambda: os.getenv("NVIDIA_APPEND_API_VERSION", "True").lower() != "false", + description="When set to false, the API version will not be appended to the base_url. By default, it is true.", + ) @classmethod def sample_run_config(cls, **kwargs) -> Dict[str, Any]: return { "url": "${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com}", "api_key": "${env.NVIDIA_API_KEY:}", + "append_api_version": "${env.NVIDIA_APPEND_API_VERSION:True}", } diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index c91b4d768..4a62ad6cb 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -33,7 +33,6 @@ from llama_stack.apis.inference import ( TextTruncation, ToolChoice, ToolConfig, - ToolDefinition, ) from llama_stack.apis.inference.inference import ( OpenAIChatCompletion, @@ -42,7 +41,11 @@ from llama_stack.apis.inference.inference import ( OpenAIMessageParam, OpenAIResponseFormatParam, ) -from llama_stack.models.llama.datatypes import ToolPromptFormat +from llama_stack.apis.models import Model, ModelType +from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat +from llama_stack.providers.utils.inference import ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, +) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -120,10 +123,10 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", } - base_url = f"{self._config.url}/v1" + base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url + if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: base_url = special_model_urls[provider_model_id] - return _get_client_for_base_url(base_url) async def _get_provider_model_id(self, model_id: str) -> str: @@ -387,3 +390,44 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): return await self._get_client(provider_model_id).chat.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e + + async def register_model(self, model: Model) -> Model: + """ + Allow non-llama model registration. + + Non-llama model registration: API Catalogue models, post-training models, etc. + client = LlamaStackAsLibraryClient("nvidia") + client.models.register( + model_id="mistralai/mixtral-8x7b-instruct-v0.1", + model_type=ModelType.llm, + provider_id="nvidia", + provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1" + ) + + NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format. + """ + if model.model_type == ModelType.embedding: + # embedding models are always registered by their provider model id and does not need to be mapped to a llama model + provider_resource_id = model.provider_resource_id + else: + provider_resource_id = self.get_provider_model_id(model.provider_resource_id) + + if provider_resource_id: + model.provider_resource_id = provider_resource_id + else: + llama_model = model.metadata.get("llama_model") + existing_llama_model = self.get_llama_model(model.provider_resource_id) + if existing_llama_model: + if existing_llama_model != llama_model: + raise ValueError( + f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" + ) + else: + # not llama model + if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: + self.provider_id_to_llama_model_map[model.provider_resource_id] = ( + ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] + ) + else: + self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id + return model diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 2282e2726..f51aa2ded 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -8,7 +8,6 @@ from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union import httpx -from ollama import AsyncClient from openai import AsyncOpenAI from llama_stack.apis.common.content_types import ( @@ -73,6 +72,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( interleaved_content_as_str, request_has_media, ) +from ollama import AsyncClient # type: ignore[attr-defined] from .models import model_entries diff --git a/llama_stack/providers/remote/inference/together/together.py b/llama_stack/providers/remote/inference/together/together.py index 001e6aac4..48e41f5b0 100644 --- a/llama_stack/providers/remote/inference/together/together.py +++ b/llama_stack/providers/remote/inference/together/together.py @@ -76,8 +76,11 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi async def shutdown(self) -> None: if self._client: - await self._client.close() + # Together client has no close method, so just set to None self._client = None + if self._openai_client: + await self._openai_client.close() + self._openai_client = None async def completion( self, @@ -359,7 +362,7 @@ class TogetherInferenceAdapter(ModelRegistryHelper, Inference, NeedsRequestProvi top_p=top_p, user=user, ) - if params.get("stream", True): + if params.get("stream", False): return self._stream_openai_chat_completion(params) return await self._get_openai_client().chat.completions.create(**params) # type: ignore diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index 4d3aafd6a..ac268c86c 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -231,12 +231,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): self.client = None async def initialize(self) -> None: - log.info(f"Initializing VLLM client with base_url={self.config.url}") - self.client = AsyncOpenAI( - base_url=self.config.url, - api_key=self.config.api_token, - http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), - ) + pass async def shutdown(self) -> None: pass @@ -249,6 +244,20 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): raise ValueError("Model store not set") return await self.model_store.get_model(model_id) + def _lazy_initialize_client(self): + if self.client is not None: + return + + log.info(f"Initializing vLLM client with base_url={self.config.url}") + self.client = self._create_client() + + def _create_client(self): + return AsyncOpenAI( + base_url=self.config.url, + api_key=self.config.api_token, + http_client=None if self.config.tls_verify else httpx.AsyncClient(verify=False), + ) + async def completion( self, model_id: str, @@ -258,6 +267,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): stream: Optional[bool] = False, logprobs: Optional[LogProbConfig] = None, ) -> CompletionResponse | AsyncGenerator[CompletionResponseStreamChunk, None]: + self._lazy_initialize_client() if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) @@ -287,6 +297,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): logprobs: Optional[LogProbConfig] = None, tool_config: Optional[ToolConfig] = None, ) -> ChatCompletionResponse | AsyncGenerator[ChatCompletionResponseStreamChunk, None]: + self._lazy_initialize_client() if sampling_params is None: sampling_params = SamplingParams() model = await self._get_model(model_id) @@ -357,12 +368,15 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): yield chunk async def register_model(self, model: Model) -> Model: - assert self.client is not None + # register_model is called during Llama Stack initialization, hence we cannot init self.client if not initialized yet. + # self.client should only be created after the initialization is complete to avoid asyncio cross-context errors. + # Changing this may lead to unpredictable behavior. + client = self._create_client() if self.client is None else self.client try: model = await self.register_helper.register_model(model) except ValueError: pass # Ignore statically unknown model, will check live listing - res = await self.client.models.list() + res = await client.models.list() available_models = [m.id async for m in res] if model.provider_resource_id not in available_models: raise ValueError( @@ -413,6 +427,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): output_dimension: Optional[int] = None, task_type: Optional[EmbeddingTaskType] = None, ) -> EmbeddingsResponse: + self._lazy_initialize_client() assert self.client is not None model = await self._get_model(model_id) @@ -452,6 +467,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): guided_choice: Optional[List[str]] = None, prompt_logprobs: Optional[int] = None, ) -> OpenAICompletion: + self._lazy_initialize_client() model_obj = await self._get_model(model) extra_body: Dict[str, Any] = {} @@ -508,6 +524,7 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): top_p: Optional[float] = None, user: Optional[str] = None, ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + self._lazy_initialize_client() model_obj = await self._get_model(model) params = await prepare_openai_completion_params( model=model_obj.provider_resource_id, diff --git a/llama_stack/providers/remote/inference/watsonx/__init__.py b/llama_stack/providers/remote/inference/watsonx/__init__.py new file mode 100644 index 000000000..e59e873b6 --- /dev/null +++ b/llama_stack/providers/remote/inference/watsonx/__init__.py @@ -0,0 +1,22 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.apis.inference import Inference + +from .config import WatsonXConfig + + +async def get_adapter_impl(config: WatsonXConfig, _deps) -> Inference: + # import dynamically so `llama stack build` does not fail due to missing dependencies + from .watsonx import WatsonXInferenceAdapter + + if not isinstance(config, WatsonXConfig): + raise RuntimeError(f"Unexpected config type: {type(config)}") + adapter = WatsonXInferenceAdapter(config) + return adapter + + +__all__ = ["get_adapter_impl", "WatsonXConfig"] diff --git a/llama_stack/providers/remote/inference/watsonx/config.py b/llama_stack/providers/remote/inference/watsonx/config.py new file mode 100644 index 000000000..7ee99b7e0 --- /dev/null +++ b/llama_stack/providers/remote/inference/watsonx/config.py @@ -0,0 +1,46 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +from typing import Any, Dict, Optional + +from pydantic import BaseModel, Field, SecretStr + +from llama_stack.schema_utils import json_schema_type + + +class WatsonXProviderDataValidator(BaseModel): + url: str + api_key: str + project_id: str + + +@json_schema_type +class WatsonXConfig(BaseModel): + url: str = Field( + default_factory=lambda: os.getenv("WATSONX_BASE_URL", "https://us-south.ml.cloud.ibm.com"), + description="A base url for accessing the watsonx.ai", + ) + api_key: Optional[SecretStr] = Field( + default_factory=lambda: os.getenv("WATSONX_API_KEY"), + description="The watsonx API key, only needed of using the hosted service", + ) + project_id: Optional[str] = Field( + default_factory=lambda: os.getenv("WATSONX_PROJECT_ID"), + description="The Project ID key, only needed of using the hosted service", + ) + timeout: int = Field( + default=60, + description="Timeout for the HTTP requests", + ) + + @classmethod + def sample_run_config(cls, **kwargs) -> Dict[str, Any]: + return { + "url": "${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com}", + "api_key": "${env.WATSONX_API_KEY:}", + "project_id": "${env.WATSONX_PROJECT_ID:}", + } diff --git a/llama_stack/providers/remote/inference/watsonx/models.py b/llama_stack/providers/remote/inference/watsonx/models.py new file mode 100644 index 000000000..d98f0510a --- /dev/null +++ b/llama_stack/providers/remote/inference/watsonx/models.py @@ -0,0 +1,47 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from llama_stack.models.llama.sku_types import CoreModelId +from llama_stack.providers.utils.inference.model_registry import build_hf_repo_model_entry + +MODEL_ENTRIES = [ + build_hf_repo_model_entry( + "meta-llama/llama-3-3-70b-instruct", + CoreModelId.llama3_3_70b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-2-13b-chat", + CoreModelId.llama2_13b.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-1-70b-instruct", + CoreModelId.llama3_1_70b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-1-8b-instruct", + CoreModelId.llama3_1_8b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-11b-vision-instruct", + CoreModelId.llama3_2_11b_vision_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-1b-instruct", + CoreModelId.llama3_2_1b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-3b-instruct", + CoreModelId.llama3_2_3b_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-3-2-90b-vision-instruct", + CoreModelId.llama3_2_90b_vision_instruct.value, + ), + build_hf_repo_model_entry( + "meta-llama/llama-guard-3-11b-vision", + CoreModelId.llama_guard_3_11b_vision.value, + ), +] diff --git a/llama_stack/providers/remote/inference/watsonx/watsonx.py b/llama_stack/providers/remote/inference/watsonx/watsonx.py new file mode 100644 index 000000000..fa9cc4391 --- /dev/null +++ b/llama_stack/providers/remote/inference/watsonx/watsonx.py @@ -0,0 +1,378 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from typing import Any, AsyncGenerator, AsyncIterator, Dict, List, Optional, Union + +from ibm_watson_machine_learning.foundation_models import Model +from ibm_watson_machine_learning.metanames import GenTextParamsMetaNames as GenParams +from openai import AsyncOpenAI + +from llama_stack.apis.common.content_types import InterleavedContent, InterleavedContentItem +from llama_stack.apis.inference import ( + ChatCompletionRequest, + ChatCompletionResponse, + CompletionRequest, + EmbeddingsResponse, + EmbeddingTaskType, + Inference, + LogProbConfig, + Message, + ResponseFormat, + SamplingParams, + TextTruncation, + ToolChoice, + ToolConfig, + ToolDefinition, + ToolPromptFormat, +) +from llama_stack.apis.inference.inference import ( + GreedySamplingStrategy, + OpenAIChatCompletion, + OpenAIChatCompletionChunk, + OpenAICompletion, + OpenAIMessageParam, + OpenAIResponseFormatParam, + TopKSamplingStrategy, + TopPSamplingStrategy, +) +from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper +from llama_stack.providers.utils.inference.openai_compat import ( + OpenAICompatCompletionChoice, + OpenAICompatCompletionResponse, + prepare_openai_completion_params, + process_chat_completion_response, + process_chat_completion_stream_response, + process_completion_response, + process_completion_stream_response, +) +from llama_stack.providers.utils.inference.prompt_adapter import ( + chat_completion_request_to_prompt, + completion_request_to_prompt, + request_has_media, +) + +from . import WatsonXConfig +from .models import MODEL_ENTRIES + + +class WatsonXInferenceAdapter(Inference, ModelRegistryHelper): + def __init__(self, config: WatsonXConfig) -> None: + ModelRegistryHelper.__init__(self, MODEL_ENTRIES) + + print(f"Initializing watsonx InferenceAdapter({config.url})...") + + self._config = config + + self._project_id = self._config.project_id + + async def initialize(self) -> None: + pass + + async def shutdown(self) -> None: + pass + + async def completion( + self, + model_id: str, + content: InterleavedContent, + sampling_params: Optional[SamplingParams] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() + model = await self.model_store.get_model(model_id) + request = CompletionRequest( + model=model.provider_resource_id, + content=content, + sampling_params=sampling_params, + response_format=response_format, + stream=stream, + logprobs=logprobs, + ) + if stream: + return self._stream_completion(request) + else: + return await self._nonstream_completion(request) + + def _get_client(self, model_id) -> Model: + config_api_key = self._config.api_key.get_secret_value() if self._config.api_key else None + config_url = self._config.url + project_id = self._config.project_id + credentials = {"url": config_url, "apikey": config_api_key} + + return Model(model_id=model_id, credentials=credentials, project_id=project_id) + + def _get_openai_client(self) -> AsyncOpenAI: + if not self._openai_client: + self._openai_client = AsyncOpenAI( + base_url=f"{self._config.url}/openai/v1", + api_key=self._config.api_key, + ) + return self._openai_client + + async def _nonstream_completion(self, request: CompletionRequest) -> ChatCompletionResponse: + params = await self._get_params(request) + r = self._get_client(request.model).generate(**params) + choices = [] + if "results" in r: + for result in r["results"]: + choice = OpenAICompatCompletionChoice( + finish_reason=result["stop_reason"] if result["stop_reason"] else None, + text=result["generated_text"], + ) + choices.append(choice) + response = OpenAICompatCompletionResponse( + choices=choices, + ) + return process_completion_response(response) + + async def _stream_completion(self, request: CompletionRequest) -> AsyncGenerator: + params = await self._get_params(request) + + async def _generate_and_convert_to_openai_compat(): + s = self._get_client(request.model).generate_text_stream(**params) + for chunk in s: + choice = OpenAICompatCompletionChoice( + finish_reason=None, + text=chunk, + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _generate_and_convert_to_openai_compat() + async for chunk in process_completion_stream_response(stream): + yield chunk + + async def chat_completion( + self, + model_id: str, + messages: List[Message], + sampling_params: Optional[SamplingParams] = None, + tools: Optional[List[ToolDefinition]] = None, + tool_choice: Optional[ToolChoice] = ToolChoice.auto, + tool_prompt_format: Optional[ToolPromptFormat] = None, + response_format: Optional[ResponseFormat] = None, + stream: Optional[bool] = False, + logprobs: Optional[LogProbConfig] = None, + tool_config: Optional[ToolConfig] = None, + ) -> AsyncGenerator: + if sampling_params is None: + sampling_params = SamplingParams() + model = await self.model_store.get_model(model_id) + request = ChatCompletionRequest( + model=model.provider_resource_id, + messages=messages, + sampling_params=sampling_params, + tools=tools or [], + response_format=response_format, + stream=stream, + logprobs=logprobs, + tool_config=tool_config, + ) + + if stream: + return self._stream_chat_completion(request) + else: + return await self._nonstream_chat_completion(request) + + async def _nonstream_chat_completion(self, request: ChatCompletionRequest) -> ChatCompletionResponse: + params = await self._get_params(request) + r = self._get_client(request.model).generate(**params) + choices = [] + if "results" in r: + for result in r["results"]: + choice = OpenAICompatCompletionChoice( + finish_reason=result["stop_reason"] if result["stop_reason"] else None, + text=result["generated_text"], + ) + choices.append(choice) + response = OpenAICompatCompletionResponse( + choices=choices, + ) + return process_chat_completion_response(response, request) + + async def _stream_chat_completion(self, request: ChatCompletionRequest) -> AsyncGenerator: + params = await self._get_params(request) + model_id = request.model + + # if we shift to TogetherAsyncClient, we won't need this wrapper + async def _to_async_generator(): + s = self._get_client(model_id).generate_text_stream(**params) + for chunk in s: + choice = OpenAICompatCompletionChoice( + finish_reason=None, + text=chunk, + ) + yield OpenAICompatCompletionResponse( + choices=[choice], + ) + + stream = _to_async_generator() + async for chunk in process_chat_completion_stream_response(stream, request): + yield chunk + + async def _get_params(self, request: Union[ChatCompletionRequest, CompletionRequest]) -> dict: + input_dict = {"params": {}} + media_present = request_has_media(request) + llama_model = self.get_llama_model(request.model) + if isinstance(request, ChatCompletionRequest): + input_dict["prompt"] = await chat_completion_request_to_prompt(request, llama_model) + else: + assert not media_present, "Together does not support media for Completion requests" + input_dict["prompt"] = await completion_request_to_prompt(request) + if request.sampling_params: + if request.sampling_params.strategy: + input_dict["params"][GenParams.DECODING_METHOD] = request.sampling_params.strategy.type + if request.sampling_params.max_tokens: + input_dict["params"][GenParams.MAX_NEW_TOKENS] = request.sampling_params.max_tokens + if request.sampling_params.repetition_penalty: + input_dict["params"][GenParams.REPETITION_PENALTY] = request.sampling_params.repetition_penalty + + if isinstance(request.sampling_params.strategy, TopPSamplingStrategy): + input_dict["params"][GenParams.TOP_P] = request.sampling_params.strategy.top_p + input_dict["params"][GenParams.TEMPERATURE] = request.sampling_params.strategy.temperature + if isinstance(request.sampling_params.strategy, TopKSamplingStrategy): + input_dict["params"][GenParams.TOP_K] = request.sampling_params.strategy.top_k + if isinstance(request.sampling_params.strategy, GreedySamplingStrategy): + input_dict["params"][GenParams.TEMPERATURE] = 0.0 + + input_dict["params"][GenParams.STOP_SEQUENCES] = ["<|endoftext|>"] + + params = { + **input_dict, + } + return params + + async def embeddings( + self, + model_id: str, + contents: List[str] | List[InterleavedContentItem], + text_truncation: Optional[TextTruncation] = TextTruncation.none, + output_dimension: Optional[int] = None, + task_type: Optional[EmbeddingTaskType] = None, + ) -> EmbeddingsResponse: + raise NotImplementedError("embedding is not supported for watsonx") + + async def openai_completion( + self, + model: str, + prompt: Union[str, List[str], List[int], List[List[int]]], + best_of: Optional[int] = None, + echo: Optional[bool] = None, + frequency_penalty: Optional[float] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + presence_penalty: Optional[float] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + guided_choice: Optional[List[str]] = None, + prompt_logprobs: Optional[int] = None, + ) -> OpenAICompletion: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + prompt=prompt, + best_of=best_of, + echo=echo, + frequency_penalty=frequency_penalty, + logit_bias=logit_bias, + logprobs=logprobs, + max_tokens=max_tokens, + n=n, + presence_penalty=presence_penalty, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + top_p=top_p, + user=user, + ) + return await self._get_openai_client().completions.create(**params) # type: ignore + + async def openai_chat_completion( + self, + model: str, + messages: List[OpenAIMessageParam], + frequency_penalty: Optional[float] = None, + function_call: Optional[Union[str, Dict[str, Any]]] = None, + functions: Optional[List[Dict[str, Any]]] = None, + logit_bias: Optional[Dict[str, float]] = None, + logprobs: Optional[bool] = None, + max_completion_tokens: Optional[int] = None, + max_tokens: Optional[int] = None, + n: Optional[int] = None, + parallel_tool_calls: Optional[bool] = None, + presence_penalty: Optional[float] = None, + response_format: Optional[OpenAIResponseFormatParam] = None, + seed: Optional[int] = None, + stop: Optional[Union[str, List[str]]] = None, + stream: Optional[bool] = None, + stream_options: Optional[Dict[str, Any]] = None, + temperature: Optional[float] = None, + tool_choice: Optional[Union[str, Dict[str, Any]]] = None, + tools: Optional[List[Dict[str, Any]]] = None, + top_logprobs: Optional[int] = None, + top_p: Optional[float] = None, + user: Optional[str] = None, + ) -> Union[OpenAIChatCompletion, AsyncIterator[OpenAIChatCompletionChunk]]: + model_obj = await self.model_store.get_model(model) + params = await prepare_openai_completion_params( + model=model_obj.provider_resource_id, + messages=messages, + frequency_penalty=frequency_penalty, + function_call=function_call, + functions=functions, + logit_bias=logit_bias, + logprobs=logprobs, + max_completion_tokens=max_completion_tokens, + max_tokens=max_tokens, + n=n, + parallel_tool_calls=parallel_tool_calls, + presence_penalty=presence_penalty, + response_format=response_format, + seed=seed, + stop=stop, + stream=stream, + stream_options=stream_options, + temperature=temperature, + tool_choice=tool_choice, + tools=tools, + top_logprobs=top_logprobs, + top_p=top_p, + user=user, + ) + if params.get("stream", False): + return self._stream_openai_chat_completion(params) + return await self._get_openai_client().chat.completions.create(**params) # type: ignore + + async def _stream_openai_chat_completion(self, params: dict) -> AsyncGenerator: + # watsonx.ai sometimes adds usage data to the stream + include_usage = False + if params.get("stream_options", None): + include_usage = params["stream_options"].get("include_usage", False) + stream = await self._get_openai_client().chat.completions.create(**params) + + seen_finish_reason = False + async for chunk in stream: + # Final usage chunk with no choices that the user didn't request, so discard + if not include_usage and seen_finish_reason and len(chunk.choices) == 0: + break + yield chunk + for choice in chunk.choices: + if choice.finish_reason: + seen_finish_reason = True + break diff --git a/llama_stack/providers/remote/post_training/nvidia/README.md b/llama_stack/providers/remote/post_training/nvidia/README.md index 230587d66..3ef538d29 100644 --- a/llama_stack/providers/remote/post_training/nvidia/README.md +++ b/llama_stack/providers/remote/post_training/nvidia/README.md @@ -36,7 +36,6 @@ import os os.environ["NVIDIA_API_KEY"] = "your-api-key" os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" -os.environ["NVIDIA_USER_ID"] = "llama-stack-user" os.environ["NVIDIA_DATASET_NAMESPACE"] = "default" os.environ["NVIDIA_PROJECT_ID"] = "test-project" os.environ["NVIDIA_OUTPUT_MODEL_DIR"] = "test-example-model@v1" @@ -125,6 +124,21 @@ client.post_training.job.cancel(job_uuid="your-job-id") ### Inference with the fine-tuned model +#### 1. Register the model + +```python +from llama_stack.apis.models import Model, ModelType + +client.models.register( + model_id="test-example-model@v1", + provider_id="nvidia", + provider_model_id="test-example-model@v1", + model_type=ModelType.llm, +) +``` + +#### 2. Inference with the fine-tuned model + ```python response = client.inference.completion( content="Complete the sentence using one word: Roses are red, violets are ", diff --git a/llama_stack/providers/remote/post_training/nvidia/post_training.py b/llama_stack/providers/remote/post_training/nvidia/post_training.py index d3de930f7..c74fb2a24 100644 --- a/llama_stack/providers/remote/post_training/nvidia/post_training.py +++ b/llama_stack/providers/remote/post_training/nvidia/post_training.py @@ -67,13 +67,18 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): self.timeout = aiohttp.ClientTimeout(total=config.timeout) # TODO: filter by available models based on /config endpoint ModelRegistryHelper.__init__(self, model_entries=_MODEL_ENTRIES) - self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) - self.customizer_url = config.customizer_url + self.session = None + self.customizer_url = config.customizer_url if not self.customizer_url: warnings.warn("Customizer URL is not set, using default value: http://nemo.test", stacklevel=2) self.customizer_url = "http://nemo.test" + async def _get_session(self) -> aiohttp.ClientSession: + if self.session is None or self.session.closed: + self.session = aiohttp.ClientSession(headers=self.headers, timeout=self.timeout) + return self.session + async def _make_request( self, method: str, @@ -94,8 +99,9 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): if json and "Content-Type" not in request_headers: request_headers["Content-Type"] = "application/json" + session = await self._get_session() for _ in range(self.config.max_retries): - async with self.session.request(method, url, params=params, json=json, **kwargs) as response: + async with session.request(method, url, params=params, json=json, **kwargs) as response: if response.status >= 400: error_data = await response.json() raise Exception(f"API request failed: {error_data}") @@ -122,8 +128,8 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): jobs = [] for job in response.get("data", []): job_id = job.pop("id") - job_status = job.pop("status", "unknown").lower() - mapped_status = STATUS_MAPPING.get(job_status, "unknown") + job_status = job.pop("status", "scheduled").lower() + mapped_status = STATUS_MAPPING.get(job_status, "scheduled") # Convert string timestamps to datetime objects created_at = ( @@ -177,7 +183,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): ) api_status = response.pop("status").lower() - mapped_status = STATUS_MAPPING.get(api_status, "unknown") + mapped_status = STATUS_MAPPING.get(api_status, "scheduled") return NvidiaPostTrainingJobStatusResponse( status=JobStatus(mapped_status), @@ -239,6 +245,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): Supported models: - meta/llama-3.1-8b-instruct + - meta/llama-3.2-1b-instruct Supported algorithm configs: - LoRA, SFT @@ -284,10 +291,6 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): - LoRA config: ## NeMo customizer specific LoRA parameters - - adapter_dim: int - Adapter dimension - Default: 8 (supports powers of 2) - - adapter_dropout: float - Adapter dropout - Default: None (0.0-1.0) - alpha: int - Scaling factor for the LoRA update Default: 16 Note: @@ -297,7 +300,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): User is informed about unsupported parameters via warnings. """ # Map model to nvidia model name - # ToDo: only supports llama-3.1-8b-instruct now, need to update this to support other models + # See `_MODEL_ENTRIES` for supported models nvidia_model = self.get_provider_model_id(model) # Check for unsupported method parameters @@ -330,7 +333,7 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): }, "data_config": {"dataset_id", "batch_size"}, "optimizer_config": {"lr", "weight_decay"}, - "lora_config": {"type", "adapter_dim", "adapter_dropout", "alpha"}, + "lora_config": {"type", "alpha"}, } # Validate all parameters at once @@ -389,16 +392,10 @@ class NvidiaPostTrainingAdapter(ModelRegistryHelper): # Handle LoRA-specific configuration if algorithm_config: - if isinstance(algorithm_config, dict) and algorithm_config.get("type") == "LoRA": + if algorithm_config.type == "LoRA": warn_unsupported_params(algorithm_config, supported_params["lora_config"], "LoRA config") job_config["hyperparameters"]["lora"] = { - k: v - for k, v in { - "adapter_dim": algorithm_config.get("adapter_dim"), - "alpha": algorithm_config.get("alpha"), - "adapter_dropout": algorithm_config.get("adapter_dropout"), - }.items() - if v is not None + k: v for k, v in {"alpha": algorithm_config.alpha}.items() if v is not None } else: raise NotImplementedError(f"Unsupported algorithm config: {algorithm_config}") diff --git a/llama_stack/providers/utils/inference/openai_compat.py b/llama_stack/providers/utils/inference/openai_compat.py index f91e7d7dc..4d690287b 100644 --- a/llama_stack/providers/utils/inference/openai_compat.py +++ b/llama_stack/providers/utils/inference/openai_compat.py @@ -524,11 +524,26 @@ async def convert_message_to_openai_dict(message: Message, download: bool = Fals else: content = [await _convert_content(message.content)] - return { + result = { "role": message.role, "content": content, } + if hasattr(message, "tool_calls") and message.tool_calls: + result["tool_calls"] = [] + for tc in message.tool_calls: + result["tool_calls"].append( + { + "id": tc.call_id, + "type": "function", + "function": { + "name": tc.tool_name, + "arguments": tc.arguments_json if hasattr(tc, "arguments_json") else json.dumps(tc.arguments), + }, + } + ) + return result + class UnparseableToolCall(BaseModel): """ diff --git a/llama_stack/providers/utils/inference/prompt_adapter.py b/llama_stack/providers/utils/inference/prompt_adapter.py index 4f9c4927a..657dc4b86 100644 --- a/llama_stack/providers/utils/inference/prompt_adapter.py +++ b/llama_stack/providers/utils/inference/prompt_adapter.py @@ -52,6 +52,9 @@ from llama_stack.models.llama.llama3.prompt_templates import ( SystemDefaultGenerator, ) from llama_stack.models.llama.llama3.tokenizer import Tokenizer +from llama_stack.models.llama.llama4.prompt_templates.system_prompts import ( + PythonListCustomToolGenerator as PythonListCustomToolGeneratorLlama4, +) from llama_stack.models.llama.sku_list import resolve_model from llama_stack.models.llama.sku_types import ModelFamily, is_multimodal from llama_stack.providers.utils.inference import supported_inference_models @@ -306,10 +309,11 @@ def chat_completion_request_to_messages( elif model.model_family in ( ModelFamily.llama3_2, ModelFamily.llama3_3, - ModelFamily.llama4, ): - # llama3.2, llama3.3 and llama4 models follow the same tool prompt format - messages = augment_messages_for_tools_llama_3_2(request) + # llama3.2, llama3.3 follow the same tool prompt format + messages = augment_messages_for_tools_llama(request, PythonListCustomToolGenerator) + elif model.model_family == ModelFamily.llama4: + messages = augment_messages_for_tools_llama(request, PythonListCustomToolGeneratorLlama4) else: messages = request.messages @@ -399,8 +403,9 @@ def augment_messages_for_tools_llama_3_1( return messages -def augment_messages_for_tools_llama_3_2( +def augment_messages_for_tools_llama( request: ChatCompletionRequest, + custom_tool_prompt_generator, ) -> List[Message]: existing_messages = request.messages existing_system_message = None @@ -434,7 +439,7 @@ def augment_messages_for_tools_llama_3_2( if existing_system_message and request.tool_config.system_message_behavior == SystemMessageBehavior.replace: system_prompt = existing_system_message.content - tool_template = PythonListCustomToolGenerator().gen(custom_tools, system_prompt) + tool_template = custom_tool_prompt_generator().gen(custom_tools, system_prompt) sys_content += tool_template.render() sys_content += "\n" diff --git a/llama_stack/templates/dependencies.json b/llama_stack/templates/dependencies.json index b96191752..4c16411f0 100644 --- a/llama_stack/templates/dependencies.json +++ b/llama_stack/templates/dependencies.json @@ -394,12 +394,10 @@ "aiosqlite", "blobfile", "chardet", - "emoji", "faiss-cpu", "fastapi", "fire", "httpx", - "langdetect", "matplotlib", "nltk", "numpy", @@ -411,7 +409,6 @@ "psycopg2-binary", "pymongo", "pypdf", - "pythainlp", "redis", "requests", "scikit-learn", @@ -419,7 +416,6 @@ "sentencepiece", "tqdm", "transformers", - "tree_sitter", "uvicorn" ], "ollama": [ @@ -759,5 +755,41 @@ "vllm", "sentence-transformers --no-deps", "torch torchvision --index-url https://download.pytorch.org/whl/cpu" + ], + "watsonx": [ + "aiosqlite", + "autoevals", + "blobfile", + "chardet", + "datasets", + "emoji", + "faiss-cpu", + "fastapi", + "fire", + "httpx", + "ibm_watson_machine_learning", + "langdetect", + "matplotlib", + "mcp", + "nltk", + "numpy", + "openai", + "opentelemetry-exporter-otlp-proto-http", + "opentelemetry-sdk", + "pandas", + "pillow", + "psycopg2-binary", + "pymongo", + "pypdf", + "pythainlp", + "redis", + "requests", + "scikit-learn", + "scipy", + "sentencepiece", + "tqdm", + "transformers", + "tree_sitter", + "uvicorn" ] } diff --git a/llama_stack/templates/meta-reference-gpu/doc_template.md b/llama_stack/templates/meta-reference-gpu/doc_template.md index a174331b4..2ca6793d7 100644 --- a/llama_stack/templates/meta-reference-gpu/doc_template.md +++ b/llama_stack/templates/meta-reference-gpu/doc_template.md @@ -69,6 +69,7 @@ LLAMA_STACK_PORT=8321 docker run \ -it \ --pull always \ + --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ llamastack/distribution-{{ name }} \ @@ -82,6 +83,7 @@ If you are using Llama Stack Safety / Shield APIs, use: docker run \ -it \ --pull always \ + --gpu all \ -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ -v ~/.llama:/root/.llama \ llamastack/distribution-{{ name }} \ diff --git a/llama_stack/templates/nvidia/build.yaml b/llama_stack/templates/nvidia/build.yaml index f99ff6c81..a33fa3737 100644 --- a/llama_stack/templates/nvidia/build.yaml +++ b/llama_stack/templates/nvidia/build.yaml @@ -1,6 +1,6 @@ version: '2' distribution_spec: - description: Use NVIDIA NIM for running LLM inference and safety + description: Use NVIDIA NIM for running LLM inference, evaluation and safety providers: inference: - remote::nvidia @@ -13,7 +13,7 @@ distribution_spec: telemetry: - inline::meta-reference eval: - - inline::meta-reference + - remote::nvidia post_training: - remote::nvidia datasetio: diff --git a/llama_stack/templates/nvidia/nvidia.py b/llama_stack/templates/nvidia/nvidia.py index a0cefba52..463c13879 100644 --- a/llama_stack/templates/nvidia/nvidia.py +++ b/llama_stack/templates/nvidia/nvidia.py @@ -7,6 +7,7 @@ from pathlib import Path from llama_stack.distribution.datatypes import ModelInput, Provider, ShieldInput, ToolGroupInput +from llama_stack.providers.remote.eval.nvidia import NVIDIAEvalConfig from llama_stack.providers.remote.inference.nvidia import NVIDIAConfig from llama_stack.providers.remote.inference.nvidia.models import MODEL_ENTRIES from llama_stack.providers.remote.safety.nvidia import NVIDIASafetyConfig @@ -20,7 +21,7 @@ def get_distribution_template() -> DistributionTemplate: "safety": ["remote::nvidia"], "agents": ["inline::meta-reference"], "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], + "eval": ["remote::nvidia"], "post_training": ["remote::nvidia"], "datasetio": ["inline::localfs"], "scoring": ["inline::basic"], @@ -37,6 +38,11 @@ def get_distribution_template() -> DistributionTemplate: provider_type="remote::nvidia", config=NVIDIASafetyConfig.sample_run_config(), ) + eval_provider = Provider( + provider_id="nvidia", + provider_type="remote::nvidia", + config=NVIDIAEvalConfig.sample_run_config(), + ) inference_model = ModelInput( model_id="${env.INFERENCE_MODEL}", provider_id="nvidia", @@ -60,7 +66,7 @@ def get_distribution_template() -> DistributionTemplate: return DistributionTemplate( name="nvidia", distro_type="self_hosted", - description="Use NVIDIA NIM for running LLM inference and safety", + description="Use NVIDIA NIM for running LLM inference, evaluation and safety", container_image=None, template_path=Path(__file__).parent / "doc_template.md", providers=providers, @@ -69,6 +75,7 @@ def get_distribution_template() -> DistributionTemplate: "run.yaml": RunConfigSettings( provider_overrides={ "inference": [inference_provider], + "eval": [eval_provider], }, default_models=default_models, default_tool_groups=default_tool_groups, @@ -78,7 +85,8 @@ def get_distribution_template() -> DistributionTemplate: "inference": [ inference_provider, safety_provider, - ] + ], + "eval": [eval_provider], }, default_models=[inference_model, safety_model], default_shields=[ShieldInput(shield_id="${env.SAFETY_MODEL}", provider_id="nvidia")], @@ -90,19 +98,15 @@ def get_distribution_template() -> DistributionTemplate: "", "NVIDIA API Key", ), - ## Nemo Customizer related variables - "NVIDIA_USER_ID": ( - "llama-stack-user", - "NVIDIA User ID", + "NVIDIA_APPEND_API_VERSION": ( + "True", + "Whether to append the API version to the base_url", ), + ## Nemo Customizer related variables "NVIDIA_DATASET_NAMESPACE": ( "default", "NVIDIA Dataset Namespace", ), - "NVIDIA_ACCESS_POLICIES": ( - "{}", - "NVIDIA Access Policies", - ), "NVIDIA_PROJECT_ID": ( "test-project", "NVIDIA Project ID", @@ -119,6 +123,10 @@ def get_distribution_template() -> DistributionTemplate: "http://0.0.0.0:7331", "URL for the NeMo Guardrails Service", ), + "NVIDIA_EVALUATOR_URL": ( + "http://0.0.0.0:7331", + "URL for the NeMo Evaluator Service", + ), "INFERENCE_MODEL": ( "Llama3.1-8B-Instruct", "Inference model", diff --git a/llama_stack/templates/nvidia/run-with-safety.yaml b/llama_stack/templates/nvidia/run-with-safety.yaml index 658d9377e..a3e5fefa4 100644 --- a/llama_stack/templates/nvidia/run-with-safety.yaml +++ b/llama_stack/templates/nvidia/run-with-safety.yaml @@ -18,6 +18,7 @@ providers: config: url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com} api_key: ${env.NVIDIA_API_KEY:} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True} - provider_id: nvidia provider_type: remote::nvidia config: @@ -53,13 +54,10 @@ providers: sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db} eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_id: nvidia + provider_type: remote::nvidia config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db + evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331} post_training: - provider_id: nvidia provider_type: remote::nvidia diff --git a/llama_stack/templates/nvidia/run.yaml b/llama_stack/templates/nvidia/run.yaml index ff548d82e..271ce1a16 100644 --- a/llama_stack/templates/nvidia/run.yaml +++ b/llama_stack/templates/nvidia/run.yaml @@ -18,6 +18,7 @@ providers: config: url: ${env.NVIDIA_BASE_URL:https://integrate.api.nvidia.com} api_key: ${env.NVIDIA_API_KEY:} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:True} vector_io: - provider_id: faiss provider_type: inline::faiss @@ -48,13 +49,10 @@ providers: sinks: ${env.TELEMETRY_SINKS:console,sqlite} sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/nvidia/trace_store.db} eval: - - provider_id: meta-reference - provider_type: inline::meta-reference + - provider_id: nvidia + provider_type: remote::nvidia config: - kvstore: - type: sqlite - namespace: null - db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/nvidia}/meta_reference_eval.db + evaluator_url: ${env.NVIDIA_EVALUATOR_URL:http://localhost:7331} post_training: - provider_id: nvidia provider_type: remote::nvidia diff --git a/llama_stack/templates/watsonx/__init__.py b/llama_stack/templates/watsonx/__init__.py new file mode 100644 index 000000000..078d86144 --- /dev/null +++ b/llama_stack/templates/watsonx/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .watsonx import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/watsonx/build.yaml b/llama_stack/templates/watsonx/build.yaml new file mode 100644 index 000000000..badd643ad --- /dev/null +++ b/llama_stack/templates/watsonx/build.yaml @@ -0,0 +1,30 @@ +version: '2' +distribution_spec: + description: Use watsonx for running LLM inference + providers: + inference: + - remote::watsonx + vector_io: + - inline::faiss + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::code-interpreter + - inline::rag-runtime + - remote::model-context-protocol +image_type: conda diff --git a/llama_stack/templates/watsonx/doc_template.md b/llama_stack/templates/watsonx/doc_template.md new file mode 100644 index 000000000..af0ae15a8 --- /dev/null +++ b/llama_stack/templates/watsonx/doc_template.md @@ -0,0 +1,74 @@ +--- +orphan: true +--- +# watsonx Distribution + +```{toctree} +:maxdepth: 2 +:hidden: + +self +``` + +The `llamastack/distribution-{{ name }}` distribution consists of the following provider configurations. + +{{ providers_table }} + +{% if run_config_env_vars %} + +### Environment Variables + +The following environment variables can be configured: + +{% for var, (default_value, description) in run_config_env_vars.items() %} +- `{{ var }}`: {{ description }} (default: `{{ default_value }}`) +{% endfor %} +{% endif %} + +{% if default_models %} +### Models + +The following models are available by default: + +{% for model in default_models %} +- `{{ model.model_id }} {{ model.doc_string }}` +{% endfor %} +{% endif %} + + +### Prerequisite: API Keys + +Make sure you have access to a watsonx API Key. You can get one by referring [watsonx.ai](https://www.ibm.com/docs/en/masv-and-l/maximo-manage/continuous-delivery?topic=setup-create-watsonx-api-key). + + +## Running Llama Stack with watsonx + +You can do this via Conda (build code), venv or Docker which has a pre-built image. + +### Via Docker + +This method allows you to get started quickly without having to build the distribution code. + +```bash +LLAMA_STACK_PORT=5001 +docker run \ + -it \ + -p $LLAMA_STACK_PORT:$LLAMA_STACK_PORT \ + -v ./run.yaml:/root/my-run.yaml \ + llamastack/distribution-{{ name }} \ + --yaml-config /root/my-run.yaml \ + --port $LLAMA_STACK_PORT \ + --env WATSONX_API_KEY=$WATSONX_API_KEY \ + --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID \ + --env WATSONX_BASE_URL=$WATSONX_BASE_URL +``` + +### Via Conda + +```bash +llama stack build --template watsonx --image-type conda +llama stack run ./run.yaml \ + --port $LLAMA_STACK_PORT \ + --env WATSONX_API_KEY=$WATSONX_API_KEY \ + --env WATSONX_PROJECT_ID=$WATSONX_PROJECT_ID +``` diff --git a/llama_stack/templates/watsonx/run.yaml b/llama_stack/templates/watsonx/run.yaml new file mode 100644 index 000000000..1048f7192 --- /dev/null +++ b/llama_stack/templates/watsonx/run.yaml @@ -0,0 +1,210 @@ +version: '2' +image_name: watsonx +apis: +- agents +- datasetio +- eval +- inference +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: watsonx + provider_type: remote::watsonx + config: + url: ${env.WATSONX_BASE_URL:https://us-south.ml.cloud.ibm.com} + api_key: ${env.WATSONX_API_KEY:} + project_id: ${env.WATSONX_PROJECT_ID:} + vector_io: + - provider_id: faiss + provider_type: inline::faiss + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/faiss_store.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/agents_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:\u200B}" + sinks: ${env.TELEMETRY_SINKS:console,sqlite} + sqlite_db_path: ${env.SQLITE_DB_PATH:~/.llama/distributions/watsonx/trace_store.db} + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + namespace: null + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:} + max_results: 3 + - provider_id: code-interpreter + provider_type: inline::code-interpreter + config: {} + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:~/.llama/distributions/watsonx}/registry.db +models: +- metadata: {} + model_id: meta-llama/llama-3-3-70b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-3-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.3-70B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-3-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-2-13b-chat + provider_id: watsonx + provider_model_id: meta-llama/llama-2-13b-chat + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-2-13b + provider_id: watsonx + provider_model_id: meta-llama/llama-2-13b-chat + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-1-70b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-70B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-70b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-1-8b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.1-8B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-1-8b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-11b-vision-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-1b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-1b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-1B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-1b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-3b-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-3B-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-3b-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-3-2-90b-vision-instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: watsonx + provider_model_id: meta-llama/llama-3-2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: meta-llama/llama-guard-3-11b-vision + provider_id: watsonx + provider_model_id: meta-llama/llama-guard-3-11b-vision + model_type: llm +- metadata: {} + model_id: meta-llama/Llama-Guard-3-11B-Vision + provider_id: watsonx + provider_model_id: meta-llama/llama-guard-3-11b-vision + model_type: llm +shields: [] +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +- toolgroup_id: builtin::code_interpreter + provider_id: code-interpreter +server: + port: 8321 diff --git a/llama_stack/templates/watsonx/watsonx.py b/llama_stack/templates/watsonx/watsonx.py new file mode 100644 index 000000000..d59bb6f20 --- /dev/null +++ b/llama_stack/templates/watsonx/watsonx.py @@ -0,0 +1,90 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from pathlib import Path + +from llama_stack.distribution.datatypes import Provider, ToolGroupInput +from llama_stack.providers.remote.inference.watsonx import WatsonXConfig +from llama_stack.providers.remote.inference.watsonx.models import MODEL_ENTRIES +from llama_stack.templates.template import DistributionTemplate, RunConfigSettings, get_model_registry + + +def get_distribution_template() -> DistributionTemplate: + providers = { + "inference": ["remote::watsonx"], + "vector_io": ["inline::faiss"], + "safety": ["inline::llama-guard"], + "agents": ["inline::meta-reference"], + "telemetry": ["inline::meta-reference"], + "eval": ["inline::meta-reference"], + "datasetio": ["remote::huggingface", "inline::localfs"], + "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], + "tool_runtime": [ + "remote::brave-search", + "remote::tavily-search", + "inline::code-interpreter", + "inline::rag-runtime", + "remote::model-context-protocol", + ], + } + + inference_provider = Provider( + provider_id="watsonx", + provider_type="remote::watsonx", + config=WatsonXConfig.sample_run_config(), + ) + + available_models = { + "watsonx": MODEL_ENTRIES, + } + default_tool_groups = [ + ToolGroupInput( + toolgroup_id="builtin::websearch", + provider_id="tavily-search", + ), + ToolGroupInput( + toolgroup_id="builtin::rag", + provider_id="rag-runtime", + ), + ToolGroupInput( + toolgroup_id="builtin::code_interpreter", + provider_id="code-interpreter", + ), + ] + + default_models = get_model_registry(available_models) + return DistributionTemplate( + name="watsonx", + distro_type="remote_hosted", + description="Use watsonx for running LLM inference", + container_image=None, + template_path=Path(__file__).parent / "doc_template.md", + providers=providers, + available_models_by_provider=available_models, + run_configs={ + "run.yaml": RunConfigSettings( + provider_overrides={ + "inference": [inference_provider], + }, + default_models=default_models, + default_tool_groups=default_tool_groups, + ), + }, + run_config_env_vars={ + "LLAMASTACK_PORT": ( + "5001", + "Port for the Llama Stack distribution server", + ), + "WATSONX_API_KEY": ( + "", + "watsonx API Key", + ), + "WATSONX_PROJECT_ID": ( + "", + "watsonx Project ID", + ), + }, + ) diff --git a/pyproject.toml b/pyproject.toml index 47d845c30..d661f45fb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -58,7 +58,16 @@ dev = [ "ruamel.yaml", # needed for openapi generator ] # These are the dependencies required for running unit tests. -unit = ["sqlite-vec", "openai", "aiosqlite", "aiohttp", "pypdf", "chardet", "qdrant-client"] +unit = [ + "sqlite-vec", + "openai", + "aiosqlite", + "aiohttp", + "pypdf", + "chardet", + "qdrant-client", + "opentelemetry-exporter-otlp-proto-http" +] # These are the core dependencies required for running integration tests. They are shared across all # providers. If a provider requires additional dependencies, please add them to your environment # separately. If you are using "uv" to execute your tests, you can use the "--with" flag to specify extra @@ -265,6 +274,7 @@ exclude = [ "^llama_stack/providers/remote/inference/sample/", "^llama_stack/providers/remote/inference/tgi/", "^llama_stack/providers/remote/inference/together/", + "^llama_stack/providers/remote/inference/watsonx/", "^llama_stack/providers/remote/safety/bedrock/", "^llama_stack/providers/remote/safety/nvidia/", "^llama_stack/providers/remote/safety/sample/", diff --git a/tests/integration/conftest.py b/tests/integration/conftest.py index 22290b519..131219e52 100644 --- a/tests/integration/conftest.py +++ b/tests/integration/conftest.py @@ -10,6 +10,7 @@ import platform import textwrap import time +import pytest from dotenv import load_dotenv from llama_stack.log import get_logger @@ -19,10 +20,29 @@ from .report import Report logger = get_logger(__name__, category="tests") +@pytest.hookimpl(hookwrapper=True) +def pytest_runtest_makereport(item, call): + outcome = yield + report = outcome.get_result() + if report.when == "call": + item.execution_outcome = report.outcome + item.was_xfail = getattr(report, "wasxfail", False) + + def pytest_runtest_teardown(item): - interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS") - if interval_seconds: - time.sleep(float(interval_seconds)) + # Check if the test actually ran and passed or failed, but was not skipped or an expected failure (xfail) + outcome = getattr(item, "execution_outcome", None) + was_xfail = getattr(item, "was_xfail", False) + + name = item.nodeid + if not any(x in name for x in ("inference/", "safety/", "agents/")): + return + + logger.debug(f"Test '{item.nodeid}' outcome was '{outcome}' (xfail={was_xfail})") + if outcome in ("passed", "failed") and not was_xfail: + interval_seconds = os.getenv("LLAMA_STACK_TEST_INTERVAL_SECONDS") + if interval_seconds: + time.sleep(float(interval_seconds)) def pytest_configure(config): diff --git a/tests/integration/inference/test_openai_completion.py b/tests/integration/inference/test_openai_completion.py index 75b53100c..46ec03d2e 100644 --- a/tests/integration/inference/test_openai_completion.py +++ b/tests/integration/inference/test_openai_completion.py @@ -75,19 +75,24 @@ def openai_client(client_with_models): return OpenAI(base_url=base_url, api_key="bar") +@pytest.fixture(params=["openai_client", "llama_stack_client"]) +def compat_client(request): + return request.getfixturevalue(request.param) + + @pytest.mark.parametrize( "test_case", [ "inference:completion:sanity", ], ) -def test_openai_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_openai_completion_non_streaming(llama_stack_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... prompt = "Respond to this question and explain your answer. " + tc["content"] - response = openai_client.completions.create( + response = llama_stack_client.completions.create( model=text_model_id, prompt=prompt, stream=False, @@ -103,13 +108,13 @@ def test_openai_completion_non_streaming(openai_client, client_with_models, text "inference:completion:sanity", ], ) -def test_openai_completion_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_openai_completion_streaming(llama_stack_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_completion(client_with_models, text_model_id) tc = TestCase(test_case) # ollama needs more verbose prompting for some reason here... prompt = "Respond to this question and explain your answer. " + tc["content"] - response = openai_client.completions.create( + response = llama_stack_client.completions.create( model=text_model_id, prompt=prompt, stream=True, @@ -127,11 +132,11 @@ def test_openai_completion_streaming(openai_client, client_with_models, text_mod 0, ], ) -def test_openai_completion_prompt_logprobs(openai_client, client_with_models, text_model_id, prompt_logprobs): +def test_openai_completion_prompt_logprobs(llama_stack_client, client_with_models, text_model_id, prompt_logprobs): skip_if_provider_isnt_vllm(client_with_models, text_model_id) prompt = "Hello, world!" - response = openai_client.completions.create( + response = llama_stack_client.completions.create( model=text_model_id, prompt=prompt, stream=False, @@ -144,11 +149,11 @@ def test_openai_completion_prompt_logprobs(openai_client, client_with_models, te assert len(choice.prompt_logprobs) > 0 -def test_openai_completion_guided_choice(openai_client, client_with_models, text_model_id): +def test_openai_completion_guided_choice(llama_stack_client, client_with_models, text_model_id): skip_if_provider_isnt_vllm(client_with_models, text_model_id) prompt = "I am feeling really sad today." - response = openai_client.completions.create( + response = llama_stack_client.completions.create( model=text_model_id, prompt=prompt, stream=False, @@ -161,6 +166,9 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text assert choice.text in ["joy", "sadness"] +# Run the chat-completion tests with both the OpenAI client and the LlamaStack client + + @pytest.mark.parametrize( "test_case", [ @@ -168,13 +176,13 @@ def test_openai_completion_guided_choice(openai_client, client_with_models, text "inference:chat_completion:non_streaming_02", ], ) -def test_openai_chat_completion_non_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_openai_chat_completion_non_streaming(compat_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] - response = openai_client.chat.completions.create( + response = compat_client.chat.completions.create( model=text_model_id, messages=[ { @@ -196,13 +204,13 @@ def test_openai_chat_completion_non_streaming(openai_client, client_with_models, "inference:chat_completion:streaming_02", ], ) -def test_openai_chat_completion_streaming(openai_client, client_with_models, text_model_id, test_case): +def test_openai_chat_completion_streaming(compat_client, client_with_models, text_model_id, test_case): skip_if_model_doesnt_support_openai_chat_completion(client_with_models, text_model_id) tc = TestCase(test_case) question = tc["question"] expected = tc["expected"] - response = openai_client.chat.completions.create( + response = compat_client.chat.completions.create( model=text_model_id, messages=[{"role": "user", "content": question}], stream=True, diff --git a/tests/integration/tool_runtime/test_registration.py b/tests/integration/tool_runtime/test_registration.py index e4241d813..b36237d05 100644 --- a/tests/integration/tool_runtime/test_registration.py +++ b/tests/integration/tool_runtime/test_registration.py @@ -114,7 +114,7 @@ def test_register_and_unregister_toolgroup(llama_stack_client, mcp_server): llama_stack_client.toolgroups.unregister(toolgroup_id=test_toolgroup_id) # Verify it is unregistered - with pytest.raises(ValueError, match=f"Tool group '{test_toolgroup_id}' not found"): + with pytest.raises(Exception, match=f"Tool group '{test_toolgroup_id}' not found"): llama_stack_client.toolgroups.get(toolgroup_id=test_toolgroup_id) # Verify tools are also unregistered diff --git a/tests/unit/distribution/test_build_path.py b/tests/unit/distribution/test_build_path.py index a913bd88b..555cdda4a 100644 --- a/tests/unit/distribution/test_build_path.py +++ b/tests/unit/distribution/test_build_path.py @@ -16,8 +16,9 @@ from llama_stack.distribution.utils.image_types import LlamaStackImageType def test_container_build_passes_path(monkeypatch, tmp_path): called_with = {} - def spy_build_image(cfg, build_file_path, image_name, template_or_config): + def spy_build_image(cfg, build_file_path, image_name, template_or_config, run_config=None): called_with["path"] = template_or_config + called_with["run_config"] = run_config return 0 monkeypatch.setattr( @@ -36,3 +37,4 @@ def test_container_build_passes_path(monkeypatch, tmp_path): assert "path" in called_with assert isinstance(called_with["path"], str) assert Path(called_with["path"]).exists() + assert called_with["run_config"] is None diff --git a/tests/unit/providers/inference/test_remote_vllm.py b/tests/unit/providers/inference/test_remote_vllm.py index 88399198d..b3172cad4 100644 --- a/tests/unit/providers/inference/test_remote_vllm.py +++ b/tests/unit/providers/inference/test_remote_vllm.py @@ -28,12 +28,15 @@ from openai.types.model import Model as OpenAIModel from llama_stack.apis.inference import ( ChatCompletionRequest, + CompletionMessage, + SystemMessage, ToolChoice, ToolConfig, + ToolResponseMessage, UserMessage, ) from llama_stack.apis.models import Model -from llama_stack.models.llama.datatypes import StopReason +from llama_stack.models.llama.datatypes import StopReason, ToolCall from llama_stack.providers.remote.inference.vllm.config import VLLMInferenceAdapterConfig from llama_stack.providers.remote.inference.vllm.vllm import ( VLLMInferenceAdapter, @@ -135,6 +138,49 @@ async def test_old_vllm_tool_choice(vllm_inference_adapter): assert request.tool_config.tool_choice == ToolChoice.none +@pytest.mark.asyncio +async def test_tool_call_response(vllm_inference_adapter): + """Verify that tool call arguments from a CompletionMessage are correctly converted + into the expected JSON format.""" + + # Patch the call to vllm so we can inspect the arguments sent were correct + with patch.object( + vllm_inference_adapter.client.chat.completions, "create", new_callable=AsyncMock + ) as mock_nonstream_completion: + messages = [ + SystemMessage(content="You are a helpful assistant"), + UserMessage(content="How many?"), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + call_id="foo", + tool_name="knowledge_search", + arguments={"query": "How many?"}, + arguments_json='{"query": "How many?"}', + ) + ], + ), + ToolResponseMessage(call_id="foo", content="knowledge_search found 5...."), + ] + await vllm_inference_adapter.chat_completion( + "mock-model", + messages, + stream=False, + tools=[], + tool_config=ToolConfig(tool_choice=ToolChoice.auto), + ) + + assert mock_nonstream_completion.call_args.kwargs["messages"][2]["tool_calls"] == [ + { + "id": "foo", + "type": "function", + "function": {"name": "knowledge_search", "arguments": '{"query": "How many?"}'}, + } + ] + + @pytest.mark.asyncio async def test_tool_call_delta_empty_tool_call_buf(): """ diff --git a/tests/unit/providers/nvidia/test_eval.py b/tests/unit/providers/nvidia/test_eval.py new file mode 100644 index 000000000..584ca2101 --- /dev/null +++ b/tests/unit/providers/nvidia/test_eval.py @@ -0,0 +1,201 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import os +import unittest +from unittest.mock import MagicMock, patch + +import pytest + +from llama_stack.apis.benchmarks import Benchmark +from llama_stack.apis.common.job_types import Job, JobStatus +from llama_stack.apis.eval.eval import BenchmarkConfig, EvaluateResponse, ModelCandidate, SamplingParams +from llama_stack.models.llama.sku_types import CoreModelId +from llama_stack.providers.remote.eval.nvidia.config import NVIDIAEvalConfig +from llama_stack.providers.remote.eval.nvidia.eval import NVIDIAEvalImpl + +MOCK_DATASET_ID = "default/test-dataset" +MOCK_BENCHMARK_ID = "test-benchmark" + + +class TestNVIDIAEvalImpl(unittest.TestCase): + def setUp(self): + os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test" + + # Create mock APIs + self.datasetio_api = MagicMock() + self.datasets_api = MagicMock() + self.scoring_api = MagicMock() + self.inference_api = MagicMock() + self.agents_api = MagicMock() + + self.config = NVIDIAEvalConfig( + evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"], + ) + + self.eval_impl = NVIDIAEvalImpl( + config=self.config, + datasetio_api=self.datasetio_api, + datasets_api=self.datasets_api, + scoring_api=self.scoring_api, + inference_api=self.inference_api, + agents_api=self.agents_api, + ) + + # Mock the HTTP request methods + self.evaluator_get_patcher = patch( + "llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get" + ) + self.evaluator_post_patcher = patch( + "llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post" + ) + + self.mock_evaluator_get = self.evaluator_get_patcher.start() + self.mock_evaluator_post = self.evaluator_post_patcher.start() + + def tearDown(self): + """Clean up after each test.""" + self.evaluator_get_patcher.stop() + self.evaluator_post_patcher.stop() + + def _assert_request_body(self, expected_json): + """Helper method to verify request body in Evaluator POST request is correct""" + call_args = self.mock_evaluator_post.call_args + actual_json = call_args[0][1] + + # Check that all expected keys contain the expected values in the actual JSON + for key, value in expected_json.items(): + assert key in actual_json, f"Key '{key}' missing in actual JSON" + + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']" + assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'" + else: + assert actual_json[key] == value, f"Value mismatch for '{key}'" + + @pytest.fixture(autouse=True) + def inject_fixtures(self, run_async): + self.run_async = run_async + + def test_register_benchmark(self): + eval_config = { + "type": "custom", + "params": {"parallelism": 8}, + "tasks": { + "qa": { + "type": "completion", + "params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}}, + "dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"}, + "metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}}, + } + }, + } + + benchmark = Benchmark( + provider_id="nvidia", + type="benchmark", + identifier=MOCK_BENCHMARK_ID, + dataset_id=MOCK_DATASET_ID, + scoring_functions=["basic::equality"], + metadata=eval_config, + ) + + # Mock Evaluator API response + mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"} + self.mock_evaluator_post.return_value = mock_evaluator_response + + # Register the benchmark + self.run_async(self.eval_impl.register_benchmark(benchmark)) + + # Verify the Evaluator API was called correctly + self.mock_evaluator_post.assert_called_once() + self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}) + + def test_run_eval(self): + benchmark_config = BenchmarkConfig( + eval_candidate=ModelCandidate( + type="model", + model=CoreModelId.llama3_1_8b_instruct.value, + sampling_params=SamplingParams(max_tokens=100, temperature=0.7), + ) + ) + + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "created"} + self.mock_evaluator_post.return_value = mock_evaluator_response + + # Run the Evaluation job + result = self.run_async( + self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config) + ) + + # Verify the Evaluator API was called correctly + self.mock_evaluator_post.assert_called_once() + self._assert_request_body( + { + "config": f"nvidia/{MOCK_BENCHMARK_ID}", + "target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"}, + } + ) + + # Verify the result + assert isinstance(result, Job) + assert result.job_id == "job-123" + assert result.status == JobStatus.in_progress + + def test_job_status(self): + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "completed"} + self.mock_evaluator_get.return_value = mock_evaluator_response + + # Get the Evaluation job + result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + + # Verify the result + assert isinstance(result, Job) + assert result.job_id == "job-123" + assert result.status == JobStatus.completed + + # Verify the API was called correctly + self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}") + + def test_job_cancel(self): + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "cancelled"} + self.mock_evaluator_post.return_value = mock_evaluator_response + + # Cancel the Evaluation job + self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + + # Verify the API was called correctly + self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {}) + + def test_job_result(self): + # Mock Evaluator API responses + mock_job_status_response = {"id": "job-123", "status": "completed"} + mock_job_results_response = { + "id": "job-123", + "status": "completed", + "results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}}, + } + self.mock_evaluator_get.side_effect = [ + mock_job_status_response, # First call to retrieve job + mock_job_results_response, # Second call to retrieve job results + ] + + # Get the Evaluation job results + result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + + # Verify the result + assert isinstance(result, EvaluateResponse) + assert MOCK_BENCHMARK_ID in result.scores + assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85 + + # Verify the API was called correctly + assert self.mock_evaluator_get.call_count == 2 + self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123") + self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results") diff --git a/tests/unit/providers/nvidia/test_parameters.py b/tests/unit/providers/nvidia/test_parameters.py index cb1b92fba..ea12122a0 100644 --- a/tests/unit/providers/nvidia/test_parameters.py +++ b/tests/unit/providers/nvidia/test_parameters.py @@ -10,14 +10,17 @@ import warnings from unittest.mock import patch import pytest -from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig -from llama_stack_client.types.post_training_supervised_fine_tune_params import ( - TrainingConfig, - TrainingConfigDataConfig, - TrainingConfigEfficiencyConfig, - TrainingConfigOptimizerConfig, -) +from llama_stack.apis.post_training.post_training import ( + DataConfig, + DatasetFormat, + EfficiencyConfig, + LoraFinetuningConfig, + OptimizerConfig, + OptimizerType, + TrainingConfig, +) +from llama_stack.distribution.library_client import convert_pydantic_to_json_value from llama_stack.providers.remote.post_training.nvidia.post_training import ( NvidiaPostTrainingAdapter, NvidiaPostTrainingConfig, @@ -66,11 +69,8 @@ class TestNvidiaParameters(unittest.TestCase): def test_customizer_parameters_passed(self): """Test scenario 1: When an optional parameter is passed and value is correctly set.""" - custom_adapter_dim = 32 # Different from default of 8 algorithm_config = LoraFinetuningConfig( type="LoRA", - adapter_dim=custom_adapter_dim, - adapter_dropout=0.2, apply_lora_to_mlp=True, apply_lora_to_output=True, alpha=16, @@ -78,8 +78,15 @@ class TestNvidiaParameters(unittest.TestCase): lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) - data_config = TrainingConfigDataConfig(dataset_id="test-dataset", batch_size=16) - optimizer_config = TrainingConfigOptimizerConfig(lr=0.0002) + data_config = DataConfig( + dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0002, + weight_decay=0.01, + num_warmup_steps=100, + ) training_config = TrainingConfig( n_epochs=3, data_config=data_config, @@ -95,7 +102,7 @@ class TestNvidiaParameters(unittest.TestCase): model="meta-llama/Llama-3.1-8B-Instruct", checkpoint_dir="", algorithm_config=algorithm_config, - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={}, hyperparam_search_config={}, ) @@ -114,7 +121,7 @@ class TestNvidiaParameters(unittest.TestCase): self._assert_request_params( { "hyperparameters": { - "lora": {"adapter_dim": custom_adapter_dim, "adapter_dropout": 0.2, "alpha": 16}, + "lora": {"alpha": 16}, "epochs": 3, "learning_rate": 0.0002, "batch_size": 16, @@ -130,8 +137,6 @@ class TestNvidiaParameters(unittest.TestCase): algorithm_config = LoraFinetuningConfig( type="LoRA", - adapter_dim=16, - adapter_dropout=0.1, apply_lora_to_mlp=True, apply_lora_to_output=True, alpha=16, @@ -139,12 +144,16 @@ class TestNvidiaParameters(unittest.TestCase): lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) - data_config = TrainingConfigDataConfig( - dataset_id=required_dataset_id, # Required parameter - batch_size=8, + data_config = DataConfig( + dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct ) - optimizer_config = TrainingConfigOptimizerConfig(lr=0.0001) + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, + ) training_config = TrainingConfig( n_epochs=1, @@ -161,7 +170,7 @@ class TestNvidiaParameters(unittest.TestCase): model=required_model, # Required parameter checkpoint_dir="", algorithm_config=algorithm_config, - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={}, hyperparam_search_config={}, ) @@ -186,24 +195,24 @@ class TestNvidiaParameters(unittest.TestCase): def test_unsupported_parameters_warning(self): """Test that warnings are raised for unsupported parameters.""" - data_config = TrainingConfigDataConfig( + data_config = DataConfig( dataset_id="test-dataset", batch_size=8, # Unsupported parameters shuffle=True, - data_format="instruct", + data_format=DatasetFormat.instruct, validation_dataset_id="val-dataset", ) - optimizer_config = TrainingConfigOptimizerConfig( + optimizer_config = OptimizerConfig( lr=0.0001, weight_decay=0.01, # Unsupported parameters - optimizer_type="adam", + optimizer_type=OptimizerType.adam, num_warmup_steps=100, ) - efficiency_config = TrainingConfigEfficiencyConfig( + efficiency_config = EfficiencyConfig( enable_activation_checkpointing=True # Unsupported parameter ) @@ -230,15 +239,13 @@ class TestNvidiaParameters(unittest.TestCase): checkpoint_dir="test-dir", # Unsupported parameter algorithm_config=LoraFinetuningConfig( type="LoRA", - adapter_dim=16, - adapter_dropout=0.1, apply_lora_to_mlp=True, apply_lora_to_output=True, alpha=16, rank=16, lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ), - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={"test": "value"}, # Unsupported parameter hyperparam_search_config={"test": "value"}, # Unsupported parameter ) diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 43e0ac11c..319011be3 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -10,13 +10,19 @@ import warnings from unittest.mock import patch import pytest -from llama_stack_client.types.algorithm_config_param import LoraFinetuningConfig, QatFinetuningConfig -from llama_stack_client.types.post_training_supervised_fine_tune_params import ( - TrainingConfig, - TrainingConfigDataConfig, - TrainingConfigOptimizerConfig, -) +from llama_stack.apis.models import Model, ModelType +from llama_stack.apis.post_training.post_training import ( + DataConfig, + DatasetFormat, + LoraFinetuningConfig, + OptimizerConfig, + OptimizerType, + QATFinetuningConfig, + TrainingConfig, +) +from llama_stack.distribution.library_client import convert_pydantic_to_json_value +from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter from llama_stack.providers.remote.post_training.nvidia.post_training import ( ListNvidiaPostTrainingJobs, NvidiaPostTrainingAdapter, @@ -40,8 +46,22 @@ class TestNvidiaPostTraining(unittest.TestCase): ) self.mock_make_request = self.make_request_patcher.start() + # Mock the inference client + inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None) + self.inference_adapter = NVIDIAInferenceAdapter(inference_config) + + self.mock_client = unittest.mock.MagicMock() + self.mock_client.chat.completions.create = unittest.mock.AsyncMock() + self.inference_mock_make_request = self.mock_client.chat.completions.create + self.inference_make_request_patcher = patch( + "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", + return_value=self.mock_client, + ) + self.inference_make_request_patcher.start() + def tearDown(self): self.make_request_patcher.stop() + self.inference_make_request_patcher.stop() @pytest.fixture(autouse=True) def inject_fixtures(self, run_async): @@ -105,7 +125,7 @@ class TestNvidiaPostTraining(unittest.TestCase): "batch_size": 16, "epochs": 2, "learning_rate": 0.0001, - "lora": {"adapter_dim": 16, "adapter_dropout": 0.1}, + "lora": {"alpha": 16}, }, "output_model": "default/job-1234", "status": "created", @@ -116,8 +136,6 @@ class TestNvidiaPostTraining(unittest.TestCase): algorithm_config = LoraFinetuningConfig( type="LoRA", - adapter_dim=16, - adapter_dropout=0.1, apply_lora_to_mlp=True, apply_lora_to_output=True, alpha=16, @@ -125,10 +143,15 @@ class TestNvidiaPostTraining(unittest.TestCase): lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ) - data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16) + data_config = DataConfig( + dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) - optimizer_config = TrainingConfigOptimizerConfig( + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, ) training_config = TrainingConfig( @@ -145,7 +168,7 @@ class TestNvidiaPostTraining(unittest.TestCase): model="meta-llama/Llama-3.1-8B-Instruct", checkpoint_dir="", algorithm_config=algorithm_config, - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={}, hyperparam_search_config={}, ) @@ -169,16 +192,22 @@ class TestNvidiaPostTraining(unittest.TestCase): "epochs": 2, "batch_size": 16, "learning_rate": 0.0001, - "lora": {"alpha": 16, "adapter_dim": 16, "adapter_dropout": 0.1}, + "weight_decay": 0.01, + "lora": {"alpha": 16}, }, }, ) def test_supervised_fine_tune_with_qat(self): - algorithm_config = QatFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) - data_config = TrainingConfigDataConfig(dataset_id="sample-basic-test", batch_size=16) - optimizer_config = TrainingConfigOptimizerConfig( + algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) + data_config = DataConfig( + dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, ) training_config = TrainingConfig( n_epochs=2, @@ -193,7 +222,7 @@ class TestNvidiaPostTraining(unittest.TestCase): model="meta-llama/Llama-3.1-8B-Instruct", checkpoint_dir="", algorithm_config=algorithm_config, - training_config=training_config, + training_config=convert_pydantic_to_json_value(training_config), logger_config={}, hyperparam_search_config={}, ) @@ -303,6 +332,31 @@ class TestNvidiaPostTraining(unittest.TestCase): expected_params={"job_id": job_id}, ) + def test_inference_register_model(self): + model_id = "default/job-1234" + model_type = ModelType.llm + model = Model( + identifier=model_id, + provider_id="nvidia", + provider_model_id=model_id, + provider_resource_id=model_id, + model_type=model_type, + ) + result = self.run_async(self.inference_adapter.register_model(model)) + assert result == model + assert len(self.inference_adapter.alias_to_provider_id_map) > 1 + assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id + + with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion: + self.run_async( + self.inference_adapter.chat_completion( + model_id=model_id, + messages=[{"role": "user", "content": "Hello, model"}], + ) + ) + + mock_chat_completion.assert_called() + if __name__ == "__main__": unittest.main() diff --git a/tests/unit/providers/utils/inference/test_openai_compat.py b/tests/unit/providers/utils/inference/test_openai_compat.py new file mode 100644 index 000000000..eb02f8203 --- /dev/null +++ b/tests/unit/providers/utils/inference/test_openai_compat.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import pytest + +from llama_stack.apis.common.content_types import TextContentItem +from llama_stack.apis.inference.inference import CompletionMessage, UserMessage +from llama_stack.models.llama.datatypes import StopReason, ToolCall +from llama_stack.providers.utils.inference.openai_compat import convert_message_to_openai_dict + + +@pytest.mark.asyncio +async def test_convert_message_to_openai_dict(): + message = UserMessage(content=[TextContentItem(text="Hello, world!")], role="user") + assert await convert_message_to_openai_dict(message) == { + "role": "user", + "content": [{"type": "text", "text": "Hello, world!"}], + } + + +# Test convert_message_to_openai_dict with a tool call +@pytest.mark.asyncio +async def test_convert_message_to_openai_dict_with_tool_call(): + message = CompletionMessage( + content="", + tool_calls=[ + ToolCall(call_id="123", tool_name="test_tool", arguments_json='{"foo": "bar"}', arguments={"foo": "bar"}) + ], + stop_reason=StopReason.end_of_turn, + ) + + openai_dict = await convert_message_to_openai_dict(message) + + assert openai_dict == { + "role": "assistant", + "content": [{"type": "text", "text": ""}], + "tool_calls": [ + {"id": "123", "type": "function", "function": {"name": "test_tool", "arguments": '{"foo": "bar"}'}} + ], + } diff --git a/tests/unit/server/test_sse.py b/tests/unit/server/test_sse.py new file mode 100644 index 000000000..c78122294 --- /dev/null +++ b/tests/unit/server/test_sse.py @@ -0,0 +1,91 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +import asyncio + +import pytest + +from llama_stack.distribution.server.server import create_sse_event, sse_generator + + +@pytest.mark.asyncio +async def test_sse_generator_basic(): + # An AsyncIterator wrapped in an Awaitable, just like our web methods + async def async_event_gen(): + async def event_gen(): + yield "Test event 1" + yield "Test event 2" + + return event_gen() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + # Test that the events are streamed correctly + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + assert len(seen_events) == 2 + assert seen_events[0] == create_sse_event("Test event 1") + assert seen_events[1] == create_sse_event("Test event 2") + + +@pytest.mark.asyncio +async def test_sse_generator_client_disconnected(): + # An AsyncIterator wrapped in an Awaitable, just like our web methods + async def async_event_gen(): + async def event_gen(): + yield "Test event 1" + # Simulate a client disconnect before emitting event 2 + raise asyncio.CancelledError() + + return event_gen() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + + # We should see 1 event before the client disconnected + assert len(seen_events) == 1 + assert seen_events[0] == create_sse_event("Test event 1") + + +@pytest.mark.asyncio +async def test_sse_generator_client_disconnected_before_response_starts(): + # Disconnect before the response starts + async def async_event_gen(): + raise asyncio.CancelledError() + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + + # No events should be seen since the client disconnected immediately + assert len(seen_events) == 0 + + +@pytest.mark.asyncio +async def test_sse_generator_error_before_response_starts(): + # Raise an error before the response starts + async def async_event_gen(): + raise Exception("Test error") + + sse_gen = sse_generator(async_event_gen()) + assert sse_gen is not None + + seen_events = [] + async for event in sse_gen: + seen_events.append(event) + + # We should have 1 error event + assert len(seen_events) == 1 + assert 'data: {"error":' in seen_events[0] diff --git a/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml b/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml index 1ace76e34..0c9f1fe9e 100644 --- a/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml +++ b/tests/verifications/openai_api/fixtures/test_cases/chat_completion.yaml @@ -15,6 +15,52 @@ test_chat_basic: S? role: user output: Saturn +test_chat_input_validation: + test_name: test_chat_input_validation + test_params: + case: + - case_id: "messages_missing" + input: + messages: [] + output: + error: + status_code: 400 + - case_id: "messages_role_invalid" + input: + messages: + - content: Which planet do humans live on? + role: fake_role + output: + error: + status_code: 400 + - case_id: "tool_choice_invalid" + input: + messages: + - content: Which planet do humans live on? + role: user + tool_choice: invalid + output: + error: + status_code: 400 + - case_id: "tool_choice_no_tools" + input: + messages: + - content: Which planet do humans live on? + role: user + tool_choice: required + output: + error: + status_code: 400 + - case_id: "tools_type_invalid" + input: + messages: + - content: Which planet do humans live on? + role: user + tools: + - type: invalid + output: + error: + status_code: 400 test_chat_image: test_name: test_chat_image test_params: diff --git a/tests/verifications/openai_api/test_chat_completion.py b/tests/verifications/openai_api/test_chat_completion.py index 3a311667a..277eaafa3 100644 --- a/tests/verifications/openai_api/test_chat_completion.py +++ b/tests/verifications/openai_api/test_chat_completion.py @@ -12,6 +12,7 @@ from pathlib import Path from typing import Any import pytest +from openai import APIError from pydantic import BaseModel from tests.verifications.openai_api.fixtures.fixtures import ( @@ -136,6 +137,50 @@ def test_chat_streaming_basic(request, openai_client, model, provider, verificat assert case["output"].lower() in content.lower() +@pytest.mark.parametrize( + "case", + chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"], + ids=case_id_generator, +) +def test_chat_non_streaming_error_handling(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with pytest.raises(APIError) as e: + openai_client.chat.completions.create( + model=model, + messages=case["input"]["messages"], + stream=False, + tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None, + tools=case["input"]["tools"] if "tools" in case["input"] else None, + ) + assert case["output"]["error"]["status_code"] == e.value.status_code + + +@pytest.mark.parametrize( + "case", + chat_completion_test_cases["test_chat_input_validation"]["test_params"]["case"], + ids=case_id_generator, +) +def test_chat_streaming_error_handling(request, openai_client, model, provider, verification_config, case): + test_name_base = get_base_test_name(request) + if should_skip_test(verification_config, provider, model, test_name_base): + pytest.skip(f"Skipping {test_name_base} for model {model} on provider {provider} based on config.") + + with pytest.raises(APIError) as e: + response = openai_client.chat.completions.create( + model=model, + messages=case["input"]["messages"], + stream=True, + tool_choice=case["input"]["tool_choice"] if "tool_choice" in case["input"] else None, + tools=case["input"]["tools"] if "tools" in case["input"] else None, + ) + for _chunk in response: + pass + assert str(case["output"]["error"]["status_code"]) in e.value.message + + @pytest.mark.parametrize( "case", chat_completion_test_cases["test_chat_image"]["test_params"]["case"], diff --git a/uv.lock b/uv.lock index cd82a016c..e6368f131 100644 --- a/uv.lock +++ b/uv.lock @@ -1458,6 +1458,7 @@ unit = [ { name = "aiosqlite" }, { name = "chardet" }, { name = "openai" }, + { name = "opentelemetry-exporter-otlp-proto-http" }, { name = "pypdf" }, { name = "qdrant-client" }, { name = "sqlite-vec" }, @@ -1491,6 +1492,7 @@ requires-dist = [ { name = "openai", marker = "extra == 'test'" }, { name = "openai", marker = "extra == 'unit'" }, { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'test'" }, + { name = "opentelemetry-exporter-otlp-proto-http", marker = "extra == 'unit'" }, { name = "opentelemetry-sdk", marker = "extra == 'test'" }, { name = "pandas", marker = "extra == 'ui'" }, { name = "pillow" },