From ff9d4d8a9dd7d6d62f9c6d030e527c1f86798698 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?S=C3=A9bastien=20Han?= Date: Wed, 16 Jul 2025 13:58:05 +0200 Subject: [PATCH 01/40] ci: do not pull model (#2776) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit the model is now available in the container image Signed-off-by: Sébastien Han --- .github/actions/setup-ollama/action.yml | 2 -- .github/workflows/integration-tests.yml | 1 + 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/.github/actions/setup-ollama/action.yml b/.github/actions/setup-ollama/action.yml index 37a369a9a..bb08520f1 100644 --- a/.github/actions/setup-ollama/action.yml +++ b/.github/actions/setup-ollama/action.yml @@ -7,7 +7,5 @@ runs: shell: bash run: | docker run -d --name ollama -p 11434:11434 docker.io/leseb/ollama-with-models - # TODO: rebuild an ollama image with llama-guard3:1b echo "Verifying Ollama status..." timeout 30 bash -c 'while ! curl -s -L http://127.0.0.1:11434; do sleep 1 && echo "."; done' - docker exec ollama ollama pull llama-guard3:1b diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 7c00acfb5..a5883daf7 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -12,6 +12,7 @@ on: - 'pyproject.toml' - 'requirements.txt' - '.github/workflows/integration-tests.yml' # This workflow + - '.github/actions/setup-ollama/action.yml' schedule: - cron: '0 0 * * *' # Daily at 12 AM UTC workflow_dispatch: From fa1bb9ae002c6e60ff3e0a40b29a2444bedae4e3 Mon Sep 17 00:00:00 2001 From: IAN MILLER <75687988+r3v5@users.noreply.github.com> Date: Wed, 16 Jul 2025 15:09:44 +0100 Subject: [PATCH 02/40] docs: fix typo and link self loop for index.html#running-tests (#2777) # What does this PR do? This PR fixes typo "here here" and self loop link at [https://llama-stack.readthedocs.io/en/latest/contributing/index.html#tests/README.md](https://llama-stack.readthedocs.io/en/latest/contributing/index.html#tests/README.md) Closes #2762 ## Test Plan --- CONTRIBUTING.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/CONTRIBUTING.md b/CONTRIBUTING.md index 304c4dd26..75b29213c 100644 --- a/CONTRIBUTING.md +++ b/CONTRIBUTING.md @@ -112,7 +112,7 @@ uv run pre-commit run --all-files ## Running tests -You can find the Llama Stack testing documentation here [here](tests/README.md). +You can find the Llama Stack testing documentation [here](https://github.com/meta-llama/llama-stack/blob/main/tests/README.md). ## Adding a new dependency to the project From a3e249807bfb6de7638ca7c3cc59a6f1780a49e3 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Wed, 16 Jul 2025 10:10:04 -0400 Subject: [PATCH 03/40] chore: remove vision model URL workarounds and simplify client creation (#2775) The vision models are now available at the standard URL, so the workaround code has been removed. This also simplifies the codebase by eliminating the need for per-model client caching. - Remove special URL handling for meta/llama-3.2-11b/90b-vision-instruct models - Convert _get_client method to _client property for cleaner API - Remove unnecessary lru_cache decorator and functools import - Simplify client creation logic to use single base URL for all models --- .../remote/inference/nvidia/nvidia.py | 47 +++++-------------- .../nvidia/test_supervised_fine_tuning.py | 3 +- 2 files changed, 15 insertions(+), 35 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index 1dd72da3f..f790c2312 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -7,7 +7,6 @@ import logging import warnings from collections.abc import AsyncIterator -from functools import lru_cache from typing import Any from openai import APIConnectionError, AsyncOpenAI, BadRequestError @@ -93,41 +92,21 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self._config = config - @lru_cache # noqa: B019 - def _get_client(self, provider_model_id: str) -> AsyncOpenAI: + @property + def _client(self) -> AsyncOpenAI: """ - For hosted models, https://integrate.api.nvidia.com/v1 is the primary base_url. However, - some models are hosted on different URLs. This function returns the appropriate client - for the given provider_model_id. + Returns an OpenAI client for the configured NVIDIA API endpoint. - This relies on lru_cache and self._default_client to avoid creating a new client for each request - or for each model that is hosted on https://integrate.api.nvidia.com/v1. - - :param provider_model_id: The provider model ID :return: An OpenAI client """ - @lru_cache # noqa: B019 - def _get_client_for_base_url(base_url: str) -> AsyncOpenAI: - """ - Maintain a single OpenAI client per base_url. - """ - return AsyncOpenAI( - base_url=base_url, - api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), - timeout=self._config.timeout, - ) - - special_model_urls = { - "meta/llama-3.2-11b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-11b-vision-instruct", - "meta/llama-3.2-90b-vision-instruct": "https://ai.api.nvidia.com/v1/gr/meta/llama-3.2-90b-vision-instruct", - } - base_url = f"{self._config.url}/v1" if self._config.append_api_version else self._config.url - if _is_nvidia_hosted(self._config) and provider_model_id in special_model_urls: - base_url = special_model_urls[provider_model_id] - return _get_client_for_base_url(base_url) + return AsyncOpenAI( + base_url=base_url, + api_key=(self._config.api_key.get_secret_value() if self._config.api_key else "NO KEY"), + timeout=self._config.timeout, + ) async def _get_provider_model_id(self, model_id: str) -> str: if not self.model_store: @@ -169,7 +148,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._get_client(provider_model_id).completions.create(**request) + response = await self._client.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -222,7 +201,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): extra_body["input_type"] = task_type_options[task_type] try: - response = await self._get_client(provider_model_id).embeddings.create( + response = await self._client.embeddings.create( model=provider_model_id, input=input, extra_body=extra_body, @@ -283,7 +262,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - response = await self._get_client(provider_model_id).chat.completions.create(**request) + response = await self._client.chat.completions.create(**request) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -339,7 +318,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - return await self._get_client(provider_model_id).completions.create(**params) + return await self._client.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e @@ -398,7 +377,7 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): ) try: - return await self._get_client(provider_model_id).chat.completions.create(**params) + return await self._client.chat.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 97ca02fba..f75b0add9 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -54,7 +54,8 @@ class TestNvidiaPostTraining(unittest.TestCase): self.mock_client.chat.completions.create = unittest.mock.AsyncMock() self.inference_mock_make_request = self.mock_client.chat.completions.create self.inference_make_request_patcher = patch( - "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", + "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client", + new_callable=unittest.mock.PropertyMock, return_value=self.mock_client, ) self.inference_make_request_patcher.start() From 3165197b75d84eb1e4b349cd93e911279982afab Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Wed, 16 Jul 2025 10:12:26 -0400 Subject: [PATCH 04/40] chore: remove 'gha_workflow_llama_stack_tests.yml' (#2767) This was introduced in https://github.com/meta-llama/llama-stack/pull/523 but as far as I can tell has never been used. It's been over six months so it feels fair to remove it at this point. Signed-off-by: Nathan Weinberg --- .../gha_workflow_llama_stack_tests.yml | 355 ------------------ 1 file changed, 355 deletions(-) delete mode 100644 .github/workflows/gha_workflow_llama_stack_tests.yml diff --git a/.github/workflows/gha_workflow_llama_stack_tests.yml b/.github/workflows/gha_workflow_llama_stack_tests.yml deleted file mode 100644 index 9eae291e9..000000000 --- a/.github/workflows/gha_workflow_llama_stack_tests.yml +++ /dev/null @@ -1,355 +0,0 @@ -name: "Run Llama-stack Tests" - -on: - #### Temporarily disable PR runs until tests run as intended within mainline. - #TODO Add this back. - #pull_request_target: - # types: ["opened"] - # branches: - # - 'main' - # paths: - # - 'llama_stack/**/*.py' - # - 'tests/**/*.py' - - workflow_dispatch: - inputs: - runner: - description: 'GHA Runner Scale Set label to run workflow on.' - required: true - default: "llama-stack-gha-runner-gpu" - - checkout_reference: - description: "The branch, tag, or SHA to checkout" - required: true - default: "main" - - debug: - description: 'Run debugging steps?' - required: false - default: "true" - - sleep_time: - description: '[DEBUG] sleep time for debugging' - required: true - default: "0" - - provider_id: - description: 'ID of your provider' - required: true - default: "meta_reference" - - model_id: - description: 'Shorthand name for target model ID (llama_3b or llama_8b)' - required: true - default: "llama_3b" - - model_override_3b: - description: 'Specify shorthand model for ' - required: false - default: "Llama3.2-3B-Instruct" - - model_override_8b: - description: 'Specify shorthand model for ' - required: false - default: "Llama3.1-8B-Instruct" - -env: - # ID used for each test's provider config - PROVIDER_ID: "${{ inputs.provider_id || 'meta_reference' }}" - - # Path to model checkpoints within EFS volume - MODEL_CHECKPOINT_DIR: "/data/llama" - - # Path to directory to run tests from - TESTS_PATH: "${{ github.workspace }}/llama_stack/providers/tests" - - # Keep track of a list of model IDs that are valid to use within pytest fixture marks - AVAILABLE_MODEL_IDs: "llama_3b llama_8b" - - # Shorthand name for model ID, used in pytest fixture marks - MODEL_ID: "${{ inputs.model_id || 'llama_3b' }}" - - # Override the `llama_3b` / `llama_8b' models, else use the default. - LLAMA_3B_OVERRIDE: "${{ inputs.model_override_3b || 'Llama3.2-3B-Instruct' }}" - LLAMA_8B_OVERRIDE: "${{ inputs.model_override_8b || 'Llama3.1-8B-Instruct' }}" - - # Defines which directories in TESTS_PATH to exclude from the test loop - EXCLUDED_DIRS: "__pycache__" - - # Defines the output xml reports generated after a test is run - REPORTS_GEN: "" - -jobs: - execute_workflow: - name: Execute workload on Self-Hosted GPU k8s runner - permissions: - pull-requests: write - defaults: - run: - shell: bash - runs-on: ${{ inputs.runner != '' && inputs.runner || 'llama-stack-gha-runner-gpu' }} - if: always() - steps: - - ############################## - #### INITIAL DEBUG CHECKS #### - ############################## - - name: "[DEBUG] Check content of the EFS mount" - id: debug_efs_volume - continue-on-error: true - if: inputs.debug == 'true' - run: | - echo "========= Content of the EFS mount =============" - ls -la ${{ env.MODEL_CHECKPOINT_DIR }} - - - name: "[DEBUG] Get runner container OS information" - id: debug_os_info - if: ${{ inputs.debug == 'true' }} - run: | - cat /etc/os-release - - - name: "[DEBUG] Print environment variables" - id: debug_env_vars - if: ${{ inputs.debug == 'true' }} - run: | - echo "PROVIDER_ID = ${PROVIDER_ID}" - echo "MODEL_CHECKPOINT_DIR = ${MODEL_CHECKPOINT_DIR}" - echo "AVAILABLE_MODEL_IDs = ${AVAILABLE_MODEL_IDs}" - echo "MODEL_ID = ${MODEL_ID}" - echo "LLAMA_3B_OVERRIDE = ${LLAMA_3B_OVERRIDE}" - echo "LLAMA_8B_OVERRIDE = ${LLAMA_8B_OVERRIDE}" - echo "EXCLUDED_DIRS = ${EXCLUDED_DIRS}" - echo "REPORTS_GEN = ${REPORTS_GEN}" - - ############################ - #### MODEL INPUT CHECKS #### - ############################ - - - name: "Check if env.model_id is valid" - id: check_model_id - run: | - if [[ " ${AVAILABLE_MODEL_IDs[@]} " =~ " ${MODEL_ID} " ]]; then - echo "Model ID '${MODEL_ID}' is valid." - else - echo "Model ID '${MODEL_ID}' is invalid. Terminating workflow." - exit 1 - fi - - ####################### - #### CODE CHECKOUT #### - ####################### - - name: "Checkout 'meta-llama/llama-stack' repository" - id: checkout_repo - uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - with: - ref: ${{ inputs.branch }} - - - name: "[DEBUG] Content of the repository after checkout" - id: debug_content_after_checkout - if: ${{ inputs.debug == 'true' }} - run: | - ls -la ${GITHUB_WORKSPACE} - - ########################################################## - #### OPTIONAL SLEEP DEBUG #### - # # - # Use to "exec" into the test k8s POD and run tests # - # manually to identify what dependencies are being used. # - # # - ########################################################## - - name: "[DEBUG] sleep" - id: debug_sleep - if: ${{ inputs.debug == 'true' && inputs.sleep_time != '' }} - run: | - sleep ${{ inputs.sleep_time }} - - ############################ - #### UPDATE SYSTEM PATH #### - ############################ - - name: "Update path: execute" - id: path_update_exec - run: | - # .local/bin is needed for certain libraries installed below to be recognized - # when calling their executable to install sub-dependencies - mkdir -p ${HOME}/.local/bin - echo "${HOME}/.local/bin" >> "$GITHUB_PATH" - - ##################################### - #### UPDATE CHECKPOINT DIRECTORY #### - ##################################### - - name: "Update checkpoint directory" - id: checkpoint_update - run: | - echo "Checkpoint directory: ${MODEL_CHECKPOINT_DIR}/$LLAMA_3B_OVERRIDE" - if [ "${MODEL_ID}" = "llama_3b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" ]; then - echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_3B_OVERRIDE}" >> "$GITHUB_ENV" - elif [ "${MODEL_ID}" = "llama_8b" ] && [ -d "${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" ]; then - echo "MODEL_CHECKPOINT_DIR=${MODEL_CHECKPOINT_DIR}/${LLAMA_8B_OVERRIDE}" >> "$GITHUB_ENV" - else - echo "MODEL_ID & LLAMA_*B_OVERRIDE are not a valid pairing. Terminating workflow." - exit 1 - fi - - - name: "[DEBUG] Checkpoint update check" - id: debug_checkpoint_update - if: ${{ inputs.debug == 'true' }} - run: | - echo "MODEL_CHECKPOINT_DIR (after update) = ${MODEL_CHECKPOINT_DIR}" - - ################################## - #### DEPENDENCY INSTALLATIONS #### - ################################## - - name: "Installing 'apt' required packages" - id: install_apt - run: | - echo "[STEP] Installing 'apt' required packages" - sudo apt update -y - sudo apt install -y python3 python3-pip npm wget - - - name: "Installing packages with 'curl'" - id: install_curl - run: | - curl -fsSL https://ollama.com/install.sh | sh - - - name: "Installing packages with 'wget'" - id: install_wget - run: | - wget https://repo.continuum.io/miniconda/Miniconda3-latest-Linux-x86_64.sh - chmod +x Miniconda3-latest-Linux-x86_64.sh - ./Miniconda3-latest-Linux-x86_64.sh -b install -c pytorch -c nvidia faiss-gpu=1.9.0 - # Add miniconda3 bin to system path - echo "${HOME}/miniconda3/bin" >> "$GITHUB_PATH" - - - name: "Installing packages with 'npm'" - id: install_npm_generic - run: | - sudo npm install -g junit-merge - - - name: "Installing pip dependencies" - id: install_pip_generic - run: | - echo "[STEP] Installing 'llama-stack' models" - pip install -U pip setuptools - pip install -r requirements.txt - pip install -e . - pip install -U \ - torch torchvision \ - pytest pytest_asyncio \ - fairscale lm-format-enforcer \ - zmq chardet pypdf \ - pandas sentence_transformers together \ - aiosqlite - - name: "Installing packages with conda" - id: install_conda_generic - run: | - conda install -q -c pytorch -c nvidia faiss-gpu=1.9.0 - - ############################################################# - #### TESTING TO BE DONE FOR BOTH PRS AND MANUAL DISPATCH #### - ############################################################# - - name: "Run Tests: Loop" - id: run_tests_loop - working-directory: "${{ github.workspace }}" - run: | - pattern="" - for dir in llama_stack/providers/tests/*; do - if [ -d "$dir" ]; then - dir_name=$(basename "$dir") - if [[ ! " $EXCLUDED_DIRS " =~ " $dir_name " ]]; then - for file in "$dir"/test_*.py; do - test_name=$(basename "$file") - new_file="result-${dir_name}-${test_name}.xml" - if torchrun $(which pytest) -s -v ${TESTS_PATH}/${dir_name}/${test_name} -m "${PROVIDER_ID} and ${MODEL_ID}" \ - --junitxml="${{ github.workspace }}/${new_file}"; then - echo "Ran test: ${test_name}" - else - echo "Did NOT run test: ${test_name}" - fi - pattern+="${new_file} " - done - fi - fi - done - echo "REPORTS_GEN=$pattern" >> "$GITHUB_ENV" - - - name: "Test Summary: Merge" - id: test_summary_merge - working-directory: "${{ github.workspace }}" - run: | - echo "Merging the following test result files: ${REPORTS_GEN}" - # Defaults to merging them into 'merged-test-results.xml' - junit-merge ${{ env.REPORTS_GEN }} - - ############################################ - #### AUTOMATIC TESTING ON PULL REQUESTS #### - ############################################ - - #### Run tests #### - - - name: "PR - Run Tests" - id: pr_run_tests - working-directory: "${{ github.workspace }}" - if: github.event_name == 'pull_request_target' - run: | - echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${GITHUB_WORKSPACE} | path: ${{ github.workspace }}" - # (Optional) Add more tests here. - - # Merge test results with 'merged-test-results.xml' from above. - # junit-merge merged-test-results.xml - - #### Create test summary #### - - - name: "PR - Test Summary" - id: pr_test_summary_create - if: github.event_name == 'pull_request_target' - uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4 - with: - paths: "${{ github.workspace }}/merged-test-results.xml" - output: test-summary.md - - - name: "PR - Upload Test Summary" - id: pr_test_summary_upload - if: github.event_name == 'pull_request_target' - uses: actions/upload-artifact@ea165f8d65b6e75b540449e92b4886f43607fa02 # v4.6.2 - with: - name: test-summary - path: test-summary.md - - #### Update PR request #### - - - name: "PR - Update comment" - id: pr_update_comment - if: github.event_name == 'pull_request_target' - uses: thollander/actions-comment-pull-request@24bffb9b452ba05a4f3f77933840a6a841d1b32b # v3.0.1 - with: - filePath: test-summary.md - - ######################## - #### MANUAL TESTING #### - ######################## - - #### Run tests #### - - - name: "Manual - Run Tests: Prep" - id: manual_run_tests - working-directory: "${{ github.workspace }}" - if: github.event_name == 'workflow_dispatch' - run: | - echo "[STEP] Running PyTest tests at 'GITHUB_WORKSPACE' path: ${{ github.workspace }}" - - #TODO Use this when collection errors are resolved - # pytest -s -v -m "${PROVIDER_ID} and ${MODEL_ID}" --junitxml="${{ github.workspace }}/merged-test-results.xml" - - # (Optional) Add more tests here. - - # Merge test results with 'merged-test-results.xml' from above. - # junit-merge merged-test-results.xml - - #### Create test summary #### - - - name: "Manual - Test Summary" - id: manual_test_summary - if: always() && github.event_name == 'workflow_dispatch' - uses: test-summary/action@31493c76ec9e7aa675f1585d3ed6f1da69269a86 # v2.4 - with: - paths: "${{ github.workspace }}/merged-test-results.xml" From 72e606355d9dba05142d848bd98ae85a777e7050 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Wed, 16 Jul 2025 11:24:57 -0400 Subject: [PATCH 05/40] fix: add shutdown function for localfs provider (#2781) # What does this PR do? this was causing an unnessessary logger warning ## Test Plan Run `LLAMA_STACK_DIR=. ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b llama stack build --template starter --image-type venv --run` and then `Crtl-C` to shutdown Signed-off-by: Nathan Weinberg --- llama_stack/providers/inline/files/localfs/files.py | 3 +++ 1 file changed, 3 insertions(+) diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index 851ce2a6a..bdf8c42c7 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -51,6 +51,9 @@ class LocalfsFilesImpl(Files): }, ) + async def shutdown(self) -> None: + pass + def _generate_file_id(self) -> str: """Generate a unique file ID for OpenAI API.""" return f"file-{uuid.uuid4().hex}" From 30be1fd8b7fb454d528647cde7cc12be4e32dba8 Mon Sep 17 00:00:00 2001 From: Sergey Yedrikov <48031344+syedriko@users.noreply.github.com> Date: Wed, 16 Jul 2025 11:25:44 -0400 Subject: [PATCH 06/40] fix: SQLiteVecIndex.create(..., bank_id="test_bank.123") - bank_id with a dot - leads to sqlite3.OperationalError (#2770) (#2771) # What does this PR do? Resolves https://github.com/meta-llama/llama-stack/issues/2770. It replaces characters in SQLite table names that are not alphanumeric or underscores with underscores and quotes the table names with square brackets in SQL statements. Closes #[2770] ## Test Plan I added a ".123" suffix to the bank_id on the following line ``` index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123") ``` in tests/unit/providers/vector_io/test_sqlite_vec.py, which, without the fix in place, demonstrates the issue. --- .../inline/vector_io/sqlite_vec/sqlite_vec.py | 41 +++++++++++-------- .../providers/vector_io/test_sqlite_vec.py | 4 +- 2 files changed, 25 insertions(+), 20 deletions(-) diff --git a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py index 771ffa607..060b5b15c 100644 --- a/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py +++ b/llama_stack/providers/inline/vector_io/sqlite_vec/sqlite_vec.py @@ -7,6 +7,7 @@ import asyncio import json import logging +import re import sqlite3 import struct from typing import Any @@ -117,6 +118,10 @@ def _rrf_rerank( return rrf_scores +def _make_sql_identifier(name: str) -> str: + return re.sub(r"[^a-zA-Z0-9_]", "_", name) + + class SQLiteVecIndex(EmbeddingIndex): """ An index implementation that stores embeddings in a SQLite virtual table using sqlite-vec. @@ -130,9 +135,9 @@ class SQLiteVecIndex(EmbeddingIndex): self.dimension = dimension self.db_path = db_path self.bank_id = bank_id - self.metadata_table = f"chunks_{bank_id}".replace("-", "_") - self.vector_table = f"vec_chunks_{bank_id}".replace("-", "_") - self.fts_table = f"fts_chunks_{bank_id}".replace("-", "_") + self.metadata_table = _make_sql_identifier(f"chunks_{bank_id}") + self.vector_table = _make_sql_identifier(f"vec_chunks_{bank_id}") + self.fts_table = _make_sql_identifier(f"fts_chunks_{bank_id}") self.kvstore = kvstore @classmethod @@ -148,14 +153,14 @@ class SQLiteVecIndex(EmbeddingIndex): try: # Create the table to store chunk metadata. cur.execute(f""" - CREATE TABLE IF NOT EXISTS {self.metadata_table} ( + CREATE TABLE IF NOT EXISTS [{self.metadata_table}] ( id TEXT PRIMARY KEY, chunk TEXT ); """) # Create the virtual table for embeddings. cur.execute(f""" - CREATE VIRTUAL TABLE IF NOT EXISTS {self.vector_table} + CREATE VIRTUAL TABLE IF NOT EXISTS [{self.vector_table}] USING vec0(embedding FLOAT[{self.dimension}], id TEXT); """) connection.commit() @@ -163,7 +168,7 @@ class SQLiteVecIndex(EmbeddingIndex): # based on query. Implementation of the change on client side will allow passing the search_mode option # during initialization to make it easier to create the table that is required. cur.execute(f""" - CREATE VIRTUAL TABLE IF NOT EXISTS {self.fts_table} + CREATE VIRTUAL TABLE IF NOT EXISTS [{self.fts_table}] USING fts5(id, content); """) connection.commit() @@ -178,9 +183,9 @@ class SQLiteVecIndex(EmbeddingIndex): connection = _create_sqlite_connection(self.db_path) cur = connection.cursor() try: - cur.execute(f"DROP TABLE IF EXISTS {self.metadata_table};") - cur.execute(f"DROP TABLE IF EXISTS {self.vector_table};") - cur.execute(f"DROP TABLE IF EXISTS {self.fts_table};") + cur.execute(f"DROP TABLE IF EXISTS [{self.metadata_table}];") + cur.execute(f"DROP TABLE IF EXISTS [{self.vector_table}];") + cur.execute(f"DROP TABLE IF EXISTS [{self.fts_table}];") connection.commit() finally: cur.close() @@ -212,7 +217,7 @@ class SQLiteVecIndex(EmbeddingIndex): metadata_data = [(chunk.chunk_id, chunk.model_dump_json()) for chunk in batch_chunks] cur.executemany( f""" - INSERT INTO {self.metadata_table} (id, chunk) + INSERT INTO [{self.metadata_table}] (id, chunk) VALUES (?, ?) ON CONFLICT(id) DO UPDATE SET chunk = excluded.chunk; """, @@ -230,7 +235,7 @@ class SQLiteVecIndex(EmbeddingIndex): for chunk, emb in zip(batch_chunks, batch_embeddings, strict=True) ] cur.executemany( - f"INSERT INTO {self.vector_table} (id, embedding) VALUES (?, ?);", + f"INSERT INTO [{self.vector_table}] (id, embedding) VALUES (?, ?);", embedding_data, ) @@ -238,13 +243,13 @@ class SQLiteVecIndex(EmbeddingIndex): fts_data = [(chunk.chunk_id, chunk.content) for chunk in batch_chunks] # DELETE existing entries with same IDs (FTS5 doesn't support ON CONFLICT) cur.executemany( - f"DELETE FROM {self.fts_table} WHERE id = ?;", + f"DELETE FROM [{self.fts_table}] WHERE id = ?;", [(row[0],) for row in fts_data], ) # INSERT new entries cur.executemany( - f"INSERT INTO {self.fts_table} (id, content) VALUES (?, ?);", + f"INSERT INTO [{self.fts_table}] (id, content) VALUES (?, ?);", fts_data, ) @@ -280,8 +285,8 @@ class SQLiteVecIndex(EmbeddingIndex): emb_blob = serialize_vector(emb_list) query_sql = f""" SELECT m.id, m.chunk, v.distance - FROM {self.vector_table} AS v - JOIN {self.metadata_table} AS m ON m.id = v.id + FROM [{self.vector_table}] AS v + JOIN [{self.metadata_table}] AS m ON m.id = v.id WHERE v.embedding MATCH ? AND k = ? ORDER BY v.distance; """ @@ -322,9 +327,9 @@ class SQLiteVecIndex(EmbeddingIndex): cur = connection.cursor() try: query_sql = f""" - SELECT DISTINCT m.id, m.chunk, bm25({self.fts_table}) AS score - FROM {self.fts_table} AS f - JOIN {self.metadata_table} AS m ON m.id = f.id + SELECT DISTINCT m.id, m.chunk, bm25([{self.fts_table}]) AS score + FROM [{self.fts_table}] AS f + JOIN [{self.metadata_table}] AS m ON m.id = f.id WHERE f.content MATCH ? ORDER BY score ASC LIMIT ?; diff --git a/tests/unit/providers/vector_io/test_sqlite_vec.py b/tests/unit/providers/vector_io/test_sqlite_vec.py index a61eeeeca..23c4d6ff6 100644 --- a/tests/unit/providers/vector_io/test_sqlite_vec.py +++ b/tests/unit/providers/vector_io/test_sqlite_vec.py @@ -37,7 +37,7 @@ def loop(): async def sqlite_vec_index(embedding_dimension, tmp_path_factory): temp_dir = tmp_path_factory.getbasetemp() db_path = str(temp_dir / "test_sqlite.db") - index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank") + index = await SQLiteVecIndex.create(dimension=embedding_dimension, db_path=db_path, bank_id="test_bank.123") yield index await index.delete() @@ -110,7 +110,7 @@ async def test_chunk_id_conflict(sqlite_vec_index, sample_chunks, embedding_dime cur = connection.cursor() # Retrieve all chunk IDs to check for duplicates - cur.execute(f"SELECT id FROM {sqlite_vec_index.metadata_table}") + cur.execute(f"SELECT id FROM [{sqlite_vec_index.metadata_table}]") chunk_ids = [row[0] for row in cur.fetchall()] cur.close() connection.close() From 919ee3199bed49a701cea65a757a004e5ae38c9e Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Wed, 16 Jul 2025 12:05:48 -0400 Subject: [PATCH 07/40] docs: add missing bold title to match others (#2782) Signed-off-by: Nathan Weinberg --- docs/source/concepts/architecture.md | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/docs/source/concepts/architecture.md b/docs/source/concepts/architecture.md index 14c10848e..50cc62c7c 100644 --- a/docs/source/concepts/architecture.md +++ b/docs/source/concepts/architecture.md @@ -13,7 +13,7 @@ Llama Stack allows you to build different layers of distributions for your AI wo Building production AI applications today requires solving multiple challenges: -Infrastructure Complexity +**Infrastructure Complexity** - Running large language models efficiently requires specialized infrastructure. - Different deployment scenarios (local development, cloud, edge) need different solutions. From 6c516d391bc2ed2da6f488f2ed444f1cb6b7b738 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Wed, 16 Jul 2025 12:44:26 -0400 Subject: [PATCH 08/40] fix: de-clutter `llama stack run` logs (#2783) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit # What does this PR do? currently each disabled provider is printed as a warning, switch to debug. This level of verbosity isn't necessary, especially if we intend to grow the list of providers over time that can be in a single run yaml ## Test Plan before: Screenshot 2025-07-16 at 12 37
18 PM after: Screenshot 2025-07-16 at 12 37 42 PM Signed-off-by: Charlie Doern --- llama_stack/distribution/resolver.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/llama_stack/distribution/resolver.py b/llama_stack/distribution/resolver.py index 46cd1161e..c83218276 100644 --- a/llama_stack/distribution/resolver.py +++ b/llama_stack/distribution/resolver.py @@ -200,7 +200,7 @@ def validate_and_prepare_providers( specs = {} for provider in providers: if not provider.provider_id or provider.provider_id == "__disabled__": - logger.warning(f"Provider `{provider.provider_type}` for API `{api}` is disabled") + logger.debug(f"Provider `{provider.provider_type}` for API `{api}` is disabled") continue validate_provider(provider, api, provider_registry) From b57db11bed112cbad8885d9e9a9c0b17612c5f11 Mon Sep 17 00:00:00 2001 From: IAN MILLER <75687988+r3v5@users.noreply.github.com> Date: Wed, 16 Jul 2025 17:49:38 +0100 Subject: [PATCH 09/40] feat: create dynamic model registration for OpenAI and Llama compat remote inference providers (#2745) # What does this PR do? The purpose of this task is to create a solution that can automatically detect when new models are added, deprecated, or removed by OpenAI and Llama API providers, and automatically update the list of supported models in LLamaStack. This feature is vitally important in order to avoid missing new models and editing the entries manually hence I created automation allowing users to dynamically register: - any models from OpenAI provider available at [https://api.openai.com/v1/models](https://api.openai.com/v1/models) that are not in [https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/openai/models.py](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/openai/models.py) - any models from Llama API provider available at [https://api.llama.com/v1/models](https://api.llama.com/v1/models) that are not in [https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/llama_openai_compat/models.py](https://github.com/meta-llama/llama-stack/blob/main/llama_stack/providers/remote/inference/llama_openai_compat/models.py) Closes #2504 this PR is dependant on #2710 ## Test Plan 1. Create venv at root llamastack directory: `uv venv .venv --python 3.12 --seed` 2. Activate venv: `source .venv/bin/activate` 3. `uv pip install -e .` 4. Create OpenAI distro modifying run.yaml 5. Build distro: `llama stack build --template starter --image-type venv` 6. Then run LlamaStack, but before navigate to templates/starter folder: `llama stack run run.yaml --image-type venv OPENAI_API_KEY= ENABLE_OPENAI=openai` 7. Then try to register dummy llm that doesn't exist in OpenAI provider: ` llama-stack-client models register ianm/ianllm --provider-model-id=ianllm --provider-id=openai ` You should receive this output - combined list of static config + fetched available models from OpenAI: Screenshot 2025-07-14 at 12 48 50 8. Then register real llm from OpenAI: llama-stack-client models register openai/gpt-4-turbo-preview --provider-model-id=gpt-4-turbo-preview --provider-id=openai Screenshot 2025-07-14 at 13 43 02 Screenshot 2025-07-14 at 13 43 11 We correctly fetched all available models from OpenAI As for Llama API, as a non-US person I don't have access to Llama API Key but I joined wait list. The implementation for Llama is the same as for OpenAI since Llama is openai compatible. So, the response from GET endpoint has the same structure as OpenAI https://llama.developer.meta.com/docs/api/models --- .../inference/llama_openai_compat/llama.py | 37 ++++++++++++++++--- .../remote/inference/openai/openai.py | 23 +++++++++++- .../utils/inference/litellm_openai_mixin.py | 8 ---- pyproject.toml | 1 + requirements.txt | 8 ++++ uv.lock | 19 ++++++++++ 6 files changed, 81 insertions(+), 15 deletions(-) diff --git a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py index 29b5e889a..5f9cb20b2 100644 --- a/llama_stack/providers/remote/inference/llama_openai_compat/llama.py +++ b/llama_stack/providers/remote/inference/llama_openai_compat/llama.py @@ -3,16 +3,17 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import logging -from llama_stack.providers.remote.inference.llama_openai_compat.config import ( - LlamaCompatConfig, -) -from llama_stack.providers.utils.inference.litellm_openai_mixin import ( - LiteLLMOpenAIMixin, -) +from llama_api_client import AsyncLlamaAPIClient, NotFoundError + +from llama_stack.providers.remote.inference.llama_openai_compat.config import LlamaCompatConfig +from llama_stack.providers.utils.inference.litellm_openai_mixin import LiteLLMOpenAIMixin from .models import MODEL_ENTRIES +logger = logging.getLogger(__name__) + class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): _config: LlamaCompatConfig @@ -27,8 +28,32 @@ class LlamaCompatInferenceAdapter(LiteLLMOpenAIMixin): ) self.config = config + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from Llama API. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + try: + llama_api_client = self._get_llama_api_client() + retrieved_model = await llama_api_client.models.retrieve(model) + logger.info(f"Model {retrieved_model.id} is available from Llama API") + return True + + except NotFoundError: + logger.error(f"Model {model} is not available from Llama API") + return False + + except Exception as e: + logger.error(f"Failed to check model availability from Llama API: {e}") + return False + async def initialize(self): await super().initialize() async def shutdown(self): await super().shutdown() + + def _get_llama_api_client(self) -> AsyncLlamaAPIClient: + return AsyncLlamaAPIClient(api_key=self.get_api_key(), base_url=self.config.openai_compat_api_base) diff --git a/llama_stack/providers/remote/inference/openai/openai.py b/llama_stack/providers/remote/inference/openai/openai.py index 818883919..7e167f621 100644 --- a/llama_stack/providers/remote/inference/openai/openai.py +++ b/llama_stack/providers/remote/inference/openai/openai.py @@ -8,7 +8,7 @@ import logging from collections.abc import AsyncIterator from typing import Any -from openai import AsyncOpenAI +from openai import AsyncOpenAI, NotFoundError from llama_stack.apis.inference import ( OpenAIChatCompletion, @@ -60,6 +60,27 @@ class OpenAIInferenceAdapter(LiteLLMOpenAIMixin): # litellm specific model names, an abstraction leak. self.is_openai_compat = True + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available from OpenAI. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + try: + openai_client = self._get_openai_client() + retrieved_model = await openai_client.models.retrieve(model) + logger.info(f"Model {retrieved_model.id} is available from OpenAI") + return True + + except NotFoundError: + logger.error(f"Model {model} is not available from OpenAI") + return False + + except Exception as e: + logger.error(f"Failed to check model availability from OpenAI: {e}") + return False + async def initialize(self) -> None: await super().initialize() diff --git a/llama_stack/providers/utils/inference/litellm_openai_mixin.py b/llama_stack/providers/utils/inference/litellm_openai_mixin.py index 188e82125..0de267f6c 100644 --- a/llama_stack/providers/utils/inference/litellm_openai_mixin.py +++ b/llama_stack/providers/utils/inference/litellm_openai_mixin.py @@ -13,7 +13,6 @@ from llama_stack.apis.common.content_types import ( InterleavedContent, InterleavedContentItem, ) -from llama_stack.apis.common.errors import UnsupportedModelError from llama_stack.apis.inference import ( ChatCompletionRequest, ChatCompletionResponse, @@ -39,7 +38,6 @@ from llama_stack.apis.inference import ( ToolDefinition, ToolPromptFormat, ) -from llama_stack.apis.models import Model from llama_stack.distribution.request_headers import NeedsRequestProviderData from llama_stack.log import get_logger from llama_stack.providers.utils.inference.model_registry import ModelRegistryHelper @@ -90,12 +88,6 @@ class LiteLLMOpenAIMixin( async def shutdown(self): pass - async def register_model(self, model: Model) -> Model: - model_id = self.get_provider_model_id(model.provider_resource_id) - if model_id is None: - raise UnsupportedModelError(model.provider_resource_id, self.alias_to_provider_id_map.keys()) - return model - def get_litellm_model_name(self, model_id: str) -> str: # users may be using openai/ prefix in their model names. the openai/models.py did this by default. # model_id.startswith("openai/") is for backwards compatibility. diff --git a/pyproject.toml b/pyproject.toml index b557dfb9d..72f3a323f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -29,6 +29,7 @@ dependencies = [ "jinja2>=3.1.6", "jsonschema", "llama-stack-client>=0.2.15", + "llama-api-client>=0.1.2", "openai>=1.66", "prompt-toolkit", "python-dotenv", diff --git a/requirements.txt b/requirements.txt index eb97f7b4c..1106efac5 100644 --- a/requirements.txt +++ b/requirements.txt @@ -13,6 +13,7 @@ annotated-types==0.7.0 anyio==4.8.0 # via # httpx + # llama-api-client # llama-stack-client # openai # starlette @@ -49,6 +50,7 @@ deprecated==1.2.18 # opentelemetry-semantic-conventions distro==1.9.0 # via + # llama-api-client # llama-stack-client # openai ecdsa==0.19.1 @@ -80,6 +82,7 @@ httpcore==1.0.9 # via httpx httpx==0.28.1 # via + # llama-api-client # llama-stack # llama-stack-client # openai @@ -101,6 +104,8 @@ jsonschema==4.23.0 # via llama-stack jsonschema-specifications==2024.10.1 # via jsonschema +llama-api-client==0.1.2 + # via llama-stack llama-stack-client==0.2.15 # via llama-stack markdown-it-py==3.0.0 @@ -165,6 +170,7 @@ pycparser==2.22 ; platform_python_implementation != 'PyPy' pydantic==2.10.6 # via # fastapi + # llama-api-client # llama-stack # llama-stack-client # openai @@ -215,6 +221,7 @@ six==1.17.0 sniffio==1.3.1 # via # anyio + # llama-api-client # llama-stack-client # openai starlette==0.45.3 @@ -239,6 +246,7 @@ typing-extensions==4.12.2 # anyio # fastapi # huggingface-hub + # llama-api-client # llama-stack-client # openai # opentelemetry-sdk diff --git a/uv.lock b/uv.lock index 666cdf21f..7a9c5cab0 100644 --- a/uv.lock +++ b/uv.lock @@ -1268,6 +1268,23 @@ wheels = [ { url = "https://files.pythonhosted.org/packages/2a/f7/67689245f48b9e79bcd2f3a10a3690cb1918fb99fffd5a623ed2496bca66/litellm-1.74.2-py3-none-any.whl", hash = "sha256:29bb555b45128e4cc696e72921a6ec24e97b14e9b69e86eed6f155124ad629b1", size = 8587065 }, ] +[[package]] +name = "llama-api-client" +version = "0.1.2" +source = { registry = "https://pypi.org/simple" } +dependencies = [ + { name = "anyio" }, + { name = "distro" }, + { name = "httpx" }, + { name = "pydantic" }, + { name = "sniffio" }, + { name = "typing-extensions" }, +] +sdist = { url = "https://files.pythonhosted.org/packages/d0/78/875de3a16efd0442718ac47cc27319cd80cc5f38e12298e454e08611acc4/llama_api_client-0.1.2.tar.gz", hash = "sha256:709011f2d506009b1b3b3bceea1c84f2a3a7600df1420fb256e680fcd7251387", size = 113695, upload-time = "2025-06-27T19:56:14.057Z" } +wheels = [ + { url = "https://files.pythonhosted.org/packages/99/08/5d7e6e7e6af5353391376288c200acacebb8e6b156d3636eae598a451673/llama_api_client-0.1.2-py3-none-any.whl", hash = "sha256:8ad6e10726f74b2302bfd766c61c41355a9ecf60f57cde2961882d22af998941", size = 84091, upload-time = "2025-06-27T19:56:12.8Z" }, +] + [[package]] name = "llama-stack" version = "0.2.15" @@ -1283,6 +1300,7 @@ dependencies = [ { name = "huggingface-hub" }, { name = "jinja2" }, { name = "jsonschema" }, + { name = "llama-api-client" }, { name = "llama-stack-client" }, { name = "openai" }, { name = "opentelemetry-exporter-otlp-proto-http" }, @@ -1398,6 +1416,7 @@ requires-dist = [ { name = "jsonschema" }, { name = "llama-stack-client", specifier = ">=0.2.15" }, { name = "llama-stack-client", marker = "extra == 'ui'", specifier = ">=0.2.15" }, + { name = "llama-api-client", specifier = ">=0.1.2" }, { name = "openai", specifier = ">=1.66" }, { name = "opentelemetry-exporter-otlp-proto-http", specifier = ">=1.30.0" }, { name = "opentelemetry-sdk", specifier = ">=1.30.0" }, From 51b179e1c5b2fdc5bedb34be4a863fb5fdde95ef Mon Sep 17 00:00:00 2001 From: ehhuang Date: Wed, 16 Jul 2025 15:07:26 -0700 Subject: [PATCH 10/40] chore: update k8s template (#2786) # What does this PR do? - enables auth - updates to use distribution-starter docker ## Test Plan bash apply.sh --- docs/source/deploying/kubernetes_deployment.md | 13 ++++++++++++- docs/source/distributions/k8s/apply.sh | 18 ++++++++++++++++++ .../distributions/k8s/stack-configmap.yaml | 3 +++ .../distributions/k8s/stack-k8s.yaml.template | 2 +- .../distributions/k8s/stack_run_config.yaml | 3 +++ .../distributions/k8s/ui-k8s.yaml.template | 6 ++++++ 6 files changed, 43 insertions(+), 2 deletions(-) diff --git a/docs/source/deploying/kubernetes_deployment.md b/docs/source/deploying/kubernetes_deployment.md index c8fd075fc..7e9791d8d 100644 --- a/docs/source/deploying/kubernetes_deployment.md +++ b/docs/source/deploying/kubernetes_deployment.md @@ -222,10 +222,21 @@ llama-stack-client --endpoint http://localhost:5000 inference chat-completion -- ## Deploying Llama Stack Server in AWS EKS -We've also provided a script to deploy the Llama Stack server in an AWS EKS cluster. Once you have an [EKS cluster](https://docs.aws.amazon.com/eks/latest/userguide/getting-started.html), you can run the following script to deploy the Llama Stack server. +We've also provided a script to deploy the Llama Stack server in an AWS EKS cluster. + +Prerequisites: +- Set up an [EKS cluster](https://docs.aws.amazon.com/eks/latest/userguide/getting-started.html). +- Create a [Github OAuth app](https://docs.github.com/en/apps/oauth-apps/building-oauth-apps/creating-an-oauth-app) and get the client ID and client secret. + - Set the `Authorization callback URL` to `http:///api/auth/callback/` +Run the following script to deploy the Llama Stack server: ``` +export HF_TOKEN= +export GITHUB_CLIENT_ID= +export GITHUB_CLIENT_SECRET= +export LLAMA_STACK_UI_URL= + cd docs/source/distributions/eks ./apply.sh ``` diff --git a/docs/source/distributions/k8s/apply.sh b/docs/source/distributions/k8s/apply.sh index 7b403d34e..3356da53e 100755 --- a/docs/source/distributions/k8s/apply.sh +++ b/docs/source/distributions/k8s/apply.sh @@ -21,6 +21,24 @@ else exit 1 fi +if [ -z "${GITHUB_CLIENT_ID:-}" ]; then + echo "ERROR: GITHUB_CLIENT_ID not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide" + exit 1 +fi + +if [ -z "${GITHUB_CLIENT_SECRET:-}" ]; then + echo "ERROR: GITHUB_CLIENT_SECRET not set. You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide" + exit 1 +fi + +if [ -z "${LLAMA_STACK_UI_URL:-}" ]; then + echo "ERROR: LLAMA_STACK_UI_URL not set. Should be set to the external URL of the UI (excluding port). You need it for Github login to work. Refer to https://llama-stack.readthedocs.io/en/latest/deploying/index.html#kubernetes-deployment-guide" + exit 1 +fi + + + + set -euo pipefail set -x diff --git a/docs/source/distributions/k8s/stack-configmap.yaml b/docs/source/distributions/k8s/stack-configmap.yaml index 129471862..c505cba49 100644 --- a/docs/source/distributions/k8s/stack-configmap.yaml +++ b/docs/source/distributions/k8s/stack-configmap.yaml @@ -122,6 +122,9 @@ data: provider_id: rag-runtime server: port: 8321 + auth: + provider_config: + type: github_token kind: ConfigMap metadata: creationTimestamp: null diff --git a/docs/source/distributions/k8s/stack-k8s.yaml.template b/docs/source/distributions/k8s/stack-k8s.yaml.template index 1cfc63ef5..912445f68 100644 --- a/docs/source/distributions/k8s/stack-k8s.yaml.template +++ b/docs/source/distributions/k8s/stack-k8s.yaml.template @@ -27,7 +27,7 @@ spec: spec: containers: - name: llama-stack - image: llamastack/distribution-remote-vllm:latest + image: llamastack/distribution-starter:latest imagePullPolicy: Always # since we have specified latest instead of a version env: - name: ENABLE_CHROMADB diff --git a/docs/source/distributions/k8s/stack_run_config.yaml b/docs/source/distributions/k8s/stack_run_config.yaml index 23993ca5d..4da1bd8b4 100644 --- a/docs/source/distributions/k8s/stack_run_config.yaml +++ b/docs/source/distributions/k8s/stack_run_config.yaml @@ -119,3 +119,6 @@ tool_groups: provider_id: rag-runtime server: port: 8321 + auth: + provider_config: + type: github_token diff --git a/docs/source/distributions/k8s/ui-k8s.yaml.template b/docs/source/distributions/k8s/ui-k8s.yaml.template index ef1bf0c55..a6859cb86 100644 --- a/docs/source/distributions/k8s/ui-k8s.yaml.template +++ b/docs/source/distributions/k8s/ui-k8s.yaml.template @@ -26,6 +26,12 @@ spec: value: "http://llama-stack-service:8321" - name: LLAMA_STACK_UI_PORT value: "8322" + - name: GITHUB_CLIENT_ID + value: "${GITHUB_CLIENT_ID}" + - name: GITHUB_CLIENT_SECRET + value: "${GITHUB_CLIENT_SECRET}" + - name: NEXTAUTH_URL + value: "${LLAMA_STACK_UI_URL}:8322" args: - -c - | From c2b64dce5b58a80b69fbdf243858eb71d674e848 Mon Sep 17 00:00:00 2001 From: Derek Higgins Date: Thu, 17 Jul 2025 15:31:30 +0100 Subject: [PATCH 11/40] fix: Move sentence-transformers to the top (#2703) Move sentence-transformers to be the first embedding in the list of models. This ensures it will always be the default and is more consistent then having the default change based on what env variables are available Closes: #2702 ## Test Plan Manually verified Signed-off-by: Derek Higgins --- llama_stack/templates/starter/run.yaml | 10 +++++----- llama_stack/templates/starter/starter.py | 2 +- 2 files changed, 6 insertions(+), 6 deletions(-) diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 8e20f5224..27400348a 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -262,6 +262,11 @@ inference_store: type: sqlite db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/inference_store.db models: +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} + model_type: embedding - metadata: {} model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama3.1-8b provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} @@ -1168,11 +1173,6 @@ models: provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} provider_model_id: sambanova/Meta-Llama-Guard-3-8B model_type: llm -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} - model_type: embedding shields: - shield_id: ${env.SAFETY_MODEL:=__disabled__} provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__} diff --git a/llama_stack/templates/starter/starter.py b/llama_stack/templates/starter/starter.py index f6ca73028..ec6e8fdce 100644 --- a/llama_stack/templates/starter/starter.py +++ b/llama_stack/templates/starter/starter.py @@ -323,7 +323,7 @@ def get_distribution_template() -> DistributionTemplate: "files": [files_provider], "post_training": [post_training_provider], }, - default_models=default_models + [embedding_model], + default_models=[embedding_model] + default_models, default_tool_groups=default_tool_groups, # TODO: add a way to enable/disable shields on the fly default_shields=shields, From 57745101be4333dab8ddd83454662e598410aefc Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 17 Jul 2025 11:26:57 -0400 Subject: [PATCH 12/40] chore: internal change, make Model.provider_model_id non-optional (#2690) - POST /v1/models accepts optional provider_model_id - ModelsRoutingTable.register_model handler ensures it is non-None, providing a default usage of Model.provider_model_id will no longer need to detect None --- llama_stack/apis/models/models.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/llama_stack/apis/models/models.py b/llama_stack/apis/models/models.py index 36da97e62..2143346d9 100644 --- a/llama_stack/apis/models/models.py +++ b/llama_stack/apis/models/models.py @@ -7,7 +7,7 @@ from enum import StrEnum from typing import Any, Literal, Protocol, runtime_checkable -from pydantic import BaseModel, ConfigDict, Field +from pydantic import BaseModel, ConfigDict, Field, field_validator from llama_stack.apis.resource import Resource, ResourceType from llama_stack.providers.utils.telemetry.trace_protocol import trace_protocol @@ -36,13 +36,21 @@ class Model(CommonModelFields, Resource): return self.identifier @property - def provider_model_id(self) -> str | None: + def provider_model_id(self) -> str: + assert self.provider_resource_id is not None, "Provider resource ID must be set" return self.provider_resource_id model_config = ConfigDict(protected_namespaces=()) model_type: ModelType = Field(default=ModelType.llm) + @field_validator("provider_resource_id") + @classmethod + def validate_provider_resource_id(cls, v): + if v is None: + raise ValueError("provider_resource_id cannot be None") + return v + class ModelInput(CommonModelFields): model_id: str From 477bcd4d092eb2c9e87794b16a3d987bcc7c9351 Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 17 Jul 2025 15:11:30 -0400 Subject: [PATCH 13/40] feat: allow dynamic model registration for nvidia inference provider (#2726) # What does this PR do? let's users register models available at https://integrate.api.nvidia.com/v1/models that isn't already in llama_stack/providers/remote/inference/nvidia/models.py ## Test Plan 1. run the nvidia distro 2. register a model from https://integrate.api.nvidia.com/v1/models that isn't already know, as of this writing nvidia/llama-3.1-nemotron-ultra-253b-v1 is a good example 3. perform inference w/ the model --- .../remote/inference/nvidia/nvidia.py | 63 +++++-------------- .../nvidia/test_supervised_fine_tuning.py | 8 ++- 2 files changed, 23 insertions(+), 48 deletions(-) diff --git a/llama_stack/providers/remote/inference/nvidia/nvidia.py b/llama_stack/providers/remote/inference/nvidia/nvidia.py index f790c2312..cb7554523 100644 --- a/llama_stack/providers/remote/inference/nvidia/nvidia.py +++ b/llama_stack/providers/remote/inference/nvidia/nvidia.py @@ -9,7 +9,7 @@ import warnings from collections.abc import AsyncIterator from typing import Any -from openai import APIConnectionError, AsyncOpenAI, BadRequestError +from openai import APIConnectionError, AsyncOpenAI, BadRequestError, NotFoundError from llama_stack.apis.common.content_types import ( InterleavedContent, @@ -40,11 +40,7 @@ from llama_stack.apis.inference import ( ToolChoice, ToolConfig, ) -from llama_stack.apis.models import Model, ModelType from llama_stack.models.llama.datatypes import ToolDefinition, ToolPromptFormat -from llama_stack.providers.utils.inference import ( - ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR, -) from llama_stack.providers.utils.inference.model_registry import ( ModelRegistryHelper, ) @@ -92,6 +88,22 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): self._config = config + async def check_model_availability(self, model: str) -> bool: + """ + Check if a specific model is available. + + :param model: The model identifier to check. + :return: True if the model is available dynamically, False otherwise. + """ + try: + await self._client.models.retrieve(model) + return True + except NotFoundError: + logger.error(f"Model {model} is not available") + except Exception as e: + logger.error(f"Failed to check model availability: {e}") + return False + @property def _client(self) -> AsyncOpenAI: """ @@ -380,44 +392,3 @@ class NVIDIAInferenceAdapter(Inference, ModelRegistryHelper): return await self._client.chat.completions.create(**params) except APIConnectionError as e: raise ConnectionError(f"Failed to connect to NVIDIA NIM at {self._config.url}: {e}") from e - - async def register_model(self, model: Model) -> Model: - """ - Allow non-llama model registration. - - Non-llama model registration: API Catalogue models, post-training models, etc. - client = LlamaStackAsLibraryClient("nvidia") - client.models.register( - model_id="mistralai/mixtral-8x7b-instruct-v0.1", - model_type=ModelType.llm, - provider_id="nvidia", - provider_model_id="mistralai/mixtral-8x7b-instruct-v0.1" - ) - - NOTE: Only supports models endpoints compatible with AsyncOpenAI base_url format. - """ - if model.model_type == ModelType.embedding: - # embedding models are always registered by their provider model id and does not need to be mapped to a llama model - provider_resource_id = model.provider_resource_id - else: - provider_resource_id = self.get_provider_model_id(model.provider_resource_id) - - if provider_resource_id: - model.provider_resource_id = provider_resource_id - else: - llama_model = model.metadata.get("llama_model") - existing_llama_model = self.get_llama_model(model.provider_resource_id) - if existing_llama_model: - if existing_llama_model != llama_model: - raise ValueError( - f"Provider model id '{model.provider_resource_id}' is already registered to a different llama model: '{existing_llama_model}'" - ) - else: - # not llama model - if llama_model in ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR: - self.provider_id_to_llama_model_map[model.provider_resource_id] = ( - ALL_HUGGINGFACE_REPOS_TO_MODEL_DESCRIPTOR[llama_model] - ) - else: - self.alias_to_provider_id_map[model.provider_model_id] = model.provider_model_id - return model diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index f75b0add9..bbbb60a30 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -7,7 +7,7 @@ import os import unittest import warnings -from unittest.mock import patch +from unittest.mock import AsyncMock, patch import pytest @@ -343,7 +343,11 @@ class TestNvidiaPostTraining(unittest.TestCase): provider_resource_id=model_id, model_type=model_type, ) - result = self.run_async(self.inference_adapter.register_model(model)) + + # simulate a NIM where default/job-1234 is an available model + with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check: + mock_check.return_value = True + result = self.run_async(self.inference_adapter.register_model(model)) assert result == model assert len(self.inference_adapter.alias_to_provider_id_map) > 1 assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id From 73868ce9e30ea60d97c151e9d74e578bf8355599 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Fri, 18 Jul 2025 01:20:12 +0200 Subject: [PATCH 14/40] =?UTF-8?q?chore(test):=20migrate=20unit=20tests=20f?= =?UTF-8?q?rom=20unittest=20to=20pytest=20for=20server=20en=E2=80=A6=20(#2?= =?UTF-8?q?795)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR replaces unittest with pytest. Part of https://github.com/meta-llama/llama-stack/issues/2680 cc @leseb Signed-off-by: Mustafa Elbehery --- tests/unit/server/test_replace_env_vars.py | 135 +++++++++++---------- 1 file changed, 74 insertions(+), 61 deletions(-) diff --git a/tests/unit/server/test_replace_env_vars.py b/tests/unit/server/test_replace_env_vars.py index 432d6aee5..55817044d 100644 --- a/tests/unit/server/test_replace_env_vars.py +++ b/tests/unit/server/test_replace_env_vars.py @@ -5,73 +5,86 @@ # the root directory of this source tree. import os -import unittest + +import pytest from llama_stack.distribution.stack import replace_env_vars -class TestReplaceEnvVars(unittest.TestCase): - def setUp(self): - # Clear any existing environment variables we'll use in tests - for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]: - if var in os.environ: - del os.environ[var] +@pytest.fixture +def setup_env_vars(): + # Clear any existing environment variables we'll use in tests + for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]: + if var in os.environ: + del os.environ[var] - # Set up test environment variables - os.environ["TEST_VAR"] = "test_value" - os.environ["EMPTY_VAR"] = "" - os.environ["ZERO_VAR"] = "0" + # Set up test environment variables + os.environ["TEST_VAR"] = "test_value" + os.environ["EMPTY_VAR"] = "" + os.environ["ZERO_VAR"] = "0" - def test_simple_replacement(self): - self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value") + yield - def test_default_value_when_not_set(self): - self.assertEqual(replace_env_vars("${env.NOT_SET:=default}"), "default") - - def test_default_value_when_set(self): - self.assertEqual(replace_env_vars("${env.TEST_VAR:=default}"), "test_value") - - def test_default_value_when_empty(self): - self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=default}"), "default") - - def test_none_value_when_empty(self): - self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=}"), None) - - def test_value_when_set(self): - self.assertEqual(replace_env_vars("${env.TEST_VAR:=}"), "test_value") - - def test_empty_var_no_default(self): - self.assertEqual(replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}"), None) - - def test_conditional_value_when_set(self): - self.assertEqual(replace_env_vars("${env.TEST_VAR:+conditional}"), "conditional") - - def test_conditional_value_when_not_set(self): - self.assertEqual(replace_env_vars("${env.NOT_SET:+conditional}"), None) - - def test_conditional_value_when_empty(self): - self.assertEqual(replace_env_vars("${env.EMPTY_VAR:+conditional}"), None) - - def test_conditional_value_with_zero(self): - self.assertEqual(replace_env_vars("${env.ZERO_VAR:+conditional}"), "conditional") - - def test_mixed_syntax(self): - self.assertEqual( - replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}"), "test_value and " - ) - self.assertEqual( - replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}"), "default and conditional" - ) - - def test_nested_structures(self): - data = { - "key1": "${env.TEST_VAR:=default}", - "key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"], - "key3": {"nested": "${env.NOT_SET:+conditional}"}, - } - expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}} - self.assertEqual(replace_env_vars(data), expected) + # Cleanup after test + for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]: + if var in os.environ: + del os.environ[var] -if __name__ == "__main__": - unittest.main() +def test_simple_replacement(setup_env_vars): + assert replace_env_vars("${env.TEST_VAR}") == "test_value" + + +def test_default_value_when_not_set(setup_env_vars): + assert replace_env_vars("${env.NOT_SET:=default}") == "default" + + +def test_default_value_when_set(setup_env_vars): + assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value" + + +def test_default_value_when_empty(setup_env_vars): + assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default" + + +def test_none_value_when_empty(setup_env_vars): + assert replace_env_vars("${env.EMPTY_VAR:=}") is None + + +def test_value_when_set(setup_env_vars): + assert replace_env_vars("${env.TEST_VAR:=}") == "test_value" + + +def test_empty_var_no_default(setup_env_vars): + assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None + + +def test_conditional_value_when_set(setup_env_vars): + assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional" + + +def test_conditional_value_when_not_set(setup_env_vars): + assert replace_env_vars("${env.NOT_SET:+conditional}") is None + + +def test_conditional_value_when_empty(setup_env_vars): + assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None + + +def test_conditional_value_with_zero(setup_env_vars): + assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional" + + +def test_mixed_syntax(setup_env_vars): + assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and " + assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional" + + +def test_nested_structures(setup_env_vars): + data = { + "key1": "${env.TEST_VAR:=default}", + "key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"], + "key3": {"nested": "${env.NOT_SET:+conditional}"}, + } + expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}} + assert replace_env_vars(data) == expected From 3ae4aeb344b0a9977fc0fe83abb8ff09fb0eefa4 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Thu, 17 Jul 2025 16:20:51 -0700 Subject: [PATCH 15/40] test: add some tests for Telemetry API (#2787) # What does this PR do? ## Test Plan ENABLE_OLLAMA=ollama LLAMA_STACK_CONFIG=starter uv run pytest tests/integration/telemetry --text-model="ollama/llama3.2:3b-instruct-fp16" --- tests/integration/telemetry/test_telemetry.py | 194 +++++++++++++++--- 1 file changed, 168 insertions(+), 26 deletions(-) diff --git a/tests/integration/telemetry/test_telemetry.py b/tests/integration/telemetry/test_telemetry.py index c65f87489..9df03da70 100644 --- a/tests/integration/telemetry/test_telemetry.py +++ b/tests/integration/telemetry/test_telemetry.py @@ -5,41 +5,183 @@ # the root directory of this source tree. import time +from datetime import UTC, datetime from uuid import uuid4 import pytest from llama_stack_client import Agent -@pytest.mark.skip(reason="telemetry is not stable") -def test_agent_query_spans(llama_stack_client, text_model_id): +@pytest.fixture(scope="module", autouse=True) +def setup_telemetry_data(llama_stack_client, text_model_id): + """Setup fixture that creates telemetry data before tests run.""" agent = Agent(llama_stack_client, model=text_model_id, instructions="You are a helpful assistant") - session_id = agent.create_session(f"test-session-{uuid4()}") - agent.create_turn( - messages=[ - { - "role": "user", - "content": "Give me a sentence that contains the word: hello", - } - ], - session_id=session_id, - stream=False, + + session_id = agent.create_session(f"test-setup-session-{uuid4()}") + + messages = [ + "What is 2 + 2?", + "Tell me a short joke", + ] + + for msg in messages: + agent.create_turn( + messages=[{"role": "user", "content": msg}], + session_id=session_id, + stream=False, + ) + + for i in range(2): + llama_stack_client.inference.chat_completion( + model_id=text_model_id, messages=[{"role": "user", "content": f"Test trace {i}"}] + ) + + start_time = time.time() + + while time.time() - start_time < 30: + traces = llama_stack_client.telemetry.query_traces(limit=10) + if len(traces) >= 4: + break + time.sleep(1) + + if len(traces) < 4: + pytest.fail(f"Failed to create sufficient telemetry data after 30s. Got {len(traces)} traces.") + + yield + + +def test_query_traces_basic(llama_stack_client): + """Test basic trace querying functionality with proper data validation.""" + all_traces = llama_stack_client.telemetry.query_traces(limit=5) + + assert isinstance(all_traces, list), "Should return a list of traces" + assert len(all_traces) >= 4, "Should have at least 4 traces from setup" + + # Verify trace structure and data quality + first_trace = all_traces[0] + assert hasattr(first_trace, "trace_id"), "Trace should have trace_id" + assert hasattr(first_trace, "start_time"), "Trace should have start_time" + assert hasattr(first_trace, "root_span_id"), "Trace should have root_span_id" + + # Validate trace_id is a valid UUID format + assert isinstance(first_trace.trace_id, str) and len(first_trace.trace_id) > 0, ( + "trace_id should be non-empty string" ) - # Wait for the span to be logged - time.sleep(2) + # Validate start_time format and not in the future + now = datetime.now(UTC) + if isinstance(first_trace.start_time, str): + trace_time = datetime.fromisoformat(first_trace.start_time.replace("Z", "+00:00")) + else: + # start_time is already a datetime object + trace_time = first_trace.start_time + if trace_time.tzinfo is None: + trace_time = trace_time.replace(tzinfo=UTC) - agent_logs = [] + # Ensure trace time is not in the future (but allow any age in the past for persistent test data) + time_diff = (now - trace_time).total_seconds() + assert time_diff >= 0, f"Trace start_time should not be in the future, got {time_diff}s" - for span in llama_stack_client.telemetry.query_spans( - attribute_filters=[ - {"key": "session_id", "op": "eq", "value": session_id}, - ], - attributes_to_return=["input", "output"], - ): - if span.attributes["output"] != "no shields": - agent_logs.append(span.attributes) + # Validate root_span_id exists and is non-empty + assert isinstance(first_trace.root_span_id, str) and len(first_trace.root_span_id) > 0, ( + "root_span_id should be non-empty string" + ) - assert len(agent_logs) == 1 - assert "Give me a sentence that contains the word: hello" in agent_logs[0]["input"] - assert "hello" in agent_logs[0]["output"].lower() + # Test querying specific trace by ID + specific_trace = llama_stack_client.telemetry.get_trace(trace_id=first_trace.trace_id) + assert specific_trace.trace_id == first_trace.trace_id, "Retrieved trace should match requested ID" + assert specific_trace.start_time == first_trace.start_time, "Retrieved trace should have same start_time" + assert specific_trace.root_span_id == first_trace.root_span_id, "Retrieved trace should have same root_span_id" + + # Test pagination with proper validation + recent_traces = llama_stack_client.telemetry.query_traces(limit=3, offset=0) + assert len(recent_traces) <= 3, "Should return at most 3 traces when limit=3" + assert len(recent_traces) >= 1, "Should return at least 1 trace" + + # Verify all traces have required fields + for trace in recent_traces: + assert hasattr(trace, "trace_id") and trace.trace_id, "Each trace should have non-empty trace_id" + assert hasattr(trace, "start_time") and trace.start_time, "Each trace should have non-empty start_time" + assert hasattr(trace, "root_span_id") and trace.root_span_id, "Each trace should have non-empty root_span_id" + + +def test_query_spans_basic(llama_stack_client): + """Test basic span querying functionality with proper validation.""" + spans = llama_stack_client.telemetry.query_spans(attribute_filters=[], attributes_to_return=[]) + + assert isinstance(spans, list), "Should return a list of spans" + assert len(spans) >= 1, "Should have at least one span from setup" + + # Verify span structure and data quality + first_span = spans[0] + required_attrs = ["span_id", "name", "trace_id"] + for attr in required_attrs: + assert hasattr(first_span, attr), f"Span should have {attr} attribute" + assert getattr(first_span, attr), f"Span {attr} should not be empty" + + # Validate span data types and content + assert isinstance(first_span.span_id, str) and len(first_span.span_id) > 0, "span_id should be non-empty string" + assert isinstance(first_span.name, str) and len(first_span.name) > 0, "span name should be non-empty string" + assert isinstance(first_span.trace_id, str) and len(first_span.trace_id) > 0, "trace_id should be non-empty string" + + # Verify span belongs to a valid trace (test with traces we know exist) + all_traces = llama_stack_client.telemetry.query_traces(limit=10) + trace_ids = {t.trace_id for t in all_traces} + if first_span.trace_id in trace_ids: + trace = llama_stack_client.telemetry.get_trace(trace_id=first_span.trace_id) + assert trace is not None, "Should be able to retrieve trace for valid trace_id" + assert trace.trace_id == first_span.trace_id, "Trace ID should match span's trace_id" + + # Test with span filtering and validate results + filtered_spans = llama_stack_client.telemetry.query_spans( + attribute_filters=[{"key": "name", "op": "eq", "value": first_span.name}], + attributes_to_return=["name", "span_id"], + ) + assert isinstance(filtered_spans, list), "Should return a list with span name filter" + + # Validate filtered spans if filtering works + if len(filtered_spans) > 0: + for span in filtered_spans: + assert hasattr(span, "name"), "Filtered spans should have name attribute" + assert hasattr(span, "span_id"), "Filtered spans should have span_id attribute" + assert span.name == first_span.name, "Filtered spans should match the filter criteria" + assert isinstance(span.span_id, str) and len(span.span_id) > 0, "Filtered span_id should be valid" + + # Test that all spans have consistent structure + for span in spans: + for attr in required_attrs: + assert hasattr(span, attr) and getattr(span, attr), f"All spans should have non-empty {attr}" + + +def test_telemetry_pagination(llama_stack_client): + """Test pagination in telemetry queries.""" + # Get total count of traces + all_traces = llama_stack_client.telemetry.query_traces(limit=20) + total_count = len(all_traces) + assert total_count >= 4, "Should have at least 4 traces from setup" + + # Test trace pagination + page1 = llama_stack_client.telemetry.query_traces(limit=2, offset=0) + page2 = llama_stack_client.telemetry.query_traces(limit=2, offset=2) + + assert len(page1) == 2, "First page should have exactly 2 traces" + assert len(page2) >= 1, "Second page should have at least 1 trace" + + # Verify no overlap between pages + page1_ids = {t.trace_id for t in page1} + page2_ids = {t.trace_id for t in page2} + assert len(page1_ids.intersection(page2_ids)) == 0, "Pages should contain different traces" + + # Test ordering + ordered_traces = llama_stack_client.telemetry.query_traces(limit=5, order_by=["start_time"]) + assert len(ordered_traces) >= 4, "Should have at least 4 traces for ordering test" + + # Verify ordering by start_time + for i in range(len(ordered_traces) - 1): + current_time = ordered_traces[i].start_time + next_time = ordered_traces[i + 1].start_time + assert current_time <= next_time, f"Traces should be ordered by start_time: {current_time} > {next_time}" + + # Test limit behavior + limited = llama_stack_client.telemetry.query_traces(limit=3) + assert len(limited) == 3, "Should return exactly 3 traces when limit=3" From bd8a3ae3ccd4d22165faf69e00ed0759d50cd372 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Fri, 18 Jul 2025 01:31:38 +0200 Subject: [PATCH 16/40] chore(test): migrate unit tests from unittest to pytest for prompt adapter (#2788) This PR replaces unittest with pytest. Part of https://github.com/meta-llama/llama-stack/issues/2680 cc @leseb Co-authored-by: ehhuang --- tests/unit/models/test_prompt_adapter.py | 501 ++++++++++++----------- 1 file changed, 256 insertions(+), 245 deletions(-) diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py index 0e2780e50..577496cec 100644 --- a/tests/unit/models/test_prompt_adapter.py +++ b/tests/unit/models/test_prompt_adapter.py @@ -4,14 +4,14 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio -import unittest +import pytest from llama_stack.apis.inference import ( ChatCompletionRequest, CompletionMessage, StopReason, SystemMessage, + SystemMessageBehavior, ToolCall, ToolConfig, UserMessage, @@ -25,264 +25,275 @@ from llama_stack.models.llama.datatypes import ( from llama_stack.providers.utils.inference.prompt_adapter import ( chat_completion_request_to_messages, chat_completion_request_to_prompt, + interleaved_content_as_str, ) MODEL = "Llama3.1-8B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct" -class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - asyncio.get_running_loop().set_debug(False) +@pytest.mark.asyncio +async def test_system_default(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[-1].content == content + assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) - async def test_system_default(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2) - self.assertEqual(messages[-1].content, content) - self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) - async def test_system_builtin_only(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2) - self.assertEqual(messages[-1].content, content) - self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) - self.assertTrue("Tools: brave_search" in messages[0].content) +@pytest.mark.asyncio +async def test_system_builtin_only(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[-1].content == content + assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) + assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) - async def test_system_custom_only(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ) - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 3) - self.assertTrue("Environment: ipython" in messages[0].content) - self.assertTrue("Return function calls in JSON format" in messages[1].content) - self.assertEqual(messages[-1].content, content) +@pytest.mark.asyncio +async def test_system_custom_only(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ) + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 3 + assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) - async def test_system_custom_and_builtin(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 3) + assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) + assert messages[-1].content == content - self.assertTrue("Environment: ipython" in messages[0].content) - self.assertTrue("Tools: brave_search" in messages[0].content) - self.assertTrue("Return function calls in JSON format" in messages[1].content) - self.assertEqual(messages[-1].content, content) - - async def test_completion_message_encoding(self): - request = ChatCompletionRequest( - model=MODEL3_2, - messages=[ - UserMessage(content="hello"), - CompletionMessage( - content="", - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - tool_name="custom1", - arguments={"param1": "value1"}, - call_id="123", - ) - ], - ), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), - ) - prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn('[custom1(param1="value1")]', prompt) - - request.model = MODEL - request.tool_config.tool_prompt_format = ToolPromptFormat.json - prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn( - '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', - prompt, - ) - - async def test_user_provided_system_message(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - - self.assertEqual(messages[-1].content, content) - - async def test_repalce_system_message_behavior_builtin_tools(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", +@pytest.mark.asyncio +async def test_system_custom_and_builtin(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - self.assertIn("Environment: ipython", messages[0].content) - self.assertEqual(messages[-1].content, content) + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 3 - async def test_repalce_system_message_behavior_custom_tools(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", + assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) + assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) + + assert "Return function calls in JSON format" in interleaved_content_as_str(messages[1].content) + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_completion_message_encoding(): + request = ChatCompletionRequest( + model=MODEL3_2, + messages=[ + UserMessage(content="hello"), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + tool_name="custom1", + arguments={"param1": "value1"}, + call_id="123", + ) + ], ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - self.assertIn("Environment: ipython", messages[0].content) - self.assertEqual(messages[-1].content, content) - - async def test_replace_system_message_behavior_custom_tools_with_template(self): - content = "Hello !" - system_prompt = "You are a pirate {{ function_description }}" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", + ], + tools=[ + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), + ) + prompt = await chat_completion_request_to_prompt(request, request.model) + assert '[custom1(param1="value1")]' in prompt - self.assertEqual(len(messages), 2, messages) - self.assertIn("Environment: ipython", messages[0].content) - self.assertIn("You are a pirate", messages[0].content) - # function description is present in the system prompt - self.assertIn('"name": "custom1"', messages[0].content) - self.assertEqual(messages[-1].content, content) + request.model = MODEL + request.tool_config = ToolConfig(tool_prompt_format=ToolPromptFormat.json) + prompt = await chat_completion_request_to_prompt(request, request.model) + assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt + + +@pytest.mark.asyncio +async def test_user_provided_system_message(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) + + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_replace_system_message_behavior_builtin_tools(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format=ToolPromptFormat.python_list, + system_message_behavior=SystemMessageBehavior.replace, + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + assert len(messages) == 2 + assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) + assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_replace_system_message_behavior_custom_tools(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format=ToolPromptFormat.python_list, + system_message_behavior=SystemMessageBehavior.replace, + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + + assert len(messages) == 2 + assert interleaved_content_as_str(messages[0].content).endswith(system_prompt) + assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) + assert messages[-1].content == content + + +@pytest.mark.asyncio +async def test_replace_system_message_behavior_custom_tools_with_template(): + content = "Hello !" + system_prompt = "You are a pirate {{ function_description }}" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format=ToolPromptFormat.python_list, + system_message_behavior=SystemMessageBehavior.replace, + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + + assert len(messages) == 2 + assert "Environment: ipython" in interleaved_content_as_str(messages[0].content) + assert "You are a pirate" in interleaved_content_as_str(messages[0].content) + # function description is present in the system prompt + assert '"name": "custom1"' in interleaved_content_as_str(messages[0].content) + assert messages[-1].content == content From 910b0176800243c6eab3a3ab500f452d664028cb Mon Sep 17 00:00:00 2001 From: Matthew Farrellee Date: Thu, 17 Jul 2025 19:33:30 -0400 Subject: [PATCH 17/40] chore: block asyncio marks in tests (#2744) # What does this PR do? use pre-commit to block addition of new asyncio marks, since we configure pytest with async-mode=auto, see https://github.com/meta-llama/llama-stack/pull/2730 --- .pre-commit-config.yaml | 16 ++++++++++++++++ 1 file changed, 16 insertions(+) diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index 3c744c6bc..cf72ecd0e 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -129,6 +129,22 @@ repos: require_serial: true always_run: true files: ^llama_stack/.*$ + - id: forbid-pytest-asyncio + name: Block @pytest.mark.asyncio and @pytest_asyncio.fixture + entry: bash + language: system + types: [python] + pass_filenames: true + args: + - -c + - | + grep -EnH '^[^#]*@pytest\.mark\.asyncio|@pytest_asyncio\.fixture' "$@" && { + echo; + echo "❌ Do not use @pytest.mark.asyncio or @pytest_asyncio.fixture." + echo " pytest is already configured with async-mode=auto." + echo; + exit 1; + } || true ci: autofix_commit_msg: 🎨 [pre-commit.ci] Auto format from pre-commit.com hooks From d64e096c5f8a30f1d8455baca2250e13c73d77c3 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Thu, 17 Jul 2025 16:40:35 -0700 Subject: [PATCH 18/40] fix(cli): image name should not default to CONDA_DEFAULT_ENV (#2806) If I am running `uv run llama stack run --image-type venv` it should not be saying to me "Conda detected" because I am pretty clearly telling it I need venv. The root cause is the offending line. --- llama_stack/cli/stack/run.py | 3 +-- 1 file changed, 1 insertion(+), 2 deletions(-) diff --git a/llama_stack/cli/stack/run.py b/llama_stack/cli/stack/run.py index 1d6c475f2..f4a119522 100644 --- a/llama_stack/cli/stack/run.py +++ b/llama_stack/cli/stack/run.py @@ -47,8 +47,7 @@ class StackRun(Subcommand): self.parser.add_argument( "--image-name", type=str, - default=os.environ.get("CONDA_DEFAULT_ENV"), - help="Name of the image to run. Defaults to the current environment", + help="Name of the image to run.", ) self.parser.add_argument( "--env", From d7cc38e93424b9d4610b139889c9d8e8d4ee2352 Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Fri, 18 Jul 2025 00:35:28 -0400 Subject: [PATCH 19/40] fix: remove async test markers (fix pre-commit) (#2808) # What does this PR do? some async test markers are in the codebase causing pre-commit to fail due to #2744 remove these pytest fixtures ## Test Plan pre-commit passes Signed-off-by: Charlie Doern --- tests/integration/post_training/test_post_training.py | 6 +++--- tests/unit/models/test_prompt_adapter.py | 10 ---------- tests/unit/providers/vector_io/remote/test_milvus.py | 10 ++-------- tests/unit/rag/test_rag_query.py | 1 - 4 files changed, 5 insertions(+), 22 deletions(-) diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py index bb4639d17..3d56b322f 100644 --- a/tests/integration/post_training/test_post_training.py +++ b/tests/integration/post_training/test_post_training.py @@ -123,14 +123,14 @@ class TestPostTraining: logger.info(f"Job artifacts: {artifacts}") # TODO: Fix these tests to properly represent the Jobs API in training - # @pytest.mark.asyncio + # # async def test_get_training_jobs(self, post_training_stack): # post_training_impl = post_training_stack # jobs_list = await post_training_impl.get_training_jobs() # assert isinstance(jobs_list, list) # assert jobs_list[0].job_uuid == "1234" - # @pytest.mark.asyncio + # # async def test_get_training_job_status(self, post_training_stack): # post_training_impl = post_training_stack # job_status = await post_training_impl.get_training_job_status("1234") @@ -139,7 +139,7 @@ class TestPostTraining: # assert job_status.status == JobStatus.completed # assert isinstance(job_status.checkpoints[0], Checkpoint) - # @pytest.mark.asyncio + # # async def test_get_training_job_artifacts(self, post_training_stack): # post_training_impl = post_training_stack # job_artifacts = await post_training_impl.get_training_job_artifacts("1234") diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py index 577496cec..0362eb5dd 100644 --- a/tests/unit/models/test_prompt_adapter.py +++ b/tests/unit/models/test_prompt_adapter.py @@ -4,7 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import pytest from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -32,7 +31,6 @@ MODEL = "Llama3.1-8B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct" -@pytest.mark.asyncio async def test_system_default(): content = "Hello !" request = ChatCompletionRequest( @@ -47,7 +45,6 @@ async def test_system_default(): assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) -@pytest.mark.asyncio async def test_system_builtin_only(): content = "Hello !" request = ChatCompletionRequest( @@ -67,7 +64,6 @@ async def test_system_builtin_only(): assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) -@pytest.mark.asyncio async def test_system_custom_only(): content = "Hello !" request = ChatCompletionRequest( @@ -98,7 +94,6 @@ async def test_system_custom_only(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_system_custom_and_builtin(): content = "Hello !" request = ChatCompletionRequest( @@ -132,7 +127,6 @@ async def test_system_custom_and_builtin(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_completion_message_encoding(): request = ChatCompletionRequest( model=MODEL3_2, @@ -174,7 +168,6 @@ async def test_completion_message_encoding(): assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt -@pytest.mark.asyncio async def test_user_provided_system_message(): content = "Hello !" system_prompt = "You are a pirate" @@ -195,7 +188,6 @@ async def test_user_provided_system_message(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_replace_system_message_behavior_builtin_tools(): content = "Hello !" system_prompt = "You are a pirate" @@ -221,7 +213,6 @@ async def test_replace_system_message_behavior_builtin_tools(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_replace_system_message_behavior_custom_tools(): content = "Hello !" system_prompt = "You are a pirate" @@ -259,7 +250,6 @@ async def test_replace_system_message_behavior_custom_tools(): assert messages[-1].content == content -@pytest.mark.asyncio async def test_replace_system_message_behavior_custom_tools_with_template(): content = "Hello !" system_prompt = "You are a pirate {{ function_description }}" diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py index 2f212e374..145edf7fb 100644 --- a/tests/unit/providers/vector_io/remote/test_milvus.py +++ b/tests/unit/providers/vector_io/remote/test_milvus.py @@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch import numpy as np import pytest -import pytest_asyncio from llama_stack.apis.vector_io import QueryChunksResponse @@ -33,7 +32,7 @@ with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}): MILVUS_PROVIDER = "milvus" -@pytest_asyncio.fixture +@pytest.fixture async def mock_milvus_client() -> MagicMock: """Create a mock Milvus client with common method behaviors.""" client = MagicMock() @@ -84,7 +83,7 @@ async def mock_milvus_client() -> MagicMock: return client -@pytest_asyncio.fixture +@pytest.fixture async def milvus_index(mock_milvus_client): """Create a MilvusIndex with mocked client.""" index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection") @@ -92,7 +91,6 @@ async def milvus_index(mock_milvus_client): # No real cleanup needed since we're using mocks -@pytest.mark.asyncio async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): # Setup: collection doesn't exist initially, then exists after creation mock_milvus_client.has_collection.side_effect = [False, True] @@ -108,7 +106,6 @@ async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_m assert len(insert_call[1]["data"]) == len(sample_chunks) -@pytest.mark.asyncio async def test_query_chunks_vector( milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client ): @@ -125,7 +122,6 @@ async def test_query_chunks_vector( mock_milvus_client.search.assert_called_once() -@pytest.mark.asyncio async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): mock_milvus_client.has_collection.return_value = True await milvus_index.add_chunks(sample_chunks, sample_embeddings) @@ -138,7 +134,6 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e assert len(response.chunks) == 2 -@pytest.mark.asyncio async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): """Test that when BM25 search fails, the system falls back to simple text search.""" mock_milvus_client.has_collection.return_value = True @@ -181,7 +176,6 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring" -@pytest.mark.asyncio async def test_delete_collection(milvus_index, mock_milvus_client): # Test collection deletion mock_milvus_client.has_collection.return_value = True diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py index ad155c205..a9149541a 100644 --- a/tests/unit/rag/test_rag_query.py +++ b/tests/unit/rag/test_rag_query.py @@ -64,7 +64,6 @@ class TestRagQuery: with pytest.raises(ValueError): RAGQueryConfig(mode="invalid_mode") - @pytest.mark.asyncio async def test_query_accepts_valid_modes(self): RAGQueryConfig() # Test default (vector) RAGQueryConfig(mode="vector") # Test vector From 55713abe7da921d8869abe8bbeb0a23b5f99d7ca Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Fri, 18 Jul 2025 11:49:45 +0200 Subject: [PATCH 20/40] =?UTF-8?q?chore(test):=20migrate=20unit=20tests=20f?= =?UTF-8?q?rom=20unittest=20to=20pytest=20nvidia=20test=20p=E2=80=A6=20(#2?= =?UTF-8?q?792)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR replaces unittest with pytest. Part of https://github.com/meta-llama/llama-stack/issues/2680 cc @leseb Signed-off-by: Mustafa Elbehery --- .../unit/providers/nvidia/test_parameters.py | 38 +++++++------------ 1 file changed, 13 insertions(+), 25 deletions(-) diff --git a/tests/unit/providers/nvidia/test_parameters.py b/tests/unit/providers/nvidia/test_parameters.py index cc33f7609..7e4323bd7 100644 --- a/tests/unit/providers/nvidia/test_parameters.py +++ b/tests/unit/providers/nvidia/test_parameters.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os -import unittest import warnings from unittest.mock import patch @@ -27,14 +26,13 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import ( ) -class TestNvidiaParameters(unittest.TestCase): - def setUp(self): - os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" +class TestNvidiaParameters: + @pytest.fixture(autouse=True) + def setup_and_teardown(self): + """Setup and teardown for each test method.""" os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" - config = NvidiaPostTrainingConfig( - base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None - ) + config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None) self.adapter = NvidiaPostTrainingAdapter(config) self.make_request_patcher = patch( @@ -48,7 +46,8 @@ class TestNvidiaParameters(unittest.TestCase): "updated_at": "2025-03-04T13:07:47.543605", } - def tearDown(self): + yield + self.make_request_patcher.stop() def _assert_request_params(self, expected_json): @@ -166,8 +165,8 @@ class TestNvidiaParameters(unittest.TestCase): self.run_async( self.adapter.supervised_fine_tune( - job_uuid=required_job_uuid, # Required parameter - model=required_model, # Required parameter + job_uuid=required_job_uuid, + model=required_model, checkpoint_dir="", algorithm_config=algorithm_config, training_config=convert_pydantic_to_json_value(training_config), @@ -198,7 +197,6 @@ class TestNvidiaParameters(unittest.TestCase): data_config = DataConfig( dataset_id="test-dataset", batch_size=8, - # Unsupported parameters shuffle=True, data_format=DatasetFormat.instruct, validation_dataset_id="val-dataset", @@ -207,20 +205,16 @@ class TestNvidiaParameters(unittest.TestCase): optimizer_config = OptimizerConfig( lr=0.0001, weight_decay=0.01, - # Unsupported parameters optimizer_type=OptimizerType.adam, num_warmup_steps=100, ) - efficiency_config = EfficiencyConfig( - enable_activation_checkpointing=True # Unsupported parameter - ) + efficiency_config = EfficiencyConfig(enable_activation_checkpointing=True) training_config = TrainingConfig( n_epochs=1, data_config=data_config, optimizer_config=optimizer_config, - # Unsupported parameters efficiency_config=efficiency_config, max_steps_per_epoch=1000, gradient_accumulation_steps=4, @@ -228,7 +222,6 @@ class TestNvidiaParameters(unittest.TestCase): dtype="bf16", ) - # Capture warnings with warnings.catch_warnings(record=True) as w: warnings.simplefilter("always") @@ -236,7 +229,7 @@ class TestNvidiaParameters(unittest.TestCase): self.adapter.supervised_fine_tune( job_uuid="test-job", model="meta-llama/Llama-3.1-8B-Instruct", - checkpoint_dir="test-dir", # Unsupported parameter + checkpoint_dir="test-dir", algorithm_config=LoraFinetuningConfig( type="LoRA", apply_lora_to_mlp=True, @@ -246,8 +239,8 @@ class TestNvidiaParameters(unittest.TestCase): lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], ), training_config=convert_pydantic_to_json_value(training_config), - logger_config={"test": "value"}, # Unsupported parameter - hyperparam_search_config={"test": "value"}, # Unsupported parameter + logger_config={"test": "value"}, + hyperparam_search_config={"test": "value"}, ) ) @@ -265,7 +258,6 @@ class TestNvidiaParameters(unittest.TestCase): "gradient_accumulation_steps", "max_validation_steps", "dtype", - # required unsupported parameters "rank", "apply_lora_to_output", "lora_attn_modules", @@ -273,7 +265,3 @@ class TestNvidiaParameters(unittest.TestCase): ] for field in fields: assert any(field in text for text in warning_texts) - - -if __name__ == "__main__": - unittest.main() From 3cdf748a8ef32d88491b1b107113db6069640135 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Fri, 18 Jul 2025 11:52:47 +0200 Subject: [PATCH 21/40] chore(test): migrate unit tests from unittest to pytest for nvidia datastore (#2790) This PR replaces unittest with pytest. Part of https://github.com/meta-llama/llama-stack/issues/2680 cc @leseb Signed-off-by: Mustafa Elbehery --- tests/unit/providers/nvidia/test_datastore.py | 163 +++++++++--------- 1 file changed, 83 insertions(+), 80 deletions(-) diff --git a/tests/unit/providers/nvidia/test_datastore.py b/tests/unit/providers/nvidia/test_datastore.py index a17e51a9c..b59636f7b 100644 --- a/tests/unit/providers/nvidia/test_datastore.py +++ b/tests/unit/providers/nvidia/test_datastore.py @@ -5,103 +5,110 @@ # the root directory of this source tree. import os -import unittest from unittest.mock import patch import pytest from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource +from llama_stack.apis.resource import ResourceType from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter -class TestNvidiaDatastore(unittest.TestCase): - def setUp(self): - os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" +@pytest.fixture +def nvidia_adapter(): + """Fixture to set up NvidiaDatasetIOAdapter with mocked requests.""" + os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" - config = NvidiaDatasetIOConfig( - datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default" - ) - self.adapter = NvidiaDatasetIOAdapter(config) - self.make_request_patcher = patch( - "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request" - ) - self.mock_make_request = self.make_request_patcher.start() + config = NvidiaDatasetIOConfig( + datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default" + ) + adapter = NvidiaDatasetIOAdapter(config) - def tearDown(self): - self.make_request_patcher.stop() + with patch( + "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request" + ) as mock_make_request: + yield adapter, mock_make_request - @pytest.fixture(autouse=True) - def inject_fixtures(self, run_async): - self.run_async = run_async - def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None): - """Helper method to verify request details in mock calls.""" - call_args = mock_call.call_args +def _assert_request(mock_call, expected_method, expected_path, expected_json=None): + """Helper function to verify request details in mock calls.""" + call_args = mock_call.call_args - assert call_args[0][0] == expected_method - assert call_args[0][1] == expected_path + assert call_args[0][0] == expected_method + assert call_args[0][1] == expected_path - if expected_json: - for key, value in expected_json.items(): - assert call_args[1]["json"][key] == value + if expected_json: + for key, value in expected_json.items(): + assert call_args[1]["json"][key] == value - def test_register_dataset(self): - self.mock_make_request.return_value = { - "id": "dataset-123456", + +def test_register_dataset(nvidia_adapter, run_async): + adapter, mock_make_request = nvidia_adapter + mock_make_request.return_value = { + "id": "dataset-123456", + "name": "test-dataset", + "namespace": "default", + } + + dataset_def = Dataset( + identifier="test-dataset", + type=ResourceType.dataset, + provider_resource_id="", + provider_id="", + purpose=DatasetPurpose.post_training_messages, + source=URIDataSource(uri="https://example.com/data.jsonl"), + metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"}, + ) + + run_async(adapter.register_dataset(dataset_def)) + + mock_make_request.assert_called_once() + _assert_request( + mock_make_request, + "POST", + "/v1/datasets", + expected_json={ "name": "test-dataset", "namespace": "default", - } + "files_url": "https://example.com/data.jsonl", + "project": "default", + "format": "jsonl", + "description": "Test dataset description", + }, + ) - dataset_def = Dataset( - identifier="test-dataset", - type="dataset", - provider_resource_id="", - provider_id="", - purpose=DatasetPurpose.post_training_messages, - source=URIDataSource(uri="https://example.com/data.jsonl"), - metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"}, - ) - self.run_async(self.adapter.register_dataset(dataset_def)) +def test_unregister_dataset(nvidia_adapter, run_async): + adapter, mock_make_request = nvidia_adapter + mock_make_request.return_value = { + "message": "Resource deleted successfully.", + "id": "dataset-81RSQp7FKX3rdBtKvF9Skn", + "deleted_at": None, + } + dataset_id = "test-dataset" - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, - "POST", - "/v1/datasets", - expected_json={ - "name": "test-dataset", - "namespace": "default", - "files_url": "https://example.com/data.jsonl", - "project": "default", - "format": "jsonl", - "description": "Test dataset description", - }, - ) + run_async(adapter.unregister_dataset(dataset_id)) - def test_unregister_dataset(self): - self.mock_make_request.return_value = { - "message": "Resource deleted successfully.", - "id": "dataset-81RSQp7FKX3rdBtKvF9Skn", - "deleted_at": None, - } - dataset_id = "test-dataset" + mock_make_request.assert_called_once() + _assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset") - self.run_async(self.adapter.unregister_dataset(dataset_id)) - self.mock_make_request.assert_called_once() - self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset") +def test_register_dataset_with_custom_namespace_project(run_async): + """Test with custom namespace and project configuration.""" + os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" - def test_register_dataset_with_custom_namespace_project(self): - custom_config = NvidiaDatasetIOConfig( - datasets_url=os.environ["NVIDIA_DATASETS_URL"], - dataset_namespace="custom-namespace", - project_id="custom-project", - ) - custom_adapter = NvidiaDatasetIOAdapter(custom_config) + custom_config = NvidiaDatasetIOConfig( + datasets_url=os.environ["NVIDIA_DATASETS_URL"], + dataset_namespace="custom-namespace", + project_id="custom-project", + ) + custom_adapter = NvidiaDatasetIOAdapter(custom_config) - self.mock_make_request.return_value = { + with patch( + "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request" + ) as mock_make_request: + mock_make_request.return_value = { "id": "dataset-123456", "name": "test-dataset", "namespace": "custom-namespace", @@ -109,7 +116,7 @@ class TestNvidiaDatastore(unittest.TestCase): dataset_def = Dataset( identifier="test-dataset", - type="dataset", + type=ResourceType.dataset, provider_resource_id="", provider_id="", purpose=DatasetPurpose.post_training_messages, @@ -117,11 +124,11 @@ class TestNvidiaDatastore(unittest.TestCase): metadata={"format": "jsonl"}, ) - self.run_async(custom_adapter.register_dataset(dataset_def)) + run_async(custom_adapter.register_dataset(dataset_def)) - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, + mock_make_request.assert_called_once() + _assert_request( + mock_make_request, "POST", "/v1/datasets", expected_json={ @@ -132,7 +139,3 @@ class TestNvidiaDatastore(unittest.TestCase): "format": "jsonl", }, ) - - -if __name__ == "__main__": - unittest.main() From 75480b01b8650d0fba3cb2aedaae75743b3dff67 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Fri, 18 Jul 2025 11:54:02 +0200 Subject: [PATCH 22/40] chore(test): migrate unit tests from unittest to pytest for system prompt (#2789) This PR replaces unittest with pytest. Part of https://github.com/meta-llama/llama-stack/issues/2680 cc @leseb Signed-off-by: Mustafa Elbehery --- tests/unit/models/test_system_prompts.py | 113 ++++++++++++----------- 1 file changed, 57 insertions(+), 56 deletions(-) diff --git a/tests/unit/models/test_system_prompts.py b/tests/unit/models/test_system_prompts.py index 1f4ccc7e3..f5580f4c5 100644 --- a/tests/unit/models/test_system_prompts.py +++ b/tests/unit/models/test_system_prompts.py @@ -12,7 +12,6 @@ # the top-level of this source tree. import textwrap -import unittest from datetime import datetime from llama_stack.models.llama.llama3.prompt_templates import ( @@ -24,59 +23,61 @@ from llama_stack.models.llama.llama3.prompt_templates import ( ) -class PromptTemplateTests(unittest.TestCase): - def check_generator_output(self, generator): - for example in generator.data_examples(): - pt = generator.gen(example) - text = pt.render() - # print(text) # debugging - if not example: - continue - for tool in example: - assert tool.tool_name in text - - def test_system_default(self): - generator = SystemDefaultGenerator() - today = datetime.now().strftime("%d %B %Y") - expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" - assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() - - def test_system_builtin_only(self): - generator = BuiltinToolGenerator() - expected_text = textwrap.dedent( - """ - Environment: ipython - Tools: brave_search, wolfram_alpha - """ - ) - assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() - - def test_system_custom_only(self): - self.maxDiff = None - generator = JsonCustomToolGenerator() - self.check_generator_output(generator) - - def test_system_custom_function_tag(self): - self.maxDiff = None - generator = FunctionTagCustomToolGenerator() - self.check_generator_output(generator) - - def test_llama_3_2_system_zero_shot(self): - generator = PythonListCustomToolGenerator() - self.check_generator_output(generator) - - def test_llama_3_2_provided_system_prompt(self): - generator = PythonListCustomToolGenerator() - user_system_prompt = textwrap.dedent( - """ - Overriding message. - - {{ function_description }} - """ - ) - example = generator.data_examples()[0] - - pt = generator.gen(example, user_system_prompt) +def check_generator_output(generator): + for example in generator.data_examples(): + pt = generator.gen(example) text = pt.render() - assert "Overriding message." in text - assert '"name": "get_weather"' in text + if not example: + continue + for tool in example: + assert tool.tool_name in text + + +def test_system_default(): + generator = SystemDefaultGenerator() + today = datetime.now().strftime("%d %B %Y") + expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" + assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() + + +def test_system_builtin_only(): + generator = BuiltinToolGenerator() + expected_text = textwrap.dedent( + """ + Environment: ipython + Tools: brave_search, wolfram_alpha + """ + ) + assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() + + +def test_system_custom_only(): + generator = JsonCustomToolGenerator() + check_generator_output(generator) + + +def test_system_custom_function_tag(): + generator = FunctionTagCustomToolGenerator() + check_generator_output(generator) + + +def test_llama_3_2_system_zero_shot(): + generator = PythonListCustomToolGenerator() + check_generator_output(generator) + + +def test_llama_3_2_provided_system_prompt(): + generator = PythonListCustomToolGenerator() + user_system_prompt = textwrap.dedent( + """ + Overriding message. + + {{ function_description }} + """ + ) + example = generator.data_examples()[0] + + pt = generator.gen(example, user_system_prompt) + text = pt.render() + assert "Overriding message." in text + assert '"name": "get_weather"' in text From ca7edcd6a4929e3c7e82899280a9b7d1936e2502 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Fri, 18 Jul 2025 11:56:53 +0200 Subject: [PATCH 23/40] chore(api): add `mypy` coverage to `chat_format` (#2654) # What does this PR do? This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 ## Test Plan Signed-off-by: Mustafa Elbehery --- .../models/llama/llama3/chat_format.py | 19 +++++++++++++++---- pyproject.toml | 1 - 2 files changed, 15 insertions(+), 5 deletions(-) diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py index 7bb05d8db..0a973cf0c 100644 --- a/llama_stack/models/llama/llama3/chat_format.py +++ b/llama_stack/models/llama/llama3/chat_format.py @@ -8,6 +8,7 @@ import io import json import uuid from dataclasses import dataclass +from typing import Any from PIL import Image as PIL_Image @@ -184,16 +185,26 @@ class ChatFormat: content = content[: -len("<|eom_id|>")] stop_reason = StopReason.end_of_message - tool_name = None - tool_arguments = {} + tool_name: str | BuiltinTool | None = None + tool_arguments: dict[str, Any] = {} custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) if custom_tool_info is not None: - tool_name, tool_arguments = custom_tool_info + # Type guard: ensure custom_tool_info is a tuple of correct types + if isinstance(custom_tool_info, tuple) and len(custom_tool_info) == 2: + extracted_tool_name, extracted_tool_arguments = custom_tool_info + # Handle both dict and str return types from the function + if isinstance(extracted_tool_arguments, dict): + tool_name, tool_arguments = extracted_tool_name, extracted_tool_arguments + else: + # If it's a string, treat it as a query parameter + tool_name, tool_arguments = extracted_tool_name, {"query": extracted_tool_arguments} + else: + tool_name, tool_arguments = None, {} # Sometimes when agent has custom tools alongside builin tools # Agent responds for builtin tool calls in the format of the custom tools # This code tries to handle that case - if tool_name in BuiltinTool.__members__: + if tool_name is not None and tool_name in BuiltinTool.__members__: tool_name = BuiltinTool[tool_name] if isinstance(tool_arguments, dict): tool_arguments = { diff --git a/pyproject.toml b/pyproject.toml index 72f3a323f..22ad816d0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -242,7 +242,6 @@ exclude = [ "^llama_stack/distribution/store/registry\\.py$", "^llama_stack/distribution/utils/exec\\.py$", "^llama_stack/distribution/utils/prompt_for_config\\.py$", - "^llama_stack/models/llama/llama3/chat_format\\.py$", "^llama_stack/models/llama/llama3/interface\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$", From b78b8e148641c73507e71cf636dffb6db77dc7bb Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Fri, 18 Jul 2025 12:01:10 +0200 Subject: [PATCH 24/40] chore: add `mypy` inference parallel utils (#2670) # What does this PR do? This PR adds static type coverage to `llama-stack` Part of https://github.com/meta-llama/llama-stack/issues/2647 ## Test Plan Signed-off-by: Mustafa Elbehery --- .../inline/inference/meta_reference/parallel_utils.py | 6 +++--- pyproject.toml | 1 - 2 files changed, 3 insertions(+), 4 deletions(-) diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py index 97e96b929..7ade75032 100644 --- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py +++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py @@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel): def mp_rank_0() -> bool: - return get_model_parallel_rank() == 0 + return bool(get_model_parallel_rank() == 0) def encode_msg(msg: ProcessingMessage) -> bytes: @@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str): reply_socket.send_multipart([client_id, encode_msg(obj)]) while True: - tasks = [None] + tasks: list[ProcessingMessage | None] = [None] if mp_rank_0(): client_id, maybe_task_json = maybe_get_work(reply_socket) if maybe_task_json is not None: @@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str): break for obj in out: - updates = [None] + updates: list[ProcessingMessage | None] = [None] if mp_rank_0(): _, update_json = maybe_get_work(reply_socket) update = maybe_parse_message(update_json) diff --git a/pyproject.toml b/pyproject.toml index 22ad816d0..4d54bece0 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -254,7 +254,6 @@ exclude = [ "^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama4/", - "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/vllm/", From fe6af7dc8b49cb86042a73ac6f314968b2d0082e Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Fri, 18 Jul 2025 12:32:19 +0200 Subject: [PATCH 25/40] =?UTF-8?q?chore(test):=20migrate=20unit=20tests=20f?= =?UTF-8?q?rom=20unittest=20to=20pytest=20nvidia=20test=20f=E2=80=A6=20(#2?= =?UTF-8?q?794)?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit This PR replaces unittest with pytest. Part of https://github.com/meta-llama/llama-stack/issues/2680 cc @leseb Signed-off-by: Mustafa Elbehery --- .../nvidia/test_supervised_fine_tuning.py | 588 ++++++++---------- 1 file changed, 273 insertions(+), 315 deletions(-) diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index bbbb60a30..bc474f3bc 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -5,13 +5,11 @@ # the root directory of this source tree. import os -import unittest import warnings -from unittest.mock import AsyncMock, patch +from unittest.mock import patch import pytest -from llama_stack.apis.models import Model, ModelType from llama_stack.apis.post_training.post_training import ( DataConfig, DatasetFormat, @@ -22,7 +20,6 @@ from llama_stack.apis.post_training.post_training import ( TrainingConfig, ) from llama_stack.distribution.library_client import convert_pydantic_to_json_value -from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter from llama_stack.providers.remote.post_training.nvidia.post_training import ( ListNvidiaPostTrainingJobs, NvidiaPostTrainingAdapter, @@ -32,336 +29,297 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import ( ) -class TestNvidiaPostTraining(unittest.TestCase): - def setUp(self): - os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference - os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer +@pytest.fixture +def nvidia_post_training_adapter(): + """Fixture to create and configure the NVIDIA post training adapter.""" + os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer - config = NvidiaPostTrainingConfig( - base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None + config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None) + adapter = NvidiaPostTrainingAdapter(config) + + with patch.object(adapter, "_make_request") as mock_make_request: + yield adapter, mock_make_request + + +def _assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None): + """Helper method to verify request details in mock calls.""" + call_args = mock_call.call_args + + if expected_method and expected_path: + if isinstance(call_args[0], tuple) and len(call_args[0]) == 2: + assert call_args[0] == (expected_method, expected_path) + else: + assert call_args[1]["method"] == expected_method + assert call_args[1]["path"] == expected_path + + if expected_params: + assert call_args[1]["params"] == expected_params + + if expected_json: + for key, value in expected_json.items(): + assert call_args[1]["json"][key] == value + + +async def test_supervised_fine_tune(nvidia_post_training_adapter): + """Test the supervised fine-tuning API call.""" + adapter, mock_make_request = nvidia_post_training_adapter + mock_make_request.return_value = { + "id": "cust-JGTaMbJMdqjJU8WbQdN9Q2", + "created_at": "2024-12-09T04:06:28.542884", + "updated_at": "2024-12-09T04:06:28.542884", + "config": { + "schema_version": "1.0", + "id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1", + "created_at": "2024-12-09T04:06:28.542657", + "updated_at": "2024-12-09T04:06:28.569837", + "custom_fields": {}, + "name": "meta-llama/Llama-3.1-8B-Instruct", + "base_model": "meta-llama/Llama-3.1-8B-Instruct", + "model_path": "llama-3_1-8b-instruct", + "training_types": [], + "finetuning_types": ["lora"], + "precision": "bf16", + "num_gpus": 4, + "num_nodes": 1, + "micro_batch_size": 1, + "tensor_parallel_size": 1, + "max_seq_length": 4096, + }, + "dataset": { + "schema_version": "1.0", + "id": "dataset-XU4pvGzr5tvawnbVxeJMTb", + "created_at": "2024-12-09T04:06:28.542657", + "updated_at": "2024-12-09T04:06:28.542660", + "custom_fields": {}, + "name": "sample-basic-test", + "version_id": "main", + "version_tags": [], + }, + "hyperparameters": { + "finetuning_type": "lora", + "training_type": "sft", + "batch_size": 16, + "epochs": 2, + "learning_rate": 0.0001, + "lora": {"alpha": 16}, + }, + "output_model": "default/job-1234", + "status": "created", + "project": "default", + "custom_fields": {}, + "ownership": {"created_by": "me", "access_policies": {}}, + } + + algorithm_config = LoraFinetuningConfig( + type="LoRA", + apply_lora_to_mlp=True, + apply_lora_to_output=True, + alpha=16, + rank=16, + lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ) + + data_config = DataConfig( + dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) + + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, + ) + + training_config = TrainingConfig( + n_epochs=2, + data_config=data_config, + optimizer_config=optimizer_config, + ) + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + training_job = await adapter.supervised_fine_tune( + job_uuid="1234", + model="meta/llama-3.2-1b-instruct@v1.0.0+L40", + checkpoint_dir="", + algorithm_config=algorithm_config, + training_config=convert_pydantic_to_json_value(training_config), + logger_config={}, + hyperparam_search_config={}, ) - self.adapter = NvidiaPostTrainingAdapter(config) - self.make_request_patcher = patch( - "llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request" - ) - self.mock_make_request = self.make_request_patcher.start() - # Mock the inference client - inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None) - self.inference_adapter = NVIDIAInferenceAdapter(inference_config) + # check the output is a PostTrainingJob + assert isinstance(training_job, NvidiaPostTrainingJob) + assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2" - self.mock_client = unittest.mock.MagicMock() - self.mock_client.chat.completions.create = unittest.mock.AsyncMock() - self.inference_mock_make_request = self.mock_client.chat.completions.create - self.inference_make_request_patcher = patch( - "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client", - new_callable=unittest.mock.PropertyMock, - return_value=self.mock_client, - ) - self.inference_make_request_patcher.start() - - def tearDown(self): - self.make_request_patcher.stop() - self.inference_make_request_patcher.stop() - - @pytest.fixture(autouse=True) - def inject_fixtures(self, run_async): - self.run_async = run_async - - def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None): - """Helper method to verify request details in mock calls.""" - call_args = mock_call.call_args - - if expected_method and expected_path: - if isinstance(call_args[0], tuple) and len(call_args[0]) == 2: - assert call_args[0] == (expected_method, expected_path) - else: - assert call_args[1]["method"] == expected_method - assert call_args[1]["path"] == expected_path - - if expected_params: - assert call_args[1]["params"] == expected_params - - if expected_json: - for key, value in expected_json.items(): - assert call_args[1]["json"][key] == value - - def test_supervised_fine_tune(self): - """Test the supervised fine-tuning API call.""" - self.mock_make_request.return_value = { - "id": "cust-JGTaMbJMdqjJU8WbQdN9Q2", - "created_at": "2024-12-09T04:06:28.542884", - "updated_at": "2024-12-09T04:06:28.542884", - "config": { - "schema_version": "1.0", - "id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1", - "created_at": "2024-12-09T04:06:28.542657", - "updated_at": "2024-12-09T04:06:28.569837", - "custom_fields": {}, - "name": "meta-llama/Llama-3.1-8B-Instruct", - "base_model": "meta-llama/Llama-3.1-8B-Instruct", - "model_path": "llama-3_1-8b-instruct", - "training_types": [], - "finetuning_types": ["lora"], - "precision": "bf16", - "num_gpus": 4, - "num_nodes": 1, - "micro_batch_size": 1, - "tensor_parallel_size": 1, - "max_seq_length": 4096, - }, - "dataset": { - "schema_version": "1.0", - "id": "dataset-XU4pvGzr5tvawnbVxeJMTb", - "created_at": "2024-12-09T04:06:28.542657", - "updated_at": "2024-12-09T04:06:28.542660", - "custom_fields": {}, - "name": "sample-basic-test", - "version_id": "main", - "version_tags": [], - }, + mock_make_request.assert_called_once() + _assert_request( + mock_make_request, + "POST", + "/v1/customization/jobs", + expected_json={ + "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40", + "dataset": {"name": "sample-basic-test", "namespace": "default"}, "hyperparameters": { - "finetuning_type": "lora", "training_type": "sft", - "batch_size": 16, + "finetuning_type": "lora", "epochs": 2, + "batch_size": 16, "learning_rate": 0.0001, + "weight_decay": 0.01, "lora": {"alpha": 16}, }, - "output_model": "default/job-1234", - "status": "created", - "project": "default", - "custom_fields": {}, - "ownership": {"created_by": "me", "access_policies": {}}, + }, + ) + + +async def test_supervised_fine_tune_with_qat(nvidia_post_training_adapter): + """Test that QAT configuration raises NotImplementedError.""" + adapter, mock_make_request = nvidia_post_training_adapter + + algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) + data_config = DataConfig( + dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, + ) + training_config = TrainingConfig( + n_epochs=2, + data_config=data_config, + optimizer_config=optimizer_config, + ) + + # This will raise NotImplementedError since QAT is not supported + with pytest.raises(NotImplementedError): + await adapter.supervised_fine_tune( + job_uuid="1234", + model="meta/llama-3.2-1b-instruct@v1.0.0+L40", + checkpoint_dir="", + algorithm_config=algorithm_config, + training_config=convert_pydantic_to_json_value(training_config), + logger_config={}, + hyperparam_search_config={}, + ) + + +async def test_get_training_job_status(nvidia_post_training_adapter): + """Test getting training job status with different statuses.""" + adapter, mock_make_request = nvidia_post_training_adapter + + customizer_status_to_job_status = [ + ("running", "in_progress"), + ("completed", "completed"), + ("failed", "failed"), + ("cancelled", "cancelled"), + ("pending", "scheduled"), + ("unknown", "scheduled"), + ] + + for customizer_status, expected_status in customizer_status_to_job_status: + mock_make_request.return_value = { + "created_at": "2024-12-09T04:06:28.580220", + "updated_at": "2024-12-09T04:21:19.852832", + "status": customizer_status, + "steps_completed": 1210, + "epochs_completed": 2, + "percentage_done": 100.0, + "best_epoch": 2, + "train_loss": 1.718016266822815, + "val_loss": 1.8661999702453613, } - algorithm_config = LoraFinetuningConfig( - type="LoRA", - apply_lora_to_mlp=True, - apply_lora_to_output=True, - alpha=16, - rank=16, - lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], - ) - - data_config = DataConfig( - dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct - ) - - optimizer_config = OptimizerConfig( - optimizer_type=OptimizerType.adam, - lr=0.0001, - weight_decay=0.01, - num_warmup_steps=100, - ) - - training_config = TrainingConfig( - n_epochs=2, - data_config=data_config, - optimizer_config=optimizer_config, - ) - - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - training_job = self.run_async( - self.adapter.supervised_fine_tune( - job_uuid="1234", - model="meta/llama-3.2-1b-instruct@v1.0.0+L40", - checkpoint_dir="", - algorithm_config=algorithm_config, - training_config=convert_pydantic_to_json_value(training_config), - logger_config={}, - hyperparam_search_config={}, - ) - ) - - # check the output is a PostTrainingJob - assert isinstance(training_job, NvidiaPostTrainingJob) - assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2" - - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, - "POST", - "/v1/customization/jobs", - expected_json={ - "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40", - "dataset": {"name": "sample-basic-test", "namespace": "default"}, - "hyperparameters": { - "training_type": "sft", - "finetuning_type": "lora", - "epochs": 2, - "batch_size": 16, - "learning_rate": 0.0001, - "weight_decay": 0.01, - "lora": {"alpha": 16}, - }, - }, - ) - - def test_supervised_fine_tune_with_qat(self): - algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) - data_config = DataConfig( - dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct - ) - optimizer_config = OptimizerConfig( - optimizer_type=OptimizerType.adam, - lr=0.0001, - weight_decay=0.01, - num_warmup_steps=100, - ) - training_config = TrainingConfig( - n_epochs=2, - data_config=data_config, - optimizer_config=optimizer_config, - ) - # This will raise NotImplementedError since QAT is not supported - with self.assertRaises(NotImplementedError): - self.run_async( - self.adapter.supervised_fine_tune( - job_uuid="1234", - model="meta/llama-3.2-1b-instruct@v1.0.0+L40", - checkpoint_dir="", - algorithm_config=algorithm_config, - training_config=convert_pydantic_to_json_value(training_config), - logger_config={}, - hyperparam_search_config={}, - ) - ) - - def test_get_training_job_status(self): - customizer_status_to_job_status = [ - ("running", "in_progress"), - ("completed", "completed"), - ("failed", "failed"), - ("cancelled", "cancelled"), - ("pending", "scheduled"), - ("unknown", "scheduled"), - ] - - for customizer_status, expected_status in customizer_status_to_job_status: - with self.subTest(customizer_status=customizer_status, expected_status=expected_status): - self.mock_make_request.return_value = { - "created_at": "2024-12-09T04:06:28.580220", - "updated_at": "2024-12-09T04:21:19.852832", - "status": customizer_status, - "steps_completed": 1210, - "epochs_completed": 2, - "percentage_done": 100.0, - "best_epoch": 2, - "train_loss": 1.718016266822815, - "val_loss": 1.8661999702453613, - } - - job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" - - status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id)) - - assert isinstance(status, NvidiaPostTrainingJobStatusResponse) - assert status.status.value == expected_status - assert status.steps_completed == 1210 - assert status.epochs_completed == 2 - assert status.percentage_done == 100.0 - assert status.best_epoch == 2 - assert status.train_loss == 1.718016266822815 - assert status.val_loss == 1.8661999702453613 - - self._assert_request( - self.mock_make_request, - "GET", - f"/v1/customization/jobs/{job_id}/status", - expected_params={"job_id": job_id}, - ) - - def test_get_training_jobs(self): job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" - self.mock_make_request.return_value = { - "data": [ - { - "id": job_id, - "created_at": "2024-12-09T04:06:28.542884", - "updated_at": "2024-12-09T04:21:19.852832", - "config": { - "name": "meta-llama/Llama-3.1-8B-Instruct", - "base_model": "meta-llama/Llama-3.1-8B-Instruct", - }, - "dataset": {"name": "default/sample-basic-test"}, - "hyperparameters": { - "finetuning_type": "lora", - "training_type": "sft", - "batch_size": 16, - "epochs": 2, - "learning_rate": 0.0001, - "lora": {"adapter_dim": 16, "adapter_dropout": 0.1}, - }, - "output_model": "default/job-1234", - "status": "completed", - "project": "default", - } - ] - } - jobs = self.run_async(self.adapter.get_training_jobs()) + status = await adapter.get_training_job_status(job_uuid=job_id) - assert isinstance(jobs, ListNvidiaPostTrainingJobs) - assert len(jobs.data) == 1 - job = jobs.data[0] - assert job.job_uuid == job_id - assert job.status.value == "completed" + assert isinstance(status, NvidiaPostTrainingJobStatusResponse) + assert status.status.value == expected_status + # Note: The response object inherits extra fields via ConfigDict(extra="allow") + # So these attributes should be accessible using getattr with defaults + assert getattr(status, "steps_completed", None) == 1210 + assert getattr(status, "epochs_completed", None) == 2 + assert getattr(status, "percentage_done", None) == 100.0 + assert getattr(status, "best_epoch", None) == 2 + assert getattr(status, "train_loss", None) == 1.718016266822815 + assert getattr(status, "val_loss", None) == 1.8661999702453613 - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, + _assert_request( + mock_make_request, "GET", - "/v1/customization/jobs", - expected_params={"page": 1, "page_size": 10, "sort": "created_at"}, - ) - - def test_cancel_training_job(self): - self.mock_make_request.return_value = {} # Empty response for successful cancellation - job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" - - result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id)) - - assert result is None - - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, - "POST", - f"/v1/customization/jobs/{job_id}/cancel", + f"/v1/customization/jobs/{job_id}/status", expected_params={"job_id": job_id}, ) - def test_inference_register_model(self): - model_id = "default/job-1234" - model_type = ModelType.llm - model = Model( - identifier=model_id, - provider_id="nvidia", - provider_model_id=model_id, - provider_resource_id=model_id, - model_type=model_type, - ) - - # simulate a NIM where default/job-1234 is an available model - with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check: - mock_check.return_value = True - result = self.run_async(self.inference_adapter.register_model(model)) - assert result == model - assert len(self.inference_adapter.alias_to_provider_id_map) > 1 - assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id - - with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion: - self.run_async( - self.inference_adapter.chat_completion( - model_id=model_id, - messages=[{"role": "user", "content": "Hello, model"}], - ) - ) - - mock_chat_completion.assert_called() + mock_make_request.reset_mock() -if __name__ == "__main__": - unittest.main() +async def test_get_training_jobs(nvidia_post_training_adapter): + """Test getting list of training jobs.""" + adapter, mock_make_request = nvidia_post_training_adapter + + job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" + mock_make_request.return_value = { + "data": [ + { + "id": job_id, + "created_at": "2024-12-09T04:06:28.542884", + "updated_at": "2024-12-09T04:21:19.852832", + "config": { + "name": "meta-llama/Llama-3.1-8B-Instruct", + "base_model": "meta-llama/Llama-3.1-8B-Instruct", + }, + "dataset": {"name": "default/sample-basic-test"}, + "hyperparameters": { + "finetuning_type": "lora", + "training_type": "sft", + "batch_size": 16, + "epochs": 2, + "learning_rate": 0.0001, + "lora": {"adapter_dim": 16, "adapter_dropout": 0.1}, + }, + "output_model": "default/job-1234", + "status": "completed", + "project": "default", + } + ] + } + + jobs = await adapter.get_training_jobs() + + assert isinstance(jobs, ListNvidiaPostTrainingJobs) + assert len(jobs.data) == 1 + job = jobs.data[0] + assert job.job_uuid == job_id + assert job.status.value == "completed" + + mock_make_request.assert_called_once() + _assert_request( + mock_make_request, + "GET", + "/v1/customization/jobs", + expected_params={"page": 1, "page_size": 10, "sort": "created_at"}, + ) + + +async def test_cancel_training_job(nvidia_post_training_adapter): + """Test canceling a training job.""" + adapter, mock_make_request = nvidia_post_training_adapter + + mock_make_request.return_value = {} # Empty response for successful cancellation + job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" + + result = await adapter.cancel_training_job(job_uuid=job_id) + + assert result is None + + mock_make_request.assert_called_once() + _assert_request( + mock_make_request, + "POST", + f"/v1/customization/jobs/{job_id}/cancel", + expected_params={"job_id": job_id}, + ) From 0eb0583cdfd87c491592345f6b34a7e507e5eb9a Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Fri, 18 Jul 2025 09:23:36 -0400 Subject: [PATCH 26/40] fix: amend integration test workflow (#2812) # What does this PR do? trigger integration tests on ALL changes to `tests/` to catch failures before they merge into main Signed-off-by: Charlie Doern --- .github/workflows/integration-tests.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index a5883daf7..0b6c1be3b 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -7,7 +7,7 @@ on: branches: [ main ] paths: - 'llama_stack/**' - - 'tests/integration/**' + - 'tests/**' - 'uv.lock' - 'pyproject.toml' - 'requirements.txt' From 1785a6b39c15b27f8ea34af9c379b4219a2dfc12 Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Fri, 18 Jul 2025 12:07:43 -0400 Subject: [PATCH 27/40] docs: add virtualenv instructions for running starter distro (#2780) # What does this PR do? we had directions for a container and conda but not venv Signed-off-by: Nathan Weinberg --- .../distributions/self_hosted_distro/starter.md | 13 ++++--------- 1 file changed, 4 insertions(+), 9 deletions(-) diff --git a/docs/source/distributions/self_hosted_distro/starter.md b/docs/source/distributions/self_hosted_distro/starter.md index 753746d84..56cdd5e73 100644 --- a/docs/source/distributions/self_hosted_distro/starter.md +++ b/docs/source/distributions/self_hosted_distro/starter.md @@ -167,7 +167,7 @@ When using the `:` pattern (like `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`), ## Running the Distribution -You can run the starter distribution via Docker or Conda. +You can run the starter distribution via Docker, Conda, or venv. ### Via Docker @@ -186,17 +186,12 @@ docker run \ --port $LLAMA_STACK_PORT ``` -### Via Conda +### Via Conda or venv -Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. +Ensure you have configured the starter distribution using the environment variables explained above. ```bash -llama stack build --template starter --image-type conda -llama stack run distributions/starter/run.yaml \ - --port 8321 \ - --env OPENAI_API_KEY=your_openai_key \ - --env FIREWORKS_API_KEY=your_fireworks_key \ - --env TOGETHER_API_KEY=your_together_key +uv run --with llama-stack llama stack build --template starter --image-type --run ``` ## Example Usage From e45543f7f385dd882878c511f49eb5e0337d9d1f Mon Sep 17 00:00:00 2001 From: Christian Zaccaria <73656840+ChristianZaccaria@users.noreply.github.com> Date: Fri, 18 Jul 2025 18:08:36 +0200 Subject: [PATCH 28/40] test: Measure and track code coverage (#2636) # What does this PR do? - Added coverage badge to README. - [See my fork](https://github.com/ChristianZaccaria/llama-stack) - Added a GitHub Actions workflow that runs the tests and updates the coverage badge. - [See run](https://github.com/ChristianZaccaria/llama-stack/actions/runs/16203511391/job/45748113233) - Documented steps in `testing.md` for running the tests locally, and viewing the `html` report. - Excluded non-essential files from coverage reporting to provide a more accurate measurement. Automatically created PR to update coverage badge: https://github.com/ChristianZaccaria/llama-stack/pull/9 # Note for reviewers 1. Currently the coverage report shows a 45% coverage. Wondering if there are other files or directories that should also be excluded from the report to increase the percentage. The directories with the least test coverage are `llama_stack/cli`, `llama_stack/models`, and `llama_stack/ui`. - Should we exclude these? 2. **[Required]** The `GITHUB_TOKEN` should have write permissions to open a PR to update the coverage badge. # GitHub Issue Closes #2355 ## Test Plan The `testing.md` file describes how to run the unit tests locally. --- .coveragerc | 6 +++ .github/workflows/coverage-badge.yml | 57 ++++++++++++++++++++++++++++ .github/workflows/unit-tests.yml | 2 +- README.md | 1 + coverage.svg | 21 ++++++++++ pyproject.toml | 1 + scripts/unit-tests.sh | 7 +++- tests/unit/README.md | 28 +++++++++++++- uv.lock | 2 + 9 files changed, 122 insertions(+), 3 deletions(-) create mode 100644 .github/workflows/coverage-badge.yml create mode 100644 coverage.svg diff --git a/.coveragerc b/.coveragerc index e16c2e461..d4925275f 100644 --- a/.coveragerc +++ b/.coveragerc @@ -4,3 +4,9 @@ omit = */llama_stack/providers/* */llama_stack/templates/* .venv/* + */llama_stack/cli/scripts/* + */llama_stack/ui/* + */llama_stack/distribution/ui/* + */llama_stack/strong_typing/* + */llama_stack/env.py + */__init__.py diff --git a/.github/workflows/coverage-badge.yml b/.github/workflows/coverage-badge.yml new file mode 100644 index 000000000..6b2f133dd --- /dev/null +++ b/.github/workflows/coverage-badge.yml @@ -0,0 +1,57 @@ +name: Coverage Badge + +on: + push: + branches: [ main ] + paths: + - 'llama_stack/**' + - 'tests/unit/**' + - 'uv.lock' + - 'pyproject.toml' + - 'requirements.txt' + - '.github/workflows/unit-tests.yml' + - '.github/workflows/coverage-badge.yml' # This workflow + workflow_dispatch: + +jobs: + unit-tests: + runs-on: ubuntu-latest + steps: + - name: Checkout repository + uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 + + - name: Install dependencies + uses: ./.github/actions/setup-runner + + - name: Run unit tests + run: | + ./scripts/unit-tests.sh + + - name: Coverage Badge + uses: tj-actions/coverage-badge-py@1788babcb24544eb5bbb6e0d374df5d1e54e670f # v2.0.4 + + - name: Verify Changed files + uses: tj-actions/verify-changed-files@a1c6acee9df209257a246f2cc6ae8cb6581c1edf # v20.0.4 + id: verify-changed-files + with: + files: coverage.svg + + - name: Commit files + if: steps.verify-changed-files.outputs.files_changed == 'true' + run: | + git config --local user.email "github-actions[bot]@users.noreply.github.com" + git config --local user.name "github-actions[bot]" + git add coverage.svg + git commit -m "Updated coverage.svg" + + - name: Create Pull Request + if: steps.verify-changed-files.outputs.files_changed == 'true' + uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8 + with: + token: ${{ secrets.GITHUB_TOKEN }} + title: "ci: [Automatic] Coverage Badge Update" + body: | + This PR updates the coverage badge based on the latest coverage report. + + Automatically generated by the [workflow coverage-badge.yaml](.github/workflows/coverage-badge.yaml) + delete-branch: true diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml index e29045e52..41034b45f 100644 --- a/.github/workflows/unit-tests.yml +++ b/.github/workflows/unit-tests.yml @@ -36,7 +36,7 @@ jobs: - name: Run unit tests run: | - PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }} + PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --junitxml=pytest-report-${{ matrix.python }}.xml - name: Upload test results if: always() diff --git a/README.md b/README.md index 9148ce05d..7f0fed345 100644 --- a/README.md +++ b/README.md @@ -6,6 +6,7 @@ [![Discord](https://img.shields.io/discord/1257833999603335178?color=6A7EC2&logo=discord&logoColor=ffffff)](https://discord.gg/llama-stack) [![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain) [![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain) +![coverage badge](./coverage.svg) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack) diff --git a/coverage.svg b/coverage.svg new file mode 100644 index 000000000..636889bb0 --- /dev/null +++ b/coverage.svg @@ -0,0 +1,21 @@ + + + + + + + + + + + + + + + + coverage + coverage + 44% + 44% + + diff --git a/pyproject.toml b/pyproject.toml index 4d54bece0..15e2e10b4 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -91,6 +91,7 @@ unit = [ "pymilvus>=2.5.12", "litellm", "together", + "coverage", ] # These are the core dependencies required for running integration tests. They are shared across all # providers. If a provider requires additional dependencies, please add them to your environment diff --git a/scripts/unit-tests.sh b/scripts/unit-tests.sh index 68d6458fc..458cd383d 100755 --- a/scripts/unit-tests.sh +++ b/scripts/unit-tests.sh @@ -16,4 +16,9 @@ if [ $FOUND_PYTHON -ne 0 ]; then uv python install "$PYTHON_VERSION" fi -uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest -s -v tests/unit/ $@ +# Run unit tests with coverage +uv run --python "$PYTHON_VERSION" --with-editable . --group unit \ + coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@" + +# Generate HTML coverage report +uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION diff --git a/tests/unit/README.md b/tests/unit/README.md index c95c3a0e7..06e22fb8c 100644 --- a/tests/unit/README.md +++ b/tests/unit/README.md @@ -1,9 +1,17 @@ # Llama Stack Unit Tests +## Unit Tests + +Unit tests verify individual components and functions in isolation. They are fast, reliable, and don't require external services. + +### Prerequisites + +1. **Python Environment**: Ensure you have Python 3.12+ installed +2. **uv Package Manager**: Install `uv` if not already installed + You can run the unit tests by running: ```bash -source .venv/bin/activate ./scripts/unit-tests.sh [PYTEST_ARGS] ``` @@ -19,3 +27,21 @@ If you'd like to run for a non-default version of Python (currently 3.12), pass source .venv/bin/activate PYTHON_VERSION=3.13 ./scripts/unit-tests.sh ``` + +### Test Configuration + +- **Test Discovery**: Tests are automatically discovered in the `tests/unit/` directory +- **Async Support**: Tests use `--asyncio-mode=auto` for automatic async test handling +- **Coverage**: Tests generate coverage reports in `htmlcov/` directory +- **Python Version**: Defaults to Python 3.12, but can be overridden with `PYTHON_VERSION` environment variable + +### Coverage Reports + +After running tests, you can view coverage reports: + +```bash +# Open HTML coverage report in browser +open htmlcov/index.html # macOS +xdg-open htmlcov/index.html # Linux +start htmlcov/index.html # Windows +``` diff --git a/uv.lock b/uv.lock index 7a9c5cab0..2c5197988 100644 --- a/uv.lock +++ b/uv.lock @@ -1390,6 +1390,7 @@ unit = [ { name = "aiosqlite" }, { name = "blobfile" }, { name = "chardet" }, + { name = "coverage" }, { name = "faiss-cpu" }, { name = "litellm" }, { name = "mcp" }, @@ -1499,6 +1500,7 @@ unit = [ { name = "aiosqlite" }, { name = "blobfile" }, { name = "chardet" }, + { name = "coverage" }, { name = "faiss-cpu" }, { name = "litellm" }, { name = "mcp" }, From 2bb903917305900c9238fc7672d0ebd16f58ac3e Mon Sep 17 00:00:00 2001 From: Nathan Weinberg <31703736+nathan-weinberg@users.noreply.github.com> Date: Fri, 18 Jul 2025 12:08:46 -0400 Subject: [PATCH 29/40] docs: fix steps in the Quick Start Guide (#2800) # What does this PR do? 'build' command didn't take into account ENABLE flags for starter distro for some reason, I was having issues with HuggingFace access for the embedding model, so added a tip for that as well Closes #2779 ## Test Plan I ran the described steps manually, but it would be nice if someone else could try it and verify this still works We might consider having some CI job ensure the QSG remains functional - it's not a great experience for new users if they try Llama Stack for the first time and it doesn't work as we describe Signed-off-by: Nathan Weinberg --- docs/source/getting_started/quickstart.md | 8 +++++++- 1 file changed, 7 insertions(+), 1 deletion(-) diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md index 881ddd29b..59791643d 100644 --- a/docs/source/getting_started/quickstart.md +++ b/docs/source/getting_started/quickstart.md @@ -19,7 +19,7 @@ ollama run llama3.2:3b --keepalive 60m #### Step 2: Run the Llama Stack server We will use `uv` to run the Llama Stack server. ```bash -INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run +ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run ``` #### Step 3: Run the demo Now open up a new terminal and copy the following script into a file named `demo_script.py`. @@ -111,6 +111,12 @@ Ultimately, great work is about making a meaningful contribution and leaving a l ``` Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳 +```{admonition} HuggingFace access +:class: tip + +If you are getting a **401 Client Error** from HuggingFace for the **all-MiniLM-L6-v2** model, try setting **HF_TOKEN** to a valid HuggingFace token in your environment +``` + ### Next Steps Now you're ready to dive deeper into Llama Stack! From 9e3ae503060cc685749839630a379c6e63d0eae6 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Jul 2025 10:29:19 -0700 Subject: [PATCH 30/40] feat(server): construct the stack in a persistent event loop (#2818) When we call `construct_stack()`, providers are instantiated and `initialize()` is called. This call can end up doing _anything_ at all -- specifically, providers are free to create long running background tasks as part of this. If we wrapped this within a `asyncio.run()` as in the current code, these tasks get canceled when the stack construction finishes. This is not correct. The PR addresses the issue by creating a persistent event loop which is used for both the stack as well as for running the uvicorn server. In other words, the lifetime of the providers (and downstream async code) is now the same as the lifetime of the uvicorn server. ## Test Plan This should not affect any current code since we don't have background tasks created right now. However, https://github.com/meta-llama/llama-stack/pull/2805 will start using this functionality. --- llama_stack/distribution/server/server.py | 12 ++++++++++-- 1 file changed, 10 insertions(+), 2 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index 974064b58..fb00b8384 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -455,6 +455,7 @@ def main(args: argparse.Namespace | None = None): redoc_url="/redoc", openapi_url="/openapi.json", ) + if not os.environ.get("LLAMA_STACK_DISABLE_VERSION_CHECK"): app.add_middleware(ClientVersionMiddleware) @@ -493,7 +494,13 @@ def main(args: argparse.Namespace | None = None): ) try: - impls = asyncio.run(construct_stack(config)) + # Create and set the event loop that will be used for both construction and server runtime + loop = asyncio.new_event_loop() + asyncio.set_event_loop(loop) + + # Construct the stack in the persistent event loop + impls = loop.run_until_complete(construct_stack(config)) + except InvalidProviderError as e: logger.error(f"Error: {str(e)}") sys.exit(1) @@ -591,7 +598,8 @@ def main(args: argparse.Namespace | None = None): if ssl_config: uvicorn_config.update(ssl_config) - uvicorn.run(**uvicorn_config) + # Run uvicorn in the existing event loop to preserve background tasks + loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) def extract_path_params(route: str) -> list[str]: From 15916852e826c1c9e929164d95757498aec4ef60 Mon Sep 17 00:00:00 2001 From: slekkala1 Date: Fri, 18 Jul 2025 10:33:30 -0700 Subject: [PATCH 31/40] chore: Add slekkala1 to codeowners (#2817) Getting started on LLAMA Stack --- .github/CODEOWNERS | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/CODEOWNERS b/.github/CODEOWNERS index a1eed9432..85f781a4f 100644 --- a/.github/CODEOWNERS +++ b/.github/CODEOWNERS @@ -2,4 +2,4 @@ # These owners will be the default owners for everything in # the repo. Unless a later match takes precedence, -* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf +* @ashwinb @yanxi0830 @hardikjshah @raghotham @ehhuang @terrytangyuan @leseb @bbrowning @reluctantfuturist @mattf @slekkala1 From d994305f0a30560c2d0cc9c83e9188cb7dee673c Mon Sep 17 00:00:00 2001 From: Charlie Doern Date: Fri, 18 Jul 2025 13:44:35 -0400 Subject: [PATCH 32/40] fix: remove disabled providers from model dump (#2784) # What does this PR do? currently when running `llama stack run --template starter...` the __disabled__ providers, their models, etc are printed alongside the enabled ones making the output really confusing in server.py add a utility `remove_disabled_providers` which post-processes the model_dump output to remove any dict with `provider_id: __disabled__` we also have `debug` logs printing the disabled providers, so I think its safe to say that is the only indicator we need when using starter. ## Test Plan before (output truncated because it was huge): ``` ... model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-3.2-11B-Vision-Instruct model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct - metadata: {} model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct - metadata: {} model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-3.2-90B-Vision-Instruct model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct - metadata: {} model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct - metadata: {} model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-4-Scout-17B-16E-Instruct model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct - metadata: {} model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct - metadata: {} model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-4-Maverick-17B-128E-Instruct model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct - metadata: {} model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-4-Maverick-17B-128E-Instruct model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct - metadata: {} model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Meta-Llama-Guard-3-8B model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Meta-Llama-Guard-3-8B - metadata: {} model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-Guard-3-8B model_type: llm provider_id: __disabled__ provider_model_id: sambanova/Meta-Llama-Guard-3-8B - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 model_type: embedding provider_id: sentence-transformers provider_model_id: null providers: agents: - config: persistence_store: db_path: /Users/charliedoern/.llama/distributions/starter/agents_store.db type: sqlite responses_store: db_path: /Users/charliedoern/.llama/distributions/starter/responses_store.db type: sqlite provider_id: meta-reference provider_type: inline::meta-reference datasetio: - config: kvstore: db_path: /Users/charliedoern/.llama/distributions/starter/huggingface_datasetio.db type: sqlite provider_id: huggingface provider_type: remote::huggingface - config: kvstore: db_path: /Users/charliedoern/.llama/distributions/starter/localfs_datasetio.db type: sqlite provider_id: localfs provider_type: inline::localfs eval: - config: kvstore: db_path: /Users/charliedoern/.llama/distributions/starter/meta_reference_eval.db type: sqlite provider_id: meta-reference provider_type: inline::meta-reference files: - config: metadata_store: db_path: /Users/charliedoern/.llama/distributions/starter/files_metadata.db type: sqlite storage_dir: /Users/charliedoern/.llama/distributions/starter/files provider_id: meta-reference-files provider_type: inline::localfs inference: - config: api_key: '********' base_url: https://api.cerebras.ai provider_id: __disabled__ provider_type: remote::cerebras - config: url: http://localhost:11434 provider_id: ollama provider_type: remote::ollama - config: api_token: '********' max_tokens: ${env.VLLM_MAX_TOKENS:=4096} tls_verify: ${env.VLLM_TLS_VERIFY:=true} url: ${env.VLLM_URL} provider_id: __disabled__ provider_type: remote::vllm - config: url: ${env.TGI_URL} provider_id: __disabled__ provider_type: remote::tgi - config: api_token: '********' huggingface_repo: ${env.INFERENCE_MODEL} provider_id: __disabled__ provider_type: remote::hf::serverless - config: api_token: '********' endpoint_name: ${env.INFERENCE_ENDPOINT_NAME} provider_id: __disabled__ provider_type: remote::hf::endpoint - config: api_key: '********' url: https://api.fireworks.ai/inference/v1 provider_id: __disabled__ provider_type: remote::fireworks - config: api_key: '********' url: https://api.together.xyz/v1 provider_id: __disabled__ provider_type: remote::together - config: {} provider_id: __disabled__ provider_type: remote::bedrock - config: api_token: '********' url: ${env.DATABRICKS_URL} provider_id: __disabled__ provider_type: remote::databricks - config: api_key: '********' append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True} url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} provider_id: __disabled__ provider_type: remote::nvidia - config: api_token: '********' url: ${env.RUNPOD_URL:=} provider_id: __disabled__ provider_type: remote::runpod - config: api_key: '********' provider_id: __disabled__ provider_type: remote::openai - config: api_key: '********' provider_id: __disabled__ provider_type: remote::anthropic - config: api_key: '********' provider_id: __disabled__ provider_type: remote::gemini - config: api_key: '********' url: https://api.groq.com provider_id: __disabled__ provider_type: remote::groq - config: api_key: '********' openai_compat_api_base: https://api.fireworks.ai/inference/v1 provider_id: __disabled__ provider_type: remote::fireworks-openai-compat - config: api_key: '********' openai_compat_api_base: https://api.llama.com/compat/v1/ provider_id: __disabled__ provider_type: remote::llama-openai-compat - config: api_key: '********' openai_compat_api_base: https://api.together.xyz/v1 provider_id: __disabled__ provider_type: remote::together-openai-compat - config: api_key: '********' openai_compat_api_base: https://api.groq.com/openai/v1 provider_id: __disabled__ provider_type: remote::groq-openai-compat - config: api_key: '********' openai_compat_api_base: https://api.sambanova.ai/v1 provider_id: __disabled__ provider_type: remote::sambanova-openai-compat - config: api_key: '********' openai_compat_api_base: https://api.cerebras.ai/v1 provider_id: __disabled__ provider_type: remote::cerebras-openai-compat - config: api_key: '********' url: https://api.sambanova.ai/v1 provider_id: __disabled__ provider_type: remote::sambanova - config: api_key: '********' url: ${env.PASSTHROUGH_URL} provider_id: __disabled__ provider_type: remote::passthrough - config: {} provider_id: sentence-transformers provider_type: inline::sentence-transformers post_training: - config: checkpoint_format: huggingface device: cpu distributed_backend: null provider_id: huggingface provider_type: inline::huggingface safety: - config: excluded_categories: [] provider_id: llama-guard provider_type: inline::llama-guard scoring: - config: {} provider_id: basic provider_type: inline::basic - config: {} provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: openai_api_key: '********' provider_id: braintrust provider_type: inline::braintrust telemetry: - config: otel_exporter_otlp_endpoint: null service_name: "\u200B" sinks: console,sqlite sqlite_db_path: /Users/charliedoern/.llama/distributions/starter/trace_store.db provider_id: meta-reference provider_type: inline::meta-reference tool_runtime: - config: api_key: '********' max_results: 3 provider_id: brave-search provider_type: remote::brave-search - config: api_key: '********' max_results: 3 provider_id: tavily-search provider_type: remote::tavily-search - config: {} provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} provider_id: model-context-protocol provider_type: remote::model-context-protocol vector_io: - config: kvstore: db_path: /Users/charliedoern/.llama/distributions/starter/faiss_store.db type: sqlite provider_id: faiss provider_type: inline::faiss - config: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db kvstore: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db type: sqlite provider_id: __disabled__ provider_type: inline::sqlite-vec - config: db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db kvstore: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db type: sqlite provider_id: __disabled__ provider_type: inline::milvus - config: url: ${env.CHROMADB_URL:=} provider_id: __disabled__ provider_type: remote::chromadb - config: db: ${env.PGVECTOR_DB:=} host: ${env.PGVECTOR_HOST:=localhost} kvstore: db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db type: sqlite password: '********' port: ${env.PGVECTOR_PORT:=5432} user: ${env.PGVECTOR_USER:=} provider_id: __disabled__ provider_type: remote::pgvector scoring_fns: [] server: auth: null host: null port: 8321 quota: null tls_cafile: null tls_certfile: null tls_keyfile: null shields: - params: null provider_id: null provider_shield_id: ollama/__disabled__ shield_id: __disabled__ tool_groups: - args: null mcp_endpoint: null provider_id: tavily-search toolgroup_id: builtin::websearch - args: null mcp_endpoint: null provider_id: rag-runtime toolgroup_id: builtin::rag vector_dbs: [] version: 2 ``` after: ``` INFO 2025-07-16 13:00:32,604 __main__:448 server: Run configuration: INFO 2025-07-16 13:00:32,606 __main__:450 server: apis: - agents - datasetio - eval - files - inference - post_training - safety - scoring - telemetry - tool_runtime - vector_io benchmarks: [] datasets: [] image_name: starter inference_store: db_path: /Users/charliedoern/.llama/distributions/starter/inference_store.db type: sqlite metadata_store: db_path: /Users/charliedoern/.llama/distributions/starter/registry.db type: sqlite models: - metadata: {} model_id: ollama/llama3.2:3b model_type: llm provider_id: ollama provider_model_id: llama3.2:3b - metadata: embedding_dimension: 384 model_id: all-MiniLM-L6-v2 model_type: embedding provider_id: sentence-transformers providers: agents: - config: persistence_store: db_path: /Users/charliedoern/.llama/distributions/starter/agents_store.db type: sqlite responses_store: db_path: /Users/charliedoern/.llama/distributions/starter/responses_store.db type: sqlite provider_id: meta-reference provider_type: inline::meta-reference datasetio: - config: kvstore: db_path: /Users/charliedoern/.llama/distributions/starter/huggingface_datasetio.db type: sqlite provider_id: huggingface provider_type: remote::huggingface - config: kvstore: db_path: /Users/charliedoern/.llama/distributions/starter/localfs_datasetio.db type: sqlite provider_id: localfs provider_type: inline::localfs eval: - config: kvstore: db_path: /Users/charliedoern/.llama/distributions/starter/meta_reference_eval.db type: sqlite provider_id: meta-reference provider_type: inline::meta-reference files: - config: metadata_store: db_path: /Users/charliedoern/.llama/distributions/starter/files_metadata.db type: sqlite storage_dir: /Users/charliedoern/.llama/distributions/starter/files provider_id: meta-reference-files provider_type: inline::localfs inference: - config: url: http://localhost:11434 provider_id: ollama provider_type: remote::ollama - config: {} provider_id: sentence-transformers provider_type: inline::sentence-transformers post_training: - config: checkpoint_format: huggingface device: cpu provider_id: huggingface provider_type: inline::huggingface safety: - config: excluded_categories: [] provider_id: llama-guard provider_type: inline::llama-guard scoring: - config: {} provider_id: basic provider_type: inline::basic - config: {} provider_id: llm-as-judge provider_type: inline::llm-as-judge - config: openai_api_key: '********' provider_id: braintrust provider_type: inline::braintrust telemetry: - config: service_name: "\u200B" sinks: console,sqlite sqlite_db_path: /Users/charliedoern/.llama/distributions/starter/trace_store.db provider_id: meta-reference provider_type: inline::meta-reference tool_runtime: - config: api_key: '********' max_results: 3 provider_id: brave-search provider_type: remote::brave-search - config: api_key: '********' max_results: 3 provider_id: tavily-search provider_type: remote::tavily-search - config: {} provider_id: rag-runtime provider_type: inline::rag-runtime - config: {} provider_id: model-context-protocol provider_type: remote::model-context-protocol vector_io: - config: kvstore: db_path: /Users/charliedoern/.llama/distributions/starter/faiss_store.db type: sqlite provider_id: faiss provider_type: inline::faiss scoring_fns: [] server: port: 8321 shields: [] tool_groups: - provider_id: tavily-search toolgroup_id: builtin::websearch - provider_id: rag-runtime toolgroup_id: builtin::rag vector_dbs: [] version: 2 ``` Signed-off-by: Charlie Doern --- llama_stack/distribution/server/server.py | 27 ++++++++++++++++++++--- llama_stack/distribution/stack.py | 1 - 2 files changed, 24 insertions(+), 4 deletions(-) diff --git a/llama_stack/distribution/server/server.py b/llama_stack/distribution/server/server.py index fb00b8384..e7e9e5e88 100644 --- a/llama_stack/distribution/server/server.py +++ b/llama_stack/distribution/server/server.py @@ -445,9 +445,7 @@ def main(args: argparse.Namespace | None = None): # now that the logger is initialized, print the line about which type of config we are using. logger.info(log_line) - logger.info("Run configuration:") - safe_config = redact_sensitive_fields(config.model_dump(mode="json")) - logger.info(yaml.dump(safe_config, indent=2)) + _log_run_config(run_config=config) app = FastAPI( lifespan=lifespan, @@ -602,6 +600,14 @@ def main(args: argparse.Namespace | None = None): loop.run_until_complete(uvicorn.Server(uvicorn.Config(**uvicorn_config)).serve()) +def _log_run_config(run_config: StackRunConfig): + """Logs the run config with redacted fields and disabled providers removed.""" + logger.info("Run configuration:") + safe_config = redact_sensitive_fields(run_config.model_dump(mode="json")) + clean_config = remove_disabled_providers(safe_config) + logger.info(yaml.dump(clean_config, indent=2)) + + def extract_path_params(route: str) -> list[str]: segments = route.split("/") params = [seg[1:-1] for seg in segments if seg.startswith("{") and seg.endswith("}")] @@ -610,5 +616,20 @@ def extract_path_params(route: str) -> list[str]: return params +def remove_disabled_providers(obj): + if isinstance(obj, dict): + if ( + obj.get("provider_id") == "__disabled__" + or obj.get("shield_id") == "__disabled__" + or obj.get("provider_model_id") == "__disabled__" + ): + return None + return {k: v for k, v in ((k, remove_disabled_providers(v)) for k, v in obj.items()) if v is not None} + elif isinstance(obj, list): + return [item for item in (remove_disabled_providers(i) for i in obj) if item is not None] + else: + return obj + + if __name__ == "__main__": main() diff --git a/llama_stack/distribution/stack.py b/llama_stack/distribution/stack.py index 98634d8c9..d7270156a 100644 --- a/llama_stack/distribution/stack.py +++ b/llama_stack/distribution/stack.py @@ -172,7 +172,6 @@ def replace_env_vars(config: Any, path: str = "") -> Any: # Create a copy with resolved provider_id but original config disabled_provider = v.copy() disabled_provider["provider_id"] = resolved_provider_id - result.append(disabled_provider) continue except EnvVarError: # If we can't resolve the provider_id, continue with normal processing From 874b1cb00f5d3190b2e23a1edaca95b59ee98320 Mon Sep 17 00:00:00 2001 From: Nehanth Narendrula Date: Fri, 18 Jul 2025 14:56:00 -0400 Subject: [PATCH 33/40] fix: DPOAlignmentConfig schema to use correct DPO parameters (#2804) # What does this PR do? This PR fixes the `DPOAlignmentConfig` schema to use the correct Direct Preference Optimization (DPO) parameters. The current schema incorrectly uses PPO-inspired parameters (`reward_scale`, `reward_clip`, `epsilon`, `gamma`) that are not part of the DPO algorithm. This PR updates it to use the standard DPO parameters: - `beta`: The KL divergence coefficient that controls deviation from the reference model - `loss_type`: The type of DPO loss function (sigmoid, hinge, ipo, kto_pair) These parameters align with standard DPO implementations like HuggingFace's TRL library. --------- Co-authored-by: Ubuntu --- docs/_static/llama-stack-spec.html | 29 ++++++++++--------- docs/_static/llama-stack-spec.yaml | 25 +++++++++------- .../apis/post_training/post_training.py | 14 ++++++--- 3 files changed, 40 insertions(+), 28 deletions(-) diff --git a/docs/_static/llama-stack-spec.html b/docs/_static/llama-stack-spec.html index db5c57821..d7801ba1c 100644 --- a/docs/_static/llama-stack-spec.html +++ b/docs/_static/llama-stack-spec.html @@ -14470,28 +14470,31 @@ "DPOAlignmentConfig": { "type": "object", "properties": { - "reward_scale": { + "beta": { "type": "number" }, - "reward_clip": { - "type": "number" - }, - "epsilon": { - "type": "number" - }, - "gamma": { - "type": "number" + "loss_type": { + "$ref": "#/components/schemas/DPOLossType", + "default": "sigmoid" } }, "additionalProperties": false, "required": [ - "reward_scale", - "reward_clip", - "epsilon", - "gamma" + "beta", + "loss_type" ], "title": "DPOAlignmentConfig" }, + "DPOLossType": { + "type": "string", + "enum": [ + "sigmoid", + "hinge", + "ipo", + "kto_pair" + ], + "title": "DPOLossType" + }, "DataConfig": { "type": "object", "properties": { diff --git a/docs/_static/llama-stack-spec.yaml b/docs/_static/llama-stack-spec.yaml index 29ba9dede..be02e1e42 100644 --- a/docs/_static/llama-stack-spec.yaml +++ b/docs/_static/llama-stack-spec.yaml @@ -10111,21 +10111,24 @@ components: DPOAlignmentConfig: type: object properties: - reward_scale: - type: number - reward_clip: - type: number - epsilon: - type: number - gamma: + beta: type: number + loss_type: + $ref: '#/components/schemas/DPOLossType' + default: sigmoid additionalProperties: false required: - - reward_scale - - reward_clip - - epsilon - - gamma + - beta + - loss_type title: DPOAlignmentConfig + DPOLossType: + type: string + enum: + - sigmoid + - hinge + - ipo + - kto_pair + title: DPOLossType DataConfig: type: object properties: diff --git a/llama_stack/apis/post_training/post_training.py b/llama_stack/apis/post_training/post_training.py index b196c8a17..f6860ea4b 100644 --- a/llama_stack/apis/post_training/post_training.py +++ b/llama_stack/apis/post_training/post_training.py @@ -104,12 +104,18 @@ class RLHFAlgorithm(Enum): dpo = "dpo" +@json_schema_type +class DPOLossType(Enum): + sigmoid = "sigmoid" + hinge = "hinge" + ipo = "ipo" + kto_pair = "kto_pair" + + @json_schema_type class DPOAlignmentConfig(BaseModel): - reward_scale: float - reward_clip: float - epsilon: float - gamma: float + beta: float + loss_type: DPOLossType = DPOLossType.sigmoid @json_schema_type From 6d55f2f137c30d48330a994fcfb90ad23de386ed Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 18 Jul 2025 12:10:30 -0700 Subject: [PATCH 34/40] feat: enable ls client for files tests (#2769) # What does this PR do? titled ## Test Plan CI --- llama_stack/distribution/library_client.py | 71 ++++++++++++++++++++-- tests/integration/files/test_files.py | 16 +++-- tests/integration/fixtures/common.py | 5 ++ 3 files changed, 83 insertions(+), 9 deletions(-) diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index cebfabba5..6c51dc2c7 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -12,11 +12,13 @@ import os import sys from concurrent.futures import ThreadPoolExecutor from enum import Enum +from io import BytesIO from pathlib import Path from typing import Any, TypeVar, Union, get_args, get_origin import httpx import yaml +from fastapi import Response as FastAPIResponse from llama_stack_client import ( NOT_GIVEN, APIResponse, @@ -112,6 +114,27 @@ def convert_to_pydantic(annotation: Any, value: Any) -> Any: raise ValueError(f"Failed to convert parameter {value} into {annotation}: {e}") from e +class LibraryClientUploadFile: + """LibraryClient UploadFile object that mimics FastAPI's UploadFile interface.""" + + def __init__(self, filename: str, content: bytes): + self.filename = filename + self.content = content + self.content_type = "application/octet-stream" + + async def read(self) -> bytes: + return self.content + + +class LibraryClientHttpxResponse: + """LibraryClient httpx Response object for FastAPI Response conversion.""" + + def __init__(self, response): + self.content = response.body if isinstance(response.body, bytes) else response.body.encode() + self.status_code = response.status_code + self.headers = response.headers + + class LlamaStackAsLibraryClient(LlamaStackClient): def __init__( self, @@ -295,6 +318,31 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return response + def _handle_file_uploads(self, options: Any, body: dict) -> tuple[dict, list[str]]: + """Handle file uploads from OpenAI client and add them to the request body.""" + if not (hasattr(options, "files") and options.files): + return body, [] + + if not isinstance(options.files, list): + return body, [] + + field_names = [] + for file_tuple in options.files: + if not (isinstance(file_tuple, tuple) and len(file_tuple) >= 2): + continue + + field_name = file_tuple[0] + file_object = file_tuple[1] + + if isinstance(file_object, BytesIO): + file_object.seek(0) + file_content = file_object.read() + filename = getattr(file_object, "name", "uploaded_file") + field_names.append(field_name) + body[field_name] = LibraryClientUploadFile(filename, file_content) + + return body, field_names + async def _call_non_streaming( self, *, @@ -310,15 +358,23 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): matched_func, path_params, route = find_matching_route(options.method, path, self.route_impls) body |= path_params - body = self._convert_body(path, options.method, body) + + body, field_names = self._handle_file_uploads(options, body) + + body = self._convert_body(path, options.method, body, exclude_params=set(field_names)) await start_trace(route, {"__location__": "library_client"}) try: result = await matched_func(**body) finally: await end_trace() + # Handle FastAPI Response objects (e.g., from file content retrieval) + if isinstance(result, FastAPIResponse): + return LibraryClientHttpxResponse(result) + json_content = json.dumps(convert_pydantic_to_json_value(result)) + filtered_body = {k: v for k, v in body.items() if not isinstance(v, LibraryClientUploadFile)} mock_response = httpx.Response( status_code=httpx.codes.OK, content=json_content.encode("utf-8"), @@ -330,7 +386,7 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): url=options.url, params=options.params, headers=options.headers or {}, - json=convert_pydantic_to_json_value(body), + json=convert_pydantic_to_json_value(filtered_body), ), ) response = APIResponse( @@ -404,13 +460,17 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): ) return await response.parse() - def _convert_body(self, path: str, method: str, body: dict | None = None) -> dict: + def _convert_body( + self, path: str, method: str, body: dict | None = None, exclude_params: set[str] | None = None + ) -> dict: if not body: return {} if self.route_impls is None: raise ValueError("Client not initialized") + exclude_params = exclude_params or set() + func, _, _ = find_matching_route(method, path, self.route_impls) sig = inspect.signature(func) @@ -422,6 +482,9 @@ class AsyncLlamaStackAsLibraryClient(AsyncLlamaStackClient): for param_name, param in sig.parameters.items(): if param_name in body: value = body.get(param_name) - converted_body[param_name] = convert_to_pydantic(param.annotation, value) + if param_name in exclude_params: + converted_body[param_name] = value + else: + converted_body[param_name] = convert_to_pydantic(param.annotation, value) return converted_body diff --git a/tests/integration/files/test_files.py b/tests/integration/files/test_files.py index 8375507dc..8547ef2f3 100644 --- a/tests/integration/files/test_files.py +++ b/tests/integration/files/test_files.py @@ -7,15 +7,16 @@ from io import BytesIO import pytest +from openai import OpenAI from llama_stack.distribution.library_client import LlamaStackAsLibraryClient -def test_openai_client_basic_operations(openai_client, client_with_models): +def test_openai_client_basic_operations(compat_client, client_with_models): """Test basic file operations through OpenAI client.""" - if isinstance(client_with_models, LlamaStackAsLibraryClient): - pytest.skip("OpenAI files are not supported when testing with library client yet.") - client = openai_client + if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI): + pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient") + client = compat_client test_content = b"files test content" @@ -41,7 +42,12 @@ def test_openai_client_basic_operations(openai_client, client_with_models): # Retrieve file content - OpenAI client returns httpx Response object content_response = client.files.content(uploaded_file.id) # The response is an httpx Response object with .content attribute containing bytes - content = content_response.content + if isinstance(content_response, str): + # Llama Stack Client returns a str + # TODO: fix Llama Stack Client + content = bytes(content_response, "utf-8") + else: + content = content_response.content assert content == test_content # Delete file diff --git a/tests/integration/fixtures/common.py b/tests/integration/fixtures/common.py index 749793b64..f6b5b3026 100644 --- a/tests/integration/fixtures/common.py +++ b/tests/integration/fixtures/common.py @@ -257,6 +257,11 @@ def openai_client(client_with_models): return OpenAI(base_url=base_url, api_key="fake") +@pytest.fixture(params=["openai_client", "llama_stack_client"]) +def compat_client(request): + return request.getfixturevalue(request.param) + + @pytest.fixture(scope="session", autouse=True) def cleanup_server_process(request): """Cleanup server process at the end of the test session.""" From 68a2dfbad72f29df78830e98daa9a05588715473 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Jul 2025 12:20:36 -0700 Subject: [PATCH 35/40] feat(ollama): periodically refresh models (#2805) For self-hosted providers like Ollama (or vLLM), the backing server is running a set of models. That server should be treated as the source of truth and the Stack registry should just be a cache for those models. Of course, in production environments, you may not want this (because you know what model you are running statically) hence there's a config boolean to control this behavior. _This is part of a series of PRs aimed at removing the requirement of needing to set `INFERENCE_MODEL` env variables for running Llama Stack server._ ## Test Plan Copy and modify the starter.yaml template / config and enable `refresh_models: true, refresh_models_interval: 10` for the ollama provider. Then, run: ``` LLAMA_STACK_LOGGING=all=debug \ ENABLE_OLLAMA=ollama uv run llama stack run --image-type venv /tmp/starter.yaml ``` See a gargantuan amount of logs, but verify that the provider is periodically refreshing models. Stop and prune a model from ollama server, restart the server. Verify that the model goes away when I call `uv run llama-stack-client models list` --- .../providers/inference/remote_ollama.md | 2 + llama_stack/apis/inference/inference.py | 6 ++ llama_stack/distribution/library_client.py | 11 +-- .../distribution/routing_tables/models.py | 31 +++++++ .../remote/inference/ollama/config.py | 4 +- .../remote/inference/ollama/ollama.py | 85 +++++++++++++++++-- 6 files changed, 123 insertions(+), 16 deletions(-) diff --git a/docs/source/providers/inference/remote_ollama.md b/docs/source/providers/inference/remote_ollama.md index fcb44c072..23b8f87a2 100644 --- a/docs/source/providers/inference/remote_ollama.md +++ b/docs/source/providers/inference/remote_ollama.md @@ -9,6 +9,8 @@ Ollama inference provider for running local models through the Ollama runtime. | Field | Type | Required | Default | Description | |-------|------|----------|---------|-------------| | `url` | `` | No | http://localhost:11434 | | +| `refresh_models` | `` | No | False | refresh and re-register models periodically | +| `refresh_models_interval` | `` | No | 300 | interval in seconds to refresh models | ## Sample Configuration diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 222099064..26de04b68 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -819,6 +819,12 @@ class OpenAIEmbeddingsResponse(BaseModel): class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... + async def update_registered_models( + self, + provider_id: str, + models: list[Model], + ) -> None: ... + class TextTruncation(Enum): """Config for how to truncate text for embedding when text is longer than the model's max sequence length. Start and End semantics depend on whether the language is left-to-right or right-to-left. diff --git a/llama_stack/distribution/library_client.py b/llama_stack/distribution/library_client.py index 6c51dc2c7..5dc0078d4 100644 --- a/llama_stack/distribution/library_client.py +++ b/llama_stack/distribution/library_client.py @@ -151,6 +151,8 @@ class LlamaStackAsLibraryClient(LlamaStackClient): self.skip_logger_removal = skip_logger_removal self.provider_data = provider_data + self.loop = asyncio.new_event_loop() + def initialize(self): if in_notebook(): import nest_asyncio @@ -159,7 +161,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): if not self.skip_logger_removal: self._remove_root_logger_handlers() - return asyncio.run(self.async_client.initialize()) + return self.loop.run_until_complete(self.async_client.initialize()) def _remove_root_logger_handlers(self): """ @@ -172,10 +174,7 @@ class LlamaStackAsLibraryClient(LlamaStackClient): logger.info(f"Removed handler {handler.__class__.__name__} from root logger") def request(self, *args, **kwargs): - # NOTE: We are using AsyncLlamaStackClient under the hood - # A new event loop is needed to convert the AsyncStream - # from async client into SyncStream return type for streaming - loop = asyncio.new_event_loop() + loop = self.loop asyncio.set_event_loop(loop) if kwargs.get("stream"): @@ -192,7 +191,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient): pending = asyncio.all_tasks(loop) if pending: loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - loop.close() return sync_generator() else: @@ -202,7 +200,6 @@ class LlamaStackAsLibraryClient(LlamaStackClient): pending = asyncio.all_tasks(loop) if pending: loop.run_until_complete(asyncio.gather(*pending, return_exceptions=True)) - loop.close() return result diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index c6a10ea9b..90f8afa1c 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -80,3 +80,34 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): if existing_model is None: raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) + + async def update_registered_models( + self, + provider_id: str, + models: list[Model], + ) -> None: + existing_models = await self.get_all_with_type("model") + + # we may have an alias for the model registered by the user (or during initialization + # from run.yaml) that we need to keep track of + model_ids = {} + for model in existing_models: + if model.provider_id == provider_id: + model_ids[model.provider_resource_id] = model.identifier + logger.debug(f"unregistering model {model.identifier}") + await self.unregister_object(model) + + for model in models: + if model.provider_resource_id in model_ids: + model.identifier = model_ids[model.provider_resource_id] + + logger.debug(f"registering model {model.identifier} ({model.provider_resource_id})") + await self.register_object( + ModelWithOwner( + identifier=model.identifier, + provider_resource_id=model.provider_resource_id, + provider_id=provider_id, + metadata=model.metadata, + model_type=model.model_type, + ) + ) diff --git a/llama_stack/providers/remote/inference/ollama/config.py b/llama_stack/providers/remote/inference/ollama/config.py index 0145810a8..ae261f47c 100644 --- a/llama_stack/providers/remote/inference/ollama/config.py +++ b/llama_stack/providers/remote/inference/ollama/config.py @@ -6,13 +6,15 @@ from typing import Any -from pydantic import BaseModel +from pydantic import BaseModel, Field DEFAULT_OLLAMA_URL = "http://localhost:11434" class OllamaImplConfig(BaseModel): url: str = DEFAULT_OLLAMA_URL + refresh_models: bool = Field(default=False, description="refresh and re-register models periodically") + refresh_models_interval: int = Field(default=300, description="interval in seconds to refresh models") @classmethod def sample_run_config(cls, url: str = "${env.OLLAMA_URL:=http://localhost:11434}", **kwargs) -> dict[str, Any]: diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index 010e346bd..a1f7743d5 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -5,6 +5,7 @@ # the root directory of this source tree. +import asyncio import base64 import uuid from collections.abc import AsyncGenerator, AsyncIterator @@ -91,23 +92,88 @@ class OllamaInferenceAdapter( InferenceProvider, ModelsProtocolPrivate, ): + # automatically set by the resolver when instantiating the provider + __provider_id__: str + def __init__(self, config: OllamaImplConfig) -> None: self.register_helper = ModelRegistryHelper(MODEL_ENTRIES) - self.url = config.url + self.config = config + self._client = None + self._openai_client = None @property def client(self) -> AsyncClient: - return AsyncClient(host=self.url) + if self._client is None: + self._client = AsyncClient(host=self.config.url) + return self._client @property def openai_client(self) -> AsyncOpenAI: - return AsyncOpenAI(base_url=f"{self.url}/v1", api_key="ollama") + if self._openai_client is None: + self._openai_client = AsyncOpenAI(base_url=f"{self.config.url}/v1", api_key="ollama") + return self._openai_client async def initialize(self) -> None: - logger.debug(f"checking connectivity to Ollama at `{self.url}`...") + logger.info(f"checking connectivity to Ollama at `{self.config.url}`...") health_response = await self.health() if health_response["status"] == HealthStatus.ERROR: - raise RuntimeError("Ollama Server is not running, start it using `ollama serve` in a separate terminal") + logger.warning( + "Ollama Server is not running, make sure to start it using `ollama serve` in a separate terminal" + ) + + if self.config.refresh_models: + logger.debug("ollama starting background model refresh task") + self._refresh_task = asyncio.create_task(self._refresh_models()) + + def cb(task): + if task.cancelled(): + import traceback + + logger.error(f"ollama background refresh task canceled:\n{''.join(traceback.format_stack())}") + elif task.exception(): + logger.error(f"ollama background refresh task died: {task.exception()}") + else: + logger.error("ollama background refresh task completed unexpectedly") + + self._refresh_task.add_done_callback(cb) + + async def _refresh_models(self) -> None: + # Wait for model store to be available (with timeout) + waited_time = 0 + while not self.model_store and waited_time < 60: + await asyncio.sleep(1) + waited_time += 1 + + if not self.model_store: + raise ValueError("Model store not set after waiting 60 seconds") + + provider_id = self.__provider_id__ + while True: + try: + response = await self.client.list() + except Exception as e: + logger.warning(f"Failed to list models: {str(e)}") + await asyncio.sleep(self.config.refresh_models_interval) + continue + + models = [] + for m in response.models: + model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm + # unfortunately, ollama does not provide embedding dimension in the model list :( + # we should likely add a hard-coded mapping of model name to embedding dimension + models.append( + Model( + identifier=m.model, + provider_resource_id=m.model, + provider_id=provider_id, + metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {}, + model_type=model_type, + ) + ) + await self.model_store.update_registered_models(provider_id, models) + logger.debug(f"ollama refreshed model list ({len(models)} models)") + + await asyncio.sleep(self.config.refresh_models_interval) async def health(self) -> HealthResponse: """ @@ -124,7 +190,12 @@ class OllamaInferenceAdapter( return HealthResponse(status=HealthStatus.ERROR, message=f"Health check failed: {str(e)}") async def shutdown(self) -> None: - pass + if hasattr(self, "_refresh_task") and not self._refresh_task.done(): + logger.debug("ollama cancelling background refresh task") + self._refresh_task.cancel() + + self._client = None + self._openai_client = None async def unregister_model(self, model_id: str) -> None: pass @@ -354,8 +425,6 @@ class OllamaInferenceAdapter( raise ValueError("Model provider_resource_id cannot be None") if model.model_type == ModelType.embedding: - logger.info(f"Pulling embedding model `{model.provider_resource_id}` if necessary...") - # TODO: you should pull here only if the model is not found in a list response = await self.client.list() if model.provider_resource_id not in [m.model for m in response.models]: await self.client.pull(model.provider_resource_id) From ade075152ebd3ce650668f78f67aa970f440a724 Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Jul 2025 15:52:18 -0700 Subject: [PATCH 36/40] chore: kill inline::vllm (#2824) Inline _inference_ providers haven't proved to be very useful -- they are rarely used. And for good reason -- it is almost never a good idea to include a complex (distributed) inference engine bundled into a distributed stateful front-end server serving many other things. Responsibility should be split properly. See Discord discussion: https://discord.com/channels/1257833999603335178/1347279619241414726/1395849853619408957 --- docs/source/providers/inference/index.md | 1 - .../source/providers/inference/inline_vllm.md | 29 - .../inline/inference/vllm/__init__.py | 17 - .../providers/inline/inference/vllm/config.py | 53 -- .../inline/inference/vllm/openai_utils.py | 170 ---- .../providers/inline/inference/vllm/vllm.py | 811 ------------------ llama_stack/providers/registry/inference.py | 10 - llama_stack/templates/vllm-gpu/__init__.py | 7 - llama_stack/templates/vllm-gpu/build.yaml | 35 - llama_stack/templates/vllm-gpu/run.yaml | 132 --- llama_stack/templates/vllm-gpu/vllm.py | 122 --- pyproject.toml | 1 - 12 files changed, 1388 deletions(-) delete mode 100644 docs/source/providers/inference/inline_vllm.md delete mode 100644 llama_stack/providers/inline/inference/vllm/__init__.py delete mode 100644 llama_stack/providers/inline/inference/vllm/config.py delete mode 100644 llama_stack/providers/inline/inference/vllm/openai_utils.py delete mode 100644 llama_stack/providers/inline/inference/vllm/vllm.py delete mode 100644 llama_stack/templates/vllm-gpu/__init__.py delete mode 100644 llama_stack/templates/vllm-gpu/build.yaml delete mode 100644 llama_stack/templates/vllm-gpu/run.yaml delete mode 100644 llama_stack/templates/vllm-gpu/vllm.py diff --git a/docs/source/providers/inference/index.md b/docs/source/providers/inference/index.md index 05773efce..6582e08de 100644 --- a/docs/source/providers/inference/index.md +++ b/docs/source/providers/inference/index.md @@ -4,7 +4,6 @@ This section contains documentation for all available providers for the **infere - [inline::meta-reference](inline_meta-reference.md) - [inline::sentence-transformers](inline_sentence-transformers.md) -- [inline::vllm](inline_vllm.md) - [remote::anthropic](remote_anthropic.md) - [remote::bedrock](remote_bedrock.md) - [remote::cerebras](remote_cerebras.md) diff --git a/docs/source/providers/inference/inline_vllm.md b/docs/source/providers/inference/inline_vllm.md deleted file mode 100644 index 6ea34acb8..000000000 --- a/docs/source/providers/inference/inline_vllm.md +++ /dev/null @@ -1,29 +0,0 @@ -# inline::vllm - -## Description - -vLLM inference provider for high-performance model serving with PagedAttention and continuous batching. - -## Configuration - -| Field | Type | Required | Default | Description | -|-------|------|----------|---------|-------------| -| `tensor_parallel_size` | `` | No | 1 | Number of tensor parallel replicas (number of GPUs to use). | -| `max_tokens` | `` | No | 4096 | Maximum number of tokens to generate. | -| `max_model_len` | `` | No | 4096 | Maximum context length to use during serving. | -| `max_num_seqs` | `` | No | 4 | Maximum parallel batch size for generation. | -| `enforce_eager` | `` | No | False | Whether to use eager mode for inference (otherwise cuda graphs are used). | -| `gpu_memory_utilization` | `` | No | 0.3 | How much GPU memory will be allocated when this provider has finished loading, including memory that was already allocated before loading. | - -## Sample Configuration - -```yaml -tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:=1} -max_tokens: ${env.MAX_TOKENS:=4096} -max_model_len: ${env.MAX_MODEL_LEN:=4096} -max_num_seqs: ${env.MAX_NUM_SEQS:=4} -enforce_eager: ${env.ENFORCE_EAGER:=False} -gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:=0.3} - -``` - diff --git a/llama_stack/providers/inline/inference/vllm/__init__.py b/llama_stack/providers/inline/inference/vllm/__init__.py deleted file mode 100644 index d0ec3e084..000000000 --- a/llama_stack/providers/inline/inference/vllm/__init__.py +++ /dev/null @@ -1,17 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from .config import VLLMConfig - - -async def get_provider_impl(config: VLLMConfig, _deps: dict[str, Any]): - from .vllm import VLLMInferenceImpl - - impl = VLLMInferenceImpl(config) - await impl.initialize() - return impl diff --git a/llama_stack/providers/inline/inference/vllm/config.py b/llama_stack/providers/inline/inference/vllm/config.py deleted file mode 100644 index 660ef206b..000000000 --- a/llama_stack/providers/inline/inference/vllm/config.py +++ /dev/null @@ -1,53 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from typing import Any - -from pydantic import BaseModel, Field - -from llama_stack.schema_utils import json_schema_type - - -@json_schema_type -class VLLMConfig(BaseModel): - """Configuration for the vLLM inference provider. - - Note that the model name is no longer part of this static configuration. - You can bind an instance of this provider to a specific model with the - ``models.register()`` API call.""" - - tensor_parallel_size: int = Field( - default=1, - description="Number of tensor parallel replicas (number of GPUs to use).", - ) - max_tokens: int = Field( - default=4096, - description="Maximum number of tokens to generate.", - ) - max_model_len: int = Field(default=4096, description="Maximum context length to use during serving.") - max_num_seqs: int = Field(default=4, description="Maximum parallel batch size for generation.") - enforce_eager: bool = Field( - default=False, - description="Whether to use eager mode for inference (otherwise cuda graphs are used).", - ) - gpu_memory_utilization: float = Field( - default=0.3, - description=( - "How much GPU memory will be allocated when this provider has finished " - "loading, including memory that was already allocated before loading." - ), - ) - - @classmethod - def sample_run_config(cls, **kwargs: Any) -> dict[str, Any]: - return { - "tensor_parallel_size": "${env.TENSOR_PARALLEL_SIZE:=1}", - "max_tokens": "${env.MAX_TOKENS:=4096}", - "max_model_len": "${env.MAX_MODEL_LEN:=4096}", - "max_num_seqs": "${env.MAX_NUM_SEQS:=4}", - "enforce_eager": "${env.ENFORCE_EAGER:=False}", - "gpu_memory_utilization": "${env.GPU_MEMORY_UTILIZATION:=0.3}", - } diff --git a/llama_stack/providers/inline/inference/vllm/openai_utils.py b/llama_stack/providers/inline/inference/vllm/openai_utils.py deleted file mode 100644 index 77cbf0403..000000000 --- a/llama_stack/providers/inline/inference/vllm/openai_utils.py +++ /dev/null @@ -1,170 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - - -import vllm - -from llama_stack.apis.inference import ( - ChatCompletionRequest, - GrammarResponseFormat, - JsonSchemaResponseFormat, - Message, - ToolChoice, - ToolDefinition, - UserMessage, -) -from llama_stack.models.llama.datatypes import BuiltinTool -from llama_stack.providers.utils.inference.openai_compat import ( - convert_message_to_openai_dict, - get_sampling_options, -) - -############################################################################### -# This file contains OpenAI compatibility code that is currently only used -# by the inline vLLM connector. Some or all of this code may be moved to a -# central location at a later date. - - -def _merge_context_into_content(message: Message) -> Message: # type: ignore - """ - Merge the ``context`` field of a Llama Stack ``Message`` object into - the content field for compabilitiy with OpenAI-style APIs. - - Generates a content string that emulates the current behavior - of ``llama_models.llama3.api.chat_format.encode_message()``. - - :param message: Message that may include ``context`` field - - :returns: A version of ``message`` with any context merged into the - ``content`` field. - """ - if not isinstance(message, UserMessage): # Separate type check for linter - return message - if message.context is None: - return message - return UserMessage( - role=message.role, - # Emumate llama_models.llama3.api.chat_format.encode_message() - content=message.content + "\n\n" + message.context, - context=None, - ) - - -def _llama_stack_tools_to_openai_tools( - tools: list[ToolDefinition] | None = None, -) -> list[vllm.entrypoints.openai.protocol.ChatCompletionToolsParam]: - """ - Convert the list of available tools from Llama Stack's format to vLLM's - version of OpenAI's format. - """ - if tools is None: - return [] - - result = [] - for t in tools: - if isinstance(t.tool_name, BuiltinTool): - raise NotImplementedError("Built-in tools not yet implemented") - if t.parameters is None: - parameters = None - else: # if t.parameters is not None - # Convert the "required" flags to a list of required params - required_params = [k for k, v in t.parameters.items() if v.required] - parameters = { - "type": "object", # Mystery value that shows up in OpenAI docs - "properties": { - k: {"type": v.param_type, "description": v.description} for k, v in t.parameters.items() - }, - "required": required_params, - } - - function_def = vllm.entrypoints.openai.protocol.FunctionDefinition( - name=t.tool_name, description=t.description, parameters=parameters - ) - - # Every tool definition is double-boxed in a ChatCompletionToolsParam - result.append(vllm.entrypoints.openai.protocol.ChatCompletionToolsParam(function=function_def)) - return result - - -async def llama_stack_chat_completion_to_openai_chat_completion_dict( - request: ChatCompletionRequest, -) -> dict: - """ - Convert a chat completion request in Llama Stack format into an - equivalent set of arguments to pass to an OpenAI-compatible - chat completions API. - - :param request: Bundled request parameters in Llama Stack format. - - :returns: Dictionary of key-value pairs to use as an initializer - for a dataclass or to be converted directly to JSON and sent - over the wire. - """ - - converted_messages = [ - # This mystery async call makes the parent function also be async - await convert_message_to_openai_dict(_merge_context_into_content(m), download=True) - for m in request.messages - ] - converted_tools = _llama_stack_tools_to_openai_tools(request.tools) - - # Llama will try to use built-in tools with no tool catalog, so don't enable - # tool choice unless at least one tool is enabled. - converted_tool_choice = "none" - if ( - request.tool_config is not None - and request.tool_config.tool_choice == ToolChoice.auto - and request.tools is not None - and len(request.tools) > 0 - ): - converted_tool_choice = "auto" - - # TODO: Figure out what to do with the tool_prompt_format argument. - # Other connectors appear to drop it quietly. - - # Use Llama Stack shared code to translate sampling parameters. - sampling_options = get_sampling_options(request.sampling_params) - - # get_sampling_options() translates repetition penalties to an option that - # OpenAI's APIs don't know about. - # vLLM's OpenAI-compatible API also handles repetition penalties wrong. - # For now, translate repetition penalties into a format that vLLM's broken - # API will handle correctly. Two wrongs make a right... - if "repeat_penalty" in sampling_options: - del sampling_options["repeat_penalty"] - if request.sampling_params.repetition_penalty is not None and request.sampling_params.repetition_penalty != 1.0: - sampling_options["repetition_penalty"] = request.sampling_params.repetition_penalty - - # Convert a single response format into four different parameters, per - # the OpenAI spec - guided_decoding_options = dict() - if request.response_format is None: - # Use defaults - pass - elif isinstance(request.response_format, JsonSchemaResponseFormat): - guided_decoding_options["guided_json"] = request.response_format.json_schema - elif isinstance(request.response_format, GrammarResponseFormat): - guided_decoding_options["guided_grammar"] = request.response_format.bnf - else: - raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(request.response_format)}'") - - logprob_options = dict() - if request.logprobs is not None: - logprob_options["logprobs"] = request.logprobs.top_k - - # Marshall together all the arguments for a ChatCompletionRequest - request_options = { - "model": request.model, - "messages": converted_messages, - "tools": converted_tools, - "tool_choice": converted_tool_choice, - "stream": request.stream, - **sampling_options, - **guided_decoding_options, - **logprob_options, - } - - return request_options diff --git a/llama_stack/providers/inline/inference/vllm/vllm.py b/llama_stack/providers/inline/inference/vllm/vllm.py deleted file mode 100644 index bf54462b5..000000000 --- a/llama_stack/providers/inline/inference/vllm/vllm.py +++ /dev/null @@ -1,811 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -import json -import re -import uuid -from collections.abc import AsyncGenerator, AsyncIterator - -# These vLLM modules contain names that overlap with Llama Stack names, so we import -# fully-qualified names -import vllm.entrypoints.openai.protocol -import vllm.sampling_params -from vllm.engine.arg_utils import AsyncEngineArgs -from vllm.engine.async_llm_engine import AsyncLLMEngine -from vllm.entrypoints.openai.serving_chat import OpenAIServingChat -from vllm.entrypoints.openai.serving_models import BaseModelPath, OpenAIServingModels - -from llama_stack.apis.common.content_types import ( - InterleavedContent, - InterleavedContentItem, - TextDelta, - ToolCallDelta, -) -from llama_stack.apis.inference import ( - ChatCompletionRequest, - ChatCompletionResponse, - ChatCompletionResponseEvent, - ChatCompletionResponseEventType, - ChatCompletionResponseStreamChunk, - CompletionMessage, - CompletionResponse, - CompletionResponseStreamChunk, - EmbeddingsResponse, - EmbeddingTaskType, - GrammarResponseFormat, - Inference, - JsonSchemaResponseFormat, - LogProbConfig, - Message, - OpenAIEmbeddingsResponse, - ResponseFormat, - SamplingParams, - TextTruncation, - TokenLogProbs, - ToolChoice, - ToolConfig, - TopKSamplingStrategy, - TopPSamplingStrategy, -) -from llama_stack.apis.models import Model -from llama_stack.log import get_logger -from llama_stack.models.llama import sku_list -from llama_stack.models.llama.datatypes import ( - StopReason, - ToolCall, - ToolDefinition, - ToolPromptFormat, -) -from llama_stack.models.llama.llama3.chat_format import ChatFormat -from llama_stack.models.llama.llama3.tokenizer import Tokenizer -from llama_stack.providers.remote.inference.vllm.vllm import build_hf_repo_model_entries -from llama_stack.providers.utils.inference.model_registry import ( - ModelRegistryHelper, - ModelsProtocolPrivate, -) -from llama_stack.providers.utils.inference.openai_compat import ( - OpenAIChatCompletionToLlamaStackMixin, - OpenAICompatCompletionChoice, - OpenAICompatCompletionResponse, - OpenAICompletionToLlamaStackMixin, - get_stop_reason, - process_chat_completion_stream_response, -) -from llama_stack.providers.utils.inference.prompt_adapter import ( - chat_completion_request_to_prompt, -) - -from .config import VLLMConfig -from .openai_utils import llama_stack_chat_completion_to_openai_chat_completion_dict - -# Map from Hugging Face model architecture name to appropriate tool parser. -# See vllm.entrypoints.openai.tool_parsers.ToolParserManager.tool_parsers for the full list of -# available parsers. -# TODO: Expand this list -CONFIG_TYPE_TO_TOOL_PARSER = { - "GraniteConfig": "granite", - "MllamaConfig": "llama3_json", - "LlamaConfig": "llama3_json", -} -DEFAULT_TOOL_PARSER = "pythonic" - - -logger = get_logger(__name__, category="inference") - - -def _random_uuid_str() -> str: - return str(uuid.uuid4().hex) - - -def _response_format_to_guided_decoding_params( - response_format: ResponseFormat | None, # type: ignore -) -> vllm.sampling_params.GuidedDecodingParams: - """ - Translate constrained decoding parameters from Llama Stack's format to vLLM's format. - - :param response_format: Llama Stack version of constrained decoding info. Can be ``None``, - indicating no constraints. - :returns: The equivalent dataclass object for the low-level inference layer of vLLM. - """ - if response_format is None: - # As of vLLM 0.6.3, the default constructor for GuidedDecodingParams() returns an invalid - # value that crashes the executor on some code paths. Use ``None`` instead. - return None - - # Llama Stack currently implements fewer types of constrained decoding than vLLM does. - # Translate the types that exist and detect if Llama Stack adds new ones. - if isinstance(response_format, JsonSchemaResponseFormat): - return vllm.sampling_params.GuidedDecodingParams(json=response_format.json_schema) - elif isinstance(response_format, GrammarResponseFormat): - # BNF grammar. - # Llama Stack uses the parse tree of the grammar, while vLLM uses the string - # representation of the grammar. - raise TypeError( - "Constrained decoding with BNF grammars is not currently implemented, because the " - "reference implementation does not implement it." - ) - else: - raise TypeError(f"ResponseFormat object is of unexpected subtype '{type(response_format)}'") - - -def _convert_sampling_params( - sampling_params: SamplingParams | None, - response_format: ResponseFormat | None, # type: ignore - log_prob_config: LogProbConfig | None, -) -> vllm.SamplingParams: - """Convert sampling and constrained decoding configuration from Llama Stack's format to vLLM's - format.""" - # In the absence of provided config values, use Llama Stack defaults as encoded in the Llama - # Stack dataclasses. These defaults are different from vLLM's defaults. - if sampling_params is None: - sampling_params = SamplingParams() - if log_prob_config is None: - log_prob_config = LogProbConfig() - - if isinstance(sampling_params.strategy, TopKSamplingStrategy): - if sampling_params.strategy.top_k == 0: - # vLLM treats "k" differently for top-k sampling - vllm_top_k = -1 - else: - vllm_top_k = sampling_params.strategy.top_k - else: - vllm_top_k = -1 - - if isinstance(sampling_params.strategy, TopPSamplingStrategy): - vllm_top_p = sampling_params.strategy.top_p - # Llama Stack only allows temperature with top-P. - vllm_temperature = sampling_params.strategy.temperature - else: - vllm_top_p = 1.0 - vllm_temperature = 0.0 - - # vLLM allows top-p and top-k at the same time. - vllm_sampling_params = vllm.SamplingParams.from_optional( - max_tokens=(None if sampling_params.max_tokens == 0 else sampling_params.max_tokens), - temperature=vllm_temperature, - top_p=vllm_top_p, - top_k=vllm_top_k, - repetition_penalty=sampling_params.repetition_penalty, - guided_decoding=_response_format_to_guided_decoding_params(response_format), - logprobs=log_prob_config.top_k, - ) - return vllm_sampling_params - - -class VLLMInferenceImpl( - Inference, - OpenAIChatCompletionToLlamaStackMixin, - OpenAICompletionToLlamaStackMixin, - ModelsProtocolPrivate, -): - """ - vLLM-based inference model adapter for Llama Stack with support for multiple models. - - Requires the configuration parameters documented in the :class:`VllmConfig2` class. - """ - - config: VLLMConfig - register_helper: ModelRegistryHelper - model_ids: set[str] - resolved_model_id: str | None - engine: AsyncLLMEngine | None - chat: OpenAIServingChat | None - is_meta_llama_model: bool - - def __init__(self, config: VLLMConfig): - self.config = config - logger.info(f"Config is: {self.config}") - - self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) - self.formatter = ChatFormat(Tokenizer.get_instance()) - - # The following are initialized when paths are bound to this provider - self.resolved_model_id = None - self.model_ids = set() - self.engine = None - self.chat = None - self.is_meta_llama_model = False - - ########################################################################### - # METHODS INHERITED FROM IMPLICIT BASE CLASS. - # TODO: Make this class inherit from the new base class ProviderBase once that class exists. - - async def initialize(self) -> None: - """ - Callback that is invoked through many levels of indirection during provider class - instantiation, sometime after when __init__() is called and before any model registration - methods or methods connected to a REST API are called. - - It's not clear what assumptions the class can make about the platform's initialization - state here that can't be made during __init__(), and vLLM can't be started until we know - what model it's supposed to be serving, so nothing happens here currently. - """ - pass - - async def shutdown(self) -> None: - logger.info(f"Shutting down inline vLLM inference provider {self}.") - if self.engine is not None: - self.engine.shutdown_background_loop() - self.engine = None - self.chat = None - self.model_ids = set() - self.resolved_model_id = None - - ########################################################################### - # METHODS INHERITED FROM ModelsProtocolPrivate INTERFACE - - # Note that the return type of the superclass method is WRONG - async def register_model(self, model: Model) -> Model: - """ - Callback that is called when the server associates an inference endpoint with an - inference provider. - - :param model: Object that encapsulates parameters necessary for identifying a specific - LLM. - - :returns: The input ``Model`` object. It may or may not be permissible to change fields - before returning this object. - """ - logger.debug(f"In register_model({model})") - - # First attempt to interpret the model coordinates as a Llama model name - resolved_llama_model = sku_list.resolve_model(model.provider_model_id) - if resolved_llama_model is not None: - # Load from Hugging Face repo into default local cache dir - model_id_for_vllm = resolved_llama_model.huggingface_repo - - # Detect a genuine Meta Llama model to trigger Meta-specific preprocessing. - # Don't set self.is_meta_llama_model until we actually load the model. - is_meta_llama_model = True - else: # if resolved_llama_model is None - # Not a Llama model name. Pass the model id through to vLLM's loader - model_id_for_vllm = model.provider_model_id - is_meta_llama_model = False - - if self.resolved_model_id is not None: - if model_id_for_vllm != self.resolved_model_id: - raise ValueError( - f"Attempted to serve two LLMs (ids '{self.resolved_model_id}') and " - f"'{model_id_for_vllm}') from one copy of provider '{self}'. Use multiple " - f"copies of the provider instead." - ) - else: - # Model already loaded - logger.info( - f"Requested id {model} resolves to {model_id_for_vllm}, which is already loaded. Continuing." - ) - self.model_ids.add(model.model_id) - return model - - logger.info(f"Requested id {model} resolves to {model_id_for_vllm}. Loading {model_id_for_vllm}.") - if is_meta_llama_model: - logger.info(f"Model {model_id_for_vllm} is a Meta Llama model.") - self.is_meta_llama_model = is_meta_llama_model - - # If we get here, this is the first time registering a model. - # Preload so that the first inference request won't time out. - engine_args = AsyncEngineArgs( - model=model_id_for_vllm, - tokenizer=model_id_for_vllm, - tensor_parallel_size=self.config.tensor_parallel_size, - enforce_eager=self.config.enforce_eager, - gpu_memory_utilization=self.config.gpu_memory_utilization, - max_num_seqs=self.config.max_num_seqs, - max_model_len=self.config.max_model_len, - ) - self.engine = AsyncLLMEngine.from_engine_args(engine_args) - - # vLLM currently requires the user to specify the tool parser manually. To choose a tool - # parser, we need to determine what model architecture is being used. For now, we infer - # that information from what config class the model uses. - low_level_model_config = self.engine.engine.get_model_config() - hf_config = low_level_model_config.hf_config - hf_config_class_name = hf_config.__class__.__name__ - if hf_config_class_name in CONFIG_TYPE_TO_TOOL_PARSER: - tool_parser = CONFIG_TYPE_TO_TOOL_PARSER[hf_config_class_name] - else: - # No info -- choose a default so we can at least attempt tool - # use. - tool_parser = DEFAULT_TOOL_PARSER - logger.debug(f"{hf_config_class_name=}") - logger.debug(f"{tool_parser=}") - - # Wrap the lower-level engine in an OpenAI-compatible chat API - model_config = await self.engine.get_model_config() - self.chat = OpenAIServingChat( - engine_client=self.engine, - model_config=model_config, - models=OpenAIServingModels( - engine_client=self.engine, - model_config=model_config, - base_model_paths=[ - # The layer below us will only see resolved model IDs - BaseModelPath(model_id_for_vllm, model_id_for_vllm) - ], - ), - response_role="assistant", - request_logger=None, # Use default logging - chat_template=None, # Use default template from model checkpoint - enable_auto_tools=True, - tool_parser=tool_parser, - chat_template_content_format="auto", - ) - self.resolved_model_id = model_id_for_vllm - self.model_ids.add(model.model_id) - - logger.info(f"Finished preloading model: {model_id_for_vllm}") - - return model - - async def unregister_model(self, model_id: str) -> None: - """ - Callback that is called when the server removes an inference endpoint from an inference - provider. - - :param model_id: The same external ID that the higher layers of the stack previously passed - to :func:`register_model()` - """ - if model_id not in self.model_ids: - raise ValueError( - f"Attempted to unregister model ID '{model_id}', but that ID is not registered to this provider." - ) - self.model_ids.remove(model_id) - - if len(self.model_ids) == 0: - # Last model was just unregistered. Shut down the connection to vLLM and free up - # resources. - # Note that this operation may cause in-flight chat completion requests on the - # now-unregistered model to return errors. - self.resolved_model_id = None - self.chat = None - self.engine.shutdown_background_loop() - self.engine = None - - ########################################################################### - # METHODS INHERITED FROM Inference INTERFACE - - async def completion( - self, - model_id: str, - content: InterleavedContent, - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - ) -> CompletionResponse | AsyncIterator[CompletionResponseStreamChunk]: - if model_id not in self.model_ids: - raise ValueError( - f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" - ) - if not isinstance(content, str): - raise NotImplementedError("Multimodal input not currently supported") - if sampling_params is None: - sampling_params = SamplingParams() - - converted_sampling_params = _convert_sampling_params(sampling_params, response_format, logprobs) - - logger.debug(f"{converted_sampling_params=}") - - if stream: - return self._streaming_completion(content, converted_sampling_params) - else: - streaming_result = None - async for _ in self._streaming_completion(content, converted_sampling_params): - pass - return CompletionResponse( - content=streaming_result.delta, - stop_reason=streaming_result.stop_reason, - logprobs=streaming_result.logprobs, - ) - - async def embeddings( - self, - model_id: str, - contents: list[str] | list[InterleavedContentItem], - text_truncation: TextTruncation | None = TextTruncation.none, - output_dimension: int | None = None, - task_type: EmbeddingTaskType | None = None, - ) -> EmbeddingsResponse: - raise NotImplementedError() - - async def openai_embeddings( - self, - model: str, - input: str | list[str], - encoding_format: str | None = "float", - dimensions: int | None = None, - user: str | None = None, - ) -> OpenAIEmbeddingsResponse: - raise NotImplementedError() - - async def chat_completion( - self, - model_id: str, - messages: list[Message], # type: ignore - sampling_params: SamplingParams | None = None, - response_format: ResponseFormat | None = None, # type: ignore - tools: list[ToolDefinition] | None = None, - tool_choice: ToolChoice | None = ToolChoice.auto, - tool_prompt_format: ToolPromptFormat | None = None, - stream: bool | None = False, - logprobs: LogProbConfig | None = None, - tool_config: ToolConfig | None = None, - ) -> ChatCompletionResponse | ChatCompletionResponseStreamChunk: - sampling_params = sampling_params or SamplingParams() - if model_id not in self.model_ids: - raise ValueError( - f"This adapter is not registered to model id '{model_id}'. Registered IDs are: {self.model_ids}" - ) - - # Convert to Llama Stack internal format for consistency - request = ChatCompletionRequest( - model=self.resolved_model_id, - messages=messages, - sampling_params=sampling_params, - response_format=response_format, - tools=tools, - tool_choice=tool_choice, - tool_prompt_format=tool_prompt_format, - stream=stream, - logprobs=logprobs, - ) - - if self.is_meta_llama_model: - # Bypass vLLM chat templating layer for Meta Llama models, because the - # templating layer in Llama Stack currently produces better results. - logger.debug( - f"Routing {self.resolved_model_id} chat completion through " - f"Llama Stack's templating layer instead of vLLM's." - ) - return await self._chat_completion_for_meta_llama(request) - - logger.debug(f"{self.resolved_model_id} is not a Meta Llama model") - - # Arguments to the vLLM call must be packaged as a ChatCompletionRequest dataclass. - # Note that this dataclass has the same name as a similar dataclass in Llama Stack. - request_options = await llama_stack_chat_completion_to_openai_chat_completion_dict(request) - chat_completion_request = vllm.entrypoints.openai.protocol.ChatCompletionRequest(**request_options) - - logger.debug(f"Converted request: {chat_completion_request}") - - vllm_result = await self.chat.create_chat_completion(chat_completion_request) - logger.debug(f"Result from vLLM: {vllm_result}") - if isinstance(vllm_result, vllm.entrypoints.openai.protocol.ErrorResponse): - raise ValueError(f"Error from vLLM layer: {vllm_result}") - - # Return type depends on "stream" argument - if stream: - if not isinstance(vllm_result, AsyncGenerator): - raise TypeError(f"Unexpected result type {type(vllm_result)} for streaming inference call") - # vLLM client returns a stream of strings, which need to be parsed. - # Stream comes in the form of an async generator. - return self._convert_streaming_results(vllm_result) - else: - if not isinstance(vllm_result, vllm.entrypoints.openai.protocol.ChatCompletionResponse): - raise TypeError(f"Unexpected result type {type(vllm_result)} for non-streaming inference call") - return self._convert_non_streaming_results(vllm_result) - - ########################################################################### - # INTERNAL METHODS - - async def _streaming_completion( - self, content: str, sampling_params: vllm.SamplingParams - ) -> AsyncIterator[CompletionResponseStreamChunk]: - """Internal implementation of :func:`completion()` API for the streaming case. Assumes - that arguments have been validated upstream. - - :param content: Must be a string - :param sampling_params: Paramters from public API's ``response_format`` - and ``sampling_params`` arguments, converted to VLLM format - """ - # We run agains the vLLM generate() call directly instead of using the OpenAI-compatible - # layer, because doing so simplifies the code here. - - # The vLLM engine requires a unique identifier for each call to generate() - request_id = _random_uuid_str() - - # The vLLM generate() API is streaming-only and returns an async generator. - # The generator returns objects of type vllm.RequestOutput. - results_generator = self.engine.generate(content, sampling_params, request_id) - - # Need to know the model's EOS token ID for the conversion code below. - # AsyncLLMEngine is a wrapper around LLMEngine, and the tokenizer is only available if - # we drill down to the LLMEngine inside the AsyncLLMEngine. - # Similarly, the tokenizer in an LLMEngine is a wrapper around a BaseTokenizerGroup, - # and we need to drill down to the Hugging Face tokenizer inside the BaseTokenizerGroup. - llm_engine = self.engine.engine - tokenizer_group = llm_engine.tokenizer - eos_token_id = tokenizer_group.tokenizer.eos_token_id - - request_output: vllm.RequestOutput = None - async for request_output in results_generator: - # Check for weird inference failures - if request_output.outputs is None or len(request_output.outputs) == 0: - # This case also should never happen - raise ValueError("Inference produced empty result") - - # If we get here, then request_output contains the final output of the generate() call. - # The result may include multiple alternate outputs, but Llama Stack APIs only allow - # us to return one. - output: vllm.CompletionOutput = request_output.outputs[0] - completion_string = output.text - - # Convert logprobs from vLLM's format to Llama Stack's format - logprobs = [ - TokenLogProbs(logprobs_by_token={v.decoded_token: v.logprob for _, v in logprob_dict.items()}) - for logprob_dict in output.logprobs - ] - - # The final output chunk should be labeled with the reason that the overall generate() - # call completed. - logger.debug(f"{output.stop_reason=}; {type(output.stop_reason)=}") - if output.stop_reason is None: - stop_reason = None # Still going - elif output.stop_reason == "stop": - stop_reason = StopReason.end_of_turn - elif output.stop_reason == "length": - stop_reason = StopReason.out_of_tokens - elif isinstance(output.stop_reason, int): - # If the model config specifies multiple end-of-sequence tokens, then vLLM - # will return the token ID of the EOS token in the stop_reason field. - stop_reason = StopReason.end_of_turn - else: - raise ValueError(f"Unrecognized stop reason '{output.stop_reason}'") - - # vLLM's protocol outputs the stop token, then sets end of message on the next step for - # some reason. - if request_output.outputs[-1].token_ids[-1] == eos_token_id: - stop_reason = StopReason.end_of_message - - yield CompletionResponseStreamChunk(delta=completion_string, stop_reason=stop_reason, logprobs=logprobs) - - # Llama Stack requires that the last chunk have a stop reason, but vLLM doesn't always - # provide one if it runs out of tokens. - if stop_reason is None: - yield CompletionResponseStreamChunk( - delta=completion_string, - stop_reason=StopReason.out_of_tokens, - logprobs=logprobs, - ) - - def _convert_non_streaming_results( - self, vllm_result: vllm.entrypoints.openai.protocol.ChatCompletionResponse - ) -> ChatCompletionResponse: - """ - Subroutine to convert the non-streaming output of vLLM's OpenAI-compatible API into an - equivalent Llama Stack object. - - The result from vLLM's non-streaming API is a dataclass with the same name as the Llama - Stack ChatCompletionResponse dataclass, but with more and different field names. We ignore - the fields that aren't currently present in the Llama Stack dataclass. - """ - - # There may be multiple responses, but we can only pass through the first one. - if len(vllm_result.choices) == 0: - raise ValueError("Don't know how to convert response object without any responses") - vllm_message = vllm_result.choices[0].message - vllm_finish_reason = vllm_result.choices[0].finish_reason - - converted_message = CompletionMessage( - role=vllm_message.role, - # Llama Stack API won't accept None for content field. - content=("" if vllm_message.content is None else vllm_message.content), - stop_reason=get_stop_reason(vllm_finish_reason), - tool_calls=[ - ToolCall( - call_id=t.id, - tool_name=t.function.name, - # vLLM function args come back as a string. Llama Stack expects JSON. - arguments=json.loads(t.function.arguments), - arguments_json=t.function.arguments, - ) - for t in vllm_message.tool_calls - ], - ) - - # TODO: Convert logprobs - - logger.debug(f"Converted message: {converted_message}") - - return ChatCompletionResponse( - completion_message=converted_message, - ) - - async def _chat_completion_for_meta_llama( - self, request: ChatCompletionRequest - ) -> ChatCompletionResponse | AsyncIterator[ChatCompletionResponseStreamChunk]: - """ - Subroutine that routes chat completions for Meta Llama models through Llama Stack's - chat template instead of using vLLM's version of that template. The Llama Stack version - of the chat template currently produces more reliable outputs. - - Once vLLM's support for Meta Llama models has matured more, we should consider routing - Meta Llama requests through the vLLM chat completions API instead of using this method. - """ - formatter = ChatFormat(Tokenizer.get_instance()) - - # Note that this function call modifies `request` in place. - prompt = await chat_completion_request_to_prompt(request, self.resolved_model_id) - - model_id = list(self.model_ids)[0] # Any model ID will do here - completion_response_or_iterator = await self.completion( - model_id=model_id, - content=prompt, - sampling_params=request.sampling_params, - response_format=request.response_format, - stream=request.stream, - logprobs=request.logprobs, - ) - - if request.stream: - if not isinstance(completion_response_or_iterator, AsyncIterator): - raise TypeError( - f"Received unexpected result type {type(completion_response_or_iterator)}for streaming request." - ) - return self._chat_completion_for_meta_llama_streaming(completion_response_or_iterator, request) - - # elsif not request.stream: - if not isinstance(completion_response_or_iterator, CompletionResponse): - raise TypeError( - f"Received unexpected result type {type(completion_response_or_iterator)}for non-streaming request." - ) - completion_response: CompletionResponse = completion_response_or_iterator - raw_message = formatter.decode_assistant_message_from_content( - completion_response.content, completion_response.stop_reason - ) - return ChatCompletionResponse( - completion_message=CompletionMessage( - content=raw_message.content, - stop_reason=raw_message.stop_reason, - tool_calls=raw_message.tool_calls, - ), - logprobs=completion_response.logprobs, - ) - - async def _chat_completion_for_meta_llama_streaming( - self, results_iterator: AsyncIterator, request: ChatCompletionRequest - ) -> AsyncIterator: - """ - Code from :func:`_chat_completion_for_meta_llama()` that needs to be a separate - method to keep asyncio happy. - """ - - # Convert to OpenAI format, then use shared code to convert to Llama Stack format. - async def _generate_and_convert_to_openai_compat(): - chunk: CompletionResponseStreamChunk # Make Pylance happy - last_text_len = 0 - async for chunk in results_iterator: - if chunk.stop_reason == StopReason.end_of_turn: - finish_reason = "stop" - elif chunk.stop_reason == StopReason.end_of_message: - finish_reason = "eos" - elif chunk.stop_reason == StopReason.out_of_tokens: - finish_reason = "length" - else: - finish_reason = None - - # Convert delta back to an actual delta - text_delta = chunk.delta[last_text_len:] - last_text_len = len(chunk.delta) - - logger.debug(f"{text_delta=}; {finish_reason=}") - - yield OpenAICompatCompletionResponse( - choices=[OpenAICompatCompletionChoice(finish_reason=finish_reason, text=text_delta)] - ) - - stream = _generate_and_convert_to_openai_compat() - async for chunk in process_chat_completion_stream_response(stream, request): - logger.debug(f"Returning chunk: {chunk}") - yield chunk - - async def _convert_streaming_results(self, vllm_result: AsyncIterator) -> AsyncIterator: - """ - Subroutine that wraps the streaming outputs of vLLM's OpenAI-compatible - API into a second async iterator that returns Llama Stack objects. - - :param vllm_result: Stream of strings that need to be parsed - """ - # Tool calls come in pieces, but Llama Stack expects them in bigger chunks. We build up - # those chunks and output them at the end. - # This data structure holds the current set of partial tool calls. - index_to_tool_call: dict[int, dict] = dict() - - # The Llama Stack event stream must always start with a start event. Use an empty one to - # simplify logic below - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.start, - delta=TextDelta(text=""), - stop_reason=None, - ) - ) - - converted_stop_reason = None - async for chunk_str in vllm_result: - # Due to OpenAI compatibility, each event in the stream will start with "data: " and - # end with "\n\n". - _prefix = "data: " - _suffix = "\n\n" - if not chunk_str.startswith(_prefix) or not chunk_str.endswith(_suffix): - raise ValueError(f"Can't parse result string from vLLM: '{re.escape(chunk_str)}'") - - # In between the "data: " and newlines is an event record - data_str = chunk_str[len(_prefix) : -len(_suffix)] - - # The end of the stream is indicated with "[DONE]" - if data_str == "[DONE]": - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.complete, - delta=TextDelta(text=""), - stop_reason=converted_stop_reason, - ) - ) - return - - # Anything that is not "[DONE]" should be a JSON record - parsed_chunk = json.loads(data_str) - - logger.debug(f"Parsed JSON event to:\n{json.dumps(parsed_chunk, indent=2)}") - - # The result may contain multiple completions, but Llama Stack APIs only support - # returning one. - first_choice = parsed_chunk["choices"][0] - converted_stop_reason = get_stop_reason(first_choice["finish_reason"]) - delta_record = first_choice["delta"] - - if "content" in delta_record: - # Text delta - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=TextDelta(text=delta_record["content"]), - stop_reason=converted_stop_reason, - ) - ) - elif "tool_calls" in delta_record: - # Tool call(s). Llama Stack APIs do not have a clear way to return partial tool - # calls, so buffer until we get a "tool calls" stop reason - for tc in delta_record["tool_calls"]: - index = tc["index"] - if index not in index_to_tool_call: - # First time this tool call is showing up - index_to_tool_call[index] = dict() - tool_call = index_to_tool_call[index] - if "id" in tc: - tool_call["call_id"] = tc["id"] - if "function" in tc: - if "name" in tc["function"]: - tool_call["tool_name"] = tc["function"]["name"] - if "arguments" in tc["function"]: - # Arguments comes in as pieces of a string - if "arguments_str" not in tool_call: - tool_call["arguments_str"] = "" - tool_call["arguments_str"] += tc["function"]["arguments"] - else: - raise ValueError(f"Don't know how to parse event delta: {delta_record}") - - if first_choice["finish_reason"] == "tool_calls": - # Special OpenAI code for "tool calls complete". - # Output the buffered tool calls. Llama Stack requires a separate event per tool - # call. - for tool_call_record in index_to_tool_call.values(): - # Arguments come in as a string. Parse the completed string. - tool_call_record["arguments"] = json.loads(tool_call_record["arguments_str"]) - del tool_call_record["arguments_str"] - - yield ChatCompletionResponseStreamChunk( - event=ChatCompletionResponseEvent( - event_type=ChatCompletionResponseEventType.progress, - delta=ToolCallDelta(tool_call=tool_call_record, parse_status="succeeded"), - stop_reason=converted_stop_reason, - ) - ) - - # If we get here, we've lost the connection with the vLLM event stream before it ended - # normally. - raise ValueError("vLLM event stream ended without [DONE] message.") diff --git a/llama_stack/providers/registry/inference.py b/llama_stack/providers/registry/inference.py index 217870ec9..ffd30a5b5 100644 --- a/llama_stack/providers/registry/inference.py +++ b/llama_stack/providers/registry/inference.py @@ -37,16 +37,6 @@ def available_providers() -> list[ProviderSpec]: config_class="llama_stack.providers.inline.inference.meta_reference.MetaReferenceInferenceConfig", description="Meta's reference implementation of inference with support for various model formats and optimization techniques.", ), - InlineProviderSpec( - api=Api.inference, - provider_type="inline::vllm", - pip_packages=[ - "vllm", - ], - module="llama_stack.providers.inline.inference.vllm", - config_class="llama_stack.providers.inline.inference.vllm.VLLMConfig", - description="vLLM inference provider for high-performance model serving with PagedAttention and continuous batching.", - ), InlineProviderSpec( api=Api.inference, provider_type="inline::sentence-transformers", diff --git a/llama_stack/templates/vllm-gpu/__init__.py b/llama_stack/templates/vllm-gpu/__init__.py deleted file mode 100644 index 7b3d59a01..000000000 --- a/llama_stack/templates/vllm-gpu/__init__.py +++ /dev/null @@ -1,7 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from .vllm import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/vllm-gpu/build.yaml b/llama_stack/templates/vllm-gpu/build.yaml deleted file mode 100644 index 147dca50d..000000000 --- a/llama_stack/templates/vllm-gpu/build.yaml +++ /dev/null @@ -1,35 +0,0 @@ -version: 2 -distribution_spec: - description: Use a built-in vLLM engine for running LLM inference - providers: - inference: - - inline::vllm - - inline::sentence-transformers - vector_io: - - inline::faiss - - remote::chromadb - - remote::pgvector - safety: - - inline::llama-guard - agents: - - inline::meta-reference - telemetry: - - inline::meta-reference - eval: - - inline::meta-reference - datasetio: - - remote::huggingface - - inline::localfs - scoring: - - inline::basic - - inline::llm-as-judge - - inline::braintrust - tool_runtime: - - remote::brave-search - - remote::tavily-search - - inline::rag-runtime - - remote::model-context-protocol -image_type: conda -additional_pip_packages: -- aiosqlite -- sqlalchemy[asyncio] diff --git a/llama_stack/templates/vllm-gpu/run.yaml b/llama_stack/templates/vllm-gpu/run.yaml deleted file mode 100644 index 4241569a4..000000000 --- a/llama_stack/templates/vllm-gpu/run.yaml +++ /dev/null @@ -1,132 +0,0 @@ -version: 2 -image_name: vllm-gpu -apis: -- agents -- datasetio -- eval -- inference -- safety -- scoring -- telemetry -- tool_runtime -- vector_io -providers: - inference: - - provider_id: vllm - provider_type: inline::vllm - config: - tensor_parallel_size: ${env.TENSOR_PARALLEL_SIZE:=1} - max_tokens: ${env.MAX_TOKENS:=4096} - max_model_len: ${env.MAX_MODEL_LEN:=4096} - max_num_seqs: ${env.MAX_NUM_SEQS:=4} - enforce_eager: ${env.ENFORCE_EAGER:=False} - gpu_memory_utilization: ${env.GPU_MEMORY_UTILIZATION:=0.3} - - provider_id: sentence-transformers - provider_type: inline::sentence-transformers - config: {} - vector_io: - - provider_id: faiss - provider_type: inline::faiss - config: - kvstore: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/faiss_store.db - safety: - - provider_id: llama-guard - provider_type: inline::llama-guard - config: - excluded_categories: [] - agents: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - persistence_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/agents_store.db - responses_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/responses_store.db - telemetry: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" - sinks: ${env.TELEMETRY_SINKS:=console,sqlite} - sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/trace_store.db - otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} - eval: - - provider_id: meta-reference - provider_type: inline::meta-reference - config: - kvstore: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/meta_reference_eval.db - datasetio: - - provider_id: huggingface - provider_type: remote::huggingface - config: - kvstore: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/huggingface_datasetio.db - - provider_id: localfs - provider_type: inline::localfs - config: - kvstore: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/localfs_datasetio.db - scoring: - - provider_id: basic - provider_type: inline::basic - config: {} - - provider_id: llm-as-judge - provider_type: inline::llm-as-judge - config: {} - - provider_id: braintrust - provider_type: inline::braintrust - config: - openai_api_key: ${env.OPENAI_API_KEY:=} - tool_runtime: - - provider_id: brave-search - provider_type: remote::brave-search - config: - api_key: ${env.BRAVE_SEARCH_API_KEY:=} - max_results: 3 - - provider_id: tavily-search - provider_type: remote::tavily-search - config: - api_key: ${env.TAVILY_SEARCH_API_KEY:=} - max_results: 3 - - provider_id: rag-runtime - provider_type: inline::rag-runtime - config: {} - - provider_id: model-context-protocol - provider_type: remote::model-context-protocol - config: {} -metadata_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/registry.db -inference_store: - type: sqlite - db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/vllm-gpu}/inference_store.db -models: -- metadata: {} - model_id: ${env.INFERENCE_MODEL} - provider_id: vllm - model_type: llm -- metadata: - embedding_dimension: 384 - model_id: all-MiniLM-L6-v2 - provider_id: sentence-transformers - model_type: embedding -shields: [] -vector_dbs: [] -datasets: [] -scoring_fns: [] -benchmarks: [] -tool_groups: -- toolgroup_id: builtin::websearch - provider_id: tavily-search -- toolgroup_id: builtin::rag - provider_id: rag-runtime -server: - port: 8321 diff --git a/llama_stack/templates/vllm-gpu/vllm.py b/llama_stack/templates/vllm-gpu/vllm.py deleted file mode 100644 index 443fcd7a3..000000000 --- a/llama_stack/templates/vllm-gpu/vllm.py +++ /dev/null @@ -1,122 +0,0 @@ -# Copyright (c) Meta Platforms, Inc. and affiliates. -# All rights reserved. -# -# This source code is licensed under the terms described in the LICENSE file in -# the root directory of this source tree. - -from llama_stack.apis.models import ModelType -from llama_stack.distribution.datatypes import ModelInput, Provider -from llama_stack.providers.inline.inference.sentence_transformers import ( - SentenceTransformersInferenceConfig, -) -from llama_stack.providers.inline.inference.vllm import VLLMConfig -from llama_stack.providers.inline.vector_io.faiss.config import FaissVectorIOConfig -from llama_stack.templates.template import ( - DistributionTemplate, - RunConfigSettings, - ToolGroupInput, -) - - -def get_distribution_template() -> DistributionTemplate: - providers = { - "inference": ["inline::vllm", "inline::sentence-transformers"], - "vector_io": ["inline::faiss", "remote::chromadb", "remote::pgvector"], - "safety": ["inline::llama-guard"], - "agents": ["inline::meta-reference"], - "telemetry": ["inline::meta-reference"], - "eval": ["inline::meta-reference"], - "datasetio": ["remote::huggingface", "inline::localfs"], - "scoring": ["inline::basic", "inline::llm-as-judge", "inline::braintrust"], - "tool_runtime": [ - "remote::brave-search", - "remote::tavily-search", - "inline::rag-runtime", - "remote::model-context-protocol", - ], - } - - name = "vllm-gpu" - inference_provider = Provider( - provider_id="vllm", - provider_type="inline::vllm", - config=VLLMConfig.sample_run_config(), - ) - vector_io_provider = Provider( - provider_id="faiss", - provider_type="inline::faiss", - config=FaissVectorIOConfig.sample_run_config(f"~/.llama/distributions/{name}"), - ) - embedding_provider = Provider( - provider_id="sentence-transformers", - provider_type="inline::sentence-transformers", - config=SentenceTransformersInferenceConfig.sample_run_config(), - ) - - inference_model = ModelInput( - model_id="${env.INFERENCE_MODEL}", - provider_id="vllm", - ) - embedding_model = ModelInput( - model_id="all-MiniLM-L6-v2", - provider_id="sentence-transformers", - model_type=ModelType.embedding, - metadata={ - "embedding_dimension": 384, - }, - ) - default_tool_groups = [ - ToolGroupInput( - toolgroup_id="builtin::websearch", - provider_id="tavily-search", - ), - ToolGroupInput( - toolgroup_id="builtin::rag", - provider_id="rag-runtime", - ), - ] - - return DistributionTemplate( - name=name, - distro_type="self_hosted", - description="Use a built-in vLLM engine for running LLM inference", - container_image=None, - template_path=None, - providers=providers, - run_configs={ - "run.yaml": RunConfigSettings( - provider_overrides={ - "inference": [inference_provider, embedding_provider], - "vector_io": [vector_io_provider], - }, - default_models=[inference_model, embedding_model], - default_tool_groups=default_tool_groups, - ), - }, - run_config_env_vars={ - "LLAMA_STACK_PORT": ( - "8321", - "Port for the Llama Stack distribution server", - ), - "INFERENCE_MODEL": ( - "meta-llama/Llama-3.2-3B-Instruct", - "Inference model loaded into the vLLM engine", - ), - "TENSOR_PARALLEL_SIZE": ( - "1", - "Number of tensor parallel replicas (number of GPUs to use).", - ), - "MAX_TOKENS": ( - "4096", - "Maximum number of tokens to generate.", - ), - "ENFORCE_EAGER": ( - "False", - "Whether to use eager mode for inference (otherwise cuda graphs are used).", - ), - "GPU_MEMORY_UTILIZATION": ( - "0.7", - "GPU memory utilization for the vLLM engine.", - ), - }, - ) diff --git a/pyproject.toml b/pyproject.toml index 15e2e10b4..30e768dcb 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -257,7 +257,6 @@ exclude = [ "^llama_stack/models/llama/llama4/", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", - "^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/post_training/common/validator\\.py$", "^llama_stack/providers/inline/safety/code_scanner/", "^llama_stack/providers/inline/safety/llama_guard/", From 199f859eec8688b19b56a0b6f044daf8dc7e26ae Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Jul 2025 15:53:09 -0700 Subject: [PATCH 37/40] feat(vllm): periodically refresh models (#2823) Just like #2805 but for vLLM. We also make VLLM_URL env variable optional (not required) -- if not specified, the provider silently sits idle and yells eventually if someone tries to call a completion on it. This is done so as to allow this provider to be present in the `starter` distribution. ## Test Plan Set up vLLM, copy the starter template and set `{ refresh_models: true, refresh_models_interval: 10 }` for the vllm provider and then run: ``` ENABLE_VLLM=vllm VLLM_URL=http://localhost:8000/v1 \ uv run llama stack run --image-type venv /tmp/starter.yaml ``` Verify that `llama-stack-client models list` brings up the model correctly from vLLM. --- .../source/providers/inference/remote_vllm.md | 4 +- llama_stack/apis/inference/inference.py | 2 +- .../distribution/routing_tables/models.py | 8 +- .../remote/inference/ollama/ollama.py | 8 +- .../providers/remote/inference/vllm/config.py | 10 ++- .../providers/remote/inference/vllm/vllm.py | 78 ++++++++++++++++++- llama_stack/templates/starter/run.yaml | 2 +- 7 files changed, 98 insertions(+), 14 deletions(-) diff --git a/docs/source/providers/inference/remote_vllm.md b/docs/source/providers/inference/remote_vllm.md index 6c725fb41..5291199a4 100644 --- a/docs/source/providers/inference/remote_vllm.md +++ b/docs/source/providers/inference/remote_vllm.md @@ -12,11 +12,13 @@ Remote vLLM inference provider for connecting to vLLM servers. | `max_tokens` | `` | No | 4096 | Maximum number of tokens to generate. | | `api_token` | `str \| None` | No | fake | The API token | | `tls_verify` | `bool \| str` | No | True | Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file. | +| `refresh_models` | `` | No | False | Whether to refresh models periodically | +| `refresh_models_interval` | `` | No | 300 | Interval in seconds to refresh models | ## Sample Configuration ```yaml -url: ${env.VLLM_URL} +url: ${env.VLLM_URL:=} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} diff --git a/llama_stack/apis/inference/inference.py b/llama_stack/apis/inference/inference.py index 26de04b68..b2bb8a8e6 100644 --- a/llama_stack/apis/inference/inference.py +++ b/llama_stack/apis/inference/inference.py @@ -819,7 +819,7 @@ class OpenAIEmbeddingsResponse(BaseModel): class ModelStore(Protocol): async def get_model(self, identifier: str) -> Model: ... - async def update_registered_models( + async def update_registered_llm_models( self, provider_id: str, models: list[Model], diff --git a/llama_stack/distribution/routing_tables/models.py b/llama_stack/distribution/routing_tables/models.py index 90f8afa1c..9a9db7257 100644 --- a/llama_stack/distribution/routing_tables/models.py +++ b/llama_stack/distribution/routing_tables/models.py @@ -81,7 +81,7 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): raise ValueError(f"Model {model_id} not found") await self.unregister_object(existing_model) - async def update_registered_models( + async def update_registered_llm_models( self, provider_id: str, models: list[Model], @@ -92,12 +92,16 @@ class ModelsRoutingTable(CommonRoutingTableImpl, Models): # from run.yaml) that we need to keep track of model_ids = {} for model in existing_models: - if model.provider_id == provider_id: + # we leave embeddings models alone because often we don't get metadata + # (embedding dimension, etc.) from the provider + if model.provider_id == provider_id and model.model_type == ModelType.llm: model_ids[model.provider_resource_id] = model.identifier logger.debug(f"unregistering model {model.identifier}") await self.unregister_object(model) for model in models: + if model.model_type != ModelType.llm: + continue if model.provider_resource_id in model_ids: model.identifier = model_ids[model.provider_resource_id] diff --git a/llama_stack/providers/remote/inference/ollama/ollama.py b/llama_stack/providers/remote/inference/ollama/ollama.py index a1f7743d5..76d789d07 100644 --- a/llama_stack/providers/remote/inference/ollama/ollama.py +++ b/llama_stack/providers/remote/inference/ollama/ollama.py @@ -159,18 +159,18 @@ class OllamaInferenceAdapter( models = [] for m in response.models: model_type = ModelType.embedding if m.details.family in ["bert"] else ModelType.llm - # unfortunately, ollama does not provide embedding dimension in the model list :( - # we should likely add a hard-coded mapping of model name to embedding dimension + if model_type == ModelType.embedding: + continue models.append( Model( identifier=m.model, provider_resource_id=m.model, provider_id=provider_id, - metadata={"embedding_dimension": 384} if model_type == ModelType.embedding else {}, + metadata={}, model_type=model_type, ) ) - await self.model_store.update_registered_models(provider_id, models) + await self.model_store.update_registered_llm_models(provider_id, models) logger.debug(f"ollama refreshed model list ({len(models)} models)") await asyncio.sleep(self.config.refresh_models_interval) diff --git a/llama_stack/providers/remote/inference/vllm/config.py b/llama_stack/providers/remote/inference/vllm/config.py index e11efa7f0..ee72f974a 100644 --- a/llama_stack/providers/remote/inference/vllm/config.py +++ b/llama_stack/providers/remote/inference/vllm/config.py @@ -29,6 +29,14 @@ class VLLMInferenceAdapterConfig(BaseModel): default=True, description="Whether to verify TLS certificates. Can be a boolean or a path to a CA certificate file.", ) + refresh_models: bool = Field( + default=False, + description="Whether to refresh models periodically", + ) + refresh_models_interval: int = Field( + default=300, + description="Interval in seconds to refresh models", + ) @field_validator("tls_verify") @classmethod @@ -46,7 +54,7 @@ class VLLMInferenceAdapterConfig(BaseModel): @classmethod def sample_run_config( cls, - url: str = "${env.VLLM_URL}", + url: str = "${env.VLLM_URL:=}", **kwargs, ): return { diff --git a/llama_stack/providers/remote/inference/vllm/vllm.py b/llama_stack/providers/remote/inference/vllm/vllm.py index d1455acaa..8bdba1e88 100644 --- a/llama_stack/providers/remote/inference/vllm/vllm.py +++ b/llama_stack/providers/remote/inference/vllm/vllm.py @@ -3,8 +3,8 @@ # # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. +import asyncio import json -import logging from collections.abc import AsyncGenerator, AsyncIterator from typing import Any @@ -38,6 +38,7 @@ from llama_stack.apis.inference import ( JsonSchemaResponseFormat, LogProbConfig, Message, + ModelStore, OpenAIChatCompletion, OpenAICompletion, OpenAIEmbeddingData, @@ -54,6 +55,7 @@ from llama_stack.apis.inference import ( ToolPromptFormat, ) from llama_stack.apis.models import Model, ModelType +from llama_stack.log import get_logger from llama_stack.models.llama.datatypes import BuiltinTool, StopReason, ToolCall from llama_stack.models.llama.sku_list import all_registered_models from llama_stack.providers.datatypes import ( @@ -84,7 +86,7 @@ from llama_stack.providers.utils.inference.prompt_adapter import ( from .config import VLLMInferenceAdapterConfig -log = logging.getLogger(__name__) +log = get_logger(name=__name__, category="inference") def build_hf_repo_model_entries(): @@ -288,16 +290,76 @@ async def _process_vllm_chat_completion_stream_response( class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): + # automatically set by the resolver when instantiating the provider + __provider_id__: str + model_store: ModelStore | None = None + _refresh_task: asyncio.Task | None = None + def __init__(self, config: VLLMInferenceAdapterConfig) -> None: self.register_helper = ModelRegistryHelper(build_hf_repo_model_entries()) self.config = config self.client = None async def initialize(self) -> None: - pass + if not self.config.url: + # intentionally don't raise an error here, we want to allow the provider to be "dormant" + # or available in distributions like "starter" without causing a ruckus + return + + if self.config.refresh_models: + self._refresh_task = asyncio.create_task(self._refresh_models()) + + def cb(task): + import traceback + + if task.cancelled(): + log.error(f"vLLM background refresh task canceled:\n{''.join(traceback.format_stack())}") + elif task.exception(): + # print the stack trace for the exception + exc = task.exception() + log.error(f"vLLM background refresh task died: {exc}") + traceback.print_exception(exc) + else: + log.error("vLLM background refresh task completed unexpectedly") + + self._refresh_task.add_done_callback(cb) + + async def _refresh_models(self) -> None: + provider_id = self.__provider_id__ + waited_time = 0 + while not self.model_store and waited_time < 60: + await asyncio.sleep(1) + waited_time += 1 + + if not self.model_store: + raise ValueError("Model store not set after waiting 60 seconds") + + self._lazy_initialize_client() + assert self.client is not None # mypy + while True: + try: + models = [] + async for m in self.client.models.list(): + model_type = ModelType.llm # unclear how to determine embedding vs. llm models + models.append( + Model( + identifier=m.id, + provider_resource_id=m.id, + provider_id=provider_id, + metadata={}, + model_type=model_type, + ) + ) + await self.model_store.update_registered_llm_models(provider_id, models) + log.debug(f"vLLM refreshed model list ({len(models)} models)") + except Exception as e: + log.error(f"vLLM background refresh task failed: {e}") + await asyncio.sleep(self.config.refresh_models_interval) async def shutdown(self) -> None: - pass + if self._refresh_task: + self._refresh_task.cancel() + self._refresh_task = None async def unregister_model(self, model_id: str) -> None: pass @@ -312,6 +374,9 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): HealthResponse: A dictionary containing the health status. """ try: + if not self.config.url: + return HealthResponse(status=HealthStatus.ERROR, message="vLLM URL is not set") + client = self._create_client() if self.client is None else self.client _ = [m async for m in client.models.list()] # Ensure the client is initialized return HealthResponse(status=HealthStatus.OK) @@ -327,6 +392,11 @@ class VLLMInferenceAdapter(Inference, ModelsProtocolPrivate): if self.client is not None: return + if not self.config.url: + raise ValueError( + "You must provide a vLLM URL in the run.yaml file (or set the VLLM_URL environment variable)" + ) + log.info(f"Initializing vLLM client with base_url={self.config.url}") self.client = self._create_client() diff --git a/llama_stack/templates/starter/run.yaml b/llama_stack/templates/starter/run.yaml index 27400348a..46573848c 100644 --- a/llama_stack/templates/starter/run.yaml +++ b/llama_stack/templates/starter/run.yaml @@ -26,7 +26,7 @@ providers: - provider_id: ${env.ENABLE_VLLM:=__disabled__} provider_type: remote::vllm config: - url: ${env.VLLM_URL} + url: ${env.VLLM_URL:=} max_tokens: ${env.VLLM_MAX_TOKENS:=4096} api_token: ${env.VLLM_API_TOKEN:=fake} tls_verify: ${env.VLLM_TLS_VERIFY:=true} From dd303327f3a15d02a1300f20a800e968c42b5e0c Mon Sep 17 00:00:00 2001 From: Ashwin Bharambe Date: Fri, 18 Jul 2025 17:11:06 -0700 Subject: [PATCH 38/40] feat(ci): add a ci-tests distro (#2826) --- .github/workflows/integration-tests.yml | 6 +- .../workflows/integration-vector-io-tests.yml | 2 +- .github/workflows/providers-build.yml | 10 +- llama_stack/templates/ci-tests/__init__.py | 7 + llama_stack/templates/ci-tests/build.yaml | 65 + llama_stack/templates/ci-tests/ci_tests.py | 19 + llama_stack/templates/ci-tests/run.yaml | 1189 +++++++++++++++++ 7 files changed, 1289 insertions(+), 9 deletions(-) create mode 100644 llama_stack/templates/ci-tests/__init__.py create mode 100644 llama_stack/templates/ci-tests/build.yaml create mode 100644 llama_stack/templates/ci-tests/ci_tests.py create mode 100644 llama_stack/templates/ci-tests/run.yaml diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml index 0b6c1be3b..f8f01756d 100644 --- a/.github/workflows/integration-tests.yml +++ b/.github/workflows/integration-tests.yml @@ -71,7 +71,7 @@ jobs: - name: Build Llama Stack run: | - uv run llama stack build --template starter --image-type venv + uv run llama stack build --template ci-tests --image-type venv - name: Check Storage and Memory Available Before Tests if: ${{ always() }} @@ -92,9 +92,9 @@ jobs: shell: bash run: | if [ "${{ matrix.client-type }}" == "library" ]; then - stack_config="starter" + stack_config="ci-tests" else - stack_config="server:starter" + stack_config="server:ci-tests" fi uv run pytest -s -v tests/integration/${{ matrix.test-type }} --stack-config=${stack_config} \ -k "not(builtin_tool or safety_with_image or code_interpreter or test_rag)" \ diff --git a/.github/workflows/integration-vector-io-tests.yml b/.github/workflows/integration-vector-io-tests.yml index c11720b4b..ec236b33b 100644 --- a/.github/workflows/integration-vector-io-tests.yml +++ b/.github/workflows/integration-vector-io-tests.yml @@ -93,7 +93,7 @@ jobs: - name: Build Llama Stack run: | - uv run llama stack build --template starter --image-type venv + uv run llama stack build --template ci-tests --image-type venv - name: Check Storage and Memory Available Before Tests if: ${{ always() }} diff --git a/.github/workflows/providers-build.yml b/.github/workflows/providers-build.yml index 6de72cd60..392fddda6 100644 --- a/.github/workflows/providers-build.yml +++ b/.github/workflows/providers-build.yml @@ -97,9 +97,9 @@ jobs: - name: Build a single provider run: | - yq -i '.image_type = "container"' llama_stack/templates/starter/build.yaml - yq -i '.image_name = "test"' llama_stack/templates/starter/build.yaml - USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/starter/build.yaml + yq -i '.image_type = "container"' llama_stack/templates/ci-tests/build.yaml + yq -i '.image_name = "test"' llama_stack/templates/ci-tests/build.yaml + USE_COPY_NOT_MOUNT=true LLAMA_STACK_DIR=. uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml - name: Inspect the container image entrypoint run: | @@ -126,14 +126,14 @@ jobs: .image_type = "container" | .image_name = "ubi9-test" | .distribution_spec.container_image = "registry.access.redhat.com/ubi9:latest" - ' llama_stack/templates/starter/build.yaml + ' llama_stack/templates/ci-tests/build.yaml - name: Build dev container (UBI9) env: USE_COPY_NOT_MOUNT: "true" LLAMA_STACK_DIR: "." run: | - uv run llama stack build --config llama_stack/templates/starter/build.yaml + uv run llama stack build --config llama_stack/templates/ci-tests/build.yaml - name: Inspect UBI9 image run: | diff --git a/llama_stack/templates/ci-tests/__init__.py b/llama_stack/templates/ci-tests/__init__.py new file mode 100644 index 000000000..b309587f5 --- /dev/null +++ b/llama_stack/templates/ci-tests/__init__.py @@ -0,0 +1,7 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + +from .ci_tests import get_distribution_template # noqa: F401 diff --git a/llama_stack/templates/ci-tests/build.yaml b/llama_stack/templates/ci-tests/build.yaml new file mode 100644 index 000000000..0aed1d185 --- /dev/null +++ b/llama_stack/templates/ci-tests/build.yaml @@ -0,0 +1,65 @@ +version: 2 +distribution_spec: + description: CI tests for Llama Stack + providers: + inference: + - remote::cerebras + - remote::ollama + - remote::vllm + - remote::tgi + - remote::hf::serverless + - remote::hf::endpoint + - remote::fireworks + - remote::together + - remote::bedrock + - remote::databricks + - remote::nvidia + - remote::runpod + - remote::openai + - remote::anthropic + - remote::gemini + - remote::groq + - remote::fireworks-openai-compat + - remote::llama-openai-compat + - remote::together-openai-compat + - remote::groq-openai-compat + - remote::sambanova-openai-compat + - remote::cerebras-openai-compat + - remote::sambanova + - remote::passthrough + - inline::sentence-transformers + vector_io: + - inline::faiss + - inline::sqlite-vec + - inline::milvus + - remote::chromadb + - remote::pgvector + files: + - inline::localfs + safety: + - inline::llama-guard + agents: + - inline::meta-reference + telemetry: + - inline::meta-reference + post_training: + - inline::huggingface + eval: + - inline::meta-reference + datasetio: + - remote::huggingface + - inline::localfs + scoring: + - inline::basic + - inline::llm-as-judge + - inline::braintrust + tool_runtime: + - remote::brave-search + - remote::tavily-search + - inline::rag-runtime + - remote::model-context-protocol +image_type: conda +additional_pip_packages: +- aiosqlite +- asyncpg +- sqlalchemy[asyncio] diff --git a/llama_stack/templates/ci-tests/ci_tests.py b/llama_stack/templates/ci-tests/ci_tests.py new file mode 100644 index 000000000..49cb36e39 --- /dev/null +++ b/llama_stack/templates/ci-tests/ci_tests.py @@ -0,0 +1,19 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# All rights reserved. +# +# This source code is licensed under the terms described in the LICENSE file in +# the root directory of this source tree. + + +from llama_stack.templates.template import DistributionTemplate + +from ..starter.starter import get_distribution_template as get_starter_distribution_template + + +def get_distribution_template() -> DistributionTemplate: + template = get_starter_distribution_template() + name = "ci-tests" + template.name = name + template.description = "CI tests for Llama Stack" + + return template diff --git a/llama_stack/templates/ci-tests/run.yaml b/llama_stack/templates/ci-tests/run.yaml new file mode 100644 index 000000000..cc7378c97 --- /dev/null +++ b/llama_stack/templates/ci-tests/run.yaml @@ -0,0 +1,1189 @@ +version: 2 +image_name: ci-tests +apis: +- agents +- datasetio +- eval +- files +- inference +- post_training +- safety +- scoring +- telemetry +- tool_runtime +- vector_io +providers: + inference: + - provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_type: remote::cerebras + config: + base_url: https://api.cerebras.ai + api_key: ${env.CEREBRAS_API_KEY} + - provider_id: ${env.ENABLE_OLLAMA:=__disabled__} + provider_type: remote::ollama + config: + url: ${env.OLLAMA_URL:=http://localhost:11434} + - provider_id: ${env.ENABLE_VLLM:=__disabled__} + provider_type: remote::vllm + config: + url: ${env.VLLM_URL:=} + max_tokens: ${env.VLLM_MAX_TOKENS:=4096} + api_token: ${env.VLLM_API_TOKEN:=fake} + tls_verify: ${env.VLLM_TLS_VERIFY:=true} + - provider_id: ${env.ENABLE_TGI:=__disabled__} + provider_type: remote::tgi + config: + url: ${env.TGI_URL} + - provider_id: ${env.ENABLE_HF_SERVERLESS:=__disabled__} + provider_type: remote::hf::serverless + config: + huggingface_repo: ${env.INFERENCE_MODEL} + api_token: ${env.HF_API_TOKEN} + - provider_id: ${env.ENABLE_HF_ENDPOINT:=__disabled__} + provider_type: remote::hf::endpoint + config: + endpoint_name: ${env.INFERENCE_ENDPOINT_NAME} + api_token: ${env.HF_API_TOKEN} + - provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_type: remote::fireworks + config: + url: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY} + - provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_type: remote::together + config: + url: https://api.together.xyz/v1 + api_key: ${env.TOGETHER_API_KEY} + - provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_type: remote::bedrock + config: {} + - provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} + provider_type: remote::databricks + config: + url: ${env.DATABRICKS_URL} + api_token: ${env.DATABRICKS_API_TOKEN} + - provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_type: remote::nvidia + config: + url: ${env.NVIDIA_BASE_URL:=https://integrate.api.nvidia.com} + api_key: ${env.NVIDIA_API_KEY:=} + append_api_version: ${env.NVIDIA_APPEND_API_VERSION:=True} + - provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_type: remote::runpod + config: + url: ${env.RUNPOD_URL:=} + api_token: ${env.RUNPOD_API_TOKEN} + - provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_type: remote::openai + config: + api_key: ${env.OPENAI_API_KEY} + - provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_type: remote::anthropic + config: + api_key: ${env.ANTHROPIC_API_KEY} + - provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_type: remote::gemini + config: + api_key: ${env.GEMINI_API_KEY} + - provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_type: remote::groq + config: + url: https://api.groq.com + api_key: ${env.GROQ_API_KEY} + - provider_id: ${env.ENABLE_FIREWORKS_OPENAI_COMPAT:=__disabled__} + provider_type: remote::fireworks-openai-compat + config: + openai_compat_api_base: https://api.fireworks.ai/inference/v1 + api_key: ${env.FIREWORKS_API_KEY} + - provider_id: ${env.ENABLE_LLAMA_OPENAI_COMPAT:=__disabled__} + provider_type: remote::llama-openai-compat + config: + openai_compat_api_base: https://api.llama.com/compat/v1/ + api_key: ${env.LLAMA_API_KEY} + - provider_id: ${env.ENABLE_TOGETHER_OPENAI_COMPAT:=__disabled__} + provider_type: remote::together-openai-compat + config: + openai_compat_api_base: https://api.together.xyz/v1 + api_key: ${env.TOGETHER_API_KEY} + - provider_id: ${env.ENABLE_GROQ_OPENAI_COMPAT:=__disabled__} + provider_type: remote::groq-openai-compat + config: + openai_compat_api_base: https://api.groq.com/openai/v1 + api_key: ${env.GROQ_API_KEY} + - provider_id: ${env.ENABLE_SAMBANOVA_OPENAI_COMPAT:=__disabled__} + provider_type: remote::sambanova-openai-compat + config: + openai_compat_api_base: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY} + - provider_id: ${env.ENABLE_CEREBRAS_OPENAI_COMPAT:=__disabled__} + provider_type: remote::cerebras-openai-compat + config: + openai_compat_api_base: https://api.cerebras.ai/v1 + api_key: ${env.CEREBRAS_API_KEY} + - provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_type: remote::sambanova + config: + url: https://api.sambanova.ai/v1 + api_key: ${env.SAMBANOVA_API_KEY} + - provider_id: ${env.ENABLE_PASSTHROUGH:=__disabled__} + provider_type: remote::passthrough + config: + url: ${env.PASSTHROUGH_URL} + api_key: ${env.PASSTHROUGH_API_KEY} + - provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} + provider_type: inline::sentence-transformers + config: {} + vector_io: + - provider_id: ${env.ENABLE_FAISS:=faiss} + provider_type: inline::faiss + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/faiss_store.db + - provider_id: ${env.ENABLE_SQLITE_VEC:=__disabled__} + provider_type: inline::sqlite-vec + config: + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/sqlite_vec_registry.db + - provider_id: ${env.ENABLE_MILVUS:=__disabled__} + provider_type: inline::milvus + config: + db_path: ${env.MILVUS_DB_PATH:=~/.llama/distributions/starter}/milvus.db + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/milvus_registry.db + - provider_id: ${env.ENABLE_CHROMADB:=__disabled__} + provider_type: remote::chromadb + config: + url: ${env.CHROMADB_URL:=} + - provider_id: ${env.ENABLE_PGVECTOR:=__disabled__} + provider_type: remote::pgvector + config: + host: ${env.PGVECTOR_HOST:=localhost} + port: ${env.PGVECTOR_PORT:=5432} + db: ${env.PGVECTOR_DB:=} + user: ${env.PGVECTOR_USER:=} + password: ${env.PGVECTOR_PASSWORD:=} + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/pgvector_registry.db + files: + - provider_id: meta-reference-files + provider_type: inline::localfs + config: + storage_dir: ${env.FILES_STORAGE_DIR:=~/.llama/distributions/starter/files} + metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/starter}/files_metadata.db + safety: + - provider_id: llama-guard + provider_type: inline::llama-guard + config: + excluded_categories: [] + agents: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + persistence_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/agents_store.db + responses_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/responses_store.db + telemetry: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + service_name: "${env.OTEL_SERVICE_NAME:=\u200B}" + sinks: ${env.TELEMETRY_SINKS:=console,sqlite} + sqlite_db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/trace_store.db + otel_exporter_otlp_endpoint: ${env.OTEL_EXPORTER_OTLP_ENDPOINT:=} + post_training: + - provider_id: huggingface + provider_type: inline::huggingface + config: + checkpoint_format: huggingface + distributed_backend: null + device: cpu + eval: + - provider_id: meta-reference + provider_type: inline::meta-reference + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/meta_reference_eval.db + datasetio: + - provider_id: huggingface + provider_type: remote::huggingface + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/huggingface_datasetio.db + - provider_id: localfs + provider_type: inline::localfs + config: + kvstore: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/localfs_datasetio.db + scoring: + - provider_id: basic + provider_type: inline::basic + config: {} + - provider_id: llm-as-judge + provider_type: inline::llm-as-judge + config: {} + - provider_id: braintrust + provider_type: inline::braintrust + config: + openai_api_key: ${env.OPENAI_API_KEY:=} + tool_runtime: + - provider_id: brave-search + provider_type: remote::brave-search + config: + api_key: ${env.BRAVE_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: tavily-search + provider_type: remote::tavily-search + config: + api_key: ${env.TAVILY_SEARCH_API_KEY:=} + max_results: 3 + - provider_id: rag-runtime + provider_type: inline::rag-runtime + config: {} + - provider_id: model-context-protocol + provider_type: remote::model-context-protocol + config: {} +metadata_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/registry.db +inference_store: + type: sqlite + db_path: ${env.SQLITE_STORE_DIR:=~/.llama/distributions/ci-tests}/inference_store.db +models: +- metadata: + embedding_dimension: 384 + model_id: all-MiniLM-L6-v2 + provider_id: ${env.ENABLE_SENTENCE_TRANSFORMERS:=sentence-transformers} + model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama3.1-8b + provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_model_id: llama3.1-8b + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct + provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_model_id: llama3.1-8b + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-3.3-70b + provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_model_id: llama-3.3-70b + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct + provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_model_id: llama-3.3-70b + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/llama-4-scout-17b-16e-instruct + provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_model_id: llama-4-scout-17b-16e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_CEREBRAS:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: ${env.ENABLE_CEREBRAS:=__disabled__} + provider_model_id: llama-4-scout-17b-16e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_INFERENCE_MODEL:=__disabled__} + provider_id: ${env.ENABLE_OLLAMA:=__disabled__} + provider_model_id: ${env.OLLAMA_INFERENCE_MODEL:=__disabled__} + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__} + provider_id: ${env.ENABLE_OLLAMA:=__disabled__} + provider_model_id: ${env.SAFETY_MODEL:=__disabled__} + model_type: llm +- metadata: + embedding_dimension: ${env.OLLAMA_EMBEDDING_DIMENSION:=384} + model_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.OLLAMA_EMBEDDING_MODEL:=__disabled__} + provider_id: ${env.ENABLE_OLLAMA:=__disabled__} + provider_model_id: ${env.OLLAMA_EMBEDDING_MODEL:=__disabled__} + model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_VLLM:=__disabled__}/${env.VLLM_INFERENCE_MODEL:=__disabled__} + provider_id: ${env.ENABLE_VLLM:=__disabled__} + provider_model_id: ${env.VLLM_INFERENCE_MODEL:=__disabled__} + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-v3p1-8b-instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p1-8b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-v3p1-70b-instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p1-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-v3p1-405b-instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p1-405b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-v3p2-3b-instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p2-3b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-v3p2-11b-vision-instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-v3p2-90b-vision-instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-v3p3-70b-instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-v3p3-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-scout-instruct-basic + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama4-scout-instruct-basic + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama4-maverick-instruct-basic + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama4-maverick-instruct-basic + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 8192 + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/nomic-ai/nomic-embed-text-v1.5 + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: nomic-ai/nomic-embed-text-v1.5 + model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-8b + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-guard-3-8b + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-8B + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-guard-3-8b + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/accounts/fireworks/models/llama-guard-3-11b-vision + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_FIREWORKS:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision + provider_id: ${env.ENABLE_FIREWORKS:=__disabled__} + provider_model_id: accounts/fireworks/models/llama-guard-3-11b-vision + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Meta-Llama-3.1-8B-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Meta-Llama-3.1-70B-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Meta-Llama-3.1-405B-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct-Turbo + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-3.2-3B-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-3.2-3B-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-3.2-11B-Vision-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-3.2-90B-Vision-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct-Turbo + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-3.3-70B-Instruct-Turbo + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 8192 + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/togethercomputer/m2-bert-80M-8k-retrieval + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: togethercomputer/m2-bert-80M-8k-retrieval + model_type: embedding +- metadata: + embedding_dimension: 768 + context_length: 32768 + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/togethercomputer/m2-bert-80M-32k-retrieval + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: togethercomputer/m2-bert-80M-32k-retrieval + model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/together/meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-4-Scout-17B-16E-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/together/meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-4-Maverick-17B-128E-Instruct-FP8 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-Guard-3-8B + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-8B + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-Guard-3-8B + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision-Turbo + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_TOGETHER:=__disabled__}/meta-llama/Llama-Guard-3-11B-Vision + provider_id: ${env.ENABLE_TOGETHER:=__disabled__} + provider_model_id: meta-llama/Llama-Guard-3-11B-Vision-Turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-8b-instruct-v1:0 + provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_model_id: meta.llama3-1-8b-instruct-v1:0 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct + provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_model_id: meta.llama3-1-8b-instruct-v1:0 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-70b-instruct-v1:0 + provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_model_id: meta.llama3-1-70b-instruct-v1:0 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct + provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_model_id: meta.llama3-1-70b-instruct-v1:0 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta.llama3-1-405b-instruct-v1:0 + provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_model_id: meta.llama3-1-405b-instruct-v1:0 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_BEDROCK:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: ${env.ENABLE_BEDROCK:=__disabled__} + provider_model_id: meta.llama3-1-405b-instruct-v1:0 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-70b-instruct + provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} + provider_model_id: databricks-meta-llama-3-1-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct + provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} + provider_model_id: databricks-meta-llama-3-1-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/databricks-meta-llama-3-1-405b-instruct + provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} + provider_model_id: databricks-meta-llama-3-1-405b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_DATABRICKS:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: ${env.ENABLE_DATABRICKS:=__disabled__} + provider_model_id: databricks-meta-llama-3-1-405b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-8b-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama3-8b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-8B-Instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama3-8b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama3-70b-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama3-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3-70B-Instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama3-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-8b-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.1-8b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.1-8b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-70b-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.1-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-70B-Instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.1-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.1-405b-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.1-405b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.1-405b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-1b-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.2-1b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-1B-Instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.2-1b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-3b-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.2-3b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.2-3b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-11b-vision-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.2-11b-vision-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.2-90b-vision-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.2-90b-vision-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta/llama-3.3-70b-instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.3-70b-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: meta/llama-3.3-70b-instruct + model_type: llm +- metadata: + embedding_dimension: 2048 + context_length: 8192 + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/llama-3.2-nv-embedqa-1b-v2 + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: nvidia/llama-3.2-nv-embedqa-1b-v2 + model_type: embedding +- metadata: + embedding_dimension: 1024 + context_length: 512 + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-e5-v5 + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: nvidia/nv-embedqa-e5-v5 + model_type: embedding +- metadata: + embedding_dimension: 4096 + context_length: 512 + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/nvidia/nv-embedqa-mistral-7b-v2 + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: nvidia/nv-embedqa-mistral-7b-v2 + model_type: embedding +- metadata: + embedding_dimension: 1024 + context_length: 512 + model_id: ${env.ENABLE_NVIDIA:=__disabled__}/snowflake/arctic-embed-l + provider_id: ${env.ENABLE_NVIDIA:=__disabled__} + provider_model_id: snowflake/arctic-embed-l + model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-8B + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-70B + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp8 + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-405B:bf16-mp8 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-405B + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B:bf16-mp16 + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-405B:bf16-mp16 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-8B-Instruct + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-8B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-70B-Instruct + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-70B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp8 + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-405B-Instruct:bf16-mp8 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-405B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.1-405B-Instruct:bf16-mp16 + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.1-405B-Instruct:bf16-mp16 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-1B + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.2-1B + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_RUNPOD:=__disabled__}/Llama3.2-3B + provider_id: ${env.ENABLE_RUNPOD:=__disabled__} + provider_model_id: Llama3.2-3B + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: openai/gpt-4o + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/gpt-4o-mini + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: openai/gpt-4o-mini + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/chatgpt-4o-latest + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: openai/chatgpt-4o-latest + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-3.5-turbo-0125 + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: gpt-3.5-turbo-0125 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-3.5-turbo + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: gpt-3.5-turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-3.5-turbo-instruct + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: gpt-3.5-turbo-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-4 + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: gpt-4 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-4-turbo + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: gpt-4-turbo + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-4o + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: gpt-4o + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-4o-2024-08-06 + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: gpt-4o-2024-08-06 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-4o-mini + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: gpt-4o-mini + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/gpt-4o-audio-preview + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: gpt-4o-audio-preview + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/chatgpt-4o-latest + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: chatgpt-4o-latest + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/o1 + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: o1 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/o1-mini + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: o1-mini + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/o3-mini + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: o3-mini + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_OPENAI:=__disabled__}/o4-mini + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: o4-mini + model_type: llm +- metadata: + embedding_dimension: 1536 + context_length: 8192 + model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-small + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: openai/text-embedding-3-small + model_type: embedding +- metadata: + embedding_dimension: 3072 + context_length: 8192 + model_id: ${env.ENABLE_OPENAI:=__disabled__}/openai/text-embedding-3-large + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: openai/text-embedding-3-large + model_type: embedding +- metadata: + embedding_dimension: 1536 + context_length: 8192 + model_id: ${env.ENABLE_OPENAI:=__disabled__}/text-embedding-3-small + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: text-embedding-3-small + model_type: embedding +- metadata: + embedding_dimension: 3072 + context_length: 8192 + model_id: ${env.ENABLE_OPENAI:=__disabled__}/text-embedding-3-large + provider_id: ${env.ENABLE_OPENAI:=__disabled__} + provider_model_id: text-embedding-3-large + model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_ANTHROPIC:=__disabled__}/anthropic/claude-3-5-sonnet-latest + provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_model_id: anthropic/claude-3-5-sonnet-latest + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_ANTHROPIC:=__disabled__}/anthropic/claude-3-7-sonnet-latest + provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_model_id: anthropic/claude-3-7-sonnet-latest + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_ANTHROPIC:=__disabled__}/anthropic/claude-3-5-haiku-latest + provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_model_id: anthropic/claude-3-5-haiku-latest + model_type: llm +- metadata: + embedding_dimension: 1024 + context_length: 32000 + model_id: ${env.ENABLE_ANTHROPIC:=__disabled__}/anthropic/voyage-3 + provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_model_id: anthropic/voyage-3 + model_type: embedding +- metadata: + embedding_dimension: 512 + context_length: 32000 + model_id: ${env.ENABLE_ANTHROPIC:=__disabled__}/anthropic/voyage-3-lite + provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_model_id: anthropic/voyage-3-lite + model_type: embedding +- metadata: + embedding_dimension: 1024 + context_length: 32000 + model_id: ${env.ENABLE_ANTHROPIC:=__disabled__}/anthropic/voyage-code-3 + provider_id: ${env.ENABLE_ANTHROPIC:=__disabled__} + provider_model_id: anthropic/voyage-code-3 + model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_GEMINI:=__disabled__}/gemini/gemini-1.5-flash + provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_model_id: gemini/gemini-1.5-flash + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GEMINI:=__disabled__}/gemini/gemini-1.5-pro + provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_model_id: gemini/gemini-1.5-pro + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GEMINI:=__disabled__}/gemini/gemini-2.0-flash + provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_model_id: gemini/gemini-2.0-flash + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GEMINI:=__disabled__}/gemini/gemini-2.5-flash + provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_model_id: gemini/gemini-2.5-flash + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GEMINI:=__disabled__}/gemini/gemini-2.5-pro + provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_model_id: gemini/gemini-2.5-pro + model_type: llm +- metadata: + embedding_dimension: 768 + context_length: 2048 + model_id: ${env.ENABLE_GEMINI:=__disabled__}/gemini/text-embedding-004 + provider_id: ${env.ENABLE_GEMINI:=__disabled__} + provider_model_id: gemini/text-embedding-004 + model_type: embedding +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/llama3-8b-8192 + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama3-8b-8192 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama3-8b-8192 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/llama-3.1-8b-instant + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama-3.1-8b-instant + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/llama3-70b-8192 + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama3-70b-8192 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/meta-llama/Llama-3-70B-Instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama3-70b-8192 + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/llama-3.3-70b-versatile + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama-3.3-70b-versatile + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama-3.3-70b-versatile + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/llama-3.2-3b-preview + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama-3.2-3b-preview + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama-3.2-3b-preview + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/llama-4-scout-17b-16e-instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama-4-scout-17b-16e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama-4-scout-17b-16e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/meta-llama/llama-4-scout-17b-16e-instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/meta-llama/llama-4-scout-17b-16e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/llama-4-maverick-17b-128e-instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama-4-maverick-17b-128e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/llama-4-maverick-17b-128e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/groq/meta-llama/llama-4-maverick-17b-128e-instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_GROQ:=__disabled__}/meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: ${env.ENABLE_GROQ:=__disabled__} + provider_model_id: groq/meta-llama/llama-4-maverick-17b-128e-instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Meta-Llama-3.1-8B-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.1-8B-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.1-8B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Meta-Llama-3.1-405B-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.1-405B-Instruct-FP8 + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.1-405B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Meta-Llama-3.2-1B-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-1B-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.2-1B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Meta-Llama-3.2-3B-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-3B-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.2-3B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Meta-Llama-3.3-70B-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.3-70B-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-3.3-70B-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-3.2-11B-Vision-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-11B-Vision-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Llama-3.2-11B-Vision-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-3.2-90B-Vision-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-3.2-90B-Vision-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Llama-3.2-90B-Vision-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-4-Scout-17B-16E-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-4-Scout-17B-16E-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Llama-4-Scout-17B-16E-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Llama-4-Maverick-17B-128E-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-4-Maverick-17B-128E-Instruct + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Llama-4-Maverick-17B-128E-Instruct + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/sambanova/Meta-Llama-Guard-3-8B + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-Guard-3-8B + model_type: llm +- metadata: {} + model_id: ${env.ENABLE_SAMBANOVA:=__disabled__}/meta-llama/Llama-Guard-3-8B + provider_id: ${env.ENABLE_SAMBANOVA:=__disabled__} + provider_model_id: sambanova/Meta-Llama-Guard-3-8B + model_type: llm +shields: +- shield_id: ${env.SAFETY_MODEL:=__disabled__} + provider_shield_id: ${env.ENABLE_OLLAMA:=__disabled__}/${env.SAFETY_MODEL:=__disabled__} +vector_dbs: [] +datasets: [] +scoring_fns: [] +benchmarks: [] +tool_groups: +- toolgroup_id: builtin::websearch + provider_id: tavily-search +- toolgroup_id: builtin::rag + provider_id: rag-runtime +server: + port: 8321 From 0a6e588f6881ecccc2b918c342138127912ee7e5 Mon Sep 17 00:00:00 2001 From: ehhuang Date: Fri, 18 Jul 2025 19:11:01 -0700 Subject: [PATCH 39/40] feat: enable auth for LocalFS Files Provider (#2773) # What does this PR do? Supports authentication for LocalFS Files provider. closes https://github.com/meta-llama/llama-stack/issues/2760 ## Test Plan CI. added tests. --- .../inline/files/localfs/__init__.py | 6 +- .../providers/inline/files/localfs/files.py | 18 +- tests/integration/files/test_files.py | 217 ++++++++++++++++++ tests/unit/files/test_files.py | 3 +- 4 files changed, 233 insertions(+), 11 deletions(-) diff --git a/llama_stack/providers/inline/files/localfs/__init__.py b/llama_stack/providers/inline/files/localfs/__init__.py index 7a04e61c6..71664efad 100644 --- a/llama_stack/providers/inline/files/localfs/__init__.py +++ b/llama_stack/providers/inline/files/localfs/__init__.py @@ -6,7 +6,7 @@ from typing import Any -from llama_stack.distribution.datatypes import Api +from llama_stack.distribution.datatypes import AccessRule, Api from .config import LocalfsFilesImplConfig from .files import LocalfsFilesImpl @@ -14,7 +14,7 @@ from .files import LocalfsFilesImpl __all__ = ["LocalfsFilesImpl", "LocalfsFilesImplConfig"] -async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any]): - impl = LocalfsFilesImpl(config) +async def get_provider_impl(config: LocalfsFilesImplConfig, deps: dict[Api, Any], policy: list[AccessRule]): + impl = LocalfsFilesImpl(config, policy) await impl.initialize() return impl diff --git a/llama_stack/providers/inline/files/localfs/files.py b/llama_stack/providers/inline/files/localfs/files.py index bdf8c42c7..433762c5a 100644 --- a/llama_stack/providers/inline/files/localfs/files.py +++ b/llama_stack/providers/inline/files/localfs/files.py @@ -19,16 +19,19 @@ from llama_stack.apis.files import ( OpenAIFileObject, OpenAIFilePurpose, ) +from llama_stack.distribution.datatypes import AccessRule from llama_stack.providers.utils.sqlstore.api import ColumnDefinition, ColumnType -from llama_stack.providers.utils.sqlstore.sqlstore import SqlStore, sqlstore_impl +from llama_stack.providers.utils.sqlstore.authorized_sqlstore import AuthorizedSqlStore +from llama_stack.providers.utils.sqlstore.sqlstore import sqlstore_impl from .config import LocalfsFilesImplConfig class LocalfsFilesImpl(Files): - def __init__(self, config: LocalfsFilesImplConfig) -> None: + def __init__(self, config: LocalfsFilesImplConfig, policy: list[AccessRule]) -> None: self.config = config - self.sql_store: SqlStore | None = None + self.policy = policy + self.sql_store: AuthorizedSqlStore | None = None async def initialize(self) -> None: """Initialize the files provider by setting up storage directory and metadata database.""" @@ -37,7 +40,7 @@ class LocalfsFilesImpl(Files): storage_path.mkdir(parents=True, exist_ok=True) # Initialize SQL store for metadata - self.sql_store = sqlstore_impl(self.config.metadata_store) + self.sql_store = AuthorizedSqlStore(sqlstore_impl(self.config.metadata_store)) await self.sql_store.create_table( "openai_files", { @@ -126,6 +129,7 @@ class LocalfsFilesImpl(Files): paginated_result = await self.sql_store.fetch_all( table="openai_files", + policy=self.policy, where=where_conditions if where_conditions else None, order_by=[("created_at", order.value)], cursor=("id", after) if after else None, @@ -156,7 +160,7 @@ class LocalfsFilesImpl(Files): if not self.sql_store: raise RuntimeError("Files provider not initialized") - row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) if not row: raise ValueError(f"File with id {file_id} not found") @@ -174,7 +178,7 @@ class LocalfsFilesImpl(Files): if not self.sql_store: raise RuntimeError("Files provider not initialized") - row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) if not row: raise ValueError(f"File with id {file_id} not found") @@ -197,7 +201,7 @@ class LocalfsFilesImpl(Files): raise RuntimeError("Files provider not initialized") # Get file metadata - row = await self.sql_store.fetch_one("openai_files", where={"id": file_id}) + row = await self.sql_store.fetch_one("openai_files", policy=self.policy, where={"id": file_id}) if not row: raise ValueError(f"File with id {file_id} not found") diff --git a/tests/integration/files/test_files.py b/tests/integration/files/test_files.py index 8547ef2f3..118a751f0 100644 --- a/tests/integration/files/test_files.py +++ b/tests/integration/files/test_files.py @@ -5,10 +5,12 @@ # the root directory of this source tree. from io import BytesIO +from unittest.mock import patch import pytest from openai import OpenAI +from llama_stack.distribution.datatypes import User from llama_stack.distribution.library_client import LlamaStackAsLibraryClient @@ -61,3 +63,218 @@ def test_openai_client_basic_operations(compat_client, client_with_models): except Exception: pass raise e + + +@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +def test_files_authentication_isolation(mock_get_authenticated_user, compat_client, client_with_models): + """Test that users can only access their own files.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI): + pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient") + if not isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)") + + client = compat_client + + # Create two test users + user1 = User("user1", {"roles": ["user"], "teams": ["team-a"]}) + user2 = User("user2", {"roles": ["user"], "teams": ["team-b"]}) + + # User 1 uploads a file + mock_get_authenticated_user.return_value = user1 + test_content_1 = b"User 1's private file content" + + with BytesIO(test_content_1) as file_buffer: + file_buffer.name = "user1_file.txt" + user1_file = client.files.create(file=file_buffer, purpose="assistants") + + # User 2 uploads a file + mock_get_authenticated_user.return_value = user2 + test_content_2 = b"User 2's private file content" + + with BytesIO(test_content_2) as file_buffer: + file_buffer.name = "user2_file.txt" + user2_file = client.files.create(file=file_buffer, purpose="assistants") + + try: + # User 1 can see their own file + mock_get_authenticated_user.return_value = user1 + user1_files = client.files.list() + user1_file_ids = [f.id for f in user1_files.data] + assert user1_file.id in user1_file_ids + assert user2_file.id not in user1_file_ids # Cannot see user2's file + + # User 2 can see their own file + mock_get_authenticated_user.return_value = user2 + user2_files = client.files.list() + user2_file_ids = [f.id for f in user2_files.data] + assert user2_file.id in user2_file_ids + assert user1_file.id not in user2_file_ids # Cannot see user1's file + + # User 1 can retrieve their own file + mock_get_authenticated_user.return_value = user1 + retrieved_file = client.files.retrieve(user1_file.id) + assert retrieved_file.id == user1_file.id + + # User 1 cannot retrieve user2's file + mock_get_authenticated_user.return_value = user1 + with pytest.raises(ValueError, match="not found"): + client.files.retrieve(user2_file.id) + + # User 1 can access their file content + mock_get_authenticated_user.return_value = user1 + content_response = client.files.content(user1_file.id) + if isinstance(content_response, str): + content = bytes(content_response, "utf-8") + else: + content = content_response.content + assert content == test_content_1 + + # User 1 cannot access user2's file content + mock_get_authenticated_user.return_value = user1 + with pytest.raises(ValueError, match="not found"): + client.files.content(user2_file.id) + + # User 1 can delete their own file + mock_get_authenticated_user.return_value = user1 + delete_response = client.files.delete(user1_file.id) + assert delete_response.deleted is True + + # User 1 cannot delete user2's file + mock_get_authenticated_user.return_value = user1 + with pytest.raises(ValueError, match="not found"): + client.files.delete(user2_file.id) + + # User 2 can still access their file after user1's file is deleted + mock_get_authenticated_user.return_value = user2 + retrieved_file = client.files.retrieve(user2_file.id) + assert retrieved_file.id == user2_file.id + + # Cleanup user2's file + mock_get_authenticated_user.return_value = user2 + client.files.delete(user2_file.id) + + except Exception as e: + # Cleanup in case of failure + try: + mock_get_authenticated_user.return_value = user1 + client.files.delete(user1_file.id) + except Exception: + pass + try: + mock_get_authenticated_user.return_value = user2 + client.files.delete(user2_file.id) + except Exception: + pass + raise e + + +@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +def test_files_authentication_shared_attributes(mock_get_authenticated_user, compat_client, client_with_models): + """Test access control with users having identical attributes.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI): + pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient") + if not isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)") + + client = compat_client + + # Create users with identical attributes (required for default policy) + user_a = User("user-a", {"roles": ["user"], "teams": ["shared-team"]}) + user_b = User("user-b", {"roles": ["user"], "teams": ["shared-team"]}) + + # User A uploads a file + mock_get_authenticated_user.return_value = user_a + test_content = b"Shared attributes file content" + + with BytesIO(test_content) as file_buffer: + file_buffer.name = "shared_attributes_file.txt" + shared_file = client.files.create(file=file_buffer, purpose="assistants") + + try: + # User B with identical attributes can access the file + mock_get_authenticated_user.return_value = user_b + files_list = client.files.list() + file_ids = [f.id for f in files_list.data] + + # User B should be able to see the file due to identical attributes + assert shared_file.id in file_ids + + # User B can retrieve file info + retrieved_file = client.files.retrieve(shared_file.id) + assert retrieved_file.id == shared_file.id + + # User B can access file content + content_response = client.files.content(shared_file.id) + if isinstance(content_response, str): + content = bytes(content_response, "utf-8") + else: + content = content_response.content + assert content == test_content + + # Cleanup + mock_get_authenticated_user.return_value = user_a + client.files.delete(shared_file.id) + + except Exception as e: + # Cleanup in case of failure + try: + mock_get_authenticated_user.return_value = user_a + client.files.delete(shared_file.id) + except Exception: + pass + try: + mock_get_authenticated_user.return_value = user_b + client.files.delete(shared_file.id) + except Exception: + pass + raise e + + +@patch("llama_stack.providers.utils.sqlstore.authorized_sqlstore.get_authenticated_user") +def test_files_authentication_anonymous_access(mock_get_authenticated_user, compat_client, client_with_models): + """Test anonymous user behavior when no authentication is present.""" + if isinstance(client_with_models, LlamaStackAsLibraryClient) and isinstance(compat_client, OpenAI): + pytest.skip("OpenAI files are not supported when testing with LlamaStackAsLibraryClient") + if not isinstance(client_with_models, LlamaStackAsLibraryClient): + pytest.skip("Authentication tests require LlamaStackAsLibraryClient (library mode)") + + client = compat_client + + # Simulate anonymous user (no authentication) + mock_get_authenticated_user.return_value = None + + test_content = b"Anonymous file content" + + with BytesIO(test_content) as file_buffer: + file_buffer.name = "anonymous_file.txt" + anonymous_file = client.files.create(file=file_buffer, purpose="assistants") + + try: + # Anonymous user should be able to access their own uploaded file + files_list = client.files.list() + file_ids = [f.id for f in files_list.data] + assert anonymous_file.id in file_ids + + # Can retrieve file info + retrieved_file = client.files.retrieve(anonymous_file.id) + assert retrieved_file.id == anonymous_file.id + + # Can access file content + content_response = client.files.content(anonymous_file.id) + if isinstance(content_response, str): + content = bytes(content_response, "utf-8") + else: + content = content_response.content + assert content == test_content + + # Can delete the file + delete_response = client.files.delete(anonymous_file.id) + assert delete_response.deleted is True + + except Exception as e: + # Cleanup in case of failure + try: + client.files.delete(anonymous_file.id) + except Exception: + pass + raise e diff --git a/tests/unit/files/test_files.py b/tests/unit/files/test_files.py index 785077e91..c3ec25116 100644 --- a/tests/unit/files/test_files.py +++ b/tests/unit/files/test_files.py @@ -9,6 +9,7 @@ import pytest from llama_stack.apis.common.responses import Order from llama_stack.apis.files import OpenAIFilePurpose +from llama_stack.distribution.access_control.access_control import default_policy from llama_stack.providers.inline.files.localfs import ( LocalfsFilesImpl, LocalfsFilesImplConfig, @@ -38,7 +39,7 @@ async def files_provider(tmp_path): storage_dir=storage_dir.as_posix(), metadata_store=SqliteSqlStoreConfig(db_path=db_path.as_posix()) ) - provider = LocalfsFilesImpl(config) + provider = LocalfsFilesImpl(config, default_policy()) await provider.initialize() yield provider From 28956f9447d445b9301e4b94767d3fa6c9f9e00a Mon Sep 17 00:00:00 2001 From: "dependabot[bot]" <49699333+dependabot[bot]@users.noreply.github.com> Date: Sat, 19 Jul 2025 21:10:35 -0500 Subject: [PATCH 40/40] chore(github-deps): bump astral-sh/setup-uv from 6.3.1 to 6.4.1 (#2827) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Bumps [astral-sh/setup-uv](https://github.com/astral-sh/setup-uv) from 6.3.1 to 6.4.1.
Release notes

Sourced from astral-sh/setup-uv's releases.

v6.4.1 🌈 Hotfix: Ignore deps starting with uv when finding uv version

Changes

Thank you @​phpmypython for raising a PR to fix this issue!

🐛 Bug fixes

v6.4.0 🌈 Add input version-file

Changes

You can now use the version-file input to specify a file that contains the version of uv to install. This can either be a pyproject.toml or uv.toml file which defines a required-version or uv defined as a dependency in pyproject.toml or requirements.txt.

- name: Install uv based on the version defined
in requirements.txt
  uses: astral-sh/setup-uv@v6
  with:
    version-file: "requirements.txt"

🚀 Enhancements

🧰 Maintenance

📚 Documentation

⬆️ Dependency updates

Commits

[![Dependabot compatibility score](https://dependabot-badges.githubapp.com/badges/compatibility_score?dependency-name=astral-sh/setup-uv&package-manager=github_actions&previous-version=6.3.1&new-version=6.4.1)](https://docs.github.com/en/github/managing-security-vulnerabilities/about-dependabot-security-updates#about-compatibility-scores) Dependabot will resolve any conflicts with this PR as long as you don't alter it yourself. You can also trigger a rebase manually by commenting `@dependabot rebase`. [//]: # (dependabot-automerge-start) [//]: # (dependabot-automerge-end) ---
Dependabot commands and options
You can trigger Dependabot actions by commenting on this PR: - `@dependabot rebase` will rebase this PR - `@dependabot recreate` will recreate this PR, overwriting any edits that have been made to it - `@dependabot merge` will merge this PR after your CI passes on it - `@dependabot squash and merge` will squash and merge this PR after your CI passes on it - `@dependabot cancel merge` will cancel a previously requested merge and block automerging - `@dependabot reopen` will reopen this PR if it is closed - `@dependabot close` will close this PR and stop Dependabot recreating it. You can achieve the same result by closing it manually - `@dependabot show ignore conditions` will show all of the ignore conditions of the specified dependency - `@dependabot ignore this major version` will close this PR and stop Dependabot creating any more for this major version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this minor version` will close this PR and stop Dependabot creating any more for this minor version (unless you reopen the PR or upgrade to it yourself) - `@dependabot ignore this dependency` will close this PR and stop Dependabot creating any more for this dependency (unless you reopen the PR or upgrade to it yourself)
Signed-off-by: dependabot[bot] Co-authored-by: dependabot[bot] <49699333+dependabot[bot]@users.noreply.github.com> --- .github/workflows/python-build-test.yml | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/.github/workflows/python-build-test.yml b/.github/workflows/python-build-test.yml index 63ddd9b54..efd1f2cc9 100644 --- a/.github/workflows/python-build-test.yml +++ b/.github/workflows/python-build-test.yml @@ -20,7 +20,7 @@ jobs: uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2 - name: Install uv - uses: astral-sh/setup-uv@bd01e18f51369d5a26f1651c3cb451d3417e3bba # v6.3.1 + uses: astral-sh/setup-uv@7edac99f961f18b581bbd960d59d049f04c0002f # v6.4.1 with: python-version: ${{ matrix.python-version }} activate-environment: true