From 933125389457b0f56b33ba63c6e5a43f375051e5 Mon Sep 17 00:00:00 2001 From: Mustafa Elbehery Date: Wed, 9 Jul 2025 12:55:16 +0200 Subject: [PATCH] test: migrate unit tests from unittest to pytest Signed-off-by: Mustafa Elbehery --- tests/unit/models/test_prompt_adapter.py | 489 +++++++------- tests/unit/models/test_system_prompts.py | 113 ++-- tests/unit/providers/nvidia/test_datastore.py | 206 +++--- tests/unit/providers/nvidia/test_eval.py | 312 ++++----- .../unit/providers/nvidia/test_parameters.py | 451 +++++++------ tests/unit/providers/nvidia/test_safety.py | 538 ++++++++-------- .../nvidia/test_supervised_fine_tuning.py | 596 +++++++++--------- tests/unit/server/test_replace_env_vars.py | 140 ++-- 8 files changed, 1440 insertions(+), 1405 deletions(-) diff --git a/tests/unit/models/test_prompt_adapter.py b/tests/unit/models/test_prompt_adapter.py index 0e2780e50..3a4dae491 100644 --- a/tests/unit/models/test_prompt_adapter.py +++ b/tests/unit/models/test_prompt_adapter.py @@ -4,8 +4,6 @@ # This source code is licensed under the terms described in the LICENSE file in # the root directory of this source tree. -import asyncio -import unittest from llama_stack.apis.inference import ( ChatCompletionRequest, @@ -31,258 +29,259 @@ MODEL = "Llama3.1-8B-Instruct" MODEL3_2 = "Llama3.2-3B-Instruct" -class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase): - async def asyncSetUp(self): - asyncio.get_running_loop().set_debug(False) +async def test_system_default(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[-1].content == content + assert "Cutting Knowledge Date: December 2023" in messages[0].content - async def test_system_default(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2) - self.assertEqual(messages[-1].content, content) - self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) - async def test_system_builtin_only(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2) - self.assertEqual(messages[-1].content, content) - self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content) - self.assertTrue("Tools: brave_search" in messages[0].content) +async def test_system_builtin_only(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[-1].content == content + assert "Cutting Knowledge Date: December 2023" in messages[0].content + assert "Tools: brave_search" in messages[0].content - async def test_system_custom_only(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ) - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 3) - self.assertTrue("Environment: ipython" in messages[0].content) - self.assertTrue("Return function calls in JSON format" in messages[1].content) - self.assertEqual(messages[-1].content, content) +async def test_system_custom_only(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ) + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json), + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 3 + assert "Environment: ipython" in messages[0].content - async def test_system_custom_and_builtin(self): - content = "Hello !" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition(tool_name=BuiltinTool.brave_search), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 3) + assert "Return function calls in JSON format" in messages[1].content + assert messages[-1].content == content - self.assertTrue("Environment: ipython" in messages[0].content) - self.assertTrue("Tools: brave_search" in messages[0].content) - self.assertTrue("Return function calls in JSON format" in messages[1].content) - self.assertEqual(messages[-1].content, content) - - async def test_completion_message_encoding(self): - request = ChatCompletionRequest( - model=MODEL3_2, - messages=[ - UserMessage(content="hello"), - CompletionMessage( - content="", - stop_reason=StopReason.end_of_turn, - tool_calls=[ - ToolCall( - tool_name="custom1", - arguments={"param1": "value1"}, - call_id="123", - ) - ], - ), - ], - tools=[ - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), - ) - prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn('[custom1(param1="value1")]', prompt) - - request.model = MODEL - request.tool_config.tool_prompt_format = ToolPromptFormat.json - prompt = await chat_completion_request_to_prompt(request, request.model) - self.assertIn( - '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}', - prompt, - ) - - async def test_user_provided_system_message(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - ) - messages = chat_completion_request_to_messages(request, MODEL) - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - - self.assertEqual(messages[-1].content, content) - - async def test_repalce_system_message_behavior_builtin_tools(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", +async def test_system_custom_and_builtin(): + content = "Hello !" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition(tool_name=BuiltinTool.brave_search), + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - self.assertIn("Environment: ipython", messages[0].content) - self.assertEqual(messages[-1].content, content) + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 3 - async def test_repalce_system_message_behavior_custom_tools(self): - content = "Hello !" - system_prompt = "You are a pirate" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", + assert "Environment: ipython" in messages[0].content + assert "Tools: brave_search" in messages[0].content + + assert "Return function calls in JSON format" in messages[1].content + assert messages[-1].content == content + + +async def test_completion_message_encoding(): + request = ChatCompletionRequest( + model=MODEL3_2, + messages=[ + UserMessage(content="hello"), + CompletionMessage( + content="", + stop_reason=StopReason.end_of_turn, + tool_calls=[ + ToolCall( + tool_name="custom1", + arguments={"param1": "value1"}, + call_id="123", + ) + ], ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) - - self.assertEqual(len(messages), 2, messages) - self.assertTrue(messages[0].content.endswith(system_prompt)) - self.assertIn("Environment: ipython", messages[0].content) - self.assertEqual(messages[-1].content, content) - - async def test_replace_system_message_behavior_custom_tools_with_template(self): - content = "Hello !" - system_prompt = "You are a pirate {{ function_description }}" - request = ChatCompletionRequest( - model=MODEL, - messages=[ - SystemMessage(content=system_prompt), - UserMessage(content=content), - ], - tools=[ - ToolDefinition(tool_name=BuiltinTool.code_interpreter), - ToolDefinition( - tool_name="custom1", - description="custom1 tool", - parameters={ - "param1": ToolParamDefinition( - param_type="str", - description="param1 description", - required=True, - ), - }, - ), - ], - tool_config=ToolConfig( - tool_choice="auto", - tool_prompt_format="python_list", - system_message_behavior="replace", + ], + tools=[ + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, ), - ) - messages = chat_completion_request_to_messages(request, MODEL3_2) + ], + tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list), + ) + prompt = await chat_completion_request_to_prompt(request, request.model) + assert '[custom1(param1="value1")]' in prompt - self.assertEqual(len(messages), 2, messages) - self.assertIn("Environment: ipython", messages[0].content) - self.assertIn("You are a pirate", messages[0].content) - # function description is present in the system prompt - self.assertIn('"name": "custom1"', messages[0].content) - self.assertEqual(messages[-1].content, content) + request.model = MODEL + request.tool_config.tool_prompt_format = ToolPromptFormat.json + prompt = await chat_completion_request_to_prompt(request, request.model) + assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt + + +async def test_user_provided_system_message(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + ) + messages = chat_completion_request_to_messages(request, MODEL) + assert len(messages) == 2 + assert messages[0].content.endswith(system_prompt) + + assert messages[-1].content == content + + +async def test_repalce_system_message_behavior_builtin_tools(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + assert len(messages) == 2 + assert messages[0].content.endswith(system_prompt) + assert "Environment: ipython" in messages[0].content + assert messages[-1].content == content + + +async def test_repalce_system_message_behavior_custom_tools(): + content = "Hello !" + system_prompt = "You are a pirate" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + + assert len(messages) == 2 + assert messages[0].content.endswith(system_prompt) + assert "Environment: ipython" in messages[0].content + assert messages[-1].content == content + + +async def test_replace_system_message_behavior_custom_tools_with_template(): + content = "Hello !" + system_prompt = "You are a pirate {{ function_description }}" + request = ChatCompletionRequest( + model=MODEL, + messages=[ + SystemMessage(content=system_prompt), + UserMessage(content=content), + ], + tools=[ + ToolDefinition(tool_name=BuiltinTool.code_interpreter), + ToolDefinition( + tool_name="custom1", + description="custom1 tool", + parameters={ + "param1": ToolParamDefinition( + param_type="str", + description="param1 description", + required=True, + ), + }, + ), + ], + tool_config=ToolConfig( + tool_choice="auto", + tool_prompt_format="python_list", + system_message_behavior="replace", + ), + ) + messages = chat_completion_request_to_messages(request, MODEL3_2) + + assert len(messages) == 2 + assert "Environment: ipython" in messages[0].content + assert "You are a pirate" in messages[0].content + # function description is present in the system prompt + assert '"name": "custom1"' in messages[0].content + assert messages[-1].content == content diff --git a/tests/unit/models/test_system_prompts.py b/tests/unit/models/test_system_prompts.py index 1f4ccc7e3..f5580f4c5 100644 --- a/tests/unit/models/test_system_prompts.py +++ b/tests/unit/models/test_system_prompts.py @@ -12,7 +12,6 @@ # the top-level of this source tree. import textwrap -import unittest from datetime import datetime from llama_stack.models.llama.llama3.prompt_templates import ( @@ -24,59 +23,61 @@ from llama_stack.models.llama.llama3.prompt_templates import ( ) -class PromptTemplateTests(unittest.TestCase): - def check_generator_output(self, generator): - for example in generator.data_examples(): - pt = generator.gen(example) - text = pt.render() - # print(text) # debugging - if not example: - continue - for tool in example: - assert tool.tool_name in text - - def test_system_default(self): - generator = SystemDefaultGenerator() - today = datetime.now().strftime("%d %B %Y") - expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" - assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() - - def test_system_builtin_only(self): - generator = BuiltinToolGenerator() - expected_text = textwrap.dedent( - """ - Environment: ipython - Tools: brave_search, wolfram_alpha - """ - ) - assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() - - def test_system_custom_only(self): - self.maxDiff = None - generator = JsonCustomToolGenerator() - self.check_generator_output(generator) - - def test_system_custom_function_tag(self): - self.maxDiff = None - generator = FunctionTagCustomToolGenerator() - self.check_generator_output(generator) - - def test_llama_3_2_system_zero_shot(self): - generator = PythonListCustomToolGenerator() - self.check_generator_output(generator) - - def test_llama_3_2_provided_system_prompt(self): - generator = PythonListCustomToolGenerator() - user_system_prompt = textwrap.dedent( - """ - Overriding message. - - {{ function_description }} - """ - ) - example = generator.data_examples()[0] - - pt = generator.gen(example, user_system_prompt) +def check_generator_output(generator): + for example in generator.data_examples(): + pt = generator.gen(example) text = pt.render() - assert "Overriding message." in text - assert '"name": "get_weather"' in text + if not example: + continue + for tool in example: + assert tool.tool_name in text + + +def test_system_default(): + generator = SystemDefaultGenerator() + today = datetime.now().strftime("%d %B %Y") + expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}" + assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() + + +def test_system_builtin_only(): + generator = BuiltinToolGenerator() + expected_text = textwrap.dedent( + """ + Environment: ipython + Tools: brave_search, wolfram_alpha + """ + ) + assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render() + + +def test_system_custom_only(): + generator = JsonCustomToolGenerator() + check_generator_output(generator) + + +def test_system_custom_function_tag(): + generator = FunctionTagCustomToolGenerator() + check_generator_output(generator) + + +def test_llama_3_2_system_zero_shot(): + generator = PythonListCustomToolGenerator() + check_generator_output(generator) + + +def test_llama_3_2_provided_system_prompt(): + generator = PythonListCustomToolGenerator() + user_system_prompt = textwrap.dedent( + """ + Overriding message. + + {{ function_description }} + """ + ) + example = generator.data_examples()[0] + + pt = generator.gen(example, user_system_prompt) + text = pt.render() + assert "Overriding message." in text + assert '"name": "get_weather"' in text diff --git a/tests/unit/providers/nvidia/test_datastore.py b/tests/unit/providers/nvidia/test_datastore.py index a17e51a9c..8f5213325 100644 --- a/tests/unit/providers/nvidia/test_datastore.py +++ b/tests/unit/providers/nvidia/test_datastore.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os -import unittest from unittest.mock import patch import pytest @@ -15,124 +14,125 @@ from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIO from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter -class TestNvidiaDatastore(unittest.TestCase): - def setUp(self): - os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" +@pytest.fixture +def nvidia_dataset_adapter(): + """Set up the NVIDIA dataset adapter with mocked dependencies""" + os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets" - config = NvidiaDatasetIOConfig( - datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default" - ) - self.adapter = NvidiaDatasetIOAdapter(config) - self.make_request_patcher = patch( - "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request" - ) - self.mock_make_request = self.make_request_patcher.start() + config = NvidiaDatasetIOConfig( + datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default" + ) + adapter = NvidiaDatasetIOAdapter(config) - def tearDown(self): - self.make_request_patcher.stop() + with patch( + "llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request" + ) as mock_make_request: + yield adapter, mock_make_request - @pytest.fixture(autouse=True) - def inject_fixtures(self, run_async): - self.run_async = run_async - def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None): - """Helper method to verify request details in mock calls.""" - call_args = mock_call.call_args +def assert_request(mock_call, expected_method, expected_path, expected_json=None): + """Helper function to verify request details in mock calls.""" + call_args = mock_call.call_args - assert call_args[0][0] == expected_method - assert call_args[0][1] == expected_path + assert call_args[0][0] == expected_method + assert call_args[0][1] == expected_path - if expected_json: - for key, value in expected_json.items(): - assert call_args[1]["json"][key] == value + if expected_json: + for key, value in expected_json.items(): + assert call_args[1]["json"][key] == value - def test_register_dataset(self): - self.mock_make_request.return_value = { - "id": "dataset-123456", + +async def test_register_dataset(nvidia_dataset_adapter): + adapter, mock_make_request = nvidia_dataset_adapter + + mock_make_request.return_value = { + "id": "dataset-123456", + "name": "test-dataset", + "namespace": "default", + } + + dataset_def = Dataset( + identifier="test-dataset", + type="dataset", + provider_resource_id="", + provider_id="", + purpose=DatasetPurpose.post_training_messages, + source=URIDataSource(uri="https://example.com/data.jsonl"), + metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"}, + ) + + await adapter.register_dataset(dataset_def) + + mock_make_request.assert_called_once() + assert_request( + mock_make_request, + "POST", + "/v1/datasets", + expected_json={ "name": "test-dataset", "namespace": "default", - } + "files_url": "https://example.com/data.jsonl", + "project": "default", + "format": "jsonl", + "description": "Test dataset description", + }, + ) - dataset_def = Dataset( - identifier="test-dataset", - type="dataset", - provider_resource_id="", - provider_id="", - purpose=DatasetPurpose.post_training_messages, - source=URIDataSource(uri="https://example.com/data.jsonl"), - metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"}, - ) - self.run_async(self.adapter.register_dataset(dataset_def)) +async def test_unregister_dataset(nvidia_dataset_adapter): + adapter, mock_make_request = nvidia_dataset_adapter - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, - "POST", - "/v1/datasets", - expected_json={ - "name": "test-dataset", - "namespace": "default", - "files_url": "https://example.com/data.jsonl", - "project": "default", - "format": "jsonl", - "description": "Test dataset description", - }, - ) + mock_make_request.return_value = { + "message": "Resource deleted successfully.", + "id": "dataset-81RSQp7FKX3rdBtKvF9Skn", + "deleted_at": None, + } + dataset_id = "test-dataset" - def test_unregister_dataset(self): - self.mock_make_request.return_value = { - "message": "Resource deleted successfully.", - "id": "dataset-81RSQp7FKX3rdBtKvF9Skn", - "deleted_at": None, - } - dataset_id = "test-dataset" + await adapter.unregister_dataset(dataset_id) - self.run_async(self.adapter.unregister_dataset(dataset_id)) + mock_make_request.assert_called_once() + assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset") - self.mock_make_request.assert_called_once() - self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset") - def test_register_dataset_with_custom_namespace_project(self): - custom_config = NvidiaDatasetIOConfig( - datasets_url=os.environ["NVIDIA_DATASETS_URL"], - dataset_namespace="custom-namespace", - project_id="custom-project", - ) - custom_adapter = NvidiaDatasetIOAdapter(custom_config) +async def test_register_dataset_with_custom_namespace_project(nvidia_dataset_adapter): + adapter, mock_make_request = nvidia_dataset_adapter - self.mock_make_request.return_value = { - "id": "dataset-123456", + custom_config = NvidiaDatasetIOConfig( + datasets_url=os.environ["NVIDIA_DATASETS_URL"], + dataset_namespace="custom-namespace", + project_id="custom-project", + ) + custom_adapter = NvidiaDatasetIOAdapter(custom_config) + + mock_make_request.return_value = { + "id": "dataset-123456", + "name": "test-dataset", + "namespace": "custom-namespace", + } + + dataset_def = Dataset( + identifier="test-dataset", + type="dataset", + provider_resource_id="", + provider_id="", + purpose=DatasetPurpose.post_training_messages, + source=URIDataSource(uri="https://example.com/data.jsonl"), + metadata={"format": "jsonl"}, + ) + + await custom_adapter.register_dataset(dataset_def) + + mock_make_request.assert_called_once() + assert_request( + mock_make_request, + "POST", + "/v1/datasets", + expected_json={ "name": "test-dataset", "namespace": "custom-namespace", - } - - dataset_def = Dataset( - identifier="test-dataset", - type="dataset", - provider_resource_id="", - provider_id="", - purpose=DatasetPurpose.post_training_messages, - source=URIDataSource(uri="https://example.com/data.jsonl"), - metadata={"format": "jsonl"}, - ) - - self.run_async(custom_adapter.register_dataset(dataset_def)) - - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, - "POST", - "/v1/datasets", - expected_json={ - "name": "test-dataset", - "namespace": "custom-namespace", - "files_url": "https://example.com/data.jsonl", - "project": "custom-project", - "format": "jsonl", - }, - ) - - -if __name__ == "__main__": - unittest.main() + "files_url": "https://example.com/data.jsonl", + "project": "custom-project", + "format": "jsonl", + }, + ) diff --git a/tests/unit/providers/nvidia/test_eval.py b/tests/unit/providers/nvidia/test_eval.py index 584ca2101..9adf976e2 100644 --- a/tests/unit/providers/nvidia/test_eval.py +++ b/tests/unit/providers/nvidia/test_eval.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os -import unittest from unittest.mock import MagicMock, patch import pytest @@ -21,181 +20,186 @@ MOCK_DATASET_ID = "default/test-dataset" MOCK_BENCHMARK_ID = "test-benchmark" -class TestNVIDIAEvalImpl(unittest.TestCase): - def setUp(self): - os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test" +@pytest.fixture +def nvidia_eval_impl(): + """Set up the NVIDIA eval implementation with mocked dependencies""" + os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test" - # Create mock APIs - self.datasetio_api = MagicMock() - self.datasets_api = MagicMock() - self.scoring_api = MagicMock() - self.inference_api = MagicMock() - self.agents_api = MagicMock() + # Create mock APIs + datasetio_api = MagicMock() + datasets_api = MagicMock() + scoring_api = MagicMock() + inference_api = MagicMock() + agents_api = MagicMock() - self.config = NVIDIAEvalConfig( - evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"], - ) + config = NVIDIAEvalConfig( + evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"], + ) - self.eval_impl = NVIDIAEvalImpl( - config=self.config, - datasetio_api=self.datasetio_api, - datasets_api=self.datasets_api, - scoring_api=self.scoring_api, - inference_api=self.inference_api, - agents_api=self.agents_api, - ) + eval_impl = NVIDIAEvalImpl( + config=config, + datasetio_api=datasetio_api, + datasets_api=datasets_api, + scoring_api=scoring_api, + inference_api=inference_api, + agents_api=agents_api, + ) - # Mock the HTTP request methods - self.evaluator_get_patcher = patch( - "llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get" - ) - self.evaluator_post_patcher = patch( - "llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post" - ) + # Mock the HTTP request methods + with ( + patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get") as mock_evaluator_get, + patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post") as mock_evaluator_post, + ): + yield eval_impl, mock_evaluator_get, mock_evaluator_post - self.mock_evaluator_get = self.evaluator_get_patcher.start() - self.mock_evaluator_post = self.evaluator_post_patcher.start() - def tearDown(self): - """Clean up after each test.""" - self.evaluator_get_patcher.stop() - self.evaluator_post_patcher.stop() +def assert_request_body(mock_evaluator_post, expected_json): + """Helper function to verify request body in Evaluator POST request is correct""" + call_args = mock_evaluator_post.call_args + actual_json = call_args[0][1] - def _assert_request_body(self, expected_json): - """Helper method to verify request body in Evaluator POST request is correct""" - call_args = self.mock_evaluator_post.call_args - actual_json = call_args[0][1] + # Check that all expected keys contain the expected values in the actual JSON + for key, value in expected_json.items(): + assert key in actual_json, f"Key '{key}' missing in actual JSON" - # Check that all expected keys contain the expected values in the actual JSON - for key, value in expected_json.items(): - assert key in actual_json, f"Key '{key}' missing in actual JSON" + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']" + assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'" + else: + assert actual_json[key] == value, f"Value mismatch for '{key}'" - if isinstance(value, dict): - for nested_key, nested_value in value.items(): - assert nested_key in actual_json[key], f"Nested key '{nested_key}' missing in actual JSON['{key}']" - assert actual_json[key][nested_key] == nested_value, f"Value mismatch for '{key}.{nested_key}'" - else: - assert actual_json[key] == value, f"Value mismatch for '{key}'" - @pytest.fixture(autouse=True) - def inject_fixtures(self, run_async): - self.run_async = run_async +async def test_register_benchmark(nvidia_eval_impl): + eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl - def test_register_benchmark(self): - eval_config = { - "type": "custom", - "params": {"parallelism": 8}, - "tasks": { - "qa": { - "type": "completion", - "params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}}, - "dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"}, - "metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}}, - } - }, - } - - benchmark = Benchmark( - provider_id="nvidia", - type="benchmark", - identifier=MOCK_BENCHMARK_ID, - dataset_id=MOCK_DATASET_ID, - scoring_functions=["basic::equality"], - metadata=eval_config, - ) - - # Mock Evaluator API response - mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"} - self.mock_evaluator_post.return_value = mock_evaluator_response - - # Register the benchmark - self.run_async(self.eval_impl.register_benchmark(benchmark)) - - # Verify the Evaluator API was called correctly - self.mock_evaluator_post.assert_called_once() - self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}) - - def test_run_eval(self): - benchmark_config = BenchmarkConfig( - eval_candidate=ModelCandidate( - type="model", - model=CoreModelId.llama3_1_8b_instruct.value, - sampling_params=SamplingParams(max_tokens=100, temperature=0.7), - ) - ) - - # Mock Evaluator API response - mock_evaluator_response = {"id": "job-123", "status": "created"} - self.mock_evaluator_post.return_value = mock_evaluator_response - - # Run the Evaluation job - result = self.run_async( - self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config) - ) - - # Verify the Evaluator API was called correctly - self.mock_evaluator_post.assert_called_once() - self._assert_request_body( - { - "config": f"nvidia/{MOCK_BENCHMARK_ID}", - "target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"}, + eval_config = { + "type": "custom", + "params": {"parallelism": 8}, + "tasks": { + "qa": { + "type": "completion", + "params": {"template": {"prompt": "{{prompt}}", "max_tokens": 200}}, + "dataset": {"files_url": f"hf://datasets/{MOCK_DATASET_ID}/testing/testing.jsonl"}, + "metrics": {"bleu": {"type": "bleu", "params": {"references": ["{{ideal_response}}"]}}}, } + }, + } + + benchmark = Benchmark( + provider_id="nvidia", + type="benchmark", + identifier=MOCK_BENCHMARK_ID, + dataset_id=MOCK_DATASET_ID, + scoring_functions=["basic::equality"], + metadata=eval_config, + ) + + # Mock Evaluator API response + mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"} + mock_evaluator_post.return_value = mock_evaluator_response + + # Register the benchmark + await eval_impl.register_benchmark(benchmark) + + # Verify the Evaluator API was called correctly + mock_evaluator_post.assert_called_once() + assert_request_body( + mock_evaluator_post, {"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config} + ) + + +async def test_run_eval(nvidia_eval_impl): + eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl + + benchmark_config = BenchmarkConfig( + eval_candidate=ModelCandidate( + type="model", + model=CoreModelId.llama3_1_8b_instruct.value, + sampling_params=SamplingParams(max_tokens=100, temperature=0.7), ) + ) - # Verify the result - assert isinstance(result, Job) - assert result.job_id == "job-123" - assert result.status == JobStatus.in_progress + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "created"} + mock_evaluator_post.return_value = mock_evaluator_response - def test_job_status(self): - # Mock Evaluator API response - mock_evaluator_response = {"id": "job-123", "status": "completed"} - self.mock_evaluator_get.return_value = mock_evaluator_response + # Run the Evaluation job + result = await eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config) - # Get the Evaluation job - result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + # Verify the Evaluator API was called correctly + mock_evaluator_post.assert_called_once() + assert_request_body( + mock_evaluator_post, + { + "config": f"nvidia/{MOCK_BENCHMARK_ID}", + "target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"}, + }, + ) - # Verify the result - assert isinstance(result, Job) - assert result.job_id == "job-123" - assert result.status == JobStatus.completed + # Verify the result + assert isinstance(result, Job) + assert result.job_id == "job-123" + assert result.status == JobStatus.in_progress - # Verify the API was called correctly - self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}") - def test_job_cancel(self): - # Mock Evaluator API response - mock_evaluator_response = {"id": "job-123", "status": "cancelled"} - self.mock_evaluator_post.return_value = mock_evaluator_response +async def test_job_status(nvidia_eval_impl): + eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl - # Cancel the Evaluation job - self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "completed"} + mock_evaluator_get.return_value = mock_evaluator_response - # Verify the API was called correctly - self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {}) + # Get the Evaluation job + result = await eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123") - def test_job_result(self): - # Mock Evaluator API responses - mock_job_status_response = {"id": "job-123", "status": "completed"} - mock_job_results_response = { - "id": "job-123", - "status": "completed", - "results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}}, - } - self.mock_evaluator_get.side_effect = [ - mock_job_status_response, # First call to retrieve job - mock_job_results_response, # Second call to retrieve job results - ] + # Verify the result + assert isinstance(result, Job) + assert result.job_id == "job-123" + assert result.status == JobStatus.completed - # Get the Evaluation job results - result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")) + # Verify the API was called correctly + mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}") - # Verify the result - assert isinstance(result, EvaluateResponse) - assert MOCK_BENCHMARK_ID in result.scores - assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85 - # Verify the API was called correctly - assert self.mock_evaluator_get.call_count == 2 - self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123") - self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results") +async def test_job_cancel(nvidia_eval_impl): + eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl + + # Mock Evaluator API response + mock_evaluator_response = {"id": "job-123", "status": "cancelled"} + mock_evaluator_post.return_value = mock_evaluator_response + + # Cancel the Evaluation job + await eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123") + + # Verify the API was called correctly + mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {}) + + +async def test_job_result(nvidia_eval_impl): + eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl + + # Mock Evaluator API responses + mock_job_status_response = {"id": "job-123", "status": "completed"} + mock_job_results_response = { + "id": "job-123", + "status": "completed", + "results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}}, + } + mock_evaluator_get.side_effect = [ + mock_job_status_response, # First call to retrieve job + mock_job_results_response, # Second call to retrieve job results + ] + + # Get the Evaluation job results + result = await eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123") + + # Verify the result + assert isinstance(result, EvaluateResponse) + assert MOCK_BENCHMARK_ID in result.scores + assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85 + + # Verify the API was called correctly + assert mock_evaluator_get.call_count == 2 + mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123") + mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results") diff --git a/tests/unit/providers/nvidia/test_parameters.py b/tests/unit/providers/nvidia/test_parameters.py index cc33f7609..3ed1b4a95 100644 --- a/tests/unit/providers/nvidia/test_parameters.py +++ b/tests/unit/providers/nvidia/test_parameters.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os -import unittest import warnings from unittest.mock import patch @@ -27,253 +26,249 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import ( ) -class TestNvidiaParameters(unittest.TestCase): - def setUp(self): - os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" - os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" +@pytest.fixture +def nvidia_adapter(): + """Set up the NVIDIA adapter with mock configuration""" + os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" + os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" - config = NvidiaPostTrainingConfig( - base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None - ) - self.adapter = NvidiaPostTrainingAdapter(config) + config = NvidiaPostTrainingConfig( + base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None + ) + adapter = NvidiaPostTrainingAdapter(config) - self.make_request_patcher = patch( - "llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request" - ) - self.mock_make_request = self.make_request_patcher.start() - self.mock_make_request.return_value = { + # Mock the _make_request method + with patch( + "llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request" + ) as mock_make_request: + mock_make_request.return_value = { "id": "job-123", "status": "created", "created_at": "2025-03-04T13:07:47.543605", "updated_at": "2025-03-04T13:07:47.543605", } + yield adapter, mock_make_request - def tearDown(self): - self.make_request_patcher.stop() - def _assert_request_params(self, expected_json): - """Helper method to verify parameters in the request JSON.""" - call_args = self.mock_make_request.call_args - actual_json = call_args[1]["json"] +def assert_request_params(mock_make_request, expected_json): + """Helper function to verify parameters in the request JSON.""" + call_args = mock_make_request.call_args + actual_json = call_args[1]["json"] - for key, value in expected_json.items(): - if isinstance(value, dict): - for nested_key, nested_value in value.items(): - assert actual_json[key][nested_key] == nested_value - else: - assert actual_json[key] == value + for key, value in expected_json.items(): + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + assert actual_json[key][nested_key] == nested_value + else: + assert actual_json[key] == value - @pytest.fixture(autouse=True) - def inject_fixtures(self, run_async): - self.run_async = run_async - def test_customizer_parameters_passed(self): - """Test scenario 1: When an optional parameter is passed and value is correctly set.""" - algorithm_config = LoraFinetuningConfig( - type="LoRA", - apply_lora_to_mlp=True, - apply_lora_to_output=True, - alpha=16, - rank=16, - lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], +async def test_customizer_parameters_passed(nvidia_adapter): + """Test scenario 1: When an optional parameter is passed and value is correctly set.""" + adapter, mock_make_request = nvidia_adapter + + algorithm_config = LoraFinetuningConfig( + type="LoRA", + apply_lora_to_mlp=True, + apply_lora_to_output=True, + alpha=16, + rank=16, + lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ) + + data_config = DataConfig( + dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0002, + weight_decay=0.01, + num_warmup_steps=100, + ) + training_config = TrainingConfig( + n_epochs=3, + data_config=data_config, + optimizer_config=optimizer_config, + ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + await adapter.supervised_fine_tune( + job_uuid="test-job", + model="meta-llama/Llama-3.1-8B-Instruct", + checkpoint_dir="", + algorithm_config=algorithm_config, + training_config=convert_pydantic_to_json_value(training_config), + logger_config={}, + hyperparam_search_config={}, ) - data_config = DataConfig( - dataset_id="test-dataset", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct - ) - optimizer_config = OptimizerConfig( - optimizer_type=OptimizerType.adam, - lr=0.0002, - weight_decay=0.01, - num_warmup_steps=100, - ) - training_config = TrainingConfig( - n_epochs=3, - data_config=data_config, - optimizer_config=optimizer_config, - ) + warning_texts = [str(warning.message) for warning in w] - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") + fields = [ + "apply_lora_to_output", + "lora_attn_modules", + "apply_lora_to_mlp", + ] + for field in fields: + assert any(field in text for text in warning_texts) - self.run_async( - self.adapter.supervised_fine_tune( - job_uuid="test-job", - model="meta-llama/Llama-3.1-8B-Instruct", - checkpoint_dir="", - algorithm_config=algorithm_config, - training_config=convert_pydantic_to_json_value(training_config), - logger_config={}, - hyperparam_search_config={}, - ) - ) - - warning_texts = [str(warning.message) for warning in w] - - fields = [ - "apply_lora_to_output", - "lora_attn_modules", - "apply_lora_to_mlp", - ] - for field in fields: - assert any(field in text for text in warning_texts) - - self._assert_request_params( - { - "hyperparameters": { - "lora": {"alpha": 16}, - "epochs": 3, - "learning_rate": 0.0002, - "batch_size": 16, - } + assert_request_params( + mock_make_request, + { + "hyperparameters": { + "lora": {"alpha": 16}, + "epochs": 3, + "learning_rate": 0.0002, + "batch_size": 16, } + }, + ) + + +async def test_required_parameters_passed(nvidia_adapter): + """Test scenario 2: When required parameters are passed.""" + adapter, mock_make_request = nvidia_adapter + + required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40" + required_dataset_id = "required-dataset" + required_job_uuid = "required-job" + + algorithm_config = LoraFinetuningConfig( + type="LoRA", + apply_lora_to_mlp=True, + apply_lora_to_output=True, + alpha=16, + rank=16, + lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ) + + data_config = DataConfig( + dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct + ) + + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, + ) + + training_config = TrainingConfig( + n_epochs=1, + data_config=data_config, + optimizer_config=optimizer_config, + ) + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + await adapter.supervised_fine_tune( + job_uuid=required_job_uuid, # Required parameter + model=required_model, # Required parameter + checkpoint_dir="", + algorithm_config=algorithm_config, + training_config=convert_pydantic_to_json_value(training_config), + logger_config={}, + hyperparam_search_config={}, ) - def test_required_parameters_passed(self): - """Test scenario 2: When required parameters are passed.""" - required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40" - required_dataset_id = "required-dataset" - required_job_uuid = "required-job" + warning_texts = [str(warning.message) for warning in w] - algorithm_config = LoraFinetuningConfig( - type="LoRA", - apply_lora_to_mlp=True, - apply_lora_to_output=True, - alpha=16, - rank=16, - lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + fields = [ + "rank", + "apply_lora_to_output", + "lora_attn_modules", + "apply_lora_to_mlp", + ] + for field in fields: + assert any(field in text for text in warning_texts) + + mock_make_request.assert_called_once() + call_args = mock_make_request.call_args + + assert call_args[1]["json"]["config"] == required_model + assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id + + +async def test_unsupported_parameters_warning(nvidia_adapter): + """Test that warnings are raised for unsupported parameters.""" + adapter, mock_make_request = nvidia_adapter + + data_config = DataConfig( + dataset_id="test-dataset", + batch_size=8, + # Unsupported parameters + shuffle=True, + data_format=DatasetFormat.instruct, + validation_dataset_id="val-dataset", + ) + + optimizer_config = OptimizerConfig( + lr=0.0001, + weight_decay=0.01, + # Unsupported parameters + optimizer_type=OptimizerType.adam, + num_warmup_steps=100, + ) + + efficiency_config = EfficiencyConfig( + enable_activation_checkpointing=True # Unsupported parameter + ) + + training_config = TrainingConfig( + n_epochs=1, + data_config=data_config, + optimizer_config=optimizer_config, + # Unsupported parameters + efficiency_config=efficiency_config, + max_steps_per_epoch=1000, + gradient_accumulation_steps=4, + max_validation_steps=100, + dtype="bf16", + ) + + # Capture warnings + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + await adapter.supervised_fine_tune( + job_uuid="test-job", + model="meta-llama/Llama-3.1-8B-Instruct", + checkpoint_dir="test-dir", # Unsupported parameter + algorithm_config=LoraFinetuningConfig( + type="LoRA", + apply_lora_to_mlp=True, + apply_lora_to_output=True, + alpha=16, + rank=16, + lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ), + training_config=convert_pydantic_to_json_value(training_config), + logger_config={"test": "value"}, # Unsupported parameter + hyperparam_search_config={"test": "value"}, # Unsupported parameter ) - data_config = DataConfig( - dataset_id=required_dataset_id, batch_size=8, shuffle=False, data_format=DatasetFormat.instruct - ) + assert len(w) >= 4 + warning_texts = [str(warning.message) for warning in w] - optimizer_config = OptimizerConfig( - optimizer_type=OptimizerType.adam, - lr=0.0001, - weight_decay=0.01, - num_warmup_steps=100, - ) - - training_config = TrainingConfig( - n_epochs=1, - data_config=data_config, - optimizer_config=optimizer_config, - ) - - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - self.run_async( - self.adapter.supervised_fine_tune( - job_uuid=required_job_uuid, # Required parameter - model=required_model, # Required parameter - checkpoint_dir="", - algorithm_config=algorithm_config, - training_config=convert_pydantic_to_json_value(training_config), - logger_config={}, - hyperparam_search_config={}, - ) - ) - - warning_texts = [str(warning.message) for warning in w] - - fields = [ - "rank", - "apply_lora_to_output", - "lora_attn_modules", - "apply_lora_to_mlp", - ] - for field in fields: - assert any(field in text for text in warning_texts) - - self.mock_make_request.assert_called_once() - call_args = self.mock_make_request.call_args - - assert call_args[1]["json"]["config"] == required_model - assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id - - def test_unsupported_parameters_warning(self): - """Test that warnings are raised for unsupported parameters.""" - data_config = DataConfig( - dataset_id="test-dataset", - batch_size=8, - # Unsupported parameters - shuffle=True, - data_format=DatasetFormat.instruct, - validation_dataset_id="val-dataset", - ) - - optimizer_config = OptimizerConfig( - lr=0.0001, - weight_decay=0.01, - # Unsupported parameters - optimizer_type=OptimizerType.adam, - num_warmup_steps=100, - ) - - efficiency_config = EfficiencyConfig( - enable_activation_checkpointing=True # Unsupported parameter - ) - - training_config = TrainingConfig( - n_epochs=1, - data_config=data_config, - optimizer_config=optimizer_config, - # Unsupported parameters - efficiency_config=efficiency_config, - max_steps_per_epoch=1000, - gradient_accumulation_steps=4, - max_validation_steps=100, - dtype="bf16", - ) - - # Capture warnings - with warnings.catch_warnings(record=True) as w: - warnings.simplefilter("always") - - self.run_async( - self.adapter.supervised_fine_tune( - job_uuid="test-job", - model="meta-llama/Llama-3.1-8B-Instruct", - checkpoint_dir="test-dir", # Unsupported parameter - algorithm_config=LoraFinetuningConfig( - type="LoRA", - apply_lora_to_mlp=True, - apply_lora_to_output=True, - alpha=16, - rank=16, - lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], - ), - training_config=convert_pydantic_to_json_value(training_config), - logger_config={"test": "value"}, # Unsupported parameter - hyperparam_search_config={"test": "value"}, # Unsupported parameter - ) - ) - - assert len(w) >= 4 - warning_texts = [str(warning.message) for warning in w] - - fields = [ - "checkpoint_dir", - "hyperparam_search_config", - "logger_config", - "TrainingConfig", - "DataConfig", - "OptimizerConfig", - "max_steps_per_epoch", - "gradient_accumulation_steps", - "max_validation_steps", - "dtype", - # required unsupported parameters - "rank", - "apply_lora_to_output", - "lora_attn_modules", - "apply_lora_to_mlp", - ] - for field in fields: - assert any(field in text for text in warning_texts) - - -if __name__ == "__main__": - unittest.main() + fields = [ + "checkpoint_dir", + "hyperparam_search_config", + "logger_config", + "TrainingConfig", + "DataConfig", + "OptimizerConfig", + "max_steps_per_epoch", + "gradient_accumulation_steps", + "max_validation_steps", + "dtype", + # required unsupported parameters + "rank", + "apply_lora_to_output", + "lora_attn_modules", + "apply_lora_to_mlp", + ] + for field in fields: + assert any(field in text for text in warning_texts) diff --git a/tests/unit/providers/nvidia/test_safety.py b/tests/unit/providers/nvidia/test_safety.py index 73fc32a02..37a861dff 100644 --- a/tests/unit/providers/nvidia/test_safety.py +++ b/tests/unit/providers/nvidia/test_safety.py @@ -5,7 +5,6 @@ # the root directory of this source tree. import os -import unittest from typing import Any from unittest.mock import AsyncMock, MagicMock, patch @@ -18,308 +17,325 @@ from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter -class TestNVIDIASafetyAdapter(unittest.TestCase): - def setUp(self): - os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" +@pytest.fixture +def nvidia_safety_adapter(): + """Set up the NVIDIA safety adapter with mocked dependencies""" + os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" - # Initialize the adapter - self.config = NVIDIASafetyConfig( - guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], - ) - self.adapter = NVIDIASafetyAdapter(config=self.config) - self.shield_store = AsyncMock() - self.adapter.shield_store = self.shield_store + # Initialize the adapter + config = NVIDIASafetyConfig( + guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], + ) + adapter = NVIDIASafetyAdapter(config=config) + shield_store = AsyncMock() + adapter.shield_store = shield_store - # Mock the HTTP request methods - self.guardrails_post_patcher = patch( - "llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post" - ) - self.mock_guardrails_post = self.guardrails_post_patcher.start() - self.mock_guardrails_post.return_value = {"status": "allowed"} + # Mock the HTTP request methods + with patch( + "llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post" + ) as mock_guardrails_post: + mock_guardrails_post.return_value = {"status": "allowed"} + yield adapter, shield_store, mock_guardrails_post - def tearDown(self): - """Clean up after each test.""" - self.guardrails_post_patcher.stop() - @pytest.fixture(autouse=True) - def inject_fixtures(self, run_async): - self.run_async = run_async +def assert_request( + mock_call: MagicMock, + expected_url: str, + expected_headers: dict[str, str] | None = None, + expected_json: dict[str, Any] | None = None, +) -> None: + """ + Helper function to verify request details in mock API calls. - def _assert_request( - self, - mock_call: MagicMock, - expected_url: str, - expected_headers: dict[str, str] | None = None, - expected_json: dict[str, Any] | None = None, - ) -> None: - """ - Helper method to verify request details in mock API calls. + Args: + mock_call: The MagicMock object that was called + expected_url: The expected URL to which the request was made + expected_headers: Optional dictionary of expected request headers + expected_json: Optional dictionary of expected JSON payload + """ + call_args = mock_call.call_args - Args: - mock_call: The MagicMock object that was called - expected_url: The expected URL to which the request was made - expected_headers: Optional dictionary of expected request headers - expected_json: Optional dictionary of expected JSON payload - """ - call_args = mock_call.call_args + # Check URL + assert call_args[0][0] == expected_url - # Check URL - assert call_args[0][0] == expected_url + # Check headers if provided + if expected_headers: + for key, value in expected_headers.items(): + assert call_args[1]["headers"][key] == value - # Check headers if provided - if expected_headers: - for key, value in expected_headers.items(): - assert call_args[1]["headers"][key] == value + # Check JSON if provided + if expected_json: + for key, value in expected_json.items(): + if isinstance(value, dict): + for nested_key, nested_value in value.items(): + assert call_args[1]["json"][key][nested_key] == nested_value + else: + assert call_args[1]["json"][key] == value - # Check JSON if provided - if expected_json: - for key, value in expected_json.items(): - if isinstance(value, dict): - for nested_key, nested_value in value.items(): - assert call_args[1]["json"][key][nested_key] == nested_value - else: - assert call_args[1]["json"][key] == value - def test_register_shield_with_valid_id(self): - shield = Shield( - provider_id="nvidia", - type="shield", - identifier="test-shield", - provider_resource_id="test-model", - ) +async def test_register_shield_with_valid_id(nvidia_safety_adapter): + adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter - # Register the shield - self.run_async(self.adapter.register_shield(shield)) + shield = Shield( + provider_id="nvidia", + type="shield", + identifier="test-shield", + provider_resource_id="test-model", + ) - def test_register_shield_without_id(self): - shield = Shield( - provider_id="nvidia", - type="shield", - identifier="test-shield", - provider_resource_id="", - ) + # Register the shield + await adapter.register_shield(shield) - # Register the shield should raise a ValueError - with self.assertRaises(ValueError): - self.run_async(self.adapter.register_shield(shield)) - def test_run_shield_allowed(self): - # Set up the shield - shield_id = "test-shield" - shield = Shield( - provider_id="nvidia", - type="shield", - identifier=shield_id, - provider_resource_id="test-model", - ) - self.shield_store.get_shield.return_value = shield +async def test_register_shield_without_id(nvidia_safety_adapter): + adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter - # Mock Guardrails API response - self.mock_guardrails_post.return_value = {"status": "allowed"} + shield = Shield( + provider_id="nvidia", + type="shield", + identifier="test-shield", + provider_resource_id="", + ) - # Run the shield - messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", - content="I'm doing well, thank you for asking!", - stop_reason="end_of_message", - tool_calls=[], - ), - ] - result = self.run_async(self.adapter.run_shield(shield_id, messages)) + # Register the shield should raise a ValueError + with pytest.raises(ValueError): + await adapter.register_shield(shield) - # Verify the shield store was called - self.shield_store.get_shield.assert_called_once_with(shield_id) - # Verify the Guardrails API was called correctly - self.mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", - data={ - "model": shield_id, - "messages": [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, - ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, +async def test_run_shield_allowed(nvidia_safety_adapter): + adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter + + # Set up the shield + shield_id = "test-shield" + shield = Shield( + provider_id="nvidia", + type="shield", + identifier=shield_id, + provider_resource_id="test-model", + ) + shield_store.get_shield.return_value = shield + + # Mock Guardrails API response + mock_guardrails_post.return_value = {"status": "allowed"} + + # Run the shield + messages = [ + UserMessage(role="user", content="Hello, how are you?"), + CompletionMessage( + role="assistant", + content="I'm doing well, thank you for asking!", + stop_reason="end_of_message", + tool_calls=[], + ), + ] + result = await adapter.run_shield(shield_id, messages) + + # Verify the shield store was called + shield_store.get_shield.assert_called_once_with(shield_id) + + # Verify the Guardrails API was called correctly + mock_guardrails_post.assert_called_once_with( + path="/v1/guardrail/checks", + data={ + "model": shield_id, + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, + ], + "temperature": 1.0, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 160, + "stream": False, + "guardrails": { + "config_id": "self-check", }, - ) + }, + ) - # Verify the result - assert isinstance(result, RunShieldResponse) - assert result.violation is None + # Verify the result + assert isinstance(result, RunShieldResponse) + assert result.violation is None - def test_run_shield_blocked(self): - # Set up the shield - shield_id = "test-shield" - shield = Shield( - provider_id="nvidia", - type="shield", - identifier=shield_id, - provider_resource_id="test-model", - ) - self.shield_store.get_shield.return_value = shield - # Mock Guardrails API response - self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} +async def test_run_shield_blocked(nvidia_safety_adapter): + adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter - # Run the shield - messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", - content="I'm doing well, thank you for asking!", - stop_reason="end_of_message", - tool_calls=[], - ), - ] - result = self.run_async(self.adapter.run_shield(shield_id, messages)) + # Set up the shield + shield_id = "test-shield" + shield = Shield( + provider_id="nvidia", + type="shield", + identifier=shield_id, + provider_resource_id="test-model", + ) + shield_store.get_shield.return_value = shield - # Verify the shield store was called - self.shield_store.get_shield.assert_called_once_with(shield_id) + # Mock Guardrails API response + mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}} - # Verify the Guardrails API was called correctly - self.mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", - data={ - "model": shield_id, - "messages": [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, - ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, + # Run the shield + messages = [ + UserMessage(role="user", content="Hello, how are you?"), + CompletionMessage( + role="assistant", + content="I'm doing well, thank you for asking!", + stop_reason="end_of_message", + tool_calls=[], + ), + ] + result = await adapter.run_shield(shield_id, messages) + + # Verify the shield store was called + shield_store.get_shield.assert_called_once_with(shield_id) + + # Verify the Guardrails API was called correctly + mock_guardrails_post.assert_called_once_with( + path="/v1/guardrail/checks", + data={ + "model": shield_id, + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, + ], + "temperature": 1.0, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 160, + "stream": False, + "guardrails": { + "config_id": "self-check", }, - ) + }, + ) - # Verify the result - assert result.violation is not None - assert isinstance(result, RunShieldResponse) - assert result.violation.user_message == "Sorry I cannot do this." - assert result.violation.violation_level == ViolationLevel.ERROR - assert result.violation.metadata == {"reason": "harmful_content"} + # Verify the result + assert result.violation is not None + assert isinstance(result, RunShieldResponse) + assert result.violation.user_message == "Sorry I cannot do this." + assert result.violation.violation_level == ViolationLevel.ERROR + assert result.violation.metadata == {"reason": "harmful_content"} - def test_run_shield_not_found(self): - # Set up shield store to return None - shield_id = "non-existent-shield" - self.shield_store.get_shield.return_value = None - messages = [ - UserMessage(role="user", content="Hello, how are you?"), - ] +async def test_run_shield_not_found(nvidia_safety_adapter): + adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter - with self.assertRaises(ValueError): - self.run_async(self.adapter.run_shield(shield_id, messages)) + # Set up shield store to return None + shield_id = "non-existent-shield" + shield_store.get_shield.return_value = None - # Verify the shield store was called - self.shield_store.get_shield.assert_called_once_with(shield_id) + messages = [ + UserMessage(role="user", content="Hello, how are you?"), + ] - # Verify the Guardrails API was not called - self.mock_guardrails_post.assert_not_called() + with pytest.raises(ValueError): + await adapter.run_shield(shield_id, messages) - def test_run_shield_http_error(self): - shield_id = "test-shield" - shield = Shield( - provider_id="nvidia", - type="shield", - identifier=shield_id, - provider_resource_id="test-model", - ) - self.shield_store.get_shield.return_value = shield + # Verify the shield store was called + shield_store.get_shield.assert_called_once_with(shield_id) - # Mock Guardrails API to raise an exception - error_msg = "API Error: 500 Internal Server Error" - self.mock_guardrails_post.side_effect = Exception(error_msg) + # Verify the Guardrails API was not called + mock_guardrails_post.assert_not_called() - # Running the shield should raise an exception - messages = [ - UserMessage(role="user", content="Hello, how are you?"), - CompletionMessage( - role="assistant", - content="I'm doing well, thank you for asking!", - stop_reason="end_of_message", - tool_calls=[], - ), - ] - with self.assertRaises(Exception) as context: - self.run_async(self.adapter.run_shield(shield_id, messages)) - # Verify the shield store was called - self.shield_store.get_shield.assert_called_once_with(shield_id) +async def test_run_shield_http_error(nvidia_safety_adapter): + adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter - # Verify the Guardrails API was called correctly - self.mock_guardrails_post.assert_called_once_with( - path="/v1/guardrail/checks", - data={ - "model": shield_id, - "messages": [ - {"role": "user", "content": "Hello, how are you?"}, - {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, - ], - "temperature": 1.0, - "top_p": 1, - "frequency_penalty": 0, - "presence_penalty": 0, - "max_tokens": 160, - "stream": False, - "guardrails": { - "config_id": "self-check", - }, + shield_id = "test-shield" + shield = Shield( + provider_id="nvidia", + type="shield", + identifier=shield_id, + provider_resource_id="test-model", + ) + shield_store.get_shield.return_value = shield + + # Mock Guardrails API to raise an exception + error_msg = "API Error: 500 Internal Server Error" + mock_guardrails_post.side_effect = Exception(error_msg) + + # Running the shield should raise an exception + messages = [ + UserMessage(role="user", content="Hello, how are you?"), + CompletionMessage( + role="assistant", + content="I'm doing well, thank you for asking!", + stop_reason="end_of_message", + tool_calls=[], + ), + ] + with pytest.raises(Exception) as excinfo: + await adapter.run_shield(shield_id, messages) + + # Verify the shield store was called + shield_store.get_shield.assert_called_once_with(shield_id) + + # Verify the Guardrails API was called correctly + mock_guardrails_post.assert_called_once_with( + path="/v1/guardrail/checks", + data={ + "model": shield_id, + "messages": [ + {"role": "user", "content": "Hello, how are you?"}, + {"role": "assistant", "content": "I'm doing well, thank you for asking!"}, + ], + "temperature": 1.0, + "top_p": 1, + "frequency_penalty": 0, + "presence_penalty": 0, + "max_tokens": 160, + "stream": False, + "guardrails": { + "config_id": "self-check", }, - ) - # Verify the exception message - assert error_msg in str(context.exception) + }, + ) + # Verify the exception message + assert error_msg in str(excinfo.value) - def test_init_nemo_guardrails(self): - from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails - test_config_id = "test-custom-config-id" - config = NVIDIASafetyConfig( - guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], - config_id=test_config_id, - ) - # Initialize with default parameters - test_model = "test-model" - guardrails = NeMoGuardrails(config, test_model) +def test_init_nemo_guardrails(): + from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails - # Verify the attributes are set correctly - assert guardrails.config_id == test_config_id - assert guardrails.model == test_model - assert guardrails.threshold == 0.9 # Default value - assert guardrails.temperature == 1.0 # Default value - assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] + os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" - # Initialize with custom parameters - guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7) + test_config_id = "test-custom-config-id" + config = NVIDIASafetyConfig( + guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], + config_id=test_config_id, + ) + # Initialize with default parameters + test_model = "test-model" + guardrails = NeMoGuardrails(config, test_model) - # Verify the attributes are set correctly - assert guardrails.config_id == test_config_id - assert guardrails.model == test_model - assert guardrails.threshold == 0.8 - assert guardrails.temperature == 0.7 - assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] + # Verify the attributes are set correctly + assert guardrails.config_id == test_config_id + assert guardrails.model == test_model + assert guardrails.threshold == 0.9 # Default value + assert guardrails.temperature == 1.0 # Default value + assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] - def test_init_nemo_guardrails_invalid_temperature(self): - from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails + # Initialize with custom parameters + guardrails = NeMoGuardrails(config, test_model, threshold=0.8, temperature=0.7) - config = NVIDIASafetyConfig( - guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], - config_id="test-custom-config-id", - ) - with self.assertRaises(ValueError): - NeMoGuardrails(config, "test-model", temperature=0) + # Verify the attributes are set correctly + assert guardrails.config_id == test_config_id + assert guardrails.model == test_model + assert guardrails.threshold == 0.8 + assert guardrails.temperature == 0.7 + assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"] + + +def test_init_nemo_guardrails_invalid_temperature(): + from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails + + os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test" + + config = NVIDIASafetyConfig( + guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"], + config_id="test-custom-config-id", + ) + with pytest.raises(ValueError): + NeMoGuardrails(config, "test-model", temperature=0) diff --git a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py index 97ca02fba..422e991fa 100644 --- a/tests/unit/providers/nvidia/test_supervised_fine_tuning.py +++ b/tests/unit/providers/nvidia/test_supervised_fine_tuning.py @@ -5,9 +5,8 @@ # the root directory of this source tree. import os -import unittest import warnings -from unittest.mock import patch +from unittest.mock import AsyncMock, MagicMock, patch import pytest @@ -32,331 +31,334 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import ( ) -class TestNvidiaPostTraining(unittest.TestCase): - def setUp(self): - os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference - os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer +@pytest.fixture +def nvidia_adapters(): + """Set up the NVIDIA adapters with mocked dependencies""" + os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference + os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer - config = NvidiaPostTrainingConfig( - base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None - ) - self.adapter = NvidiaPostTrainingAdapter(config) - self.make_request_patcher = patch( + config = NvidiaPostTrainingConfig( + base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None + ) + adapter = NvidiaPostTrainingAdapter(config) + + # Mock the inference client + inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None) + inference_adapter = NVIDIAInferenceAdapter(inference_config) + + mock_client = MagicMock() + mock_client.chat.completions.create = AsyncMock() + + with ( + patch( "llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request" - ) - self.mock_make_request = self.make_request_patcher.start() - - # Mock the inference client - inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None) - self.inference_adapter = NVIDIAInferenceAdapter(inference_config) - - self.mock_client = unittest.mock.MagicMock() - self.mock_client.chat.completions.create = unittest.mock.AsyncMock() - self.inference_mock_make_request = self.mock_client.chat.completions.create - self.inference_make_request_patcher = patch( + ) as mock_make_request, + patch( "llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client", - return_value=self.mock_client, + return_value=mock_client, + ), + ): + yield adapter, inference_adapter, mock_make_request, mock_client + + +def assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None): + """Helper function to verify request details in mock calls.""" + call_args = mock_call.call_args + + if expected_method and expected_path: + if isinstance(call_args[0], tuple) and len(call_args[0]) == 2: + assert call_args[0] == (expected_method, expected_path) + else: + assert call_args[1]["method"] == expected_method + assert call_args[1]["path"] == expected_path + + if expected_params: + assert call_args[1]["params"] == expected_params + + if expected_json: + for key, value in expected_json.items(): + assert call_args[1]["json"][key] == value + + +async def test_supervised_fine_tune(nvidia_adapters): + """Test the supervised fine-tuning API call.""" + adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters + + mock_make_request.return_value = { + "id": "cust-JGTaMbJMdqjJU8WbQdN9Q2", + "created_at": "2024-12-09T04:06:28.542884", + "updated_at": "2024-12-09T04:06:28.542884", + "config": { + "schema_version": "1.0", + "id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1", + "created_at": "2024-12-09T04:06:28.542657", + "updated_at": "2024-12-09T04:06:28.569837", + "custom_fields": {}, + "name": "meta-llama/Llama-3.1-8B-Instruct", + "base_model": "meta-llama/Llama-3.1-8B-Instruct", + "model_path": "llama-3_1-8b-instruct", + "training_types": [], + "finetuning_types": ["lora"], + "precision": "bf16", + "num_gpus": 4, + "num_nodes": 1, + "micro_batch_size": 1, + "tensor_parallel_size": 1, + "max_seq_length": 4096, + }, + "dataset": { + "schema_version": "1.0", + "id": "dataset-XU4pvGzr5tvawnbVxeJMTb", + "created_at": "2024-12-09T04:06:28.542657", + "updated_at": "2024-12-09T04:06:28.542660", + "custom_fields": {}, + "name": "sample-basic-test", + "version_id": "main", + "version_tags": [], + }, + "hyperparameters": { + "finetuning_type": "lora", + "training_type": "sft", + "batch_size": 16, + "epochs": 2, + "learning_rate": 0.0001, + "lora": {"alpha": 16}, + }, + "output_model": "default/job-1234", + "status": "created", + "project": "default", + "custom_fields": {}, + "ownership": {"created_by": "me", "access_policies": {}}, + } + + algorithm_config = LoraFinetuningConfig( + type="LoRA", + apply_lora_to_mlp=True, + apply_lora_to_output=True, + alpha=16, + rank=16, + lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + ) + + data_config = DataConfig( + dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) + + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, + ) + + training_config = TrainingConfig( + n_epochs=2, + data_config=data_config, + optimizer_config=optimizer_config, + ) + + with warnings.catch_warnings(record=True): + warnings.simplefilter("always") + training_job = await adapter.supervised_fine_tune( + job_uuid="1234", + model="meta/llama-3.2-1b-instruct@v1.0.0+L40", + checkpoint_dir="", + algorithm_config=algorithm_config, + training_config=convert_pydantic_to_json_value(training_config), + logger_config={}, + hyperparam_search_config={}, ) - self.inference_make_request_patcher.start() - def tearDown(self): - self.make_request_patcher.stop() - self.inference_make_request_patcher.stop() + # check the output is a PostTrainingJob + assert isinstance(training_job, NvidiaPostTrainingJob) + assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2" - @pytest.fixture(autouse=True) - def inject_fixtures(self, run_async): - self.run_async = run_async - - def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None): - """Helper method to verify request details in mock calls.""" - call_args = mock_call.call_args - - if expected_method and expected_path: - if isinstance(call_args[0], tuple) and len(call_args[0]) == 2: - assert call_args[0] == (expected_method, expected_path) - else: - assert call_args[1]["method"] == expected_method - assert call_args[1]["path"] == expected_path - - if expected_params: - assert call_args[1]["params"] == expected_params - - if expected_json: - for key, value in expected_json.items(): - assert call_args[1]["json"][key] == value - - def test_supervised_fine_tune(self): - """Test the supervised fine-tuning API call.""" - self.mock_make_request.return_value = { - "id": "cust-JGTaMbJMdqjJU8WbQdN9Q2", - "created_at": "2024-12-09T04:06:28.542884", - "updated_at": "2024-12-09T04:06:28.542884", - "config": { - "schema_version": "1.0", - "id": "af783f5b-d985-4e5b-bbb7-f9eec39cc0b1", - "created_at": "2024-12-09T04:06:28.542657", - "updated_at": "2024-12-09T04:06:28.569837", - "custom_fields": {}, - "name": "meta-llama/Llama-3.1-8B-Instruct", - "base_model": "meta-llama/Llama-3.1-8B-Instruct", - "model_path": "llama-3_1-8b-instruct", - "training_types": [], - "finetuning_types": ["lora"], - "precision": "bf16", - "num_gpus": 4, - "num_nodes": 1, - "micro_batch_size": 1, - "tensor_parallel_size": 1, - "max_seq_length": 4096, - }, - "dataset": { - "schema_version": "1.0", - "id": "dataset-XU4pvGzr5tvawnbVxeJMTb", - "created_at": "2024-12-09T04:06:28.542657", - "updated_at": "2024-12-09T04:06:28.542660", - "custom_fields": {}, - "name": "sample-basic-test", - "version_id": "main", - "version_tags": [], - }, + mock_make_request.assert_called_once() + assert_request( + mock_make_request, + "POST", + "/v1/customization/jobs", + expected_json={ + "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40", + "dataset": {"name": "sample-basic-test", "namespace": "default"}, "hyperparameters": { - "finetuning_type": "lora", "training_type": "sft", - "batch_size": 16, + "finetuning_type": "lora", "epochs": 2, + "batch_size": 16, "learning_rate": 0.0001, + "weight_decay": 0.01, "lora": {"alpha": 16}, }, - "output_model": "default/job-1234", - "status": "created", - "project": "default", - "custom_fields": {}, - "ownership": {"created_by": "me", "access_policies": {}}, - } + }, + ) - algorithm_config = LoraFinetuningConfig( - type="LoRA", - apply_lora_to_mlp=True, - apply_lora_to_output=True, - alpha=16, - rank=16, - lora_attn_modules=["q_proj", "k_proj", "v_proj", "o_proj"], + +async def test_supervised_fine_tune_with_qat(nvidia_adapters): + adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters + + algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) + data_config = DataConfig( + dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct + ) + optimizer_config = OptimizerConfig( + optimizer_type=OptimizerType.adam, + lr=0.0001, + weight_decay=0.01, + num_warmup_steps=100, + ) + training_config = TrainingConfig( + n_epochs=2, + data_config=data_config, + optimizer_config=optimizer_config, + ) + # This will raise NotImplementedError since QAT is not supported + with pytest.raises(NotImplementedError): + await adapter.supervised_fine_tune( + job_uuid="1234", + model="meta/llama-3.2-1b-instruct@v1.0.0+L40", + checkpoint_dir="", + algorithm_config=algorithm_config, + training_config=convert_pydantic_to_json_value(training_config), + logger_config={}, + hyperparam_search_config={}, ) - data_config = DataConfig( - dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct - ) - optimizer_config = OptimizerConfig( - optimizer_type=OptimizerType.adam, - lr=0.0001, - weight_decay=0.01, - num_warmup_steps=100, - ) +@pytest.mark.parametrize( + "customizer_status,expected_status", + [ + ("running", "in_progress"), + ("completed", "completed"), + ("failed", "failed"), + ("cancelled", "cancelled"), + ("pending", "scheduled"), + ("unknown", "scheduled"), + ], +) +async def test_get_training_job_status(nvidia_adapters, customizer_status, expected_status): + adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters - training_config = TrainingConfig( - n_epochs=2, - data_config=data_config, - optimizer_config=optimizer_config, - ) + mock_make_request.return_value = { + "created_at": "2024-12-09T04:06:28.580220", + "updated_at": "2024-12-09T04:21:19.852832", + "status": customizer_status, + "steps_completed": 1210, + "epochs_completed": 2, + "percentage_done": 100.0, + "best_epoch": 2, + "train_loss": 1.718016266822815, + "val_loss": 1.8661999702453613, + } - with warnings.catch_warnings(record=True): - warnings.simplefilter("always") - training_job = self.run_async( - self.adapter.supervised_fine_tune( - job_uuid="1234", - model="meta/llama-3.2-1b-instruct@v1.0.0+L40", - checkpoint_dir="", - algorithm_config=algorithm_config, - training_config=convert_pydantic_to_json_value(training_config), - logger_config={}, - hyperparam_search_config={}, - ) - ) + job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" - # check the output is a PostTrainingJob - assert isinstance(training_job, NvidiaPostTrainingJob) - assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2" + status = await adapter.get_training_job_status(job_uuid=job_id) - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, - "POST", - "/v1/customization/jobs", - expected_json={ - "config": "meta/llama-3.2-1b-instruct@v1.0.0+L40", - "dataset": {"name": "sample-basic-test", "namespace": "default"}, - "hyperparameters": { - "training_type": "sft", - "finetuning_type": "lora", - "epochs": 2, - "batch_size": 16, - "learning_rate": 0.0001, - "weight_decay": 0.01, - "lora": {"alpha": 16}, + assert isinstance(status, NvidiaPostTrainingJobStatusResponse) + assert status.status.value == expected_status + assert status.steps_completed == 1210 + assert status.epochs_completed == 2 + assert status.percentage_done == 100.0 + assert status.best_epoch == 2 + assert status.train_loss == 1.718016266822815 + assert status.val_loss == 1.8661999702453613 + + assert_request( + mock_make_request, + "GET", + f"/v1/customization/jobs/{job_id}/status", + expected_params={"job_id": job_id}, + ) + + +async def test_get_training_jobs(nvidia_adapters): + adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters + + job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" + mock_make_request.return_value = { + "data": [ + { + "id": job_id, + "created_at": "2024-12-09T04:06:28.542884", + "updated_at": "2024-12-09T04:21:19.852832", + "config": { + "name": "meta-llama/Llama-3.1-8B-Instruct", + "base_model": "meta-llama/Llama-3.1-8B-Instruct", }, - }, - ) - - def test_supervised_fine_tune_with_qat(self): - algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1) - data_config = DataConfig( - dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct - ) - optimizer_config = OptimizerConfig( - optimizer_type=OptimizerType.adam, - lr=0.0001, - weight_decay=0.01, - num_warmup_steps=100, - ) - training_config = TrainingConfig( - n_epochs=2, - data_config=data_config, - optimizer_config=optimizer_config, - ) - # This will raise NotImplementedError since QAT is not supported - with self.assertRaises(NotImplementedError): - self.run_async( - self.adapter.supervised_fine_tune( - job_uuid="1234", - model="meta/llama-3.2-1b-instruct@v1.0.0+L40", - checkpoint_dir="", - algorithm_config=algorithm_config, - training_config=convert_pydantic_to_json_value(training_config), - logger_config={}, - hyperparam_search_config={}, - ) - ) - - def test_get_training_job_status(self): - customizer_status_to_job_status = [ - ("running", "in_progress"), - ("completed", "completed"), - ("failed", "failed"), - ("cancelled", "cancelled"), - ("pending", "scheduled"), - ("unknown", "scheduled"), + "dataset": {"name": "default/sample-basic-test"}, + "hyperparameters": { + "finetuning_type": "lora", + "training_type": "sft", + "batch_size": 16, + "epochs": 2, + "learning_rate": 0.0001, + "lora": {"adapter_dim": 16, "adapter_dropout": 0.1}, + }, + "output_model": "default/job-1234", + "status": "completed", + "project": "default", + } ] + } - for customizer_status, expected_status in customizer_status_to_job_status: - with self.subTest(customizer_status=customizer_status, expected_status=expected_status): - self.mock_make_request.return_value = { - "created_at": "2024-12-09T04:06:28.580220", - "updated_at": "2024-12-09T04:21:19.852832", - "status": customizer_status, - "steps_completed": 1210, - "epochs_completed": 2, - "percentage_done": 100.0, - "best_epoch": 2, - "train_loss": 1.718016266822815, - "val_loss": 1.8661999702453613, - } + jobs = await adapter.get_training_jobs() - job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" + assert isinstance(jobs, ListNvidiaPostTrainingJobs) + assert len(jobs.data) == 1 + job = jobs.data[0] + assert job.job_uuid == job_id + assert job.status.value == "completed" - status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id)) + mock_make_request.assert_called_once() + assert_request( + mock_make_request, + "GET", + "/v1/customization/jobs", + expected_params={"page": 1, "page_size": 10, "sort": "created_at"}, + ) - assert isinstance(status, NvidiaPostTrainingJobStatusResponse) - assert status.status.value == expected_status - assert status.steps_completed == 1210 - assert status.epochs_completed == 2 - assert status.percentage_done == 100.0 - assert status.best_epoch == 2 - assert status.train_loss == 1.718016266822815 - assert status.val_loss == 1.8661999702453613 - self._assert_request( - self.mock_make_request, - "GET", - f"/v1/customization/jobs/{job_id}/status", - expected_params={"job_id": job_id}, - ) +async def test_cancel_training_job(nvidia_adapters): + adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters - def test_get_training_jobs(self): - job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" - self.mock_make_request.return_value = { - "data": [ - { - "id": job_id, - "created_at": "2024-12-09T04:06:28.542884", - "updated_at": "2024-12-09T04:21:19.852832", - "config": { - "name": "meta-llama/Llama-3.1-8B-Instruct", - "base_model": "meta-llama/Llama-3.1-8B-Instruct", - }, - "dataset": {"name": "default/sample-basic-test"}, - "hyperparameters": { - "finetuning_type": "lora", - "training_type": "sft", - "batch_size": 16, - "epochs": 2, - "learning_rate": 0.0001, - "lora": {"adapter_dim": 16, "adapter_dropout": 0.1}, - }, - "output_model": "default/job-1234", - "status": "completed", - "project": "default", - } - ] - } + mock_make_request.return_value = {} # Empty response for successful cancellation + job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" - jobs = self.run_async(self.adapter.get_training_jobs()) + result = await adapter.cancel_training_job(job_uuid=job_id) - assert isinstance(jobs, ListNvidiaPostTrainingJobs) - assert len(jobs.data) == 1 - job = jobs.data[0] - assert job.job_uuid == job_id - assert job.status.value == "completed" + assert result is None - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, - "GET", - "/v1/customization/jobs", - expected_params={"page": 1, "page_size": 10, "sort": "created_at"}, + mock_make_request.assert_called_once() + assert_request( + mock_make_request, + "POST", + f"/v1/customization/jobs/{job_id}/cancel", + expected_params={"job_id": job_id}, + ) + + +async def test_inference_register_model(nvidia_adapters): + adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters + + model_id = "default/job-1234" + model_type = ModelType.llm + model = Model( + identifier=model_id, + provider_id="nvidia", + provider_model_id=model_id, + provider_resource_id=model_id, + model_type=model_type, + ) + result = await inference_adapter.register_model(model) + assert result == model + assert len(inference_adapter.alias_to_provider_id_map) > 1 + assert inference_adapter.get_provider_model_id(model.provider_model_id) == model_id + + with patch.object(inference_adapter, "chat_completion") as mock_chat_completion: + await inference_adapter.chat_completion( + model_id=model_id, + messages=[{"role": "user", "content": "Hello, model"}], ) - def test_cancel_training_job(self): - self.mock_make_request.return_value = {} # Empty response for successful cancellation - job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2" - - result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id)) - - assert result is None - - self.mock_make_request.assert_called_once() - self._assert_request( - self.mock_make_request, - "POST", - f"/v1/customization/jobs/{job_id}/cancel", - expected_params={"job_id": job_id}, - ) - - def test_inference_register_model(self): - model_id = "default/job-1234" - model_type = ModelType.llm - model = Model( - identifier=model_id, - provider_id="nvidia", - provider_model_id=model_id, - provider_resource_id=model_id, - model_type=model_type, - ) - result = self.run_async(self.inference_adapter.register_model(model)) - assert result == model - assert len(self.inference_adapter.alias_to_provider_id_map) > 1 - assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id - - with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion: - self.run_async( - self.inference_adapter.chat_completion( - model_id=model_id, - messages=[{"role": "user", "content": "Hello, model"}], - ) - ) - - mock_chat_completion.assert_called() - - -if __name__ == "__main__": - unittest.main() + mock_chat_completion.assert_called() diff --git a/tests/unit/server/test_replace_env_vars.py b/tests/unit/server/test_replace_env_vars.py index 432d6aee5..ed0b02817 100644 --- a/tests/unit/server/test_replace_env_vars.py +++ b/tests/unit/server/test_replace_env_vars.py @@ -5,73 +5,91 @@ # the root directory of this source tree. import os -import unittest + +import pytest from llama_stack.distribution.stack import replace_env_vars -class TestReplaceEnvVars(unittest.TestCase): - def setUp(self): - # Clear any existing environment variables we'll use in tests - for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]: - if var in os.environ: - del os.environ[var] +@pytest.fixture(autouse=True) +def setup_env_vars(): + """Set up test environment variables and clean up after test""" + # Store original values + original_vars = {} + for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]: + original_vars[var] = os.environ.get(var) + if var in os.environ: + del os.environ[var] - # Set up test environment variables - os.environ["TEST_VAR"] = "test_value" - os.environ["EMPTY_VAR"] = "" - os.environ["ZERO_VAR"] = "0" + # Set up test environment variables + os.environ["TEST_VAR"] = "test_value" + os.environ["EMPTY_VAR"] = "" + os.environ["ZERO_VAR"] = "0" - def test_simple_replacement(self): - self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value") + yield - def test_default_value_when_not_set(self): - self.assertEqual(replace_env_vars("${env.NOT_SET:=default}"), "default") - - def test_default_value_when_set(self): - self.assertEqual(replace_env_vars("${env.TEST_VAR:=default}"), "test_value") - - def test_default_value_when_empty(self): - self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=default}"), "default") - - def test_none_value_when_empty(self): - self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=}"), None) - - def test_value_when_set(self): - self.assertEqual(replace_env_vars("${env.TEST_VAR:=}"), "test_value") - - def test_empty_var_no_default(self): - self.assertEqual(replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}"), None) - - def test_conditional_value_when_set(self): - self.assertEqual(replace_env_vars("${env.TEST_VAR:+conditional}"), "conditional") - - def test_conditional_value_when_not_set(self): - self.assertEqual(replace_env_vars("${env.NOT_SET:+conditional}"), None) - - def test_conditional_value_when_empty(self): - self.assertEqual(replace_env_vars("${env.EMPTY_VAR:+conditional}"), None) - - def test_conditional_value_with_zero(self): - self.assertEqual(replace_env_vars("${env.ZERO_VAR:+conditional}"), "conditional") - - def test_mixed_syntax(self): - self.assertEqual( - replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}"), "test_value and " - ) - self.assertEqual( - replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}"), "default and conditional" - ) - - def test_nested_structures(self): - data = { - "key1": "${env.TEST_VAR:=default}", - "key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"], - "key3": {"nested": "${env.NOT_SET:+conditional}"}, - } - expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}} - self.assertEqual(replace_env_vars(data), expected) + # Cleanup: restore original values + for var, original_value in original_vars.items(): + if original_value is not None: + os.environ[var] = original_value + elif var in os.environ: + del os.environ[var] -if __name__ == "__main__": - unittest.main() +def test_simple_replacement(): + assert replace_env_vars("${env.TEST_VAR}") == "test_value" + + +def test_default_value_when_not_set(): + assert replace_env_vars("${env.NOT_SET:=default}") == "default" + + +def test_default_value_when_set(): + assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value" + + +def test_default_value_when_empty(): + assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default" + + +def test_none_value_when_empty(): + assert replace_env_vars("${env.EMPTY_VAR:=}") is None + + +def test_value_when_set(): + assert replace_env_vars("${env.TEST_VAR:=}") == "test_value" + + +def test_empty_var_no_default(): + assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None + + +def test_conditional_value_when_set(): + assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional" + + +def test_conditional_value_when_not_set(): + assert replace_env_vars("${env.NOT_SET:+conditional}") is None + + +def test_conditional_value_when_empty(): + assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None + + +def test_conditional_value_with_zero(): + assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional" + + +def test_mixed_syntax(): + assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and " + assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional" + + +def test_nested_structures(): + data = { + "key1": "${env.TEST_VAR:=default}", + "key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"], + "key3": {"nested": "${env.NOT_SET:+conditional}"}, + } + expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}} + assert replace_env_vars(data) == expected