Merge branch 'main' into chore/add-upstream-to-sl-config

This commit is contained in:
slekkala1 2025-07-18 09:52:32 -07:00 committed by GitHub
commit 36c8196d17
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
22 changed files with 583 additions and 521 deletions

View file

@ -4,3 +4,9 @@ omit =
*/llama_stack/providers/* */llama_stack/providers/*
*/llama_stack/templates/* */llama_stack/templates/*
.venv/* .venv/*
*/llama_stack/cli/scripts/*
*/llama_stack/ui/*
*/llama_stack/distribution/ui/*
*/llama_stack/strong_typing/*
*/llama_stack/env.py
*/__init__.py

57
.github/workflows/coverage-badge.yml vendored Normal file
View file

@ -0,0 +1,57 @@
name: Coverage Badge
on:
push:
branches: [ main ]
paths:
- 'llama_stack/**'
- 'tests/unit/**'
- 'uv.lock'
- 'pyproject.toml'
- 'requirements.txt'
- '.github/workflows/unit-tests.yml'
- '.github/workflows/coverage-badge.yml' # This workflow
workflow_dispatch:
jobs:
unit-tests:
runs-on: ubuntu-latest
steps:
- name: Checkout repository
uses: actions/checkout@11bd71901bbe5b1630ceea73d27597364c9af683 # v4.2.2
- name: Install dependencies
uses: ./.github/actions/setup-runner
- name: Run unit tests
run: |
./scripts/unit-tests.sh
- name: Coverage Badge
uses: tj-actions/coverage-badge-py@1788babcb24544eb5bbb6e0d374df5d1e54e670f # v2.0.4
- name: Verify Changed files
uses: tj-actions/verify-changed-files@a1c6acee9df209257a246f2cc6ae8cb6581c1edf # v20.0.4
id: verify-changed-files
with:
files: coverage.svg
- name: Commit files
if: steps.verify-changed-files.outputs.files_changed == 'true'
run: |
git config --local user.email "github-actions[bot]@users.noreply.github.com"
git config --local user.name "github-actions[bot]"
git add coverage.svg
git commit -m "Updated coverage.svg"
- name: Create Pull Request
if: steps.verify-changed-files.outputs.files_changed == 'true'
uses: peter-evans/create-pull-request@271a8d0340265f705b14b6d32b9829c1cb33d45e # v7.0.8
with:
token: ${{ secrets.GITHUB_TOKEN }}
title: "ci: [Automatic] Coverage Badge Update"
body: |
This PR updates the coverage badge based on the latest coverage report.
Automatically generated by the [workflow coverage-badge.yaml](.github/workflows/coverage-badge.yaml)
delete-branch: true

View file

@ -7,7 +7,7 @@ on:
branches: [ main ] branches: [ main ]
paths: paths:
- 'llama_stack/**' - 'llama_stack/**'
- 'tests/integration/**' - 'tests/**'
- 'uv.lock' - 'uv.lock'
- 'pyproject.toml' - 'pyproject.toml'
- 'requirements.txt' - 'requirements.txt'

View file

@ -36,7 +36,7 @@ jobs:
- name: Run unit tests - name: Run unit tests
run: | run: |
PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --cov=llama_stack --junitxml=pytest-report-${{ matrix.python }}.xml --cov-report=html:htmlcov-${{ matrix.python }} PYTHON_VERSION=${{ matrix.python }} ./scripts/unit-tests.sh --junitxml=pytest-report-${{ matrix.python }}.xml
- name: Upload test results - name: Upload test results
if: always() if: always()

View file

@ -6,6 +6,7 @@
[![Discord](https://img.shields.io/discord/1257833999603335178?color=6A7EC2&logo=discord&logoColor=ffffff)](https://discord.gg/llama-stack) [![Discord](https://img.shields.io/discord/1257833999603335178?color=6A7EC2&logo=discord&logoColor=ffffff)](https://discord.gg/llama-stack)
[![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain) [![Unit Tests](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/unit-tests.yml?query=branch%3Amain)
[![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain) [![Integration Tests](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml/badge.svg?branch=main)](https://github.com/meta-llama/llama-stack/actions/workflows/integration-tests.yml?query=branch%3Amain)
![coverage badge](./coverage.svg)
[**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack) [**Quick Start**](https://llama-stack.readthedocs.io/en/latest/getting_started/index.html) | [**Documentation**](https://llama-stack.readthedocs.io/en/latest/index.html) | [**Colab Notebook**](./docs/getting_started.ipynb) | [**Discord**](https://discord.gg/llama-stack)

21
coverage.svg Normal file
View file

@ -0,0 +1,21 @@
<?xml version="1.0" encoding="UTF-8"?>
<svg xmlns="http://www.w3.org/2000/svg" width="99" height="20">
<linearGradient id="b" x2="0" y2="100%">
<stop offset="0" stop-color="#bbb" stop-opacity=".1"/>
<stop offset="1" stop-opacity=".1"/>
</linearGradient>
<mask id="a">
<rect width="99" height="20" rx="3" fill="#fff"/>
</mask>
<g mask="url(#a)">
<path fill="#555" d="M0 0h63v20H0z"/>
<path fill="#fe7d37" d="M63 0h36v20H63z"/>
<path fill="url(#b)" d="M0 0h99v20H0z"/>
</g>
<g fill="#fff" text-anchor="middle" font-family="DejaVu Sans,Verdana,Geneva,sans-serif" font-size="11">
<text x="31.5" y="15" fill="#010101" fill-opacity=".3">coverage</text>
<text x="31.5" y="14">coverage</text>
<text x="80" y="15" fill="#010101" fill-opacity=".3">44%</text>
<text x="80" y="14">44%</text>
</g>
</svg>

After

Width:  |  Height:  |  Size: 904 B

View file

@ -167,7 +167,7 @@ When using the `:` pattern (like `${env.OLLAMA_INFERENCE_MODEL:__disabled__}`),
## Running the Distribution ## Running the Distribution
You can run the starter distribution via Docker or Conda. You can run the starter distribution via Docker, Conda, or venv.
### Via Docker ### Via Docker
@ -186,17 +186,12 @@ docker run \
--port $LLAMA_STACK_PORT --port $LLAMA_STACK_PORT
``` ```
### Via Conda ### Via Conda or venv
Make sure you have done `uv pip install llama-stack` and have the Llama Stack CLI available. Ensure you have configured the starter distribution using the environment variables explained above.
```bash ```bash
llama stack build --template starter --image-type conda uv run --with llama-stack llama stack build --template starter --image-type <conda|venv> --run
llama stack run distributions/starter/run.yaml \
--port 8321 \
--env OPENAI_API_KEY=your_openai_key \
--env FIREWORKS_API_KEY=your_fireworks_key \
--env TOGETHER_API_KEY=your_together_key
``` ```
## Example Usage ## Example Usage

View file

@ -19,7 +19,7 @@ ollama run llama3.2:3b --keepalive 60m
#### Step 2: Run the Llama Stack server #### Step 2: Run the Llama Stack server
We will use `uv` to run the Llama Stack server. We will use `uv` to run the Llama Stack server.
```bash ```bash
INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run ENABLE_OLLAMA=ollama OLLAMA_INFERENCE_MODEL=llama3.2:3b uv run --with llama-stack llama stack build --template starter --image-type venv --run
``` ```
#### Step 3: Run the demo #### Step 3: Run the demo
Now open up a new terminal and copy the following script into a file named `demo_script.py`. Now open up a new terminal and copy the following script into a file named `demo_script.py`.
@ -111,6 +111,12 @@ Ultimately, great work is about making a meaningful contribution and leaving a l
``` ```
Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳 Congratulations! You've successfully built your first RAG application using Llama Stack! 🎉🥳
```{admonition} HuggingFace access
:class: tip
If you are getting a **401 Client Error** from HuggingFace for the **all-MiniLM-L6-v2** model, try setting **HF_TOKEN** to a valid HuggingFace token in your environment
```
### Next Steps ### Next Steps
Now you're ready to dive deeper into Llama Stack! Now you're ready to dive deeper into Llama Stack!

View file

@ -8,6 +8,7 @@ import io
import json import json
import uuid import uuid
from dataclasses import dataclass from dataclasses import dataclass
from typing import Any
from PIL import Image as PIL_Image from PIL import Image as PIL_Image
@ -184,16 +185,26 @@ class ChatFormat:
content = content[: -len("<|eom_id|>")] content = content[: -len("<|eom_id|>")]
stop_reason = StopReason.end_of_message stop_reason = StopReason.end_of_message
tool_name = None tool_name: str | BuiltinTool | None = None
tool_arguments = {} tool_arguments: dict[str, Any] = {}
custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content) custom_tool_info = ToolUtils.maybe_extract_custom_tool_call(content)
if custom_tool_info is not None: if custom_tool_info is not None:
tool_name, tool_arguments = custom_tool_info # Type guard: ensure custom_tool_info is a tuple of correct types
if isinstance(custom_tool_info, tuple) and len(custom_tool_info) == 2:
extracted_tool_name, extracted_tool_arguments = custom_tool_info
# Handle both dict and str return types from the function
if isinstance(extracted_tool_arguments, dict):
tool_name, tool_arguments = extracted_tool_name, extracted_tool_arguments
else:
# If it's a string, treat it as a query parameter
tool_name, tool_arguments = extracted_tool_name, {"query": extracted_tool_arguments}
else:
tool_name, tool_arguments = None, {}
# Sometimes when agent has custom tools alongside builin tools # Sometimes when agent has custom tools alongside builin tools
# Agent responds for builtin tool calls in the format of the custom tools # Agent responds for builtin tool calls in the format of the custom tools
# This code tries to handle that case # This code tries to handle that case
if tool_name in BuiltinTool.__members__: if tool_name is not None and tool_name in BuiltinTool.__members__:
tool_name = BuiltinTool[tool_name] tool_name = BuiltinTool[tool_name]
if isinstance(tool_arguments, dict): if isinstance(tool_arguments, dict):
tool_arguments = { tool_arguments = {

View file

@ -98,7 +98,7 @@ class ProcessingMessageWrapper(BaseModel):
def mp_rank_0() -> bool: def mp_rank_0() -> bool:
return get_model_parallel_rank() == 0 return bool(get_model_parallel_rank() == 0)
def encode_msg(msg: ProcessingMessage) -> bytes: def encode_msg(msg: ProcessingMessage) -> bytes:
@ -125,7 +125,7 @@ def retrieve_requests(reply_socket_url: str):
reply_socket.send_multipart([client_id, encode_msg(obj)]) reply_socket.send_multipart([client_id, encode_msg(obj)])
while True: while True:
tasks = [None] tasks: list[ProcessingMessage | None] = [None]
if mp_rank_0(): if mp_rank_0():
client_id, maybe_task_json = maybe_get_work(reply_socket) client_id, maybe_task_json = maybe_get_work(reply_socket)
if maybe_task_json is not None: if maybe_task_json is not None:
@ -152,7 +152,7 @@ def retrieve_requests(reply_socket_url: str):
break break
for obj in out: for obj in out:
updates = [None] updates: list[ProcessingMessage | None] = [None]
if mp_rank_0(): if mp_rank_0():
_, update_json = maybe_get_work(reply_socket) _, update_json = maybe_get_work(reply_socket)
update = maybe_parse_message(update_json) update = maybe_parse_message(update_json)

View file

@ -91,6 +91,7 @@ unit = [
"pymilvus>=2.5.12", "pymilvus>=2.5.12",
"litellm", "litellm",
"together", "together",
"coverage",
] ]
# These are the core dependencies required for running integration tests. They are shared across all # These are the core dependencies required for running integration tests. They are shared across all
# providers. If a provider requires additional dependencies, please add them to your environment # providers. If a provider requires additional dependencies, please add them to your environment
@ -242,7 +243,6 @@ exclude = [
"^llama_stack/distribution/store/registry\\.py$", "^llama_stack/distribution/store/registry\\.py$",
"^llama_stack/distribution/utils/exec\\.py$", "^llama_stack/distribution/utils/exec\\.py$",
"^llama_stack/distribution/utils/prompt_for_config\\.py$", "^llama_stack/distribution/utils/prompt_for_config\\.py$",
"^llama_stack/models/llama/llama3/chat_format\\.py$",
"^llama_stack/models/llama/llama3/interface\\.py$", "^llama_stack/models/llama/llama3/interface\\.py$",
"^llama_stack/models/llama/llama3/tokenizer\\.py$", "^llama_stack/models/llama/llama3/tokenizer\\.py$",
"^llama_stack/models/llama/llama3/tool_utils\\.py$", "^llama_stack/models/llama/llama3/tool_utils\\.py$",
@ -255,7 +255,6 @@ exclude = [
"^llama_stack/models/llama/llama3/generation\\.py$", "^llama_stack/models/llama/llama3/generation\\.py$",
"^llama_stack/models/llama/llama3/multimodal/model\\.py$", "^llama_stack/models/llama/llama3/multimodal/model\\.py$",
"^llama_stack/models/llama/llama4/", "^llama_stack/models/llama/llama4/",
"^llama_stack/providers/inline/inference/meta_reference/parallel_utils\\.py$",
"^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$", "^llama_stack/providers/inline/inference/meta_reference/quantization/fp8_impls\\.py$",
"^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$", "^llama_stack/providers/inline/inference/sentence_transformers/sentence_transformers\\.py$",
"^llama_stack/providers/inline/inference/vllm/", "^llama_stack/providers/inline/inference/vllm/",

View file

@ -16,4 +16,9 @@ if [ $FOUND_PYTHON -ne 0 ]; then
uv python install "$PYTHON_VERSION" uv python install "$PYTHON_VERSION"
fi fi
uv run --python "$PYTHON_VERSION" --with-editable . --group unit pytest -s -v tests/unit/ $@ # Run unit tests with coverage
uv run --python "$PYTHON_VERSION" --with-editable . --group unit \
coverage run --source=llama_stack -m pytest -s -v tests/unit/ "$@"
# Generate HTML coverage report
uv run --python "$PYTHON_VERSION" coverage html -d htmlcov-$PYTHON_VERSION

View file

@ -123,14 +123,14 @@ class TestPostTraining:
logger.info(f"Job artifacts: {artifacts}") logger.info(f"Job artifacts: {artifacts}")
# TODO: Fix these tests to properly represent the Jobs API in training # TODO: Fix these tests to properly represent the Jobs API in training
# @pytest.mark.asyncio #
# async def test_get_training_jobs(self, post_training_stack): # async def test_get_training_jobs(self, post_training_stack):
# post_training_impl = post_training_stack # post_training_impl = post_training_stack
# jobs_list = await post_training_impl.get_training_jobs() # jobs_list = await post_training_impl.get_training_jobs()
# assert isinstance(jobs_list, list) # assert isinstance(jobs_list, list)
# assert jobs_list[0].job_uuid == "1234" # assert jobs_list[0].job_uuid == "1234"
# @pytest.mark.asyncio #
# async def test_get_training_job_status(self, post_training_stack): # async def test_get_training_job_status(self, post_training_stack):
# post_training_impl = post_training_stack # post_training_impl = post_training_stack
# job_status = await post_training_impl.get_training_job_status("1234") # job_status = await post_training_impl.get_training_job_status("1234")
@ -139,7 +139,7 @@ class TestPostTraining:
# assert job_status.status == JobStatus.completed # assert job_status.status == JobStatus.completed
# assert isinstance(job_status.checkpoints[0], Checkpoint) # assert isinstance(job_status.checkpoints[0], Checkpoint)
# @pytest.mark.asyncio #
# async def test_get_training_job_artifacts(self, post_training_stack): # async def test_get_training_job_artifacts(self, post_training_stack):
# post_training_impl = post_training_stack # post_training_impl = post_training_stack
# job_artifacts = await post_training_impl.get_training_job_artifacts("1234") # job_artifacts = await post_training_impl.get_training_job_artifacts("1234")

View file

@ -1,9 +1,17 @@
# Llama Stack Unit Tests # Llama Stack Unit Tests
## Unit Tests
Unit tests verify individual components and functions in isolation. They are fast, reliable, and don't require external services.
### Prerequisites
1. **Python Environment**: Ensure you have Python 3.12+ installed
2. **uv Package Manager**: Install `uv` if not already installed
You can run the unit tests by running: You can run the unit tests by running:
```bash ```bash
source .venv/bin/activate
./scripts/unit-tests.sh [PYTEST_ARGS] ./scripts/unit-tests.sh [PYTEST_ARGS]
``` ```
@ -19,3 +27,21 @@ If you'd like to run for a non-default version of Python (currently 3.12), pass
source .venv/bin/activate source .venv/bin/activate
PYTHON_VERSION=3.13 ./scripts/unit-tests.sh PYTHON_VERSION=3.13 ./scripts/unit-tests.sh
``` ```
### Test Configuration
- **Test Discovery**: Tests are automatically discovered in the `tests/unit/` directory
- **Async Support**: Tests use `--asyncio-mode=auto` for automatic async test handling
- **Coverage**: Tests generate coverage reports in `htmlcov/` directory
- **Python Version**: Defaults to Python 3.12, but can be overridden with `PYTHON_VERSION` environment variable
### Coverage Reports
After running tests, you can view coverage reports:
```bash
# Open HTML coverage report in browser
open htmlcov/index.html # macOS
xdg-open htmlcov/index.html # Linux
start htmlcov/index.html # Windows
```

View file

@ -4,7 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import pytest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -32,7 +31,6 @@ MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct"
@pytest.mark.asyncio
async def test_system_default(): async def test_system_default():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -47,7 +45,6 @@ async def test_system_default():
assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content) assert "Cutting Knowledge Date: December 2023" in interleaved_content_as_str(messages[0].content)
@pytest.mark.asyncio
async def test_system_builtin_only(): async def test_system_builtin_only():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -67,7 +64,6 @@ async def test_system_builtin_only():
assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content) assert "Tools: brave_search" in interleaved_content_as_str(messages[0].content)
@pytest.mark.asyncio
async def test_system_custom_only(): async def test_system_custom_only():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -98,7 +94,6 @@ async def test_system_custom_only():
assert messages[-1].content == content assert messages[-1].content == content
@pytest.mark.asyncio
async def test_system_custom_and_builtin(): async def test_system_custom_and_builtin():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -132,7 +127,6 @@ async def test_system_custom_and_builtin():
assert messages[-1].content == content assert messages[-1].content == content
@pytest.mark.asyncio
async def test_completion_message_encoding(): async def test_completion_message_encoding():
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL3_2, model=MODEL3_2,
@ -174,7 +168,6 @@ async def test_completion_message_encoding():
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
@pytest.mark.asyncio
async def test_user_provided_system_message(): async def test_user_provided_system_message():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"
@ -195,7 +188,6 @@ async def test_user_provided_system_message():
assert messages[-1].content == content assert messages[-1].content == content
@pytest.mark.asyncio
async def test_replace_system_message_behavior_builtin_tools(): async def test_replace_system_message_behavior_builtin_tools():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"
@ -221,7 +213,6 @@ async def test_replace_system_message_behavior_builtin_tools():
assert messages[-1].content == content assert messages[-1].content == content
@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools(): async def test_replace_system_message_behavior_custom_tools():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"
@ -259,7 +250,6 @@ async def test_replace_system_message_behavior_custom_tools():
assert messages[-1].content == content assert messages[-1].content == content
@pytest.mark.asyncio
async def test_replace_system_message_behavior_custom_tools_with_template(): async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}" system_prompt = "You are a pirate {{ function_description }}"

View file

@ -12,7 +12,6 @@
# the top-level of this source tree. # the top-level of this source tree.
import textwrap import textwrap
import unittest
from datetime import datetime from datetime import datetime
from llama_stack.models.llama.llama3.prompt_templates import ( from llama_stack.models.llama.llama3.prompt_templates import (
@ -24,59 +23,61 @@ from llama_stack.models.llama.llama3.prompt_templates import (
) )
class PromptTemplateTests(unittest.TestCase): def check_generator_output(generator):
def check_generator_output(self, generator): for example in generator.data_examples():
for example in generator.data_examples(): pt = generator.gen(example)
pt = generator.gen(example)
text = pt.render()
# print(text) # debugging
if not example:
continue
for tool in example:
assert tool.tool_name in text
def test_system_default(self):
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_builtin_only(self):
generator = BuiltinToolGenerator()
expected_text = textwrap.dedent(
"""
Environment: ipython
Tools: brave_search, wolfram_alpha
"""
)
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_custom_only(self):
self.maxDiff = None
generator = JsonCustomToolGenerator()
self.check_generator_output(generator)
def test_system_custom_function_tag(self):
self.maxDiff = None
generator = FunctionTagCustomToolGenerator()
self.check_generator_output(generator)
def test_llama_3_2_system_zero_shot(self):
generator = PythonListCustomToolGenerator()
self.check_generator_output(generator)
def test_llama_3_2_provided_system_prompt(self):
generator = PythonListCustomToolGenerator()
user_system_prompt = textwrap.dedent(
"""
Overriding message.
{{ function_description }}
"""
)
example = generator.data_examples()[0]
pt = generator.gen(example, user_system_prompt)
text = pt.render() text = pt.render()
assert "Overriding message." in text if not example:
assert '"name": "get_weather"' in text continue
for tool in example:
assert tool.tool_name in text
def test_system_default():
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_builtin_only():
generator = BuiltinToolGenerator()
expected_text = textwrap.dedent(
"""
Environment: ipython
Tools: brave_search, wolfram_alpha
"""
)
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_custom_only():
generator = JsonCustomToolGenerator()
check_generator_output(generator)
def test_system_custom_function_tag():
generator = FunctionTagCustomToolGenerator()
check_generator_output(generator)
def test_llama_3_2_system_zero_shot():
generator = PythonListCustomToolGenerator()
check_generator_output(generator)
def test_llama_3_2_provided_system_prompt():
generator = PythonListCustomToolGenerator()
user_system_prompt = textwrap.dedent(
"""
Overriding message.
{{ function_description }}
"""
)
example = generator.data_examples()[0]
pt = generator.gen(example, user_system_prompt)
text = pt.render()
assert "Overriding message." in text
assert '"name": "get_weather"' in text

View file

@ -5,103 +5,110 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource from llama_stack.apis.datasets import Dataset, DatasetPurpose, URIDataSource
from llama_stack.apis.resource import ResourceType
from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIOConfig
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
class TestNvidiaDatastore(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_adapter():
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" """Fixture to set up NvidiaDatasetIOAdapter with mocked requests."""
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
config = NvidiaDatasetIOConfig( config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default" datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
) )
self.adapter = NvidiaDatasetIOAdapter(config) adapter = NvidiaDatasetIOAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
def tearDown(self): with patch(
self.make_request_patcher.stop() "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
) as mock_make_request:
yield adapter, mock_make_request
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None): def _assert_request(mock_call, expected_method, expected_path, expected_json=None):
"""Helper method to verify request details in mock calls.""" """Helper function to verify request details in mock calls."""
call_args = mock_call.call_args call_args = mock_call.call_args
assert call_args[0][0] == expected_method assert call_args[0][0] == expected_method
assert call_args[0][1] == expected_path assert call_args[0][1] == expected_path
if expected_json: if expected_json:
for key, value in expected_json.items(): for key, value in expected_json.items():
assert call_args[1]["json"][key] == value assert call_args[1]["json"][key] == value
def test_register_dataset(self):
self.mock_make_request.return_value = { def test_register_dataset(nvidia_adapter, run_async):
"id": "dataset-123456", adapter, mock_make_request = nvidia_adapter
mock_make_request.return_value = {
"id": "dataset-123456",
"name": "test-dataset",
"namespace": "default",
}
dataset_def = Dataset(
identifier="test-dataset",
type=ResourceType.dataset,
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
source=URIDataSource(uri="https://example.com/data.jsonl"),
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
)
run_async(adapter.register_dataset(dataset_def))
mock_make_request.assert_called_once()
_assert_request(
mock_make_request,
"POST",
"/v1/datasets",
expected_json={
"name": "test-dataset", "name": "test-dataset",
"namespace": "default", "namespace": "default",
} "files_url": "https://example.com/data.jsonl",
"project": "default",
"format": "jsonl",
"description": "Test dataset description",
},
)
dataset_def = Dataset(
identifier="test-dataset",
type="dataset",
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
source=URIDataSource(uri="https://example.com/data.jsonl"),
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
)
self.run_async(self.adapter.register_dataset(dataset_def)) def test_unregister_dataset(nvidia_adapter, run_async):
adapter, mock_make_request = nvidia_adapter
mock_make_request.return_value = {
"message": "Resource deleted successfully.",
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
"deleted_at": None,
}
dataset_id = "test-dataset"
self.mock_make_request.assert_called_once() run_async(adapter.unregister_dataset(dataset_id))
self._assert_request(
self.mock_make_request,
"POST",
"/v1/datasets",
expected_json={
"name": "test-dataset",
"namespace": "default",
"files_url": "https://example.com/data.jsonl",
"project": "default",
"format": "jsonl",
"description": "Test dataset description",
},
)
def test_unregister_dataset(self): mock_make_request.assert_called_once()
self.mock_make_request.return_value = { _assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
"message": "Resource deleted successfully.",
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
"deleted_at": None,
}
dataset_id = "test-dataset"
self.run_async(self.adapter.unregister_dataset(dataset_id))
self.mock_make_request.assert_called_once() def test_register_dataset_with_custom_namespace_project(run_async):
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset") """Test with custom namespace and project configuration."""
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
def test_register_dataset_with_custom_namespace_project(self): custom_config = NvidiaDatasetIOConfig(
custom_config = NvidiaDatasetIOConfig( datasets_url=os.environ["NVIDIA_DATASETS_URL"],
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="custom-namespace",
dataset_namespace="custom-namespace", project_id="custom-project",
project_id="custom-project", )
) custom_adapter = NvidiaDatasetIOAdapter(custom_config)
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
self.mock_make_request.return_value = { with patch(
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
) as mock_make_request:
mock_make_request.return_value = {
"id": "dataset-123456", "id": "dataset-123456",
"name": "test-dataset", "name": "test-dataset",
"namespace": "custom-namespace", "namespace": "custom-namespace",
@ -109,7 +116,7 @@ class TestNvidiaDatastore(unittest.TestCase):
dataset_def = Dataset( dataset_def = Dataset(
identifier="test-dataset", identifier="test-dataset",
type="dataset", type=ResourceType.dataset,
provider_resource_id="", provider_resource_id="",
provider_id="", provider_id="",
purpose=DatasetPurpose.post_training_messages, purpose=DatasetPurpose.post_training_messages,
@ -117,11 +124,11 @@ class TestNvidiaDatastore(unittest.TestCase):
metadata={"format": "jsonl"}, metadata={"format": "jsonl"},
) )
self.run_async(custom_adapter.register_dataset(dataset_def)) run_async(custom_adapter.register_dataset(dataset_def))
self.mock_make_request.assert_called_once() mock_make_request.assert_called_once()
self._assert_request( _assert_request(
self.mock_make_request, mock_make_request,
"POST", "POST",
"/v1/datasets", "/v1/datasets",
expected_json={ expected_json={
@ -132,7 +139,3 @@ class TestNvidiaDatastore(unittest.TestCase):
"format": "jsonl", "format": "jsonl",
}, },
) )
if __name__ == "__main__":
unittest.main()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
import warnings import warnings
from unittest.mock import patch from unittest.mock import patch
@ -27,14 +26,13 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
) )
class TestNvidiaParameters(unittest.TestCase): class TestNvidiaParameters:
def setUp(self): @pytest.fixture(autouse=True)
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" def setup_and_teardown(self):
"""Setup and teardown for each test method."""
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig( config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config) self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch( self.make_request_patcher = patch(
@ -48,7 +46,8 @@ class TestNvidiaParameters(unittest.TestCase):
"updated_at": "2025-03-04T13:07:47.543605", "updated_at": "2025-03-04T13:07:47.543605",
} }
def tearDown(self): yield
self.make_request_patcher.stop() self.make_request_patcher.stop()
def _assert_request_params(self, expected_json): def _assert_request_params(self, expected_json):
@ -166,8 +165,8 @@ class TestNvidiaParameters(unittest.TestCase):
self.run_async( self.run_async(
self.adapter.supervised_fine_tune( self.adapter.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter job_uuid=required_job_uuid,
model=required_model, # Required parameter model=required_model,
checkpoint_dir="", checkpoint_dir="",
algorithm_config=algorithm_config, algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config), training_config=convert_pydantic_to_json_value(training_config),
@ -198,7 +197,6 @@ class TestNvidiaParameters(unittest.TestCase):
data_config = DataConfig( data_config = DataConfig(
dataset_id="test-dataset", dataset_id="test-dataset",
batch_size=8, batch_size=8,
# Unsupported parameters
shuffle=True, shuffle=True,
data_format=DatasetFormat.instruct, data_format=DatasetFormat.instruct,
validation_dataset_id="val-dataset", validation_dataset_id="val-dataset",
@ -207,20 +205,16 @@ class TestNvidiaParameters(unittest.TestCase):
optimizer_config = OptimizerConfig( optimizer_config = OptimizerConfig(
lr=0.0001, lr=0.0001,
weight_decay=0.01, weight_decay=0.01,
# Unsupported parameters
optimizer_type=OptimizerType.adam, optimizer_type=OptimizerType.adam,
num_warmup_steps=100, num_warmup_steps=100,
) )
efficiency_config = EfficiencyConfig( efficiency_config = EfficiencyConfig(enable_activation_checkpointing=True)
enable_activation_checkpointing=True # Unsupported parameter
)
training_config = TrainingConfig( training_config = TrainingConfig(
n_epochs=1, n_epochs=1,
data_config=data_config, data_config=data_config,
optimizer_config=optimizer_config, optimizer_config=optimizer_config,
# Unsupported parameters
efficiency_config=efficiency_config, efficiency_config=efficiency_config,
max_steps_per_epoch=1000, max_steps_per_epoch=1000,
gradient_accumulation_steps=4, gradient_accumulation_steps=4,
@ -228,7 +222,6 @@ class TestNvidiaParameters(unittest.TestCase):
dtype="bf16", dtype="bf16",
) )
# Capture warnings
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
@ -236,7 +229,7 @@ class TestNvidiaParameters(unittest.TestCase):
self.adapter.supervised_fine_tune( self.adapter.supervised_fine_tune(
job_uuid="test-job", job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct", model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter checkpoint_dir="test-dir",
algorithm_config=LoraFinetuningConfig( algorithm_config=LoraFinetuningConfig(
type="LoRA", type="LoRA",
apply_lora_to_mlp=True, apply_lora_to_mlp=True,
@ -246,8 +239,8 @@ class TestNvidiaParameters(unittest.TestCase):
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
), ),
training_config=convert_pydantic_to_json_value(training_config), training_config=convert_pydantic_to_json_value(training_config),
logger_config={"test": "value"}, # Unsupported parameter logger_config={"test": "value"},
hyperparam_search_config={"test": "value"}, # Unsupported parameter hyperparam_search_config={"test": "value"},
) )
) )
@ -265,7 +258,6 @@ class TestNvidiaParameters(unittest.TestCase):
"gradient_accumulation_steps", "gradient_accumulation_steps",
"max_validation_steps", "max_validation_steps",
"dtype", "dtype",
# required unsupported parameters
"rank", "rank",
"apply_lora_to_output", "apply_lora_to_output",
"lora_attn_modules", "lora_attn_modules",
@ -273,7 +265,3 @@ class TestNvidiaParameters(unittest.TestCase):
] ]
for field in fields: for field in fields:
assert any(field in text for text in warning_texts) assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()

View file

@ -5,13 +5,11 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
import warnings import warnings
from unittest.mock import AsyncMock, patch from unittest.mock import patch
import pytest import pytest
from llama_stack.apis.models import Model, ModelType
from llama_stack.apis.post_training.post_training import ( from llama_stack.apis.post_training.post_training import (
DataConfig, DataConfig,
DatasetFormat, DatasetFormat,
@ -22,7 +20,6 @@ from llama_stack.apis.post_training.post_training import (
TrainingConfig, TrainingConfig,
) )
from llama_stack.distribution.library_client import convert_pydantic_to_json_value from llama_stack.distribution.library_client import convert_pydantic_to_json_value
from llama_stack.providers.remote.inference.nvidia.nvidia import NVIDIAConfig, NVIDIAInferenceAdapter
from llama_stack.providers.remote.post_training.nvidia.post_training import ( from llama_stack.providers.remote.post_training.nvidia.post_training import (
ListNvidiaPostTrainingJobs, ListNvidiaPostTrainingJobs,
NvidiaPostTrainingAdapter, NvidiaPostTrainingAdapter,
@ -32,336 +29,297 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
) )
class TestNvidiaPostTraining(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_post_training_adapter():
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference """Fixture to create and configure the NVIDIA post training adapter."""
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig( config = NvidiaPostTrainingConfig(customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None)
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None adapter = NvidiaPostTrainingAdapter(config)
with patch.object(adapter, "_make_request") as mock_make_request:
yield adapter, mock_make_request
def _assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
async def test_supervised_fine_tune(nvidia_post_training_adapter):
"""Test the supervised fine-tuning API call."""
adapter, mock_make_request = nvidia_post_training_adapter
mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"alpha": 16},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
algorithm_config = LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = await adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
) )
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
# Mock the inference client # check the output is a PostTrainingJob
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None) assert isinstance(training_job, NvidiaPostTrainingJob)
self.inference_adapter = NVIDIAInferenceAdapter(inference_config) assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_client = unittest.mock.MagicMock() mock_make_request.assert_called_once()
self.mock_client.chat.completions.create = unittest.mock.AsyncMock() _assert_request(
self.inference_mock_make_request = self.mock_client.chat.completions.create mock_make_request,
self.inference_make_request_patcher = patch( "POST",
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._client", "/v1/customization/jobs",
new_callable=unittest.mock.PropertyMock, expected_json={
return_value=self.mock_client, "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
) "dataset": {"name": "sample-basic-test", "namespace": "default"},
self.inference_make_request_patcher.start()
def tearDown(self):
self.make_request_patcher.stop()
self.inference_make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call."""
self.mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": { "hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft", "training_type": "sft",
"batch_size": 16, "finetuning_type": "lora",
"epochs": 2, "epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001, "learning_rate": 0.0001,
"weight_decay": 0.01,
"lora": {"alpha": 16}, "lora": {"alpha": 16},
}, },
"output_model": "default/job-1234", },
"status": "created", )
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}}, async def test_supervised_fine_tune_with_qat(nvidia_post_training_adapter):
"""Test that QAT configuration raises NotImplementedError."""
adapter, mock_make_request = nvidia_post_training_adapter
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with pytest.raises(NotImplementedError):
await adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
async def test_get_training_job_status(nvidia_post_training_adapter):
"""Test getting training job status with different statuses."""
adapter, mock_make_request = nvidia_post_training_adapter
customizer_status_to_job_status = [
("running", "in_progress"),
("completed", "completed"),
("failed", "failed"),
("cancelled", "cancelled"),
("pending", "scheduled"),
("unknown", "scheduled"),
]
for customizer_status, expected_status in customizer_status_to_job_status:
mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": customizer_status,
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
} }
algorithm_config = LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
"/v1/customization/jobs",
expected_json={
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
"dataset": {"name": "sample-basic-test", "namespace": "default"},
"hyperparameters": {
"training_type": "sft",
"finetuning_type": "lora",
"epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001,
"weight_decay": 0.01,
"lora": {"alpha": 16},
},
},
)
def test_supervised_fine_tune_with_qat(self):
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError):
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
)
def test_get_training_job_status(self):
customizer_status_to_job_status = [
("running", "in_progress"),
("completed", "completed"),
("failed", "failed"),
("cancelled", "cancelled"),
("pending", "scheduled"),
("unknown", "scheduled"),
]
for customizer_status, expected_status in customizer_status_to_job_status:
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": customizer_status,
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == expected_status
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self._assert_request(
self.mock_make_request,
"GET",
f"/v1/customization/jobs/{job_id}/status",
expected_params={"job_id": job_id},
)
def test_get_training_jobs(self):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = self.run_async(self.adapter.get_training_jobs()) status = await adapter.get_training_job_status(job_uuid=job_id)
assert isinstance(jobs, ListNvidiaPostTrainingJobs) assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert len(jobs.data) == 1 assert status.status.value == expected_status
job = jobs.data[0] # Note: The response object inherits extra fields via ConfigDict(extra="allow")
assert job.job_uuid == job_id # So these attributes should be accessible using getattr with defaults
assert job.status.value == "completed" assert getattr(status, "steps_completed", None) == 1210
assert getattr(status, "epochs_completed", None) == 2
assert getattr(status, "percentage_done", None) == 100.0
assert getattr(status, "best_epoch", None) == 2
assert getattr(status, "train_loss", None) == 1.718016266822815
assert getattr(status, "val_loss", None) == 1.8661999702453613
self.mock_make_request.assert_called_once() _assert_request(
self._assert_request( mock_make_request,
self.mock_make_request,
"GET", "GET",
"/v1/customization/jobs", f"/v1/customization/jobs/{job_id}/status",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id}, expected_params={"job_id": job_id},
) )
def test_inference_register_model(self): mock_make_request.reset_mock()
model_id = "default/job-1234"
model_type = ModelType.llm
model = Model(
identifier=model_id,
provider_id="nvidia",
provider_model_id=model_id,
provider_resource_id=model_id,
model_type=model_type,
)
# simulate a NIM where default/job-1234 is an available model
with patch.object(self.inference_adapter, "check_model_availability", new_callable=AsyncMock) as mock_check:
mock_check.return_value = True
result = self.run_async(self.inference_adapter.register_model(model))
assert result == model
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
self.run_async(
self.inference_adapter.chat_completion(
model_id=model_id,
messages=[{"role": "user", "content": "Hello, model"}],
)
)
mock_chat_completion.assert_called()
if __name__ == "__main__": async def test_get_training_jobs(nvidia_post_training_adapter):
unittest.main() """Test getting list of training jobs."""
adapter, mock_make_request = nvidia_post_training_adapter
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = await adapter.get_training_jobs()
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.status.value == "completed"
mock_make_request.assert_called_once()
_assert_request(
mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
async def test_cancel_training_job(nvidia_post_training_adapter):
"""Test canceling a training job."""
adapter, mock_make_request = nvidia_post_training_adapter
mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = await adapter.cancel_training_job(job_uuid=job_id)
assert result is None
mock_make_request.assert_called_once()
_assert_request(
mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id},
)

View file

@ -8,7 +8,6 @@ from unittest.mock import MagicMock, patch
import numpy as np import numpy as np
import pytest import pytest
import pytest_asyncio
from llama_stack.apis.vector_io import QueryChunksResponse from llama_stack.apis.vector_io import QueryChunksResponse
@ -33,7 +32,7 @@ with patch.dict("sys.modules", {"pymilvus": pymilvus_mock}):
MILVUS_PROVIDER = "milvus" MILVUS_PROVIDER = "milvus"
@pytest_asyncio.fixture @pytest.fixture
async def mock_milvus_client() -> MagicMock: async def mock_milvus_client() -> MagicMock:
"""Create a mock Milvus client with common method behaviors.""" """Create a mock Milvus client with common method behaviors."""
client = MagicMock() client = MagicMock()
@ -84,7 +83,7 @@ async def mock_milvus_client() -> MagicMock:
return client return client
@pytest_asyncio.fixture @pytest.fixture
async def milvus_index(mock_milvus_client): async def milvus_index(mock_milvus_client):
"""Create a MilvusIndex with mocked client.""" """Create a MilvusIndex with mocked client."""
index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection") index = MilvusIndex(client=mock_milvus_client, collection_name="test_collection")
@ -92,7 +91,6 @@ async def milvus_index(mock_milvus_client):
# No real cleanup needed since we're using mocks # No real cleanup needed since we're using mocks
@pytest.mark.asyncio
async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
# Setup: collection doesn't exist initially, then exists after creation # Setup: collection doesn't exist initially, then exists after creation
mock_milvus_client.has_collection.side_effect = [False, True] mock_milvus_client.has_collection.side_effect = [False, True]
@ -108,7 +106,6 @@ async def test_add_chunks(milvus_index, sample_chunks, sample_embeddings, mock_m
assert len(insert_call[1]["data"]) == len(sample_chunks) assert len(insert_call[1]["data"]) == len(sample_chunks)
@pytest.mark.asyncio
async def test_query_chunks_vector( async def test_query_chunks_vector(
milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client milvus_index, sample_chunks, sample_embeddings, embedding_dimension, mock_milvus_client
): ):
@ -125,7 +122,6 @@ async def test_query_chunks_vector(
mock_milvus_client.search.assert_called_once() mock_milvus_client.search.assert_called_once()
@pytest.mark.asyncio
async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
mock_milvus_client.has_collection.return_value = True mock_milvus_client.has_collection.return_value = True
await milvus_index.add_chunks(sample_chunks, sample_embeddings) await milvus_index.add_chunks(sample_chunks, sample_embeddings)
@ -138,7 +134,6 @@ async def test_query_chunks_keyword_search(milvus_index, sample_chunks, sample_e
assert len(response.chunks) == 2 assert len(response.chunks) == 2
@pytest.mark.asyncio
async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client): async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sample_embeddings, mock_milvus_client):
"""Test that when BM25 search fails, the system falls back to simple text search.""" """Test that when BM25 search fails, the system falls back to simple text search."""
mock_milvus_client.has_collection.return_value = True mock_milvus_client.has_collection.return_value = True
@ -181,7 +176,6 @@ async def test_bm25_fallback_to_simple_search(milvus_index, sample_chunks, sampl
assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring" assert all(score == 1.0 for score in response.scores), "Simple text search should use binary scoring"
@pytest.mark.asyncio
async def test_delete_collection(milvus_index, mock_milvus_client): async def test_delete_collection(milvus_index, mock_milvus_client):
# Test collection deletion # Test collection deletion
mock_milvus_client.has_collection.return_value = True mock_milvus_client.has_collection.return_value = True

View file

@ -64,7 +64,6 @@ class TestRagQuery:
with pytest.raises(ValueError): with pytest.raises(ValueError):
RAGQueryConfig(mode="invalid_mode") RAGQueryConfig(mode="invalid_mode")
@pytest.mark.asyncio
async def test_query_accepts_valid_modes(self): async def test_query_accepts_valid_modes(self):
RAGQueryConfig() # Test default (vector) RAGQueryConfig() # Test default (vector)
RAGQueryConfig(mode="vector") # Test vector RAGQueryConfig(mode="vector") # Test vector

2
uv.lock generated
View file

@ -1390,6 +1390,7 @@ unit = [
{ name = "aiosqlite" }, { name = "aiosqlite" },
{ name = "blobfile" }, { name = "blobfile" },
{ name = "chardet" }, { name = "chardet" },
{ name = "coverage" },
{ name = "faiss-cpu" }, { name = "faiss-cpu" },
{ name = "litellm" }, { name = "litellm" },
{ name = "mcp" }, { name = "mcp" },
@ -1499,6 +1500,7 @@ unit = [
{ name = "aiosqlite" }, { name = "aiosqlite" },
{ name = "blobfile" }, { name = "blobfile" },
{ name = "chardet" }, { name = "chardet" },
{ name = "coverage" },
{ name = "faiss-cpu" }, { name = "faiss-cpu" },
{ name = "litellm" }, { name = "litellm" },
{ name = "mcp" }, { name = "mcp" },