test: migrate unit tests from unittest to pytest

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 12:55:16 +02:00
parent ff9d4d8a9d
commit 9331253894
8 changed files with 1440 additions and 1405 deletions

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in # This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree. # the root directory of this source tree.
import asyncio
import unittest
from llama_stack.apis.inference import ( from llama_stack.apis.inference import (
ChatCompletionRequest, ChatCompletionRequest,
@ -31,258 +29,259 @@ MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): async def test_system_default():
async def asyncSetUp(self): content = "Hello !"
asyncio.get_running_loop().set_debug(False) request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in messages[0].content
async def test_system_default(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
async def test_system_builtin_only(self): async def test_system_builtin_only():
content = "Hello !" content = "Hello !"
request = ChatCompletionRequest( request = ChatCompletionRequest(
model=MODEL, model=MODEL,
messages=[ messages=[
UserMessage(content=content), UserMessage(content=content),
], ],
tools=[ tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter), ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search), ToolDefinition(tool_name=BuiltinTool.brave_search),
], ],
) )
messages = chat_completion_request_to_messages(request, MODEL) messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2) assert len(messages) == 2
self.assertEqual(messages[-1].content, content) assert messages[-1].content == content
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) assert "Cutting Knowledge Date: December 2023" in messages[0].content
self.assertTrue("Tools: brave_search" in messages[0].content) assert "Tools: brave_search" in messages[0].content
async def test_system_custom_only(self):
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content) async def test_system_custom_only():
self.assertEqual(messages[-1].content, content) content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
)
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 3
assert "Environment: ipython" in messages[0].content
async def test_system_custom_and_builtin(self): assert "Return function calls in JSON format" in messages[1].content
content = "Hello !" assert messages[-1].content == content
request = ChatCompletionRequest(
model=MODEL,
messages=[
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(tool_name=BuiltinTool.brave_search),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
self.assertTrue("Return function calls in JSON format" in messages[1].content) async def test_system_custom_and_builtin():
self.assertEqual(messages[-1].content, content) content = "Hello !"
request = ChatCompletionRequest(
async def test_completion_message_encoding(self): model=MODEL,
request = ChatCompletionRequest( messages=[
model=MODEL3_2, UserMessage(content=content),
messages=[ ],
UserMessage(content="hello"), tools=[
CompletionMessage( ToolDefinition(tool_name=BuiltinTool.code_interpreter),
content="", ToolDefinition(tool_name=BuiltinTool.brave_search),
stop_reason=StopReason.end_of_turn, ToolDefinition(
tool_calls=[ tool_name="custom1",
ToolCall( description="custom1 tool",
tool_name="custom1", parameters={
arguments={"param1": "value1"}, "param1": ToolParamDefinition(
call_id="123", param_type="str",
) description="param1 description",
], required=True,
), ),
], },
tools=[
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt)
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
async def test_user_provided_system_message(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertEqual(messages[-1].content, content)
async def test_repalce_system_message_behavior_builtin_tools(self):
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
), ),
) ],
messages = chat_completion_request_to_messages(request, MODEL3_2) )
self.assertEqual(len(messages), 2, messages) messages = chat_completion_request_to_messages(request, MODEL)
self.assertTrue(messages[0].content.endswith(system_prompt)) assert len(messages) == 3
self.assertIn("Environment: ipython", messages[0].content)
self.assertEqual(messages[-1].content, content)
async def test_repalce_system_message_behavior_custom_tools(self): assert "Environment: ipython" in messages[0].content
content = "Hello !" assert "Tools: brave_search" in messages[0].content
system_prompt = "You are a pirate"
request = ChatCompletionRequest( assert "Return function calls in JSON format" in messages[1].content
model=MODEL, assert messages[-1].content == content
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content), async def test_completion_message_encoding():
], request = ChatCompletionRequest(
tools=[ model=MODEL3_2,
ToolDefinition(tool_name=BuiltinTool.code_interpreter), messages=[
ToolDefinition( UserMessage(content="hello"),
tool_name="custom1", CompletionMessage(
description="custom1 tool", content="",
parameters={ stop_reason=StopReason.end_of_turn,
"param1": ToolParamDefinition( tool_calls=[
param_type="str", ToolCall(
description="param1 description", tool_name="custom1",
required=True, arguments={"param1": "value1"},
), call_id="123",
}, )
), ],
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
), ),
) ],
messages = chat_completion_request_to_messages(request, MODEL3_2) tools=[
ToolDefinition(
self.assertEqual(len(messages), 2, messages) tool_name="custom1",
self.assertTrue(messages[0].content.endswith(system_prompt)) description="custom1 tool",
self.assertIn("Environment: ipython", messages[0].content) parameters={
self.assertEqual(messages[-1].content, content) "param1": ToolParamDefinition(
param_type="str",
async def test_replace_system_message_behavior_custom_tools_with_template(self): description="param1 description",
content = "Hello !" required=True,
system_prompt = "You are a pirate {{ function_description }}" ),
request = ChatCompletionRequest( },
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
), ),
) ],
messages = chat_completion_request_to_messages(request, MODEL3_2) tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
assert '[custom1(param1="value1")]' in prompt
self.assertEqual(len(messages), 2, messages) request.model = MODEL
self.assertIn("Environment: ipython", messages[0].content) request.tool_config.tool_prompt_format = ToolPromptFormat.json
self.assertIn("You are a pirate", messages[0].content) prompt = await chat_completion_request_to_prompt(request, request.model)
# function description is present in the system prompt assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
self.assertIn('"name": "custom1"', messages[0].content)
self.assertEqual(messages[-1].content, content)
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
)
messages = chat_completion_request_to_messages(request, MODEL)
assert len(messages) == 2
assert messages[0].content.endswith(system_prompt)
assert messages[-1].content == content
async def test_repalce_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert messages[0].content.endswith(system_prompt)
assert "Environment: ipython" in messages[0].content
assert messages[-1].content == content
async def test_repalce_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert messages[0].content.endswith(system_prompt)
assert "Environment: ipython" in messages[0].content
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
model=MODEL,
messages=[
SystemMessage(content=system_prompt),
UserMessage(content=content),
],
tools=[
ToolDefinition(tool_name=BuiltinTool.code_interpreter),
ToolDefinition(
tool_name="custom1",
description="custom1 tool",
parameters={
"param1": ToolParamDefinition(
param_type="str",
description="param1 description",
required=True,
),
},
),
],
tool_config=ToolConfig(
tool_choice="auto",
tool_prompt_format="python_list",
system_message_behavior="replace",
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
assert len(messages) == 2
assert "Environment: ipython" in messages[0].content
assert "You are a pirate" in messages[0].content
# function description is present in the system prompt
assert '"name": "custom1"' in messages[0].content
assert messages[-1].content == content

View file

@ -12,7 +12,6 @@
# the top-level of this source tree. # the top-level of this source tree.
import textwrap import textwrap
import unittest
from datetime import datetime from datetime import datetime
from llama_stack.models.llama.llama3.prompt_templates import ( from llama_stack.models.llama.llama3.prompt_templates import (
@ -24,59 +23,61 @@ from llama_stack.models.llama.llama3.prompt_templates import (
) )
class PromptTemplateTests(unittest.TestCase): def check_generator_output(generator):
def check_generator_output(self, generator): for example in generator.data_examples():
for example in generator.data_examples(): pt = generator.gen(example)
pt = generator.gen(example)
text = pt.render()
# print(text) # debugging
if not example:
continue
for tool in example:
assert tool.tool_name in text
def test_system_default(self):
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_builtin_only(self):
generator = BuiltinToolGenerator()
expected_text = textwrap.dedent(
"""
Environment: ipython
Tools: brave_search, wolfram_alpha
"""
)
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_custom_only(self):
self.maxDiff = None
generator = JsonCustomToolGenerator()
self.check_generator_output(generator)
def test_system_custom_function_tag(self):
self.maxDiff = None
generator = FunctionTagCustomToolGenerator()
self.check_generator_output(generator)
def test_llama_3_2_system_zero_shot(self):
generator = PythonListCustomToolGenerator()
self.check_generator_output(generator)
def test_llama_3_2_provided_system_prompt(self):
generator = PythonListCustomToolGenerator()
user_system_prompt = textwrap.dedent(
"""
Overriding message.
{{ function_description }}
"""
)
example = generator.data_examples()[0]
pt = generator.gen(example, user_system_prompt)
text = pt.render() text = pt.render()
assert "Overriding message." in text if not example:
assert '"name": "get_weather"' in text continue
for tool in example:
assert tool.tool_name in text
def test_system_default():
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_builtin_only():
generator = BuiltinToolGenerator()
expected_text = textwrap.dedent(
"""
Environment: ipython
Tools: brave_search, wolfram_alpha
"""
)
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_custom_only():
generator = JsonCustomToolGenerator()
check_generator_output(generator)
def test_system_custom_function_tag():
generator = FunctionTagCustomToolGenerator()
check_generator_output(generator)
def test_llama_3_2_system_zero_shot():
generator = PythonListCustomToolGenerator()
check_generator_output(generator)
def test_llama_3_2_provided_system_prompt():
generator = PythonListCustomToolGenerator()
user_system_prompt = textwrap.dedent(
"""
Overriding message.
{{ function_description }}
"""
)
example = generator.data_examples()[0]
pt = generator.gen(example, user_system_prompt)
text = pt.render()
assert "Overriding message." in text
assert '"name": "get_weather"' in text

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
from unittest.mock import patch from unittest.mock import patch
import pytest import pytest
@ -15,124 +14,125 @@ from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIO
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
class TestNvidiaDatastore(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_dataset_adapter():
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" """Set up the NVIDIA dataset adapter with mocked dependencies"""
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
config = NvidiaDatasetIOConfig( config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default" datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
) )
self.adapter = NvidiaDatasetIOAdapter(config) adapter = NvidiaDatasetIOAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
def tearDown(self): with patch(
self.make_request_patcher.stop() "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
) as mock_make_request:
yield adapter, mock_make_request
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None): def assert_request(mock_call, expected_method, expected_path, expected_json=None):
"""Helper method to verify request details in mock calls.""" """Helper function to verify request details in mock calls."""
call_args = mock_call.call_args call_args = mock_call.call_args
assert call_args[0][0] == expected_method assert call_args[0][0] == expected_method
assert call_args[0][1] == expected_path assert call_args[0][1] == expected_path
if expected_json: if expected_json:
for key, value in expected_json.items(): for key, value in expected_json.items():
assert call_args[1]["json"][key] == value assert call_args[1]["json"][key] == value
def test_register_dataset(self):
self.mock_make_request.return_value = { async def test_register_dataset(nvidia_dataset_adapter):
"id": "dataset-123456", adapter, mock_make_request = nvidia_dataset_adapter
mock_make_request.return_value = {
"id": "dataset-123456",
"name": "test-dataset",
"namespace": "default",
}
dataset_def = Dataset(
identifier="test-dataset",
type="dataset",
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
source=URIDataSource(uri="https://example.com/data.jsonl"),
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
)
await adapter.register_dataset(dataset_def)
mock_make_request.assert_called_once()
assert_request(
mock_make_request,
"POST",
"/v1/datasets",
expected_json={
"name": "test-dataset", "name": "test-dataset",
"namespace": "default", "namespace": "default",
} "files_url": "https://example.com/data.jsonl",
"project": "default",
"format": "jsonl",
"description": "Test dataset description",
},
)
dataset_def = Dataset(
identifier="test-dataset",
type="dataset",
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
source=URIDataSource(uri="https://example.com/data.jsonl"),
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
)
self.run_async(self.adapter.register_dataset(dataset_def)) async def test_unregister_dataset(nvidia_dataset_adapter):
adapter, mock_make_request = nvidia_dataset_adapter
self.mock_make_request.assert_called_once() mock_make_request.return_value = {
self._assert_request( "message": "Resource deleted successfully.",
self.mock_make_request, "id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
"POST", "deleted_at": None,
"/v1/datasets", }
expected_json={ dataset_id = "test-dataset"
"name": "test-dataset",
"namespace": "default",
"files_url": "https://example.com/data.jsonl",
"project": "default",
"format": "jsonl",
"description": "Test dataset description",
},
)
def test_unregister_dataset(self): await adapter.unregister_dataset(dataset_id)
self.mock_make_request.return_value = {
"message": "Resource deleted successfully.",
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
"deleted_at": None,
}
dataset_id = "test-dataset"
self.run_async(self.adapter.unregister_dataset(dataset_id)) mock_make_request.assert_called_once()
assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
self.mock_make_request.assert_called_once()
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
def test_register_dataset_with_custom_namespace_project(self): async def test_register_dataset_with_custom_namespace_project(nvidia_dataset_adapter):
custom_config = NvidiaDatasetIOConfig( adapter, mock_make_request = nvidia_dataset_adapter
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
dataset_namespace="custom-namespace",
project_id="custom-project",
)
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
self.mock_make_request.return_value = { custom_config = NvidiaDatasetIOConfig(
"id": "dataset-123456", datasets_url=os.environ["NVIDIA_DATASETS_URL"],
dataset_namespace="custom-namespace",
project_id="custom-project",
)
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
mock_make_request.return_value = {
"id": "dataset-123456",
"name": "test-dataset",
"namespace": "custom-namespace",
}
dataset_def = Dataset(
identifier="test-dataset",
type="dataset",
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
source=URIDataSource(uri="https://example.com/data.jsonl"),
metadata={"format": "jsonl"},
)
await custom_adapter.register_dataset(dataset_def)
mock_make_request.assert_called_once()
assert_request(
mock_make_request,
"POST",
"/v1/datasets",
expected_json={
"name": "test-dataset", "name": "test-dataset",
"namespace": "custom-namespace", "namespace": "custom-namespace",
} "files_url": "https://example.com/data.jsonl",
"project": "custom-project",
dataset_def = Dataset( "format": "jsonl",
identifier="test-dataset", },
type="dataset", )
provider_resource_id="",
provider_id="",
purpose=DatasetPurpose.post_training_messages,
source=URIDataSource(uri="https://example.com/data.jsonl"),
metadata={"format": "jsonl"},
)
self.run_async(custom_adapter.register_dataset(dataset_def))
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
"/v1/datasets",
expected_json={
"name": "test-dataset",
"namespace": "custom-namespace",
"files_url": "https://example.com/data.jsonl",
"project": "custom-project",
"format": "jsonl",
},
)
if __name__ == "__main__":
unittest.main()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
from unittest.mock import MagicMock, patch from unittest.mock import MagicMock, patch
import pytest import pytest
@ -21,181 +20,186 @@ MOCK_DATASET_ID = "default/test-dataset"
MOCK_BENCHMARK_ID = "test-benchmark" MOCK_BENCHMARK_ID = "test-benchmark"
class TestNVIDIAEvalImpl(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_eval_impl():
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test" """Set up the NVIDIA eval implementation with mocked dependencies"""
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
# Create mock APIs # Create mock APIs
self.datasetio_api = MagicMock() datasetio_api = MagicMock()
self.datasets_api = MagicMock() datasets_api = MagicMock()
self.scoring_api = MagicMock() scoring_api = MagicMock()
self.inference_api = MagicMock() inference_api = MagicMock()
self.agents_api = MagicMock() agents_api = MagicMock()
self.config = NVIDIAEvalConfig( config = NVIDIAEvalConfig(
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"], evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
) )
self.eval_impl = NVIDIAEvalImpl( eval_impl = NVIDIAEvalImpl(
config=self.config, config=config,
datasetio_api=self.datasetio_api, datasetio_api=datasetio_api,
datasets_api=self.datasets_api, datasets_api=datasets_api,
scoring_api=self.scoring_api, scoring_api=scoring_api,
inference_api=self.inference_api, inference_api=inference_api,
agents_api=self.agents_api, agents_api=agents_api,
) )
# Mock the HTTP request methods # Mock the HTTP request methods
self.evaluator_get_patcher = patch( with (
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get" patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get") as mock_evaluator_get,
) patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post") as mock_evaluator_post,
self.evaluator_post_patcher = patch( ):
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post" yield eval_impl, mock_evaluator_get, mock_evaluator_post
)
self.mock_evaluator_get = self.evaluator_get_patcher.start()
self.mock_evaluator_post = self.evaluator_post_patcher.start()
def tearDown(self): def assert_request_body(mock_evaluator_post, expected_json):
"""Clean up after each test.""" """Helper function to verify request body in Evaluator POST request is correct"""
self.evaluator_get_patcher.stop() call_args = mock_evaluator_post.call_args
self.evaluator_post_patcher.stop() actual_json = call_args[0][1]
def _assert_request_body(self, expected_json): # Check that all expected keys contain the expected values in the actual JSON
"""Helper method to verify request body in Evaluator POST request is correct""" for key, value in expected_json.items():
call_args = self.mock_evaluator_post.call_args assert key in actual_json, f"Key '{key}' missing in actual JSON"
actual_json = call_args[0][1]
# Check that all expected keys contain the expected values in the actual JSON if isinstance(value, dict):
for key, value in expected_json.items(): for nested_key, nested_value in value.items():
assert key in actual_json, f"Key '{key}' missing in actual JSON" assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
else:
assert actual_json[key] == value, f"Value mismatch for '{key}'"
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']"
assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'"
else:
assert actual_json[key] == value, f"Value mismatch for '{key}'"
@pytest.fixture(autouse=True) async def test_register_benchmark(nvidia_eval_impl):
def inject_fixtures(self, run_async): eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
self.run_async = run_async
def test_register_benchmark(self): eval_config = {
eval_config = { "type": "custom",
"type": "custom", "params": {"parallelism": 8},
"params": {"parallelism": 8}, "tasks": {
"tasks": { "qa": {
"qa": { "type": "completion",
"type": "completion", "params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}},
"params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}}, "dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"},
"dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"}, "metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
"metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}},
}
},
}
benchmark = Benchmark(
provider_id="nvidia",
type="benchmark",
identifier=MOCK_BENCHMARK_ID,
dataset_id=MOCK_DATASET_ID,
scoring_functions=["basic::equality"],
metadata=eval_config,
)
# Mock Evaluator API response
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Register the benchmark
self.run_async(self.eval_impl.register_benchmark(benchmark))
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
def test_run_eval(self):
benchmark_config = BenchmarkConfig(
eval_candidate=ModelCandidate(
type="model",
model=CoreModelId.llama3_1_8b_instruct.value,
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
)
)
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Run the Evaluation job
result = self.run_async(
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
)
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body(
{
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
} }
},
}
benchmark = Benchmark(
provider_id="nvidia",
type="benchmark",
identifier=MOCK_BENCHMARK_ID,
dataset_id=MOCK_DATASET_ID,
scoring_functions=["basic::equality"],
metadata=eval_config,
)
# Mock Evaluator API response
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
mock_evaluator_post.return_value = mock_evaluator_response
# Register the benchmark
await eval_impl.register_benchmark(benchmark)
# Verify the Evaluator API was called correctly
mock_evaluator_post.assert_called_once()
assert_request_body(
mock_evaluator_post, {"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}
)
async def test_run_eval(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
benchmark_config = BenchmarkConfig(
eval_candidate=ModelCandidate(
type="model",
model=CoreModelId.llama3_1_8b_instruct.value,
sampling_params=SamplingParams(max_tokens=100, temperature=0.7),
) )
)
# Verify the result # Mock Evaluator API response
assert isinstance(result, Job) mock_evaluator_response = {"id": "job-123", "status": "created"}
assert result.job_id == "job-123" mock_evaluator_post.return_value = mock_evaluator_response
assert result.status == JobStatus.in_progress
def test_job_status(self): # Run the Evaluation job
# Mock Evaluator API response result = await eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
mock_evaluator_response = {"id": "job-123", "status": "completed"}
self.mock_evaluator_get.return_value = mock_evaluator_response
# Get the Evaluation job # Verify the Evaluator API was called correctly
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) mock_evaluator_post.assert_called_once()
assert_request_body(
mock_evaluator_post,
{
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
},
)
# Verify the result # Verify the result
assert isinstance(result, Job) assert isinstance(result, Job)
assert result.job_id == "job-123" assert result.job_id == "job-123"
assert result.status == JobStatus.completed assert result.status == JobStatus.in_progress
# Verify the API was called correctly
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
def test_job_cancel(self): async def test_job_status(nvidia_eval_impl):
# Mock Evaluator API response eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
self.mock_evaluator_post.return_value = mock_evaluator_response
# Cancel the Evaluation job # Mock Evaluator API response
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) mock_evaluator_response = {"id": "job-123", "status": "completed"}
mock_evaluator_get.return_value = mock_evaluator_response
# Verify the API was called correctly # Get the Evaluation job
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {}) result = await eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
def test_job_result(self): # Verify the result
# Mock Evaluator API responses assert isinstance(result, Job)
mock_job_status_response = {"id": "job-123", "status": "completed"} assert result.job_id == "job-123"
mock_job_results_response = { assert result.status == JobStatus.completed
"id": "job-123",
"status": "completed",
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
}
self.mock_evaluator_get.side_effect = [
mock_job_status_response, # First call to retrieve job
mock_job_results_response, # Second call to retrieve job results
]
# Get the Evaluation job results # Verify the API was called correctly
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
# Verify the result
assert isinstance(result, EvaluateResponse)
assert MOCK_BENCHMARK_ID in result.scores
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
# Verify the API was called correctly async def test_job_cancel(nvidia_eval_impl):
assert self.mock_evaluator_get.call_count == 2 eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results") # Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
mock_evaluator_post.return_value = mock_evaluator_response
# Cancel the Evaluation job
await eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
# Verify the API was called correctly
mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
async def test_job_result(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
# Mock Evaluator API responses
mock_job_status_response = {"id": "job-123", "status": "completed"}
mock_job_results_response = {
"id": "job-123",
"status": "completed",
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
}
mock_evaluator_get.side_effect = [
mock_job_status_response, # First call to retrieve job
mock_job_results_response, # Second call to retrieve job results
]
# Get the Evaluation job results
result = await eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
# Verify the result
assert isinstance(result, EvaluateResponse)
assert MOCK_BENCHMARK_ID in result.scores
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
# Verify the API was called correctly
assert mock_evaluator_get.call_count == 2
mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
import warnings import warnings
from unittest.mock import patch from unittest.mock import patch
@ -27,253 +26,249 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
) )
class TestNvidiaParameters(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_adapter():
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" """Set up the NVIDIA adapter with mock configuration"""
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig( config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
) )
self.adapter = NvidiaPostTrainingAdapter(config) adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch( # Mock the _make_request method
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request" with patch(
) "llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
self.mock_make_request = self.make_request_patcher.start() ) as mock_make_request:
self.mock_make_request.return_value = { mock_make_request.return_value = {
"id": "job-123", "id": "job-123",
"status": "created", "status": "created",
"created_at": "2025-03-04T13:07:47.543605", "created_at": "2025-03-04T13:07:47.543605",
"updated_at": "2025-03-04T13:07:47.543605", "updated_at": "2025-03-04T13:07:47.543605",
} }
yield adapter, mock_make_request
def tearDown(self):
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json): def assert_request_params(mock_make_request, expected_json):
"""Helper method to verify parameters in the request JSON.""" """Helper function to verify parameters in the request JSON."""
call_args = self.mock_make_request.call_args call_args = mock_make_request.call_args
actual_json = call_args[1]["json"] actual_json = call_args[1]["json"]
for key, value in expected_json.items(): for key, value in expected_json.items():
if isinstance(value, dict): if isinstance(value, dict):
for nested_key, nested_value in value.items(): for nested_key, nested_value in value.items():
assert actual_json[key][nested_key] == nested_value assert actual_json[key][nested_key] == nested_value
else: else:
assert actual_json[key] == value assert actual_json[key] == value
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_customizer_parameters_passed(self): async def test_customizer_parameters_passed(nvidia_adapter):
"""Test scenario 1: When an optional parameter is passed and value is correctly set.""" """Test scenario 1: When an optional parameter is passed and value is correctly set."""
algorithm_config = LoraFinetuningConfig( adapter, mock_make_request = nvidia_adapter
type="LoRA",
apply_lora_to_mlp=True, algorithm_config = LoraFinetuningConfig(
apply_lora_to_output=True, type="LoRA",
alpha=16, apply_lora_to_mlp=True,
rank=16, apply_lora_to_output=True,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = DataConfig(
dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0002,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=3,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
await adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
) )
data_config = DataConfig( warning_texts = [str(warning.message) for warning in w]
dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0002,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=3,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w: fields = [
warnings.simplefilter("always") "apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self.run_async( assert_request_params(
self.adapter.supervised_fine_tune( mock_make_request,
job_uuid="test-job", {
model="meta-llama/Llama-3.1-8B-Instruct", "hyperparameters": {
checkpoint_dir="", "lora": {"alpha": 16},
algorithm_config=algorithm_config, "epochs": 3,
training_config=convert_pydantic_to_json_value(training_config), "learning_rate": 0.0002,
logger_config={}, "batch_size": 16,
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self._assert_request_params(
{
"hyperparameters": {
"lora": {"alpha": 16},
"epochs": 3,
"learning_rate": 0.0002,
"batch_size": 16,
}
} }
},
)
async def test_required_parameters_passed(nvidia_adapter):
"""Test scenario 2: When required parameters are passed."""
adapter, mock_make_request = nvidia_adapter
required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40"
required_dataset_id = "required-dataset"
required_job_uuid = "required-job"
algorithm_config = LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = DataConfig(
dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
await adapter.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
) )
def test_required_parameters_passed(self): warning_texts = [str(warning.message) for warning in w]
"""Test scenario 2: When required parameters are passed."""
required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40"
required_dataset_id = "required-dataset"
required_job_uuid = "required-job"
algorithm_config = LoraFinetuningConfig( fields = [
type="LoRA", "rank",
apply_lora_to_mlp=True, "apply_lora_to_output",
apply_lora_to_output=True, "lora_attn_modules",
alpha=16, "apply_lora_to_mlp",
rank=16, ]
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], for field in fields:
assert any(field in text for text in warning_texts)
mock_make_request.assert_called_once()
call_args = mock_make_request.call_args
assert call_args[1]["json"]["config"] == required_model
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
async def test_unsupported_parameters_warning(nvidia_adapter):
"""Test that warnings are raised for unsupported parameters."""
adapter, mock_make_request = nvidia_adapter
data_config = DataConfig(
dataset_id="test-dataset",
batch_size=8,
# Unsupported parameters
shuffle=True,
data_format=DatasetFormat.instruct,
validation_dataset_id="val-dataset",
)
optimizer_config = OptimizerConfig(
lr=0.0001,
weight_decay=0.01,
# Unsupported parameters
optimizer_type=OptimizerType.adam,
num_warmup_steps=100,
)
efficiency_config = EfficiencyConfig(
enable_activation_checkpointing=True # Unsupported parameter
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
# Unsupported parameters
efficiency_config=efficiency_config,
max_steps_per_epoch=1000,
gradient_accumulation_steps=4,
max_validation_steps=100,
dtype="bf16",
)
# Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
await adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter
algorithm_config=LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
),
training_config=convert_pydantic_to_json_value(training_config),
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
) )
data_config = DataConfig( assert len(w) >= 4
dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct warning_texts = [str(warning.message) for warning in w]
)
optimizer_config = OptimizerConfig( fields = [
optimizer_type=OptimizerType.adam, "checkpoint_dir",
lr=0.0001, "hyperparam_search_config",
weight_decay=0.01, "logger_config",
num_warmup_steps=100, "TrainingConfig",
) "DataConfig",
"OptimizerConfig",
training_config = TrainingConfig( "max_steps_per_epoch",
n_epochs=1, "gradient_accumulation_steps",
data_config=data_config, "max_validation_steps",
optimizer_config=optimizer_config, "dtype",
) # required unsupported parameters
"rank",
with warnings.catch_warnings(record=True) as w: "apply_lora_to_output",
warnings.simplefilter("always") "lora_attn_modules",
"apply_lora_to_mlp",
self.run_async( ]
self.adapter.supervised_fine_tune( for field in fields:
job_uuid=required_job_uuid, # Required parameter assert any(field in text for text in warning_texts)
model=required_model, # Required parameter
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
fields = [
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
assert call_args[1]["json"]["config"] == required_model
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
def test_unsupported_parameters_warning(self):
"""Test that warnings are raised for unsupported parameters."""
data_config = DataConfig(
dataset_id="test-dataset",
batch_size=8,
# Unsupported parameters
shuffle=True,
data_format=DatasetFormat.instruct,
validation_dataset_id="val-dataset",
)
optimizer_config = OptimizerConfig(
lr=0.0001,
weight_decay=0.01,
# Unsupported parameters
optimizer_type=OptimizerType.adam,
num_warmup_steps=100,
)
efficiency_config = EfficiencyConfig(
enable_activation_checkpointing=True # Unsupported parameter
)
training_config = TrainingConfig(
n_epochs=1,
data_config=data_config,
optimizer_config=optimizer_config,
# Unsupported parameters
efficiency_config=efficiency_config,
max_steps_per_epoch=1000,
gradient_accumulation_steps=4,
max_validation_steps=100,
dtype="bf16",
)
# Capture warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter
algorithm_config=LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
),
training_config=convert_pydantic_to_json_value(training_config),
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
)
)
assert len(w) >= 4
warning_texts = [str(warning.message) for warning in w]
fields = [
"checkpoint_dir",
"hyperparam_search_config",
"logger_config",
"TrainingConfig",
"DataConfig",
"OptimizerConfig",
"max_steps_per_epoch",
"gradient_accumulation_steps",
"max_validation_steps",
"dtype",
# required unsupported parameters
"rank",
"apply_lora_to_output",
"lora_attn_modules",
"apply_lora_to_mlp",
]
for field in fields:
assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
from typing import Any from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch from unittest.mock import AsyncMock, MagicMock, patch
@ -18,308 +17,325 @@ from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
class TestNVIDIASafetyAdapter(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_safety_adapter():
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" """Set up the NVIDIA safety adapter with mocked dependencies"""
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
# Initialize the adapter # Initialize the adapter
self.config = NVIDIASafetyConfig( config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
) )
self.adapter = NVIDIASafetyAdapter(config=self.config) adapter = NVIDIASafetyAdapter(config=config)
self.shield_store = AsyncMock() shield_store = AsyncMock()
self.adapter.shield_store = self.shield_store adapter.shield_store = shield_store
# Mock the HTTP request methods # Mock the HTTP request methods
self.guardrails_post_patcher = patch( with patch(
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post" "llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
) ) as mock_guardrails_post:
self.mock_guardrails_post = self.guardrails_post_patcher.start() mock_guardrails_post.return_value = {"status": "allowed"}
self.mock_guardrails_post.return_value = {"status": "allowed"} yield adapter, shield_store, mock_guardrails_post
def tearDown(self):
"""Clean up after each test."""
self.guardrails_post_patcher.stop()
@pytest.fixture(autouse=True) def assert_request(
def inject_fixtures(self, run_async): mock_call: MagicMock,
self.run_async = run_async expected_url: str,
expected_headers: dict[str, str] | None = None,
expected_json: dict[str, Any] | None = None,
) -> None:
"""
Helper function to verify request details in mock API calls.
def _assert_request( Args:
self, mock_call: The MagicMock object that was called
mock_call: MagicMock, expected_url: The expected URL to which the request was made
expected_url: str, expected_headers: Optional dictionary of expected request headers
expected_headers: dict[str, str] | None = None, expected_json: Optional dictionary of expected JSON payload
expected_json: dict[str, Any] | None = None, """
) -> None: call_args = mock_call.call_args
"""
Helper method to verify request details in mock API calls.
Args: # Check URL
mock_call: The MagicMock object that was called assert call_args[0][0] == expected_url
expected_url: The expected URL to which the request was made
expected_headers: Optional dictionary of expected request headers
expected_json: Optional dictionary of expected JSON payload
"""
call_args = mock_call.call_args
# Check URL # Check headers if provided
assert call_args[0][0] == expected_url if expected_headers:
for key, value in expected_headers.items():
assert call_args[1]["headers"][key] == value
# Check headers if provided # Check JSON if provided
if expected_headers: if expected_json:
for key, value in expected_headers.items(): for key, value in expected_json.items():
assert call_args[1]["headers"][key] == value if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert call_args[1]["json"][key][nested_key] == nested_value
else:
assert call_args[1]["json"][key] == value
# Check JSON if provided
if expected_json:
for key, value in expected_json.items():
if isinstance(value, dict):
for nested_key, nested_value in value.items():
assert call_args[1]["json"][key][nested_key] == nested_value
else:
assert call_args[1]["json"][key] == value
def test_register_shield_with_valid_id(self): async def test_register_shield_with_valid_id(nvidia_safety_adapter):
shield = Shield( adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
provider_id="nvidia",
type="shield",
identifier="test-shield",
provider_resource_id="test-model",
)
# Register the shield shield = Shield(
self.run_async(self.adapter.register_shield(shield)) provider_id="nvidia",
type="shield",
identifier="test-shield",
provider_resource_id="test-model",
)
def test_register_shield_without_id(self): # Register the shield
shield = Shield( await adapter.register_shield(shield)
provider_id="nvidia",
type="shield",
identifier="test-shield",
provider_resource_id="",
)
# Register the shield should raise a ValueError
with self.assertRaises(ValueError):
self.run_async(self.adapter.register_shield(shield))
def test_run_shield_allowed(self): async def test_register_shield_without_id(nvidia_safety_adapter):
# Set up the shield adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type="shield",
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
# Mock Guardrails API response shield = Shield(
self.mock_guardrails_post.return_value = {"status": "allowed"} provider_id="nvidia",
type="shield",
identifier="test-shield",
provider_resource_id="",
)
# Run the shield # Register the shield should raise a ValueError
messages = [ with pytest.raises(ValueError):
UserMessage(role="user", content="Hello, how are you?"), await adapter.register_shield(shield)
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
result = self.run_async(self.adapter.run_shield(shield_id, messages))
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly async def test_run_shield_allowed(nvidia_safety_adapter):
self.mock_guardrails_post.assert_called_once_with( adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
path="/v1/guardrail/checks",
data={ # Set up the shield
"model": shield_id, shield_id = "test-shield"
"messages": [ shield = Shield(
{"role": "user", "content": "Hello, how are you?"}, provider_id="nvidia",
{"role": "assistant", "content": "I'm doing well, thank you for asking!"}, type="shield",
], identifier=shield_id,
"temperature": 1.0, provider_resource_id="test-model",
"top_p": 1, )
"frequency_penalty": 0, shield_store.get_shield.return_value = shield
"presence_penalty": 0,
"max_tokens": 160, # Mock Guardrails API response
"stream": False, mock_guardrails_post.return_value = {"status": "allowed"}
"guardrails": {
"config_id": "self-check", # Run the shield
}, messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
result = await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
}, },
) },
)
# Verify the result # Verify the result
assert isinstance(result, RunShieldResponse) assert isinstance(result, RunShieldResponse)
assert result.violation is None assert result.violation is None
def test_run_shield_blocked(self):
# Set up the shield
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
type="shield",
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
# Mock Guardrails API response async def test_run_shield_blocked(nvidia_safety_adapter):
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
# Run the shield # Set up the shield
messages = [ shield_id = "test-shield"
UserMessage(role="user", content="Hello, how are you?"), shield = Shield(
CompletionMessage( provider_id="nvidia",
role="assistant", type="shield",
content="I'm doing well, thank you for asking!", identifier=shield_id,
stop_reason="end_of_message", provider_resource_id="test-model",
tool_calls=[], )
), shield_store.get_shield.return_value = shield
]
result = self.run_async(self.adapter.run_shield(shield_id, messages))
# Verify the shield store was called # Mock Guardrails API response
self.shield_store.get_shield.assert_called_once_with(shield_id) mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
# Verify the Guardrails API was called correctly # Run the shield
self.mock_guardrails_post.assert_called_once_with( messages = [
path="/v1/guardrail/checks", UserMessage(role="user", content="Hello, how are you?"),
data={ CompletionMessage(
"model": shield_id, role="assistant",
"messages": [ content="I'm doing well, thank you for asking!",
{"role": "user", "content": "Hello, how are you?"}, stop_reason="end_of_message",
{"role": "assistant", "content": "I'm doing well, thank you for asking!"}, tool_calls=[],
], ),
"temperature": 1.0, ]
"top_p": 1, result = await adapter.run_shield(shield_id, messages)
"frequency_penalty": 0,
"presence_penalty": 0, # Verify the shield store was called
"max_tokens": 160, shield_store.get_shield.assert_called_once_with(shield_id)
"stream": False,
"guardrails": { # Verify the Guardrails API was called correctly
"config_id": "self-check", mock_guardrails_post.assert_called_once_with(
}, path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
}, },
) },
)
# Verify the result # Verify the result
assert result.violation is not None assert result.violation is not None
assert isinstance(result, RunShieldResponse) assert isinstance(result, RunShieldResponse)
assert result.violation.user_message == "Sorry I cannot do this." assert result.violation.user_message == "Sorry I cannot do this."
assert result.violation.violation_level == ViolationLevel.ERROR assert result.violation.violation_level == ViolationLevel.ERROR
assert result.violation.metadata == {"reason": "harmful_content"} assert result.violation.metadata == {"reason": "harmful_content"}
def test_run_shield_not_found(self):
# Set up shield store to return None
shield_id = "non-existent-shield"
self.shield_store.get_shield.return_value = None
messages = [ async def test_run_shield_not_found(nvidia_safety_adapter):
UserMessage(role="user", content="Hello, how are you?"), adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
]
with self.assertRaises(ValueError): # Set up shield store to return None
self.run_async(self.adapter.run_shield(shield_id, messages)) shield_id = "non-existent-shield"
shield_store.get_shield.return_value = None
# Verify the shield store was called messages = [
self.shield_store.get_shield.assert_called_once_with(shield_id) UserMessage(role="user", content="Hello, how are you?"),
]
# Verify the Guardrails API was not called with pytest.raises(ValueError):
self.mock_guardrails_post.assert_not_called() await adapter.run_shield(shield_id, messages)
def test_run_shield_http_error(self): # Verify the shield store was called
shield_id = "test-shield" shield_store.get_shield.assert_called_once_with(shield_id)
shield = Shield(
provider_id="nvidia",
type="shield",
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
# Mock Guardrails API to raise an exception # Verify the Guardrails API was not called
error_msg = "API Error: 500 Internal Server Error" mock_guardrails_post.assert_not_called()
self.mock_guardrails_post.side_effect = Exception(error_msg)
# Running the shield should raise an exception
messages = [
UserMessage(role="user", content="Hello, how are you?"),
CompletionMessage(
role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
with self.assertRaises(Exception) as context:
self.run_async(self.adapter.run_shield(shield_id, messages))
# Verify the shield store was called async def test_run_shield_http_error(nvidia_safety_adapter):
self.shield_store.get_shield.assert_called_once_with(shield_id) adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
# Verify the Guardrails API was called correctly shield_id = "test-shield"
self.mock_guardrails_post.assert_called_once_with( shield = Shield(
path="/v1/guardrail/checks", provider_id="nvidia",
data={ type="shield",
"model": shield_id, identifier=shield_id,
"messages": [ provider_resource_id="test-model",
{"role": "user", "content": "Hello, how are you?"}, )
{"role": "assistant", "content": "I'm doing well, thank you for asking!"}, shield_store.get_shield.return_value = shield
],
"temperature": 1.0, # Mock Guardrails API to raise an exception
"top_p": 1, error_msg = "API Error: 500 Internal Server Error"
"frequency_penalty": 0, mock_guardrails_post.side_effect = Exception(error_msg)
"presence_penalty": 0,
"max_tokens": 160, # Running the shield should raise an exception
"stream": False, messages = [
"guardrails": { UserMessage(role="user", content="Hello, how are you?"),
"config_id": "self-check", CompletionMessage(
}, role="assistant",
content="I'm doing well, thank you for asking!",
stop_reason="end_of_message",
tool_calls=[],
),
]
with pytest.raises(Exception) as excinfo:
await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
"messages": [
{"role": "user", "content": "Hello, how are you?"},
{"role": "assistant", "content": "I'm doing well, thank you for asking!"},
],
"temperature": 1.0,
"top_p": 1,
"frequency_penalty": 0,
"presence_penalty": 0,
"max_tokens": 160,
"stream": False,
"guardrails": {
"config_id": "self-check",
}, },
) },
# Verify the exception message )
assert error_msg in str(context.exception) # Verify the exception message
assert error_msg in str(excinfo.value)
def test_init_nemo_guardrails(self):
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
test_config_id = "test-custom-config-id" def test_init_nemo_guardrails():
config = NVIDIASafetyConfig( from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id=test_config_id,
)
# Initialize with default parameters
test_model = "test-model"
guardrails = NeMoGuardrails(config, test_model)
# Verify the attributes are set correctly os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
assert guardrails.config_id == test_config_id
assert guardrails.model == test_model
assert guardrails.threshold == 0.9 # Default value
assert guardrails.temperature == 1.0 # Default value
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
# Initialize with custom parameters test_config_id = "test-custom-config-id"
guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7) config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id=test_config_id,
)
# Initialize with default parameters
test_model = "test-model"
guardrails = NeMoGuardrails(config, test_model)
# Verify the attributes are set correctly # Verify the attributes are set correctly
assert guardrails.config_id == test_config_id assert guardrails.config_id == test_config_id
assert guardrails.model == test_model assert guardrails.model == test_model
assert guardrails.threshold == 0.8 assert guardrails.threshold == 0.9 # Default value
assert guardrails.temperature == 0.7 assert guardrails.temperature == 1.0 # Default value
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
def test_init_nemo_guardrails_invalid_temperature(self): # Initialize with custom parameters
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7)
config = NVIDIASafetyConfig( # Verify the attributes are set correctly
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], assert guardrails.config_id == test_config_id
config_id="test-custom-config-id", assert guardrails.model == test_model
) assert guardrails.threshold == 0.8
with self.assertRaises(ValueError): assert guardrails.temperature == 0.7
NeMoGuardrails(config, "test-model", temperature=0) assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
def test_init_nemo_guardrails_invalid_temperature():
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id="test-custom-config-id",
)
with pytest.raises(ValueError):
NeMoGuardrails(config, "test-model", temperature=0)

View file

@ -5,9 +5,8 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
import warnings import warnings
from unittest.mock import patch from unittest.mock import AsyncMock, MagicMock, patch
import pytest import pytest
@ -32,331 +31,334 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
) )
class TestNvidiaPostTraining(unittest.TestCase): @pytest.fixture
def setUp(self): def nvidia_adapters():
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference """Set up the NVIDIA adapters with mocked dependencies"""
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig( config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
) )
self.adapter = NvidiaPostTrainingAdapter(config) adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
# Mock the inference client
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
inference_adapter = NVIDIAInferenceAdapter(inference_config)
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
with (
patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request" "llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
) ) as mock_make_request,
self.mock_make_request = self.make_request_patcher.start() patch(
# Mock the inference client
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
self.mock_client = unittest.mock.MagicMock()
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
self.inference_mock_make_request = self.mock_client.chat.completions.create
self.inference_make_request_patcher = patch(
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
return_value=self.mock_client, return_value=mock_client,
),
):
yield adapter, inference_adapter, mock_make_request, mock_client
def assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper function to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
async def test_supervised_fine_tune(nvidia_adapters):
"""Test the supervised fine-tuning API call."""
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"alpha": 16},
},
"output_model": "default/job-1234",
"status": "created",
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
algorithm_config = LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
apply_lora_to_output=True,
alpha=16,
rank=16,
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"],
)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = await adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
) )
self.inference_make_request_patcher.start()
def tearDown(self): # check the output is a PostTrainingJob
self.make_request_patcher.stop() assert isinstance(training_job, NvidiaPostTrainingJob)
self.inference_make_request_patcher.stop() assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
@pytest.fixture(autouse=True) mock_make_request.assert_called_once()
def inject_fixtures(self, run_async): assert_request(
self.run_async = run_async mock_make_request,
"POST",
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None): "/v1/customization/jobs",
"""Helper method to verify request details in mock calls.""" expected_json={
call_args = mock_call.call_args "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40",
"dataset": {"name": "sample-basic-test", "namespace": "default"},
if expected_method and expected_path:
if isinstance(call_args[0], tuple) and len(call_args[0]) == 2:
assert call_args[0] == (expected_method, expected_path)
else:
assert call_args[1]["method"] == expected_method
assert call_args[1]["path"] == expected_path
if expected_params:
assert call_args[1]["params"] == expected_params
if expected_json:
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self):
"""Test the supervised fine-tuning API call."""
self.mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
"config": {
"schema_version": "1.0",
"id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.569837",
"custom_fields": {},
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
"model_path": "llama-3_1-8b-instruct",
"training_types": [],
"finetuning_types": ["lora"],
"precision": "bf16",
"num_gpus": 4,
"num_nodes": 1,
"micro_batch_size": 1,
"tensor_parallel_size": 1,
"max_seq_length": 4096,
},
"dataset": {
"schema_version": "1.0",
"id": "dataset-XU4pvGzr5tvawnbVxeJMTb",
"created_at": "2024-12-09T04:06:28.542657",
"updated_at": "2024-12-09T04:06:28.542660",
"custom_fields": {},
"name": "sample-basic-test",
"version_id": "main",
"version_tags": [],
},
"hyperparameters": { "hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft", "training_type": "sft",
"batch_size": 16, "finetuning_type": "lora",
"epochs": 2, "epochs": 2,
"batch_size": 16,
"learning_rate": 0.0001, "learning_rate": 0.0001,
"weight_decay": 0.01,
"lora": {"alpha": 16}, "lora": {"alpha": 16},
}, },
"output_model": "default/job-1234", },
"status": "created", )
"project": "default",
"custom_fields": {},
"ownership": {"created_by": "me", "access_policies": {}},
}
algorithm_config = LoraFinetuningConfig(
type="LoRA", async def test_supervised_fine_tune_with_qat(nvidia_adapters):
apply_lora_to_mlp=True, adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
apply_lora_to_output=True,
alpha=16, algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
rank=16, data_config = DataConfig(
lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig(
optimizer_type=OptimizerType.adam,
lr=0.0001,
weight_decay=0.01,
num_warmup_steps=100,
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with pytest.raises(NotImplementedError):
await adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
) )
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
)
optimizer_config = OptimizerConfig( @pytest.mark.parametrize(
optimizer_type=OptimizerType.adam, "customizer_status,expected_status",
lr=0.0001, [
weight_decay=0.01, ("running", "in_progress"),
num_warmup_steps=100, ("completed", "completed"),
) ("failed", "failed"),
("cancelled", "cancelled"),
("pending", "scheduled"),
("unknown", "scheduled"),
],
)
async def test_get_training_job_status(nvidia_adapters, customizer_status, expected_status):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
training_config = TrainingConfig( mock_make_request.return_value = {
n_epochs=2, "created_at": "2024-12-09T04:06:28.580220",
data_config=data_config, "updated_at": "2024-12-09T04:21:19.852832",
optimizer_config=optimizer_config, "status": customizer_status,
) "steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
with warnings.catch_warnings(record=True): job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
warnings.simplefilter("always")
training_job = self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
)
# check the output is a PostTrainingJob status = await adapter.get_training_job_status(job_uuid=job_id)
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once() assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
self._assert_request( assert status.status.value == expected_status
self.mock_make_request, assert status.steps_completed == 1210
"POST", assert status.epochs_completed == 2
"/v1/customization/jobs", assert status.percentage_done == 100.0
expected_json={ assert status.best_epoch == 2
"config": "meta/llama-3.2-1b-instruct@v1.0.0+L40", assert status.train_loss == 1.718016266822815
"dataset": {"name": "sample-basic-test", "namespace": "default"}, assert status.val_loss == 1.8661999702453613
"hyperparameters": {
"training_type": "sft", assert_request(
"finetuning_type": "lora", mock_make_request,
"epochs": 2, "GET",
"batch_size": 16, f"/v1/customization/jobs/{job_id}/status",
"learning_rate": 0.0001, expected_params={"job_id": job_id},
"weight_decay": 0.01, )
"lora": {"alpha": 16},
async def test_get_training_jobs(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
}, },
}, "dataset": {"name": "default/sample-basic-test"},
) "hyperparameters": {
"finetuning_type": "lora",
def test_supervised_fine_tune_with_qat(self): "training_type": "sft",
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) "batch_size": 16,
data_config = DataConfig( "epochs": 2,
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct "learning_rate": 0.0001,
) "lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
optimizer_config = OptimizerConfig( },
optimizer_type=OptimizerType.adam, "output_model": "default/job-1234",
lr=0.0001, "status": "completed",
weight_decay=0.01, "project": "default",
num_warmup_steps=100, }
)
training_config = TrainingConfig(
n_epochs=2,
data_config=data_config,
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError):
self.run_async(
self.adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
algorithm_config=algorithm_config,
training_config=convert_pydantic_to_json_value(training_config),
logger_config={},
hyperparam_search_config={},
)
)
def test_get_training_job_status(self):
customizer_status_to_job_status = [
("running", "in_progress"),
("completed", "completed"),
("failed", "failed"),
("cancelled", "cancelled"),
("pending", "scheduled"),
("unknown", "scheduled"),
] ]
}
for customizer_status, expected_status in customizer_status_to_job_status: jobs = await adapter.get_training_jobs()
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
self.mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": customizer_status,
"steps_completed": 1210,
"epochs_completed": 2,
"percentage_done": 100.0,
"best_epoch": 2,
"train_loss": 1.718016266822815,
"val_loss": 1.8661999702453613,
}
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.status.value == "completed"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id)) mock_make_request.assert_called_once()
assert_request(
mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == expected_status
assert status.steps_completed == 1210
assert status.epochs_completed == 2
assert status.percentage_done == 100.0
assert status.best_epoch == 2
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self._assert_request( async def test_cancel_training_job(nvidia_adapters):
self.mock_make_request, adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
"GET",
f"/v1/customization/jobs/{job_id}/status",
expected_params={"job_id": job_id},
)
def test_get_training_jobs(self): mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
"data": [
{
"id": job_id,
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:21:19.852832",
"config": {
"name": "meta-llama/Llama-3.1-8B-Instruct",
"base_model": "meta-llama/Llama-3.1-8B-Instruct",
},
"dataset": {"name": "default/sample-basic-test"},
"hyperparameters": {
"finetuning_type": "lora",
"training_type": "sft",
"batch_size": 16,
"epochs": 2,
"learning_rate": 0.0001,
"lora": {"adapter_dim": 16, "adapter_dropout": 0.1},
},
"output_model": "default/job-1234",
"status": "completed",
"project": "default",
}
]
}
jobs = self.run_async(self.adapter.get_training_jobs()) result = await adapter.cancel_training_job(job_uuid=job_id)
assert isinstance(jobs, ListNvidiaPostTrainingJobs) assert result is None
assert len(jobs.data) == 1
job = jobs.data[0]
assert job.job_uuid == job_id
assert job.status.value == "completed"
self.mock_make_request.assert_called_once() mock_make_request.assert_called_once()
self._assert_request( assert_request(
self.mock_make_request, mock_make_request,
"GET", "POST",
"/v1/customization/jobs", f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"}, expected_params={"job_id": job_id},
)
async def test_inference_register_model(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
model_id = "default/job-1234"
model_type = ModelType.llm
model = Model(
identifier=model_id,
provider_id="nvidia",
provider_model_id=model_id,
provider_resource_id=model_id,
model_type=model_type,
)
result = await inference_adapter.register_model(model)
assert result == model
assert len(inference_adapter.alias_to_provider_id_map) > 1
assert inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
with patch.object(inference_adapter, "chat_completion") as mock_chat_completion:
await inference_adapter.chat_completion(
model_id=model_id,
messages=[{"role": "user", "content": "Hello, model"}],
) )
def test_cancel_training_job(self): mock_chat_completion.assert_called()
self.mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id},
)
def test_inference_register_model(self):
model_id = "default/job-1234"
model_type = ModelType.llm
model = Model(
identifier=model_id,
provider_id="nvidia",
provider_model_id=model_id,
provider_resource_id=model_id,
model_type=model_type,
)
result = self.run_async(self.inference_adapter.register_model(model))
assert result == model
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
self.run_async(
self.inference_adapter.chat_completion(
model_id=model_id,
messages=[{"role": "user", "content": "Hello, model"}],
)
)
mock_chat_completion.assert_called()
if __name__ == "__main__":
unittest.main()

View file

@ -5,73 +5,91 @@
# the root directory of this source tree. # the root directory of this source tree.
import os import os
import unittest
import pytest
from llama_stack.distribution.stack import replace_env_vars from llama_stack.distribution.stack import replace_env_vars
class TestReplaceEnvVars(unittest.TestCase): @pytest.fixture(autouse=True)
def setUp(self): def setup_env_vars():
# Clear any existing environment variables we'll use in tests """Set up test environment variables and clean up after test"""
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]: # Store original values
if var in os.environ: original_vars = {}
del os.environ[var] for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
original_vars[var] = os.environ.get(var)
if var in os.environ:
del os.environ[var]
# Set up test environment variables # Set up test environment variables
os.environ["TEST_VAR"] = "test_value" os.environ["TEST_VAR"] = "test_value"
os.environ["EMPTY_VAR"] = "" os.environ["EMPTY_VAR"] = ""
os.environ["ZERO_VAR"] = "0" os.environ["ZERO_VAR"] = "0"
def test_simple_replacement(self): yield
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
def test_default_value_when_not_set(self): # Cleanup: restore original values
self.assertEqual(replace_env_vars("${env.NOT_SET:=default}"), "default") for var, original_value in original_vars.items():
if original_value is not None:
def test_default_value_when_set(self): os.environ[var] = original_value
self.assertEqual(replace_env_vars("${env.TEST_VAR:=default}"), "test_value") elif var in os.environ:
del os.environ[var]
def test_default_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=default}"), "default")
def test_none_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=}"), None)
def test_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:=}"), "test_value")
def test_empty_var_no_default(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}"), None)
def test_conditional_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:+conditional}"), "conditional")
def test_conditional_value_when_not_set(self):
self.assertEqual(replace_env_vars("${env.NOT_SET:+conditional}"), None)
def test_conditional_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:+conditional}"), None)
def test_conditional_value_with_zero(self):
self.assertEqual(replace_env_vars("${env.ZERO_VAR:+conditional}"), "conditional")
def test_mixed_syntax(self):
self.assertEqual(
replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}"), "test_value and "
)
self.assertEqual(
replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}"), "default and conditional"
)
def test_nested_structures(self):
data = {
"key1": "${env.TEST_VAR:=default}",
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
"key3": {"nested": "${env.NOT_SET:+conditional}"},
}
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
self.assertEqual(replace_env_vars(data), expected)
if __name__ == "__main__": def test_simple_replacement():
unittest.main() assert replace_env_vars("${env.TEST_VAR}") == "test_value"
def test_default_value_when_not_set():
assert replace_env_vars("${env.NOT_SET:=default}") == "default"
def test_default_value_when_set():
assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value"
def test_default_value_when_empty():
assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default"
def test_none_value_when_empty():
assert replace_env_vars("${env.EMPTY_VAR:=}") is None
def test_value_when_set():
assert replace_env_vars("${env.TEST_VAR:=}") == "test_value"
def test_empty_var_no_default():
assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None
def test_conditional_value_when_set():
assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional"
def test_conditional_value_when_not_set():
assert replace_env_vars("${env.NOT_SET:+conditional}") is None
def test_conditional_value_when_empty():
assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None
def test_conditional_value_with_zero():
assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional"
def test_mixed_syntax():
assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and "
assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional"
def test_nested_structures():
data = {
"key1": "${env.TEST_VAR:=default}",
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
"key3": {"nested": "${env.NOT_SET:+conditional}"},
}
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
assert replace_env_vars(data) == expected