diff --git a/.coveragerc b/.coveragerc
index e16c2e461..d4925275f 100644
--- a/.coveragerc
+++ b/.coveragerc
@@ -4,3 +4,9 @@ omit =
*/llama_stack/providers/*
*/llama_stack/templates/*
.venv/*
+ */llama_stack/cli/scripts/*
+ */llama_stack/ui/*
+ */llama_stack/distribution/ui/*
+ */llama_stack/strong_typing/*
+ */llama_stack/env.py
+ */__init__.py
diff --git a/.github/workflows/coverage-badge.yml b/.github/workflows/coverage-badge.yml
new file mode 100644
index 000000000..6b2f133dd
--- /dev/null
+++ b/.github/workflows/coverage-badge.yml
@@ -0,0 +1,57 @@
+name: Coverage Badge
+
+on:
+ push:
+ branches: [ main ]
+ paths:
+ - 'llama_stack/**'
+ - 'tests/unit/**'
+ - 'uv.lock'
+ - 'pyproject.toml'
+ - 'requirements.txt'
+ - '.github/workflows/unit-tests.yml'
+ - '.github/workflows/coverage-badge.yml' # This workflow
+ workflow_dispatch:
+
+jobs:
+ unit-tests:
+ runs-on: ubuntu-latest
+ steps:
+ - name: Checkout repository
+ uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
+
+ - name: Install dependencies
+ uses: ./.github/actions/setup-runner
+
+ - name: Run unit tests
+ run: |
+ ./scripts/unit-tests.sh
+
+ - name: Coverage Badge
+ uses: tj-actions/coverage-badge-py@1788babcb24544eb5bbb6e0d374df5d1e54e670f # v2.0.4
+
+ - name: Verify Changed files
+ uses: tj-actions/verify-changed-files@a1c6acee9df209257a246f2cc6ae8cb6581c1edf # v20.0.4
+ id: verify-changed-files
+ with:
+ files: coverage.svg
+
+ - name: Commit files
+ if: steps.verify-changed-files.outputs.files_changed == 'true'
+ run: |
+ git config --local user.email "github-actions[bot]@users.noreply.github.com"
+ git config --local user.name "github-actions[bot]"
+ git add coverage.svg
+ git commit -m "Updated coverage.svg"
+
+ - name: Create Pull Request
+ if: steps.verify-changed-files.outputs.files_changed == 'true'
+ uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8
+ with:
+ token: ${{ secrets.GITHUB_TOKEN }}
+ title: "ci: [Automatic] Coverage Badge Update"
+ body: |
+ This PR updates the coverage badge based on the latest coverage report.
+
+ Automatically generated by the [workflow coverage-badge.yaml](.github/workflows/coverage-badge.yaml)
+ delete-branch: true
diff --git a/.github/workflows/integration-tests.yml b/.github/workflows/integration-tests.yml
index a5883daf7..0b6c1be3b 100644
--- a/.github/workflows/integration-tests.yml
+++ b/.github/workflows/integration-tests.yml
@@ -7,7 +7,7 @@ on:
branches: [ main ]
paths:
- 'llama_stack/**'
- - 'tests/integration/**'
+ - 'tests/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
diff --git a/.github/workflows/unit-tests.yml b/.github/workflows/unit-tests.yml
index e29045e52..41034b45f 100644
--- a/.github/workflows/unit-tests.yml
+++ b/.github/workflows/unit-tests.yml
@@ -36,7 +36,7 @@ jobs:
- name: Run unit tests
run: |
- PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }}
+ PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --junitxml=pytest-report-${{ matrix.python }}.xml
- name: Upload test results
if: always()
diff --git a/README.md b/README.md
index 9148ce05d..7f0fed345 100644
--- a/README.md
+++ b/README.md
@@ -6,6 +6,7 @@
[](https://discord.gg/llama-stack)
[](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
[](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
+
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)
diff --git a/coverage.svg b/coverage.svg
new file mode 100644
index 000000000..636889bb0
--- /dev/null
+++ b/coverage.svg
@@ -0,0 +1,21 @@
+
+
diff --git a/docs/source/distributions/self_hosted_distro/starter.md b/docs/source/distributions/self_hosted_distro/starter.md
index 753746d84..56cdd5e73 100644
--- a/docs/source/distributions/self_hosted_distro/starter.md
+++ b/docs/source/distributions/self_hosted_distro/starter.md
@@ -167,7 +167,7 @@ When using the `:` pattern (like `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`),
## Running the Distribution
-You can run the starter distribution via Docker or Conda.
+You can run the starter distribution via Docker, Conda, or venv.
### Via Docker
@@ -186,17 +186,12 @@ docker run \
--port $LLAMA_STACK_PORT
```
-### Via Conda
+### Via Conda or venv
-Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available.
+Ensure you have configured the starter distribution using the environment variables explained above.
```bash
-llama stack build --template starter --image-type conda
-llama stack run distributions/starter/run.yaml \
- --port 8321 \
- --env OPENAI_API_KEY=your_openai_key \
- --env FIREWORKS_API_KEY=your_fireworks_key \
- --env TOGETHER_API_KEY=your_together_key
+uv run --with llama-stack llama stack build --template starter --image-type --run
```
## Example Usage
diff --git a/docs/source/getting_started/quickstart.md b/docs/source/getting_started/quickstart.md
index 881ddd29b..59791643d 100644
--- a/docs/source/getting_started/quickstart.md
+++ b/docs/source/getting_started/quickstart.md
@@ -19,7 +19,7 @@ ollama run llama3.2:3b --keepalive 60m
#### Step 2: Run the Llama Stack server
We will use `uv` to run the Llama Stack server.
```bash
-INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
+ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
```
#### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`.
@@ -111,6 +111,12 @@ Ultimately, great work is about making a meaningful contribution and leaving a l
```
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
+```{admonition} HuggingFace access
+:class: tip
+
+If you are getting a **401 Client Error** from HuggingFace for the **all-MiniLM-L6-v2** model, try setting **HF_TOKEN** to a valid HuggingFace token in your environment
+```
+
### Next Steps
Now you're ready to dive deeper into Llama Stack!
diff --git a/llama_stack/models/llama/llama3/chat_format.py b/llama_stack/models/llama/llama3/chat_format.py
index 7bb05d8db..0a973cf0c 100644
--- a/llama_stack/models/llama/llama3/chat_format.py
+++ b/llama_stack/models/llama/llama3/chat_format.py
@@ -8,6 +8,7 @@ import io
import json
import uuid
from dataclasses import dataclass
+from typing import Any
from PIL import Image as PIL_Image
@@ -184,16 +185,26 @@ class ChatFormat:
content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message
- tool_name = None
- tool_arguments = {}
+ tool_name: str | BuiltinTool | None = None
+ tool_arguments: dict[str, Any] = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None:
- tool_name, tool_arguments = custom_tool_info
+ # Type guard: ensure custom_tool_info is a tuple of correct types
+ if isinstance(custom_tool_info, tuple) and len(custom_tool_info) == 2:
+ extracted_tool_name, extracted_tool_arguments = custom_tool_info
+ # Handle both dict and str return types from the function
+ if isinstance(extracted_tool_arguments, dict):
+ tool_name, tool_arguments = extracted_tool_name, extracted_tool_arguments
+ else:
+ # If it's a string, treat it as a query parameter
+ tool_name, tool_arguments = extracted_tool_name, {"query": extracted_tool_arguments}
+ else:
+ tool_name, tool_arguments = None, {}
# Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case
- if tool_name in BuiltinTool.__members__:
+ if tool_name is not None and tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name]
if isinstance(tool_arguments, dict):
tool_arguments = {
diff --git a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py
index 97e96b929..7ade75032 100644
--- a/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py
+++ b/llama_stack/providers/inline/inference/meta_reference/parallel_utils.py
@@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel):
def mp_rank_0() -> bool:
- return get_model_parallel_rank() == 0
+ return bool(get_model_parallel_rank() == 0)
def encode_msg(msg: ProcessingMessage) -> bytes:
@@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str):
reply_socket.send_multipart([client_id, encode_msg(obj)])
while True:
- tasks = [None]
+ tasks: list[ProcessingMessage | None] = [None]
if mp_rank_0():
client_id, maybe_task_json = maybe_get_work(reply_socket)
if maybe_task_json is not None:
@@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str):
break
for obj in out:
- updates = [None]
+ updates: list[ProcessingMessage | None] = [None]
if mp_rank_0():
_, update_json = maybe_get_work(reply_socket)
update = maybe_parse_message(update_json)
diff --git a/pyproject.toml b/pyproject.toml
index 72f3a323f..15e2e10b4 100644
--- a/pyproject.toml
+++ b/pyproject.toml
@@ -91,6 +91,7 @@ unit = [
"pymilvus>=2.5.12",
"litellm",
"together",
+ "coverage",
]
# These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment
@@ -242,7 +243,6 @@ exclude = [
"^llama_stack/distribution/store/registry\\.py$",
"^llama_stack/distribution/utils/exec\\.py$",
"^llama_stack/distribution/utils/prompt_for_config\\.py$",
- "^llama_stack/models/llama/llama3/chat_format\\.py$",
"^llama_stack/models/llama/llama3/interface\\.py$",
"^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$",
@@ -255,7 +255,6 @@ exclude = [
"^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/",
- "^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/",
diff --git a/scripts/unit-tests.sh b/scripts/unit-tests.sh
index 68d6458fc..458cd383d 100755
--- a/scripts/unit-tests.sh
+++ b/scripts/unit-tests.sh
@@ -16,4 +16,9 @@ if [ $FOUND_PYTHON -ne 0 ]; then
uv python install "$PYTHON_VERSION"
fi
-uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest -s -v tests/unit/ $@
+# Run unit tests with coverage
+uv run --python "$PYTHON_VERSION" --with-editable . --group unit \
+ coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@"
+
+# Generate HTML coverage report
+uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION
diff --git a/tests/integration/post_training/test_post_training.py b/tests/integration/post_training/test_post_training.py
index bb4639d17..3d56b322f 100644
--- a/tests/integration/post_training/test_post_training.py
+++ b/tests/integration/post_training/test_post_training.py
@@ -123,14 +123,14 @@ class TestPostTraining:
logger.info(f"Job artifacts: {artifacts}")
# TODO: Fix these tests to properly represent the Jobs API in training
- # @pytest.mark.asyncio
+ #
# async def test_get_training_jobs(self, post_training_stack):
# post_training_impl = post_training_stack
# jobs_list = await post_training_impl.get_training_jobs()
# assert isinstance(jobs_list, list)
# assert jobs_list[0].job_uuid == "1234"
- # @pytest.mark.asyncio
+ #
# async def test_get_training_job_status(self, post_training_stack):
# post_training_impl = post_training_stack
# job_status = await post_training_impl.get_training_job_status("1234")
@@ -139,7 +139,7 @@ class TestPostTraining:
# assert job_status.status == JobStatus.completed
# assert isinstance(job_status.checkpoints[0], Checkpoint)
- # @pytest.mark.asyncio
+ #
# async def test_get_training_job_artifacts(self, post_training_stack):
# post_training_impl = post_training_stack
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234")
diff --git a/tests/unit/README.md b/tests/unit/README.md
index c95c3a0e7..06e22fb8c 100644
--- a/tests/unit/README.md
+++ b/tests/unit/README.md
@@ -1,9 +1,17 @@
# Llama Stack Unit Tests
+## Unit Tests
+
+Unit tests verify individual components and functions in isolation. They are fast, reliable, and don't require external services.
+
+### Prerequisites
+
+1. **Python Environment**: Ensure you have Python 3.12+ installed
+2. **uv Package Manager**: Install `uv` if not already installed
+
You can run the unit tests by running:
```bash
-source .venv/bin/activate
./scripts/unit-tests.sh [PYTEST_ARGS]
```
@@ -19,3 +27,21 @@ If you'd like to run for a non-default version of Python (currently 3.12), pass
source .venv/bin/activate
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
```
+
+### Test Configuration
+
+- **Test Discovery**: Tests are automatically discovered in the `tests/unit/` directory
+- **Async Support**: Tests use `--asyncio-mode=auto` for automatic async test handling
+- **Coverage**: Tests generate coverage reports in `htmlcov/` directory
+- **Python Version**: Defaults to Python 3.12, but can be overridden with `PYTHON_VERSION` environment variable
+
+### Coverage Reports
+
+After running tests, you can view coverage reports:
+
+```bash
+# Open HTML coverage report in browser
+open htmlcov/index.html # macOS
+xdg-open htmlcov/index.html # Linux
+start htmlcov/index.html # Windows
+```
diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py
index 577496cec..0362eb5dd 100644
--- a/tests/unit/models/test_prompt_adapter.py
+++ b/tests/unit/models/test_prompt_adapter.py
@@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
-import pytest
from llama_stack.apis.inference import (
ChatCompletionRequest,
@@ -32,7 +31,6 @@ MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
-@pytest.mark.asyncio
async def test_system_default():
content = "Hello !"
request = ChatCompletionRequest(
@@ -47,7 +45,6 @@ async def test_system_default():
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
-@pytest.mark.asyncio
async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
@@ -67,7 +64,6 @@ async def test_system_builtin_only():
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
-@pytest.mark.asyncio
async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
@@ -98,7 +94,6 @@ async def test_system_custom_only():
assert messages[-1].content == content
-@pytest.mark.asyncio
async def test_system_custom_and_builtin():
content = "Hello !"
request = ChatCompletionRequest(
@@ -132,7 +127,6 @@ async def test_system_custom_and_builtin():
assert messages[-1].content == content
-@pytest.mark.asyncio
async def test_completion_message_encoding():
request = ChatCompletionRequest(
model=MODEL3_2,
@@ -174,7 +168,6 @@ async def test_completion_message_encoding():
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
-@pytest.mark.asyncio
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
@@ -195,7 +188,6 @@ async def test_user_provided_system_message():
assert messages[-1].content == content
-@pytest.mark.asyncio
async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
@@ -221,7 +213,6 @@ async def test_replace_system_message_behavior_builtin_tools():
assert messages[-1].content == content
-@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
@@ -259,7 +250,6 @@ async def test_replace_system_message_behavior_custom_tools():
assert messages[-1].content == content
-@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
diff --git a/tests/unit/models/test_system_prompts.py b/tests/unit/models/test_system_prompts.py
index 1f4ccc7e3..f5580f4c5 100644
--- a/tests/unit/models/test_system_prompts.py
+++ b/tests/unit/models/test_system_prompts.py
@@ -12,7 +12,6 @@
# the top-level of this source tree.
import textwrap
-import unittest
from datetime import datetime
from llama_stack.models.llama.llama3.prompt_templates import (
@@ -24,59 +23,61 @@ from llama_stack.models.llama.llama3.prompt_templates import (
)
-class PromptTemplateTests(unittest.TestCase):
- def check_generator_output(self, generator):
- for example in generator.data_examples():
- pt = generator.gen(example)
- text = pt.render()
- # print(text) # debugging
- if not example:
- continue
- for tool in example:
- assert tool.tool_name in text
-
- def test_system_default(self):
- generator = SystemDefaultGenerator()
- today = datetime.now().strftime("%d %B %Y")
- expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
- assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
-
- def test_system_builtin_only(self):
- generator = BuiltinToolGenerator()
- expected_text = textwrap.dedent(
- """
- Environment: ipython
- Tools: brave_search, wolfram_alpha
- """
- )
- assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
-
- def test_system_custom_only(self):
- self.maxDiff = None
- generator = JsonCustomToolGenerator()
- self.check_generator_output(generator)
-
- def test_system_custom_function_tag(self):
- self.maxDiff = None
- generator = FunctionTagCustomToolGenerator()
- self.check_generator_output(generator)
-
- def test_llama_3_2_system_zero_shot(self):
- generator = PythonListCustomToolGenerator()
- self.check_generator_output(generator)
-
- def test_llama_3_2_provided_system_prompt(self):
- generator = PythonListCustomToolGenerator()
- user_system_prompt = textwrap.dedent(
- """
- Overriding message.
-
- {{ function_description }}
- """
- )
- example = generator.data_examples()[0]
-
- pt = generator.gen(example, user_system_prompt)
+def check_generator_output(generator):
+ for example in generator.data_examples():
+ pt = generator.gen(example)
text = pt.render()
- assert "Overriding message." in text
- assert '"name": "get_weather"' in text
+ if not example:
+ continue
+ for tool in example:
+ assert tool.tool_name in text
+
+
+def test_system_default():
+ generator = SystemDefaultGenerator()
+ today = datetime.now().strftime("%d %B %Y")
+ expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
+ assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
+
+
+def test_system_builtin_only():
+ generator = BuiltinToolGenerator()
+ expected_text = textwrap.dedent(
+ """
+ Environment: ipython
+ Tools: brave_search, wolfram_alpha
+ """
+ )
+ assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
+
+
+def test_system_custom_only():
+ generator = JsonCustomToolGenerator()
+ check_generator_output(generator)
+
+
+def test_system_custom_function_tag():
+ generator = FunctionTagCustomToolGenerator()
+ check_generator_output(generator)
+
+
+def test_llama_3_2_system_zero_shot():
+ generator = PythonListCustomToolGenerator()
+ check_generator_output(generator)
+
+
+def test_llama_3_2_provided_system_prompt():
+ generator = PythonListCustomToolGenerator()
+ user_system_prompt = textwrap.dedent(
+ """
+ Overriding message.
+
+ {{ function_description }}
+ """
+ )
+ example = generator.data_examples()[0]
+
+ pt = generator.gen(example, user_system_prompt)
+ text = pt.render()
+ assert "Overriding message." in text
+ assert '"name": "get_weather"' in text
diff --git a/tests/unit/providers/nvidia/test_datastore.py b/tests/unit/providers/nvidia/test_datastore.py
index a17e51a9c..b59636f7b 100644
--- a/tests/unit/providers/nvidia/test_datastore.py
+++ b/tests/unit/providers/nvidia/test_datastore.py
@@ -5,103 +5,110 @@
# the root directory of this source tree.
import os
-import unittest
from unittest.mock import patch
import pytest
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
+from llama_stack.apis.resource import ResourceType
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
-class TestNvidiaDatastore(unittest.TestCase):
- def setUp(self):
- os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
+@pytest.fixture
+def nvidia_adapter():
+ """Fixture to set up NvidiaDatasetIOAdapter with mocked requests."""
+ os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
- config = NvidiaDatasetIOConfig(
- datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
- )
- self.adapter = NvidiaDatasetIOAdapter(config)
- self.make_request_patcher = patch(
- "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
- )
- self.mock_make_request = self.make_request_patcher.start()
+ config = NvidiaDatasetIOConfig(
+ datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
+ )
+ adapter = NvidiaDatasetIOAdapter(config)
- def tearDown(self):
- self.make_request_patcher.stop()
+ with patch(
+ "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
+ ) as mock_make_request:
+ yield adapter, mock_make_request
- @pytest.fixture(autouse=True)
- def inject_fixtures(self, run_async):
- self.run_async = run_async
- def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None):
- """Helper method to verify request details in mock calls."""
- call_args = mock_call.call_args
+def _assert_request(mock_call, expected_method, expected_path, expected_json=None):
+ """Helper function to verify request details in mock calls."""
+ call_args = mock_call.call_args
- assert call_args[0][0] == expected_method
- assert call_args[0][1] == expected_path
+ assert call_args[0][0] == expected_method
+ assert call_args[0][1] == expected_path
- if expected_json:
- for key, value in expected_json.items():
- assert call_args[1]["json"][key] == value
+ if expected_json:
+ for key, value in expected_json.items():
+ assert call_args[1]["json"][key] == value
- def test_register_dataset(self):
- self.mock_make_request.return_value = {
- "id": "dataset-123456",
+
+def test_register_dataset(nvidia_adapter, run_async):
+ adapter, mock_make_request = nvidia_adapter
+ mock_make_request.return_value = {
+ "id": "dataset-123456",
+ "name": "test-dataset",
+ "namespace": "default",
+ }
+
+ dataset_def = Dataset(
+ identifier="test-dataset",
+ type=ResourceType.dataset,
+ provider_resource_id="",
+ provider_id="",
+ purpose=DatasetPurpose.post_training_messages,
+ source=URIDataSource(uri="https://example.com/data.jsonl"),
+ metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
+ )
+
+ run_async(adapter.register_dataset(dataset_def))
+
+ mock_make_request.assert_called_once()
+ _assert_request(
+ mock_make_request,
+ "POST",
+ "/v1/datasets",
+ expected_json={
"name": "test-dataset",
"namespace": "default",
- }
+ "files_url": "https://example.com/data.jsonl",
+ "project": "default",
+ "format": "jsonl",
+ "description": "Test dataset description",
+ },
+ )
- dataset_def = Dataset(
- identifier="test-dataset",
- type="dataset",
- provider_resource_id="",
- provider_id="",
- purpose=DatasetPurpose.post_training_messages,
- source=URIDataSource(uri="https://example.com/data.jsonl"),
- metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
- )
- self.run_async(self.adapter.register_dataset(dataset_def))
+def test_unregister_dataset(nvidia_adapter, run_async):
+ adapter, mock_make_request = nvidia_adapter
+ mock_make_request.return_value = {
+ "message": "Resource deleted successfully.",
+ "id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
+ "deleted_at": None,
+ }
+ dataset_id = "test-dataset"
- self.mock_make_request.assert_called_once()
- self._assert_request(
- self.mock_make_request,
- "POST",
- "/v1/datasets",
- expected_json={
- "name": "test-dataset",
- "namespace": "default",
- "files_url": "https://example.com/data.jsonl",
- "project": "default",
- "format": "jsonl",
- "description": "Test dataset description",
- },
- )
+ run_async(adapter.unregister_dataset(dataset_id))
- def test_unregister_dataset(self):
- self.mock_make_request.return_value = {
- "message": "Resource deleted successfully.",
- "id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
- "deleted_at": None,
- }
- dataset_id = "test-dataset"
+ mock_make_request.assert_called_once()
+ _assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
- self.run_async(self.adapter.unregister_dataset(dataset_id))
- self.mock_make_request.assert_called_once()
- self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
+def test_register_dataset_with_custom_namespace_project(run_async):
+ """Test with custom namespace and project configuration."""
+ os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
- def test_register_dataset_with_custom_namespace_project(self):
- custom_config = NvidiaDatasetIOConfig(
- datasets_url=os.environ["NVIDIA_DATASETS_URL"],
- dataset_namespace="custom-namespace",
- project_id="custom-project",
- )
- custom_adapter = NvidiaDatasetIOAdapter(custom_config)
+ custom_config = NvidiaDatasetIOConfig(
+ datasets_url=os.environ["NVIDIA_DATASETS_URL"],
+ dataset_namespace="custom-namespace",
+ project_id="custom-project",
+ )
+ custom_adapter = NvidiaDatasetIOAdapter(custom_config)
- self.mock_make_request.return_value = {
+ with patch(
+ "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
+ ) as mock_make_request:
+ mock_make_request.return_value = {
"id": "dataset-123456",
"name": "test-dataset",
"namespace": "custom-namespace",
@@ -109,7 +116,7 @@ class TestNvidiaDatastore(unittest.TestCase):
dataset_def = Dataset(
identifier="test-dataset",
- type="dataset",
+ type=ResourceType.dataset,
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
@@ -117,11 +124,11 @@ class TestNvidiaDatastore(unittest.TestCase):
metadata={"format": "jsonl"},
)
- self.run_async(custom_adapter.register_dataset(dataset_def))
+ run_async(custom_adapter.register_dataset(dataset_def))
- self.mock_make_request.assert_called_once()
- self._assert_request(
- self.mock_make_request,
+ mock_make_request.assert_called_once()
+ _assert_request(
+ mock_make_request,
"POST",
"/v1/datasets",
expected_json={
@@ -132,7 +139,3 @@ class TestNvidiaDatastore(unittest.TestCase):
"format": "jsonl",
},
)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/unit/providers/nvidia/test_parameters.py b/tests/unit/providers/nvidia/test_parameters.py
index cc33f7609..7e4323bd7 100644
--- a/tests/unit/providers/nvidia/test_parameters.py
+++ b/tests/unit/providers/nvidia/test_parameters.py
@@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
-import unittest
import warnings
from unittest.mock import patch
@@ -27,14 +26,13 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
)
-class TestNvidiaParameters(unittest.TestCase):
- def setUp(self):
- os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
+class TestNvidiaParameters:
+ @pytest.fixture(autouse=True)
+ def setup_and_teardown(self):
+ """Setup and teardown for each test method."""
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
- config = NvidiaPostTrainingConfig(
- base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
- )
+ config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
@@ -48,7 +46,8 @@ class TestNvidiaParameters(unittest.TestCase):
"updated_at": "2025-03-04T13:07:47.543605",
}
- def tearDown(self):
+ yield
+
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
@@ -166,8 +165,8 @@ class TestNvidiaParameters(unittest.TestCase):
self.run_async(
self.adapter.supervised_fine_tune(
- job_uuid=required_job_uuid, # Required parameter
- model=required_model, # Required parameter
+ job_uuid=required_job_uuid,
+ model=required_model,
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
@@ -198,7 +197,6 @@ class TestNvidiaParameters(unittest.TestCase):
data_config = DataConfig(
dataset_id="test-dataset",
batch_size=8,
- # Unsupported parameters
shuffle=True,
data_format=DatasetFormat.instruct,
validation_dataset_id="val-dataset",
@@ -207,20 +205,16 @@ class TestNvidiaParameters(unittest.TestCase):
optimizer_config = OptimizerConfig(
lr=0.0001,
weight_decay=0.01,
- # Unsupported parameters
optimizer_type=OptimizerType.adam,
num_warmup_steps=100,
)
- efficiency_config = EfficiencyConfig(
- enable_activation_checkpointing=True # Unsupported parameter
- )
+ efficiency_config = EfficiencyConfig(enable_activation_checkpointing=True)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
- # Unsupported parameters
efficiency_config=efficiency_config,
max_steps_per_epoch=1000,
gradient_accumulation_steps=4,
@@ -228,7 +222,6 @@ class TestNvidiaParameters(unittest.TestCase):
dtype="bf16",
)
- # Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
@@ -236,7 +229,7 @@ class TestNvidiaParameters(unittest.TestCase):
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
- checkpoint_dir="test-dir", # Unsupported parameter
+ checkpoint_dir="test-dir",
algorithm_config=LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
@@ -246,8 +239,8 @@ class TestNvidiaParameters(unittest.TestCase):
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
),
training_config=convert_pydantic_to_json_value(training_config),
- logger_config={"test": "value"}, # Unsupported parameter
- hyperparam_search_config={"test": "value"}, # Unsupported parameter
+ logger_config={"test": "value"},
+ hyperparam_search_config={"test": "value"},
)
)
@@ -265,7 +258,6 @@ class TestNvidiaParameters(unittest.TestCase):
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
- # required unsupported parameters
"rank",
"apply_lora_to_output",
"lora_attn_modules",
@@ -273,7 +265,3 @@ class TestNvidiaParameters(unittest.TestCase):
]
for field in fields:
assert any(field in text for text in warning_texts)
-
-
-if __name__ == "__main__":
- unittest.main()
diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py
index bbbb60a30..bc474f3bc 100644
--- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py
+++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py
@@ -5,13 +5,11 @@
# the root directory of this source tree.
import os
-import unittest
import warnings
-from unittest.mock import AsyncMock, patch
+from unittest.mock import patch
import pytest
-from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.post_training.post_training import (
DataConfig,
DatasetFormat,
@@ -22,7 +20,6 @@ from llama_stack.apis.post_training.post_training import (
TrainingConfig,
)
from llama_stack.distribution.library_client import convert_pydantic_to_json_value
-from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter,
@@ -32,336 +29,297 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
)
-class TestNvidiaPostTraining(unittest.TestCase):
- def setUp(self):
- os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
- os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
+@pytest.fixture
+def nvidia_post_training_adapter():
+ """Fixture to create and configure the NVIDIA post training adapter."""
+ os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
- config = NvidiaPostTrainingConfig(
- base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
+ config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
+ adapter = NvidiaPostTrainingAdapter(config)
+
+ with patch.object(adapter, "_make_request") as mock_make_request:
+ yield adapter, mock_make_request
+
+
+def _assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
+ """Helper method to verify request details in mock calls."""
+ call_args = mock_call.call_args
+
+ if expected_method and expected_path:
+ if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
+ assert call_args[0] == (expected_method, expected_path)
+ else:
+ assert call_args[1]["method"] == expected_method
+ assert call_args[1]["path"] == expected_path
+
+ if expected_params:
+ assert call_args[1]["params"] == expected_params
+
+ if expected_json:
+ for key, value in expected_json.items():
+ assert call_args[1]["json"][key] == value
+
+
+async def test_supervised_fine_tune(nvidia_post_training_adapter):
+ """Test the supervised fine-tuning API call."""
+ adapter, mock_make_request = nvidia_post_training_adapter
+ mock_make_request.return_value = {
+ "id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
+ "created_at": "2024-12-09T04:06:28.542884",
+ "updated_at": "2024-12-09T04:06:28.542884",
+ "config": {
+ "schema_version": "1.0",
+ "id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
+ "created_at": "2024-12-09T04:06:28.542657",
+ "updated_at": "2024-12-09T04:06:28.569837",
+ "custom_fields": {},
+ "name": "meta-llama/Llama-3.1-8B-Instruct",
+ "base_model": "meta-llama/Llama-3.1-8B-Instruct",
+ "model_path": "llama-3_1-8b-instruct",
+ "training_types": [],
+ "finetuning_types": ["lora"],
+ "precision": "bf16",
+ "num_gpus": 4,
+ "num_nodes": 1,
+ "micro_batch_size": 1,
+ "tensor_parallel_size": 1,
+ "max_seq_length": 4096,
+ },
+ "dataset": {
+ "schema_version": "1.0",
+ "id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
+ "created_at": "2024-12-09T04:06:28.542657",
+ "updated_at": "2024-12-09T04:06:28.542660",
+ "custom_fields": {},
+ "name": "sample-basic-test",
+ "version_id": "main",
+ "version_tags": [],
+ },
+ "hyperparameters": {
+ "finetuning_type": "lora",
+ "training_type": "sft",
+ "batch_size": 16,
+ "epochs": 2,
+ "learning_rate": 0.0001,
+ "lora": {"alpha": 16},
+ },
+ "output_model": "default/job-1234",
+ "status": "created",
+ "project": "default",
+ "custom_fields": {},
+ "ownership": {"created_by": "me", "access_policies": {}},
+ }
+
+ algorithm_config = LoraFinetuningConfig(
+ type="LoRA",
+ apply_lora_to_mlp=True,
+ apply_lora_to_output=True,
+ alpha=16,
+ rank=16,
+ lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
+ )
+
+ data_config = DataConfig(
+ dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
+ )
+
+ optimizer_config = OptimizerConfig(
+ optimizer_type=OptimizerType.adam,
+ lr=0.0001,
+ weight_decay=0.01,
+ num_warmup_steps=100,
+ )
+
+ training_config = TrainingConfig(
+ n_epochs=2,
+ data_config=data_config,
+ optimizer_config=optimizer_config,
+ )
+
+ with warnings.catch_warnings(record=True):
+ warnings.simplefilter("always")
+ training_job = await adapter.supervised_fine_tune(
+ job_uuid="1234",
+ model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
+ checkpoint_dir="",
+ algorithm_config=algorithm_config,
+ training_config=convert_pydantic_to_json_value(training_config),
+ logger_config={},
+ hyperparam_search_config={},
)
- self.adapter = NvidiaPostTrainingAdapter(config)
- self.make_request_patcher = patch(
- "llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
- )
- self.mock_make_request = self.make_request_patcher.start()
- # Mock the inference client
- inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
- self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
+ # check the output is a PostTrainingJob
+ assert isinstance(training_job, NvidiaPostTrainingJob)
+ assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
- self.mock_client = unittest.mock.MagicMock()
- self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
- self.inference_mock_make_request = self.mock_client.chat.completions.create
- self.inference_make_request_patcher = patch(
- "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client",
- new_callable=unittest.mock.PropertyMock,
- return_value=self.mock_client,
- )
- self.inference_make_request_patcher.start()
-
- def tearDown(self):
- self.make_request_patcher.stop()
- self.inference_make_request_patcher.stop()
-
- @pytest.fixture(autouse=True)
- def inject_fixtures(self, run_async):
- self.run_async = run_async
-
- def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
- """Helper method to verify request details in mock calls."""
- call_args = mock_call.call_args
-
- if expected_method and expected_path:
- if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
- assert call_args[0] == (expected_method, expected_path)
- else:
- assert call_args[1]["method"] == expected_method
- assert call_args[1]["path"] == expected_path
-
- if expected_params:
- assert call_args[1]["params"] == expected_params
-
- if expected_json:
- for key, value in expected_json.items():
- assert call_args[1]["json"][key] == value
-
- def test_supervised_fine_tune(self):
- """Test the supervised fine-tuning API call."""
- self.mock_make_request.return_value = {
- "id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
- "created_at": "2024-12-09T04:06:28.542884",
- "updated_at": "2024-12-09T04:06:28.542884",
- "config": {
- "schema_version": "1.0",
- "id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
- "created_at": "2024-12-09T04:06:28.542657",
- "updated_at": "2024-12-09T04:06:28.569837",
- "custom_fields": {},
- "name": "meta-llama/Llama-3.1-8B-Instruct",
- "base_model": "meta-llama/Llama-3.1-8B-Instruct",
- "model_path": "llama-3_1-8b-instruct",
- "training_types": [],
- "finetuning_types": ["lora"],
- "precision": "bf16",
- "num_gpus": 4,
- "num_nodes": 1,
- "micro_batch_size": 1,
- "tensor_parallel_size": 1,
- "max_seq_length": 4096,
- },
- "dataset": {
- "schema_version": "1.0",
- "id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
- "created_at": "2024-12-09T04:06:28.542657",
- "updated_at": "2024-12-09T04:06:28.542660",
- "custom_fields": {},
- "name": "sample-basic-test",
- "version_id": "main",
- "version_tags": [],
- },
+ mock_make_request.assert_called_once()
+ _assert_request(
+ mock_make_request,
+ "POST",
+ "/v1/customization/jobs",
+ expected_json={
+ "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
+ "dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
- "finetuning_type": "lora",
"training_type": "sft",
- "batch_size": 16,
+ "finetuning_type": "lora",
"epochs": 2,
+ "batch_size": 16,
"learning_rate": 0.0001,
+ "weight_decay": 0.01,
"lora": {"alpha": 16},
},
- "output_model": "default/job-1234",
- "status": "created",
- "project": "default",
- "custom_fields": {},
- "ownership": {"created_by": "me", "access_policies": {}},
+ },
+ )
+
+
+async def test_supervised_fine_tune_with_qat(nvidia_post_training_adapter):
+ """Test that QAT configuration raises NotImplementedError."""
+ adapter, mock_make_request = nvidia_post_training_adapter
+
+ algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
+ data_config = DataConfig(
+ dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
+ )
+ optimizer_config = OptimizerConfig(
+ optimizer_type=OptimizerType.adam,
+ lr=0.0001,
+ weight_decay=0.01,
+ num_warmup_steps=100,
+ )
+ training_config = TrainingConfig(
+ n_epochs=2,
+ data_config=data_config,
+ optimizer_config=optimizer_config,
+ )
+
+ # This will raise NotImplementedError since QAT is not supported
+ with pytest.raises(NotImplementedError):
+ await adapter.supervised_fine_tune(
+ job_uuid="1234",
+ model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
+ checkpoint_dir="",
+ algorithm_config=algorithm_config,
+ training_config=convert_pydantic_to_json_value(training_config),
+ logger_config={},
+ hyperparam_search_config={},
+ )
+
+
+async def test_get_training_job_status(nvidia_post_training_adapter):
+ """Test getting training job status with different statuses."""
+ adapter, mock_make_request = nvidia_post_training_adapter
+
+ customizer_status_to_job_status = [
+ ("running", "in_progress"),
+ ("completed", "completed"),
+ ("failed", "failed"),
+ ("cancelled", "cancelled"),
+ ("pending", "scheduled"),
+ ("unknown", "scheduled"),
+ ]
+
+ for customizer_status, expected_status in customizer_status_to_job_status:
+ mock_make_request.return_value = {
+ "created_at": "2024-12-09T04:06:28.580220",
+ "updated_at": "2024-12-09T04:21:19.852832",
+ "status": customizer_status,
+ "steps_completed": 1210,
+ "epochs_completed": 2,
+ "percentage_done": 100.0,
+ "best_epoch": 2,
+ "train_loss": 1.718016266822815,
+ "val_loss": 1.8661999702453613,
}
- algorithm_config = LoraFinetuningConfig(
- type="LoRA",
- apply_lora_to_mlp=True,
- apply_lora_to_output=True,
- alpha=16,
- rank=16,
- lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
- )
-
- data_config = DataConfig(
- dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
- )
-
- optimizer_config = OptimizerConfig(
- optimizer_type=OptimizerType.adam,
- lr=0.0001,
- weight_decay=0.01,
- num_warmup_steps=100,
- )
-
- training_config = TrainingConfig(
- n_epochs=2,
- data_config=data_config,
- optimizer_config=optimizer_config,
- )
-
- with warnings.catch_warnings(record=True):
- warnings.simplefilter("always")
- training_job = self.run_async(
- self.adapter.supervised_fine_tune(
- job_uuid="1234",
- model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
- checkpoint_dir="",
- algorithm_config=algorithm_config,
- training_config=convert_pydantic_to_json_value(training_config),
- logger_config={},
- hyperparam_search_config={},
- )
- )
-
- # check the output is a PostTrainingJob
- assert isinstance(training_job, NvidiaPostTrainingJob)
- assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
-
- self.mock_make_request.assert_called_once()
- self._assert_request(
- self.mock_make_request,
- "POST",
- "/v1/customization/jobs",
- expected_json={
- "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
- "dataset": {"name": "sample-basic-test", "namespace": "default"},
- "hyperparameters": {
- "training_type": "sft",
- "finetuning_type": "lora",
- "epochs": 2,
- "batch_size": 16,
- "learning_rate": 0.0001,
- "weight_decay": 0.01,
- "lora": {"alpha": 16},
- },
- },
- )
-
- def test_supervised_fine_tune_with_qat(self):
- algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
- data_config = DataConfig(
- dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
- )
- optimizer_config = OptimizerConfig(
- optimizer_type=OptimizerType.adam,
- lr=0.0001,
- weight_decay=0.01,
- num_warmup_steps=100,
- )
- training_config = TrainingConfig(
- n_epochs=2,
- data_config=data_config,
- optimizer_config=optimizer_config,
- )
- # This will raise NotImplementedError since QAT is not supported
- with self.assertRaises(NotImplementedError):
- self.run_async(
- self.adapter.supervised_fine_tune(
- job_uuid="1234",
- model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
- checkpoint_dir="",
- algorithm_config=algorithm_config,
- training_config=convert_pydantic_to_json_value(training_config),
- logger_config={},
- hyperparam_search_config={},
- )
- )
-
- def test_get_training_job_status(self):
- customizer_status_to_job_status = [
- ("running", "in_progress"),
- ("completed", "completed"),
- ("failed", "failed"),
- ("cancelled", "cancelled"),
- ("pending", "scheduled"),
- ("unknown", "scheduled"),
- ]
-
- for customizer_status, expected_status in customizer_status_to_job_status:
- with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
- self.mock_make_request.return_value = {
- "created_at": "2024-12-09T04:06:28.580220",
- "updated_at": "2024-12-09T04:21:19.852832",
- "status": customizer_status,
- "steps_completed": 1210,
- "epochs_completed": 2,
- "percentage_done": 100.0,
- "best_epoch": 2,
- "train_loss": 1.718016266822815,
- "val_loss": 1.8661999702453613,
- }
-
- job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
-
- status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
-
- assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
- assert status.status.value == expected_status
- assert status.steps_completed == 1210
- assert status.epochs_completed == 2
- assert status.percentage_done == 100.0
- assert status.best_epoch == 2
- assert status.train_loss == 1.718016266822815
- assert status.val_loss == 1.8661999702453613
-
- self._assert_request(
- self.mock_make_request,
- "GET",
- f"/v1/customization/jobs/{job_id}/status",
- expected_params={"job_id": job_id},
- )
-
- def test_get_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
- self.mock_make_request.return_value = {
- "data": [
- {
- "id": job_id,
- "created_at": "2024-12-09T04:06:28.542884",
- "updated_at": "2024-12-09T04:21:19.852832",
- "config": {
- "name": "meta-llama/Llama-3.1-8B-Instruct",
- "base_model": "meta-llama/Llama-3.1-8B-Instruct",
- },
- "dataset": {"name": "default/sample-basic-test"},
- "hyperparameters": {
- "finetuning_type": "lora",
- "training_type": "sft",
- "batch_size": 16,
- "epochs": 2,
- "learning_rate": 0.0001,
- "lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
- },
- "output_model": "default/job-1234",
- "status": "completed",
- "project": "default",
- }
- ]
- }
- jobs = self.run_async(self.adapter.get_training_jobs())
+ status = await adapter.get_training_job_status(job_uuid=job_id)
- assert isinstance(jobs, ListNvidiaPostTrainingJobs)
- assert len(jobs.data) == 1
- job = jobs.data[0]
- assert job.job_uuid == job_id
- assert job.status.value == "completed"
+ assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
+ assert status.status.value == expected_status
+ # Note: The response object inherits extra fields via ConfigDict(extra="allow")
+ # So these attributes should be accessible using getattr with defaults
+ assert getattr(status, "steps_completed", None) == 1210
+ assert getattr(status, "epochs_completed", None) == 2
+ assert getattr(status, "percentage_done", None) == 100.0
+ assert getattr(status, "best_epoch", None) == 2
+ assert getattr(status, "train_loss", None) == 1.718016266822815
+ assert getattr(status, "val_loss", None) == 1.8661999702453613
- self.mock_make_request.assert_called_once()
- self._assert_request(
- self.mock_make_request,
+ _assert_request(
+ mock_make_request,
"GET",
- "/v1/customization/jobs",
- expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
- )
-
- def test_cancel_training_job(self):
- self.mock_make_request.return_value = {} # Empty response for successful cancellation
- job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
-
- result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
-
- assert result is None
-
- self.mock_make_request.assert_called_once()
- self._assert_request(
- self.mock_make_request,
- "POST",
- f"/v1/customization/jobs/{job_id}/cancel",
+ f"/v1/customization/jobs/{job_id}/status",
expected_params={"job_id": job_id},
)
- def test_inference_register_model(self):
- model_id = "default/job-1234"
- model_type = ModelType.llm
- model = Model(
- identifier=model_id,
- provider_id="nvidia",
- provider_model_id=model_id,
- provider_resource_id=model_id,
- model_type=model_type,
- )
-
- # simulate a NIM where default/job-1234 is an available model
- with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check:
- mock_check.return_value = True
- result = self.run_async(self.inference_adapter.register_model(model))
- assert result == model
- assert len(self.inference_adapter.alias_to_provider_id_map) > 1
- assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
-
- with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
- self.run_async(
- self.inference_adapter.chat_completion(
- model_id=model_id,
- messages=[{"role": "user", "content": "Hello, model"}],
- )
- )
-
- mock_chat_completion.assert_called()
+ mock_make_request.reset_mock()
-if __name__ == "__main__":
- unittest.main()
+async def test_get_training_jobs(nvidia_post_training_adapter):
+ """Test getting list of training jobs."""
+ adapter, mock_make_request = nvidia_post_training_adapter
+
+ job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
+ mock_make_request.return_value = {
+ "data": [
+ {
+ "id": job_id,
+ "created_at": "2024-12-09T04:06:28.542884",
+ "updated_at": "2024-12-09T04:21:19.852832",
+ "config": {
+ "name": "meta-llama/Llama-3.1-8B-Instruct",
+ "base_model": "meta-llama/Llama-3.1-8B-Instruct",
+ },
+ "dataset": {"name": "default/sample-basic-test"},
+ "hyperparameters": {
+ "finetuning_type": "lora",
+ "training_type": "sft",
+ "batch_size": 16,
+ "epochs": 2,
+ "learning_rate": 0.0001,
+ "lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
+ },
+ "output_model": "default/job-1234",
+ "status": "completed",
+ "project": "default",
+ }
+ ]
+ }
+
+ jobs = await adapter.get_training_jobs()
+
+ assert isinstance(jobs, ListNvidiaPostTrainingJobs)
+ assert len(jobs.data) == 1
+ job = jobs.data[0]
+ assert job.job_uuid == job_id
+ assert job.status.value == "completed"
+
+ mock_make_request.assert_called_once()
+ _assert_request(
+ mock_make_request,
+ "GET",
+ "/v1/customization/jobs",
+ expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
+ )
+
+
+async def test_cancel_training_job(nvidia_post_training_adapter):
+ """Test canceling a training job."""
+ adapter, mock_make_request = nvidia_post_training_adapter
+
+ mock_make_request.return_value = {} # Empty response for successful cancellation
+ job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
+
+ result = await adapter.cancel_training_job(job_uuid=job_id)
+
+ assert result is None
+
+ mock_make_request.assert_called_once()
+ _assert_request(
+ mock_make_request,
+ "POST",
+ f"/v1/customization/jobs/{job_id}/cancel",
+ expected_params={"job_id": job_id},
+ )
diff --git a/tests/unit/providers/vector_io/remote/test_milvus.py b/tests/unit/providers/vector_io/remote/test_milvus.py
index 2f212e374..145edf7fb 100644
--- a/tests/unit/providers/vector_io/remote/test_milvus.py
+++ b/tests/unit/providers/vector_io/remote/test_milvus.py
@@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch
import numpy as np
import pytest
-import pytest_asyncio
from llama_stack.apis.vector_io import QueryChunksResponse
@@ -33,7 +32,7 @@ with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
MILVUS_PROVIDER = "milvus"
-@pytest_asyncio.fixture
+@pytest.fixture
async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors."""
client = MagicMock()
@@ -84,7 +83,7 @@ async def mock_milvus_client() -> MagicMock:
return client
-@pytest_asyncio.fixture
+@pytest.fixture
async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
@@ -92,7 +91,6 @@ async def milvus_index(mock_milvus_client):
# No real cleanup needed since we're using mocks
-@pytest.mark.asyncio
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True]
@@ -108,7 +106,6 @@ async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_m
assert len(insert_call[1]["data"]) == len(sample_chunks)
-@pytest.mark.asyncio
async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
):
@@ -125,7 +122,6 @@ async def test_query_chunks_vector(
mock_milvus_client.search.assert_called_once()
-@pytest.mark.asyncio
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings)
@@ -138,7 +134,6 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
assert len(response.chunks) == 2
-@pytest.mark.asyncio
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True
@@ -181,7 +176,6 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
-@pytest.mark.asyncio
async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion
mock_milvus_client.has_collection.return_value = True
diff --git a/tests/unit/rag/test_rag_query.py b/tests/unit/rag/test_rag_query.py
index ad155c205..a9149541a 100644
--- a/tests/unit/rag/test_rag_query.py
+++ b/tests/unit/rag/test_rag_query.py
@@ -64,7 +64,6 @@ class TestRagQuery:
with pytest.raises(ValueError):
RAGQueryConfig(mode="invalid_mode")
- @pytest.mark.asyncio
async def test_query_accepts_valid_modes(self):
RAGQueryConfig() # Test default (vector)
RAGQueryConfig(mode="vector") # Test vector
diff --git a/uv.lock b/uv.lock
index 7a9c5cab0..2c5197988 100644
--- a/uv.lock
+++ b/uv.lock
@@ -1390,6 +1390,7 @@ unit = [
{ name = "aiosqlite" },
{ name = "blobfile" },
{ name = "chardet" },
+ { name = "coverage" },
{ name = "faiss-cpu" },
{ name = "litellm" },
{ name = "mcp" },
@@ -1499,6 +1500,7 @@ unit = [
{ name = "aiosqlite" },
{ name = "blobfile" },
{ name = "chardet" },
+ { name = "coverage" },
{ name = "faiss-cpu" },
{ name = "litellm" },
{ name = "mcp" },