test: migrate unit tests from unittest to pytest

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 12:55:16 +02:00
parent ff9d4d8a9d
commit 9331253894
8 changed files with 1440 additions and 1405 deletions

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import unittest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -31,11 +29,7 @@ MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): async def test_system_default():
async def asyncSetUp(self):
asyncio.get_running_loop().set_debug(False)
async def test_system_default(self):
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL,
@ -44,11 +38,12 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2) assert len(messages) == 2
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) assert "Cutting Knowledge Date: December 2023" in messages[0].content
async def test_system_builtin_only(self):
async def test_system_builtin_only():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL,
@ -61,12 +56,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2) assert len(messages) == 2
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) assert "Cutting Knowledge Date: December 2023" in messages[0].content
self.assertTrue("Tools: brave_search" in messages[0].content) assert "Tools: brave_search" in messages[0].content
async def test_system_custom_only(self):
async def test_system_custom_only():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL,
@ -89,13 +85,14 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3) assert len(messages) == 3
self.assertTrue("Environment: ipython" in messages[0].content) assert "Environment: ipython" in messages[0].content
self.assertTrue("Return function calls in JSON format" in messages[1].content) assert "Return function calls in JSON format" in messages[1].content
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_system_custom_and_builtin(self):
async def test_system_custom_and_builtin():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL,
@ -119,15 +116,16 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3) assert len(messages) == 3
self.assertTrue("Environment: ipython" in messages[0].content) assert "Environment: ipython" in messages[0].content
self.assertTrue("Tools: brave_search" in messages[0].content) assert "Tools: brave_search" in messages[0].content
self.assertTrue("Return function calls in JSON format" in messages[1].content) assert "Return function calls in JSON format" in messages[1].content
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_completion_message_encoding(self):
async def test_completion_message_encoding():
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL3_2, model=MODEL3_2,
messages=[ messages=[
@ -160,17 +158,15 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
) )
prompt = await chat_completion_request_to_prompt(request, request.model) prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt) assert '[custom1(param1="value1")]' in prompt
request.model = MODEL request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model) prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn( assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self):
async def test_user_provided_system_message():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -184,12 +180,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
], ],
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages) assert len(messages) == 2
self.assertTrue(messages[0].content.endswith(system_prompt)) assert messages[0].content.endswith(system_prompt)
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_repalce_system_message_behavior_builtin_tools(self):
async def test_repalce_system_message_behavior_builtin_tools():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -208,12 +205,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
), ),
) )
messages = chat_completion_request_to_messages(request, MODEL3_2) messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages) assert len(messages) == 2
self.assertTrue(messages[0].content.endswith(system_prompt)) assert messages[0].content.endswith(system_prompt)
self.assertIn("Environment: ipython", messages[0].content) assert "Environment: ipython" in messages[0].content
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_repalce_system_message_behavior_custom_tools(self):
async def test_repalce_system_message_behavior_custom_tools():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate" system_prompt = "You are a pirate"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -244,12 +242,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
) )
messages = chat_completion_request_to_messages(request, MODEL3_2) messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages) assert len(messages) == 2
self.assertTrue(messages[0].content.endswith(system_prompt)) assert messages[0].content.endswith(system_prompt)
self.assertIn("Environment: ipython", messages[0].content) assert "Environment: ipython" in messages[0].content
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template(self):
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !" content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}" system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest( request = ChatCompletionRequest(
@ -280,9 +279,9 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
) )
messages = chat_completion_request_to_messages(request, MODEL3_2) messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages) assert len(messages) == 2
self.assertIn("Environment: ipython", messages[0].content) assert "Environment: ipython" in messages[0].content
self.assertIn("You are a pirate", messages[0].content) assert "You are a pirate" in messages[0].content
# function description is present in the system prompt # function description is present in the system prompt
self.assertIn('"name": "custom1"', messages[0].content) assert '"name": "custom1"' in messages[0].content
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content

View file

@ -12,7 +12,6 @@
# the top-level of this source tree. # the top-level of this source tree.
import textwrap import textwrap
import unittest
from datetime import datetime from datetime import datetime
from llama_stack.models.llama.llama3.prompt_templates import ( from llama_stack.models.llama.llama3.prompt_templates import (
@ -24,24 +23,24 @@ from llama_stack.models.llama.llama3.prompt_templates import (
) )
class PromptTemplateTests(unittest.TestCase): def check_generator_output(generator):
def check_generator_output(self, generator):
for example in generator.data_examples(): for example in generator.data_examples():
pt = generator.gen(example) pt = generator.gen(example)
text = pt.render() text = pt.render()
# print(text) # debugging
if not example: if not example:
continue continue
for tool in example: for tool in example:
assert tool.tool_name in text assert tool.tool_name in text
def test_system_default(self):
def test_system_default():
generator = SystemDefaultGenerator() generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y") today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_builtin_only(self):
def test_system_builtin_only():
generator = BuiltinToolGenerator() generator = BuiltinToolGenerator()
expected_text = textwrap.dedent( expected_text = textwrap.dedent(
""" """
@ -51,21 +50,23 @@ class PromptTemplateTests(unittest.TestCase):
) )
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_custom_only(self):
self.maxDiff = None def test_system_custom_only():
generator = JsonCustomToolGenerator() generator = JsonCustomToolGenerator()
self.check_generator_output(generator) check_generator_output(generator)
def test_system_custom_function_tag(self):
self.maxDiff = None def test_system_custom_function_tag():
generator = FunctionTagCustomToolGenerator() generator = FunctionTagCustomToolGenerator()
self.check_generator_output(generator) check_generator_output(generator)
def test_llama_3_2_system_zero_shot(self):
def test_llama_3_2_system_zero_shot():
generator = PythonListCustomToolGenerator() generator = PythonListCustomToolGenerator()
self.check_generator_output(generator) check_generator_output(generator)
def test_llama_3_2_provided_system_prompt(self):
def test_llama_3_2_provided_system_prompt():
generator = PythonListCustomToolGenerator() generator = PythonListCustomToolGenerator()
user_system_prompt = textwrap.dedent( user_system_prompt = textwrap.dedent(
""" """

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -15,28 +14,24 @@ from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIO
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
class TestNvidiaDatastore(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_dataset_adapter():
"""Set up the NVIDIA dataset adapter with mocked dependencies"""
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
config = NvidiaDatasetIOConfig( config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default" datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
) )
self.adapter = NvidiaDatasetIOAdapter(config) adapter = NvidiaDatasetIOAdapter(config)
self.make_request_patcher = patch(
with patch(
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request" "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
) ) as mock_make_request:
self.mock_make_request = self.make_request_patcher.start() yield adapter, mock_make_request
def tearDown(self):
self.make_request_patcher.stop()
@pytest.fixture(autouse=True) def assert_request(mock_call, expected_method, expected_path, expected_json=None):
def inject_fixtures(self, run_async): """Helper function to verify request details in mock calls."""
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args call_args = mock_call.call_args
assert call_args[0][0] == expected_method assert call_args[0][0] == expected_method
@ -46,8 +41,11 @@ class TestNvidiaDatastore(unittest.TestCase):
for key, value in expected_json.items(): for key, value in expected_json.items():
assert call_args[1]["json"][key] == value assert call_args[1]["json"][key] == value
def test_register_dataset(self):
self.mock_make_request.return_value = { async def test_register_dataset(nvidia_dataset_adapter):
adapter, mock_make_request = nvidia_dataset_adapter
mock_make_request.return_value = {
"id": "dataset-123456", "id": "dataset-123456",
"name": "test-dataset", "name": "test-dataset",
"namespace": "default", "namespace": "default",
@ -63,11 +61,11 @@ class TestNvidiaDatastore(unittest.TestCase):
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"}, metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
) )
self.run_async(self.adapter.register_dataset(dataset_def)) await adapter.register_dataset(dataset_def)
self.mock_make_request.assert_called_once() mock_make_request.assert_called_once()
self._assert_request( assert_request(
self.mock_make_request, mock_make_request,
"POST", "POST",
"/v1/datasets", "/v1/datasets",
expected_json={ expected_json={
@ -80,20 +78,26 @@ class TestNvidiaDatastore(unittest.TestCase):
}, },
) )
def test_unregister_dataset(self):
self.mock_make_request.return_value = { async def test_unregister_dataset(nvidia_dataset_adapter):
adapter, mock_make_request = nvidia_dataset_adapter
mock_make_request.return_value = {
"message": "Resource deleted successfully.", "message": "Resource deleted successfully.",
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn", "id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
"deleted_at": None, "deleted_at": None,
} }
dataset_id = "test-dataset" dataset_id = "test-dataset"
self.run_async(self.adapter.unregister_dataset(dataset_id)) await adapter.unregister_dataset(dataset_id)
self.mock_make_request.assert_called_once() mock_make_request.assert_called_once()
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset") assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
async def test_register_dataset_with_custom_namespace_project(nvidia_dataset_adapter):
adapter, mock_make_request = nvidia_dataset_adapter
def test_register_dataset_with_custom_namespace_project(self):
custom_config = NvidiaDatasetIOConfig( custom_config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"], datasets_url=os.environ["NVIDIA_DATASETS_URL"],
dataset_namespace="custom-namespace", dataset_namespace="custom-namespace",
@ -101,7 +105,7 @@ class TestNvidiaDatastore(unittest.TestCase):
) )
custom_adapter = NvidiaDatasetIOAdapter(custom_config) custom_adapter = NvidiaDatasetIOAdapter(custom_config)
self.mock_make_request.return_value = { mock_make_request.return_value = {
"id": "dataset-123456", "id": "dataset-123456",
"name": "test-dataset", "name": "test-dataset",
"namespace": "custom-namespace", "namespace": "custom-namespace",
@ -117,11 +121,11 @@ class TestNvidiaDatastore(unittest.TestCase):
metadata={"format": "jsonl"}, metadata={"format": "jsonl"},
) )
self.run_async(custom_adapter.register_dataset(dataset_def)) await custom_adapter.register_dataset(dataset_def)
self.mock_make_request.assert_called_once() mock_make_request.assert_called_once()
self._assert_request( assert_request(
self.mock_make_request, mock_make_request,
"POST", "POST",
"/v1/datasets", "/v1/datasets",
expected_json={ expected_json={
@ -132,7 +136,3 @@ class TestNvidiaDatastore(unittest.TestCase):
"format": "jsonl", "format": "jsonl",
}, },
) )
if __name__ == "__main__":
unittest.main()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -21,49 +20,42 @@ MOCK_DATASET_ID = "default/test-dataset"
MOCK_BENCHMARK_ID = "test-benchmark" MOCK_BENCHMARK_ID = "test-benchmark"
class TestNVIDIAEvalImpl(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_eval_impl():
"""Set up the NVIDIA eval implementation with mocked dependencies"""
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test" os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
# Create mock APIs # Create mock APIs
self.datasetio_api = MagicMock() datasetio_api = MagicMock()
self.datasets_api = MagicMock() datasets_api = MagicMock()
self.scoring_api = MagicMock() scoring_api = MagicMock()
self.inference_api = MagicMock() inference_api = MagicMock()
self.agents_api = MagicMock() agents_api = MagicMock()
self.config = NVIDIAEvalConfig( config = NVIDIAEvalConfig(
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"], evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
) )
self.eval_impl = NVIDIAEvalImpl( eval_impl = NVIDIAEvalImpl(
config=self.config, config=config,
datasetio_api=self.datasetio_api, datasetio_api=datasetio_api,
datasets_api=self.datasets_api, datasets_api=datasets_api,
scoring_api=self.scoring_api, scoring_api=scoring_api,
inference_api=self.inference_api, inference_api=inference_api,
agents_api=self.agents_api, agents_api=agents_api,
) )
# Mock the HTTP request methods # Mock the HTTP request methods
self.evaluator_get_patcher = patch( with (
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get" patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get") as mock_evaluator_get,
) patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post") as mock_evaluator_post,
self.evaluator_post_patcher = patch( ):
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post" yield eval_impl, mock_evaluator_get, mock_evaluator_post
)
self.mock_evaluator_get = self.evaluator_get_patcher.start()
self.mock_evaluator_post = self.evaluator_post_patcher.start()
def tearDown(self): def assert_request_body(mock_evaluator_post, expected_json):
"""Clean up after each test.""" """Helper function to verify request body in Evaluator POST request is correct"""
self.evaluator_get_patcher.stop() call_args = mock_evaluator_post.call_args
self.evaluator_post_patcher.stop()
def _assert_request_body(self, expected_json):
"""Helper method to verify request body in Evaluator POST request is correct"""
call_args = self.mock_evaluator_post.call_args
actual_json = call_args[0][1] actual_json = call_args[0][1]
# Check that all expected keys contain the expected values in the actual JSON # Check that all expected keys contain the expected values in the actual JSON
@ -77,11 +69,10 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
else: else:
assert actual_json[key] == value, f"Value mismatch for '{key}'" assert actual_json[key] == value, f"Value mismatch for '{key}'"
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_register_benchmark(self): async def test_register_benchmark(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
eval_config = { eval_config = {
"type": "custom", "type": "custom",
"params": {"parallelism": 8}, "params": {"parallelism": 8},
@ -106,16 +97,21 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
# Mock Evaluator API response # Mock Evaluator API response
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"} mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response mock_evaluator_post.return_value = mock_evaluator_response
# Register the benchmark # Register the benchmark
self.run_async(self.eval_impl.register_benchmark(benchmark)) await eval_impl.register_benchmark(benchmark)
# Verify the Evaluator API was called correctly # Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once() mock_evaluator_post.assert_called_once()
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}) assert_request_body(
mock_evaluator_post, {"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}
)
async def test_run_eval(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
def test_run_eval(self):
benchmark_config = BenchmarkConfig( benchmark_config = BenchmarkConfig(
eval_candidate=ModelCandidate( eval_candidate=ModelCandidate(
type="model", type="model",
@ -126,20 +122,19 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
# Mock Evaluator API response # Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "created"} mock_evaluator_response = {"id": "job-123", "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response mock_evaluator_post.return_value = mock_evaluator_response
# Run the Evaluation job # Run the Evaluation job
result = self.run_async( result = await eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
)
# Verify the Evaluator API was called correctly # Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once() mock_evaluator_post.assert_called_once()
self._assert_request_body( assert_request_body(
mock_evaluator_post,
{ {
"config": f"nvidia/{MOCK_BENCHMARK_ID}", "config": f"nvidia/{MOCK_BENCHMARK_ID}",
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"}, "target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
} },
) )
# Verify the result # Verify the result
@ -147,13 +142,16 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
assert result.job_id == "job-123" assert result.job_id == "job-123"
assert result.status == JobStatus.in_progress assert result.status == JobStatus.in_progress
def test_job_status(self):
async def test_job_status(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
# Mock Evaluator API response # Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "completed"} mock_evaluator_response = {"id": "job-123", "status": "completed"}
self.mock_evaluator_get.return_value = mock_evaluator_response mock_evaluator_get.return_value = mock_evaluator_response
# Get the Evaluation job # Get the Evaluation job
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) result = await eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
# Verify the result # Verify the result
assert isinstance(result, Job) assert isinstance(result, Job)
@ -161,20 +159,26 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
assert result.status == JobStatus.completed assert result.status == JobStatus.completed
# Verify the API was called correctly # Verify the API was called correctly
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}") mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
async def test_job_cancel(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
def test_job_cancel(self):
# Mock Evaluator API response # Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "cancelled"} mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
self.mock_evaluator_post.return_value = mock_evaluator_response mock_evaluator_post.return_value = mock_evaluator_response
# Cancel the Evaluation job # Cancel the Evaluation job
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) await eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
# Verify the API was called correctly # Verify the API was called correctly
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {}) mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
async def test_job_result(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
def test_job_result(self):
# Mock Evaluator API responses # Mock Evaluator API responses
mock_job_status_response = {"id": "job-123", "status": "completed"} mock_job_status_response = {"id": "job-123", "status": "completed"}
mock_job_results_response = { mock_job_results_response = {
@ -182,13 +186,13 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
"status": "completed", "status": "completed",
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}}, "results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
} }
self.mock_evaluator_get.side_effect = [ mock_evaluator_get.side_effect = [
mock_job_status_response, # First call to retrieve job mock_job_status_response, # First call to retrieve job
mock_job_results_response, # Second call to retrieve job results mock_job_results_response, # Second call to retrieve job results
] ]
# Get the Evaluation job results # Get the Evaluation job results
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) result = await eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
# Verify the result # Verify the result
assert isinstance(result, EvaluateResponse) assert isinstance(result, EvaluateResponse)
@ -196,6 +200,6 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85 assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
# Verify the API was called correctly # Verify the API was called correctly
assert self.mock_evaluator_get.call_count == 2 assert mock_evaluator_get.call_count == 2
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123") mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results") mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
import warnings import warnings
from unittest.mock import patch from unittest.mock import patch
@ -27,33 +26,33 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
) )
class TestNvidiaParameters(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_adapter():
"""Set up the NVIDIA adapter with mock configuration"""
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig( config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
) )
self.adapter = NvidiaPostTrainingAdapter(config) adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch( # Mock the _make_request method
with patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request" "llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
) ) as mock_make_request:
self.mock_make_request = self.make_request_patcher.start() mock_make_request.return_value = {
self.mock_make_request.return_value = {
"id": "job-123", "id": "job-123",
"status": "created", "status": "created",
"created_at": "2025-03-04T13:07:47.543605", "created_at": "2025-03-04T13:07:47.543605",
"updated_at": "2025-03-04T13:07:47.543605", "updated_at": "2025-03-04T13:07:47.543605",
} }
yield adapter, mock_make_request
def tearDown(self):
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json): def assert_request_params(mock_make_request, expected_json):
"""Helper method to verify parameters in the request JSON.""" """Helper function to verify parameters in the request JSON."""
call_args = self.mock_make_request.call_args call_args = mock_make_request.call_args
actual_json = call_args[1]["json"] actual_json = call_args[1]["json"]
for key, value in expected_json.items(): for key, value in expected_json.items():
@ -63,12 +62,11 @@ class TestNvidiaParameters(unittest.TestCase):
else: else:
assert actual_json[key] == value assert actual_json[key] == value
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_customizer_parameters_passed(self): async def test_customizer_parameters_passed(nvidia_adapter):
"""Test scenario 1: When an optional parameter is passed and value is correctly set.""" """Test scenario 1: When an optional parameter is passed and value is correctly set."""
adapter, mock_make_request = nvidia_adapter
algorithm_config = LoraFinetuningConfig( algorithm_config = LoraFinetuningConfig(
type="LoRA", type="LoRA",
apply_lora_to_mlp=True, apply_lora_to_mlp=True,
@ -96,8 +94,7 @@ class TestNvidiaParameters(unittest.TestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
self.run_async( await adapter.supervised_fine_tune(
self.adapter.supervised_fine_tune(
job_uuid="test-job", job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct", model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="", checkpoint_dir="",
@ -106,7 +103,6 @@ class TestNvidiaParameters(unittest.TestCase):
logger_config={}, logger_config={},
hyperparam_search_config={}, hyperparam_search_config={},
) )
)
warning_texts = [str(warning.message) for warning in w] warning_texts = [str(warning.message) for warning in w]
@ -118,7 +114,8 @@ class TestNvidiaParameters(unittest.TestCase):
for field in fields: for field in fields:
assert any(field in text for text in warning_texts) assert any(field in text for text in warning_texts)
self._assert_request_params( assert_request_params(
mock_make_request,
{ {
"hyperparameters": { "hyperparameters": {
"lora": {"alpha": 16}, "lora": {"alpha": 16},
@ -126,11 +123,14 @@ class TestNvidiaParameters(unittest.TestCase):
"learning_rate": 0.0002, "learning_rate": 0.0002,
"batch_size": 16, "batch_size": 16,
} }
} },
) )
def test_required_parameters_passed(self):
async def test_required_parameters_passed(nvidia_adapter):
"""Test scenario 2: When required parameters are passed.""" """Test scenario 2: When required parameters are passed."""
adapter, mock_make_request = nvidia_adapter
required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40" required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40"
required_dataset_id = "required-dataset" required_dataset_id = "required-dataset"
required_job_uuid = "required-job" required_job_uuid = "required-job"
@ -164,8 +164,7 @@ class TestNvidiaParameters(unittest.TestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
self.run_async( await adapter.supervised_fine_tune(
self.adapter.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter model=required_model, # Required parameter
checkpoint_dir="", checkpoint_dir="",
@ -174,7 +173,6 @@ class TestNvidiaParameters(unittest.TestCase):
logger_config={}, logger_config={},
hyperparam_search_config={}, hyperparam_search_config={},
) )
)
warning_texts = [str(warning.message) for warning in w] warning_texts = [str(warning.message) for warning in w]
@ -187,14 +185,17 @@ class TestNvidiaParameters(unittest.TestCase):
for field in fields: for field in fields:
assert any(field in text for text in warning_texts) assert any(field in text for text in warning_texts)
self.mock_make_request.assert_called_once() mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args call_args = mock_make_request.call_args
assert call_args[1]["json"]["config"] == required_model assert call_args[1]["json"]["config"] == required_model
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
def test_unsupported_parameters_warning(self):
async def test_unsupported_parameters_warning(nvidia_adapter):
"""Test that warnings are raised for unsupported parameters.""" """Test that warnings are raised for unsupported parameters."""
adapter, mock_make_request = nvidia_adapter
data_config = DataConfig( data_config = DataConfig(
dataset_id="test-dataset", dataset_id="test-dataset",
batch_size=8, batch_size=8,
@ -232,8 +233,7 @@ class TestNvidiaParameters(unittest.TestCase):
with warnings.catch_warnings(record=True) as w: with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always") warnings.simplefilter("always")
self.run_async( await adapter.supervised_fine_tune(
self.adapter.supervised_fine_tune(
job_uuid="test-job", job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct", model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter checkpoint_dir="test-dir", # Unsupported parameter
@ -249,7 +249,6 @@ class TestNvidiaParameters(unittest.TestCase):
logger_config={"test": "value"}, # Unsupported parameter logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter hyperparam_search_config={"test": "value"}, # Unsupported parameter
) )
)
assert len(w) >= 4 assert len(w) >= 4
warning_texts = [str(warning.message) for warning in w] warning_texts = [str(warning.message) for warning in w]
@ -273,7 +272,3 @@ class TestNvidiaParameters(unittest.TestCase):
] ]
for field in fields: for field in fields:
assert any(field in text for text in warning_texts) assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@ -18,42 +17,35 @@ from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
class TestNVIDIASafetyAdapter(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_safety_adapter():
"""Set up the NVIDIA safety adapter with mocked dependencies"""
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
# Initialize the adapter # Initialize the adapter
self.config = NVIDIASafetyConfig( config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
) )
self.adapter = NVIDIASafetyAdapter(config=self.config) adapter = NVIDIASafetyAdapter(config=config)
self.shield_store = AsyncMock() shield_store = AsyncMock()
self.adapter.shield_store = self.shield_store adapter.shield_store = shield_store
# Mock the HTTP request methods # Mock the HTTP request methods
self.guardrails_post_patcher = patch( with patch(
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post" "llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
) ) as mock_guardrails_post:
self.mock_guardrails_post = self.guardrails_post_patcher.start() mock_guardrails_post.return_value = {"status": "allowed"}
self.mock_guardrails_post.return_value = {"status": "allowed"} yield adapter, shield_store, mock_guardrails_post
def tearDown(self):
"""Clean up after each test."""
self.guardrails_post_patcher.stop()
@pytest.fixture(autouse=True) def assert_request(
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(
self,
mock_call: MagicMock, mock_call: MagicMock,
expected_url: str, expected_url: str,
expected_headers: dict[str, str] | None = None, expected_headers: dict[str, str] | None = None,
expected_json: dict[str, Any] | None = None, expected_json: dict[str, Any] | None = None,
) -> None: ) -> None:
""" """
Helper method to verify request details in mock API calls. Helper function to verify request details in mock API calls.
Args: Args:
mock_call: The MagicMock object that was called mock_call: The MagicMock object that was called
@ -80,7 +72,10 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
else: else:
assert call_args[1]["json"][key] == value assert call_args[1]["json"][key] == value
def test_register_shield_with_valid_id(self):
async def test_register_shield_with_valid_id(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
shield = Shield( shield = Shield(
provider_id="nvidia", provider_id="nvidia",
type="shield", type="shield",
@ -89,9 +84,12 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
) )
# Register the shield # Register the shield
self.run_async(self.adapter.register_shield(shield)) await adapter.register_shield(shield)
async def test_register_shield_without_id(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
def test_register_shield_without_id(self):
shield = Shield( shield = Shield(
provider_id="nvidia", provider_id="nvidia",
type="shield", type="shield",
@ -100,10 +98,13 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
) )
# Register the shield should raise a ValueError # Register the shield should raise a ValueError
with self.assertRaises(ValueError): with pytest.raises(ValueError):
self.run_async(self.adapter.register_shield(shield)) await adapter.register_shield(shield)
async def test_run_shield_allowed(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
def test_run_shield_allowed(self):
# Set up the shield # Set up the shield
shield_id = "test-shield" shield_id = "test-shield"
shield = Shield( shield = Shield(
@ -112,10 +113,10 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
identifier=shield_id, identifier=shield_id,
provider_resource_id="test-model", provider_resource_id="test-model",
) )
self.shield_store.get_shield.return_value = shield shield_store.get_shield.return_value = shield
# Mock Guardrails API response # Mock Guardrails API response
self.mock_guardrails_post.return_value = {"status": "allowed"} mock_guardrails_post.return_value = {"status": "allowed"}
# Run the shield # Run the shield
messages = [ messages = [
@ -127,13 +128,13 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
tool_calls=[], tool_calls=[],
), ),
] ]
result = self.run_async(self.adapter.run_shield(shield_id, messages)) result = await adapter.run_shield(shield_id, messages)
# Verify the shield store was called # Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id) shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly # Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with( mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks", path="/v1/guardrail/checks",
data={ data={
"model": shield_id, "model": shield_id,
@ -157,7 +158,10 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
assert isinstance(result, RunShieldResponse) assert isinstance(result, RunShieldResponse)
assert result.violation is None assert result.violation is None
def test_run_shield_blocked(self):
async def test_run_shield_blocked(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
# Set up the shield # Set up the shield
shield_id = "test-shield" shield_id = "test-shield"
shield = Shield( shield = Shield(
@ -166,10 +170,10 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
identifier=shield_id, identifier=shield_id,
provider_resource_id="test-model", provider_resource_id="test-model",
) )
self.shield_store.get_shield.return_value = shield shield_store.get_shield.return_value = shield
# Mock Guardrails API response # Mock Guardrails API response
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
# Run the shield # Run the shield
messages = [ messages = [
@ -181,13 +185,13 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
tool_calls=[], tool_calls=[],
), ),
] ]
result = self.run_async(self.adapter.run_shield(shield_id, messages)) result = await adapter.run_shield(shield_id, messages)
# Verify the shield store was called # Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id) shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly # Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with( mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks", path="/v1/guardrail/checks",
data={ data={
"model": shield_id, "model": shield_id,
@ -214,25 +218,31 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
assert result.violation.violation_level == ViolationLevel.ERROR assert result.violation.violation_level == ViolationLevel.ERROR
assert result.violation.metadata == {"reason": "harmful_content"} assert result.violation.metadata == {"reason": "harmful_content"}
def test_run_shield_not_found(self):
async def test_run_shield_not_found(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
# Set up shield store to return None # Set up shield store to return None
shield_id = "non-existent-shield" shield_id = "non-existent-shield"
self.shield_store.get_shield.return_value = None shield_store.get_shield.return_value = None
messages = [ messages = [
UserMessage(role="user", content="Hello, how are you?"), UserMessage(role="user", content="Hello, how are you?"),
] ]
with self.assertRaises(ValueError): with pytest.raises(ValueError):
self.run_async(self.adapter.run_shield(shield_id, messages)) await adapter.run_shield(shield_id, messages)
# Verify the shield store was called # Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id) shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was not called # Verify the Guardrails API was not called
self.mock_guardrails_post.assert_not_called() mock_guardrails_post.assert_not_called()
async def test_run_shield_http_error(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
def test_run_shield_http_error(self):
shield_id = "test-shield" shield_id = "test-shield"
shield = Shield( shield = Shield(
provider_id="nvidia", provider_id="nvidia",
@ -240,11 +250,11 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
identifier=shield_id, identifier=shield_id,
provider_resource_id="test-model", provider_resource_id="test-model",
) )
self.shield_store.get_shield.return_value = shield shield_store.get_shield.return_value = shield
# Mock Guardrails API to raise an exception # Mock Guardrails API to raise an exception
error_msg = "API Error: 500 Internal Server Error" error_msg = "API Error: 500 Internal Server Error"
self.mock_guardrails_post.side_effect = Exception(error_msg) mock_guardrails_post.side_effect = Exception(error_msg)
# Running the shield should raise an exception # Running the shield should raise an exception
messages = [ messages = [
@ -256,14 +266,14 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
tool_calls=[], tool_calls=[],
), ),
] ]
with self.assertRaises(Exception) as context: with pytest.raises(Exception) as excinfo:
self.run_async(self.adapter.run_shield(shield_id, messages)) await adapter.run_shield(shield_id, messages)
# Verify the shield store was called # Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id) shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly # Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with( mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks", path="/v1/guardrail/checks",
data={ data={
"model": shield_id, "model": shield_id,
@ -283,11 +293,14 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
}, },
) )
# Verify the exception message # Verify the exception message
assert error_msg in str(context.exception) assert error_msg in str(excinfo.value)
def test_init_nemo_guardrails(self):
def test_init_nemo_guardrails():
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
test_config_id = "test-custom-config-id" test_config_id = "test-custom-config-id"
config = NVIDIASafetyConfig( config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
@ -314,12 +327,15 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
assert guardrails.temperature == 0.7 assert guardrails.temperature == 0.7
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
def test_init_nemo_guardrails_invalid_temperature(self):
def test_init_nemo_guardrails_invalid_temperature():
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
config = NVIDIASafetyConfig( config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id="test-custom-config-id", config_id="test-custom-config-id",
) )
with self.assertRaises(ValueError): with pytest.raises(ValueError):
NeMoGuardrails(config, "test-model", temperature=0) NeMoGuardrails(config, "test-model", temperature=0)

View file

@ -5,9 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
import warnings import warnings
from unittest.mock import patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -32,43 +31,38 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
) )
class TestNvidiaPostTraining(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_adapters():
"""Set up the NVIDIA adapters with mocked dependencies"""
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig( config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
) )
self.adapter = NvidiaPostTrainingAdapter(config) adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
# Mock the inference client # Mock the inference client
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None) inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
self.inference_adapter = NVIDIAInferenceAdapter(inference_config) inference_adapter = NVIDIAInferenceAdapter(inference_config)
self.mock_client = unittest.mock.MagicMock() mock_client = MagicMock()
self.mock_client.chat.completions.create = unittest.mock.AsyncMock() mock_client.chat.completions.create = AsyncMock()
self.inference_mock_make_request = self.mock_client.chat.completions.create
self.inference_make_request_patcher = patch( with (
patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
) as mock_make_request,
patch(
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
return_value=self.mock_client, return_value=mock_client,
) ),
self.inference_make_request_patcher.start() ):
yield adapter, inference_adapter, mock_make_request, mock_client
def tearDown(self):
self.make_request_patcher.stop()
self.inference_make_request_patcher.stop()
@pytest.fixture(autouse=True) def assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
def inject_fixtures(self, run_async): """Helper function to verify request details in mock calls."""
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
call_args = mock_call.call_args call_args = mock_call.call_args
if expected_method and expected_path: if expected_method and expected_path:
@ -85,9 +79,12 @@ class TestNvidiaPostTraining(unittest.TestCase):
for key, value in expected_json.items(): for key, value in expected_json.items():
assert call_args[1]["json"][key] == value assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self):
async def test_supervised_fine_tune(nvidia_adapters):
"""Test the supervised fine-tuning API call.""" """Test the supervised fine-tuning API call."""
self.mock_make_request.return_value = { adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2", "id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884", "created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884", "updated_at": "2024-12-09T04:06:28.542884",
@ -162,8 +159,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
with warnings.catch_warnings(record=True): with warnings.catch_warnings(record=True):
warnings.simplefilter("always") warnings.simplefilter("always")
training_job = self.run_async( training_job = await adapter.supervised_fine_tune(
self.adapter.supervised_fine_tune(
job_uuid="1234", job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40", model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="", checkpoint_dir="",
@ -172,15 +168,14 @@ class TestNvidiaPostTraining(unittest.TestCase):
logger_config={}, logger_config={},
hyperparam_search_config={}, hyperparam_search_config={},
) )
)
# check the output is a PostTrainingJob # check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob) assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2" assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once() mock_make_request.assert_called_once()
self._assert_request( assert_request(
self.mock_make_request, mock_make_request,
"POST", "POST",
"/v1/customization/jobs", "/v1/customization/jobs",
expected_json={ expected_json={
@ -198,7 +193,10 @@ class TestNvidiaPostTraining(unittest.TestCase):
}, },
) )
def test_supervised_fine_tune_with_qat(self):
async def test_supervised_fine_tune_with_qat(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = DataConfig( data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
@ -215,9 +213,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
optimizer_config=optimizer_config, optimizer_config=optimizer_config,
) )
# This will raise NotImplementedError since QAT is not supported # This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError): with pytest.raises(NotImplementedError):
self.run_async( await adapter.supervised_fine_tune(
self.adapter.supervised_fine_tune(
job_uuid="1234", job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40", model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="", checkpoint_dir="",
@ -226,21 +223,23 @@ class TestNvidiaPostTraining(unittest.TestCase):
logger_config={}, logger_config={},
hyperparam_search_config={}, hyperparam_search_config={},
) )
)
def test_get_training_job_status(self):
customizer_status_to_job_status = [ @pytest.mark.parametrize(
"customizer_status,expected_status",
[
("running", "in_progress"), ("running", "in_progress"),
("completed", "completed"), ("completed", "completed"),
("failed", "failed"), ("failed", "failed"),
("cancelled", "cancelled"), ("cancelled", "cancelled"),
("pending", "scheduled"), ("pending", "scheduled"),
("unknown", "scheduled"), ("unknown", "scheduled"),
] ],
)
async def test_get_training_job_status(nvidia_adapters, customizer_status, expected_status):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
for customizer_status, expected_status in customizer_status_to_job_status: mock_make_request.return_value = {
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220", "created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832", "updated_at": "2024-12-09T04:21:19.852832",
"status": customizer_status, "status": customizer_status,
@ -254,7 +253,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id)) status = await adapter.get_training_job_status(job_uuid=job_id)
assert isinstance(status, NvidiaPostTrainingJobStatusResponse) assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == expected_status assert status.status.value == expected_status
@ -265,16 +264,19 @@ class TestNvidiaPostTraining(unittest.TestCase):
assert status.train_loss == 1.718016266822815 assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613 assert status.val_loss == 1.8661999702453613
self._assert_request( assert_request(
self.mock_make_request, mock_make_request,
"GET", "GET",
f"/v1/customization/jobs/{job_id}/status", f"/v1/customization/jobs/{job_id}/status",
expected_params={"job_id": job_id}, expected_params={"job_id": job_id},
) )
def test_get_training_jobs(self):
async def test_get_training_jobs(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = { mock_make_request.return_value = {
"data": [ "data": [
{ {
"id": job_id, "id": job_id,
@ -300,7 +302,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
] ]
} }
jobs = self.run_async(self.adapter.get_training_jobs()) jobs = await adapter.get_training_jobs()
assert isinstance(jobs, ListNvidiaPostTrainingJobs) assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1 assert len(jobs.data) == 1
@ -308,31 +310,37 @@ class TestNvidiaPostTraining(unittest.TestCase):
assert job.job_uuid == job_id assert job.job_uuid == job_id
assert job.status.value == "completed" assert job.status.value == "completed"
self.mock_make_request.assert_called_once() mock_make_request.assert_called_once()
self._assert_request( assert_request(
self.mock_make_request, mock_make_request,
"GET", "GET",
"/v1/customization/jobs", "/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"}, expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
) )
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation async def test_cancel_training_job(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id)) result = await adapter.cancel_training_job(job_uuid=job_id)
assert result is None assert result is None
self.mock_make_request.assert_called_once() mock_make_request.assert_called_once()
self._assert_request( assert_request(
self.mock_make_request, mock_make_request,
"POST", "POST",
f"/v1/customization/jobs/{job_id}/cancel", f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id}, expected_params={"job_id": job_id},
) )
def test_inference_register_model(self):
async def test_inference_register_model(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
model_id = "default/job-1234" model_id = "default/job-1234"
model_type = ModelType.llm model_type = ModelType.llm
model = Model( model = Model(
@ -342,21 +350,15 @@ class TestNvidiaPostTraining(unittest.TestCase):
provider_resource_id=model_id, provider_resource_id=model_id,
model_type=model_type, model_type=model_type,
) )
result = self.run_async(self.inference_adapter.register_model(model)) result = await inference_adapter.register_model(model)
assert result == model assert result == model
assert len(self.inference_adapter.alias_to_provider_id_map) > 1 assert len(inference_adapter.alias_to_provider_id_map) > 1
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id assert inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion: with patch.object(inference_adapter, "chat_completion") as mock_chat_completion:
self.run_async( await inference_adapter.chat_completion(
self.inference_adapter.chat_completion(
model_id=model_id, model_id=model_id,
messages=[{"role": "user", "content": "Hello, model"}], messages=[{"role": "user", "content": "Hello, model"}],
) )
)
mock_chat_completion.assert_called() mock_chat_completion.assert_called()
if __name__ == "__main__":
unittest.main()

View file

@ -5,15 +5,19 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
import pytest
from llama_stack.distribution.stack import replace_env_vars from llama_stack.distribution.stack import replace_env_vars
class TestReplaceEnvVars(unittest.TestCase): @pytest.fixture(autouse=True)
def setUp(self): def setup_env_vars():
# Clear any existing environment variables we'll use in tests """Set up test environment variables and clean up after test"""
# Store original values
original_vars = {}
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]: for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
original_vars[var] = os.environ.get(var)
if var in os.environ: if var in os.environ:
del os.environ[var] del os.environ[var]
@ -22,56 +26,70 @@ class TestReplaceEnvVars(unittest.TestCase):
os.environ["EMPTY_VAR"] = "" os.environ["EMPTY_VAR"] = ""
os.environ["ZERO_VAR"] = "0" os.environ["ZERO_VAR"] = "0"
def test_simple_replacement(self): yield
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
def test_default_value_when_not_set(self): # Cleanup: restore original values
self.assertEqual(replace_env_vars("${env.NOT_SET:=default}"), "default") for var, original_value in original_vars.items():
if original_value is not None:
os.environ[var] = original_value
elif var in os.environ:
del os.environ[var]
def test_default_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:=default}"), "test_value")
def test_default_value_when_empty(self): def test_simple_replacement():
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=default}"), "default") assert replace_env_vars("${env.TEST_VAR}") == "test_value"
def test_none_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=}"), None)
def test_value_when_set(self): def test_default_value_when_not_set():
self.assertEqual(replace_env_vars("${env.TEST_VAR:=}"), "test_value") assert replace_env_vars("${env.NOT_SET:=default}") == "default"
def test_empty_var_no_default(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}"), None)
def test_conditional_value_when_set(self): def test_default_value_when_set():
self.assertEqual(replace_env_vars("${env.TEST_VAR:+conditional}"), "conditional") assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value"
def test_conditional_value_when_not_set(self):
self.assertEqual(replace_env_vars("${env.NOT_SET:+conditional}"), None)
def test_conditional_value_when_empty(self): def test_default_value_when_empty():
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:+conditional}"), None) assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default"
def test_conditional_value_with_zero(self):
self.assertEqual(replace_env_vars("${env.ZERO_VAR:+conditional}"), "conditional")
def test_mixed_syntax(self): def test_none_value_when_empty():
self.assertEqual( assert replace_env_vars("${env.EMPTY_VAR:=}") is None
replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}"), "test_value and "
)
self.assertEqual(
replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}"), "default and conditional"
)
def test_nested_structures(self):
def test_value_when_set():
assert replace_env_vars("${env.TEST_VAR:=}") == "test_value"
def test_empty_var_no_default():
assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None
def test_conditional_value_when_set():
assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional"
def test_conditional_value_when_not_set():
assert replace_env_vars("${env.NOT_SET:+conditional}") is None
def test_conditional_value_when_empty():
assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None
def test_conditional_value_with_zero():
assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional"
def test_mixed_syntax():
assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and "
assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional"
def test_nested_structures():
data = { data = {
"key1": "${env.TEST_VAR:=default}", "key1": "${env.TEST_VAR:=default}",
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"], "key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
"key3": {"nested": "${env.NOT_SET:+conditional}"}, "key3": {"nested": "${env.NOT_SET:+conditional}"},
} }
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}} expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
self.assertEqual(replace_env_vars(data), expected) assert replace_env_vars(data) == expected
if __name__ == "__main__":
unittest.main()