test: migrate unit tests from unittest to pytest

Signed-off-by: Mustafa Elbehery <melbeher@redhat.com>
This commit is contained in:
Mustafa Elbehery 2025-07-09 12:55:16 +02:00
parent ff9d4d8a9d
commit 9331253894
8 changed files with 1440 additions and 1405 deletions

View file

@ -4,8 +4,6 @@
# This source code is licensed under the terms described in the LICENSE file in
# the root directory of this source tree.
import asyncio
import unittest
from llama_stack.apis.inference import (
ChatCompletionRequest,
@ -31,11 +29,7 @@ MODEL = "Llama3.1-8B-Instruct"
MODEL3_2 = "Llama3.2-3B-Instruct"
class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
async def asyncSetUp(self):
asyncio.get_running_loop().set_debug(False)
async def test_system_default(self):
async def test_system_default():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
@ -44,11 +38,12 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in messages[0].content
async def test_system_builtin_only(self):
async def test_system_builtin_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
@ -61,12 +56,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2)
self.assertEqual(messages[-1].content, content)
self.assertTrue("Cutting Knowledge Date: December 2023" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
assert len(messages) == 2
assert messages[-1].content == content
assert "Cutting Knowledge Date: December 2023" in messages[0].content
assert "Tools: brave_search" in messages[0].content
async def test_system_custom_only(self):
async def test_system_custom_only():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
@ -89,13 +85,14 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.json),
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
self.assertTrue("Environment: ipython" in messages[0].content)
assert len(messages) == 3
assert "Environment: ipython" in messages[0].content
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
assert "Return function calls in JSON format" in messages[1].content
assert messages[-1].content == content
async def test_system_custom_and_builtin(self):
async def test_system_custom_and_builtin():
content = "Hello !"
request = ChatCompletionRequest(
model=MODEL,
@ -119,15 +116,16 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 3)
assert len(messages) == 3
self.assertTrue("Environment: ipython" in messages[0].content)
self.assertTrue("Tools: brave_search" in messages[0].content)
assert "Environment: ipython" in messages[0].content
assert "Tools: brave_search" in messages[0].content
self.assertTrue("Return function calls in JSON format" in messages[1].content)
self.assertEqual(messages[-1].content, content)
assert "Return function calls in JSON format" in messages[1].content
assert messages[-1].content == content
async def test_completion_message_encoding(self):
async def test_completion_message_encoding():
request = ChatCompletionRequest(
model=MODEL3_2,
messages=[
@ -160,17 +158,15 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
tool_config=ToolConfig(tool_prompt_format=ToolPromptFormat.python_list),
)
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn('[custom1(param1="value1")]', prompt)
assert '[custom1(param1="value1")]' in prompt
request.model = MODEL
request.tool_config.tool_prompt_format = ToolPromptFormat.json
prompt = await chat_completion_request_to_prompt(request, request.model)
self.assertIn(
'{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}',
prompt,
)
assert '{"type": "function", "name": "custom1", "parameters": {"param1": "value1"}}' in prompt
async def test_user_provided_system_message(self):
async def test_user_provided_system_message():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
@ -184,12 +180,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
],
)
messages = chat_completion_request_to_messages(request, MODEL)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
assert len(messages) == 2
assert messages[0].content.endswith(system_prompt)
self.assertEqual(messages[-1].content, content)
assert messages[-1].content == content
async def test_repalce_system_message_behavior_builtin_tools(self):
async def test_repalce_system_message_behavior_builtin_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
@ -208,12 +205,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
),
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertIn("Environment: ipython", messages[0].content)
self.assertEqual(messages[-1].content, content)
assert len(messages) == 2
assert messages[0].content.endswith(system_prompt)
assert "Environment: ipython" in messages[0].content
assert messages[-1].content == content
async def test_repalce_system_message_behavior_custom_tools(self):
async def test_repalce_system_message_behavior_custom_tools():
content = "Hello !"
system_prompt = "You are a pirate"
request = ChatCompletionRequest(
@ -244,12 +242,13 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages)
self.assertTrue(messages[0].content.endswith(system_prompt))
self.assertIn("Environment: ipython", messages[0].content)
self.assertEqual(messages[-1].content, content)
assert len(messages) == 2
assert messages[0].content.endswith(system_prompt)
assert "Environment: ipython" in messages[0].content
assert messages[-1].content == content
async def test_replace_system_message_behavior_custom_tools_with_template(self):
async def test_replace_system_message_behavior_custom_tools_with_template():
content = "Hello !"
system_prompt = "You are a pirate {{ function_description }}"
request = ChatCompletionRequest(
@ -280,9 +279,9 @@ class PrepareMessagesTests(unittest.IsolatedAsyncioTestCase):
)
messages = chat_completion_request_to_messages(request, MODEL3_2)
self.assertEqual(len(messages), 2, messages)
self.assertIn("Environment: ipython", messages[0].content)
self.assertIn("You are a pirate", messages[0].content)
assert len(messages) == 2
assert "Environment: ipython" in messages[0].content
assert "You are a pirate" in messages[0].content
# function description is present in the system prompt
self.assertIn('"name": "custom1"', messages[0].content)
self.assertEqual(messages[-1].content, content)
assert '"name": "custom1"' in messages[0].content
assert messages[-1].content == content

View file

@ -12,7 +12,6 @@
# the top-level of this source tree.
import textwrap
import unittest
from datetime import datetime
from llama_stack.models.llama.llama3.prompt_templates import (
@ -24,24 +23,24 @@ from llama_stack.models.llama.llama3.prompt_templates import (
)
class PromptTemplateTests(unittest.TestCase):
def check_generator_output(self, generator):
def check_generator_output(generator):
for example in generator.data_examples():
pt = generator.gen(example)
text = pt.render()
# print(text) # debugging
if not example:
continue
for tool in example:
assert tool.tool_name in text
def test_system_default(self):
def test_system_default():
generator = SystemDefaultGenerator()
today = datetime.now().strftime("%d %B %Y")
expected_text = f"Cutting Knowledge Date: December 2023\nToday Date: {today}"
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_builtin_only(self):
def test_system_builtin_only():
generator = BuiltinToolGenerator()
expected_text = textwrap.dedent(
"""
@ -51,21 +50,23 @@ class PromptTemplateTests(unittest.TestCase):
)
assert expected_text.strip("\n") == generator.gen(generator.data_examples()[0]).render()
def test_system_custom_only(self):
self.maxDiff = None
def test_system_custom_only():
generator = JsonCustomToolGenerator()
self.check_generator_output(generator)
check_generator_output(generator)
def test_system_custom_function_tag(self):
self.maxDiff = None
def test_system_custom_function_tag():
generator = FunctionTagCustomToolGenerator()
self.check_generator_output(generator)
check_generator_output(generator)
def test_llama_3_2_system_zero_shot(self):
def test_llama_3_2_system_zero_shot():
generator = PythonListCustomToolGenerator()
self.check_generator_output(generator)
check_generator_output(generator)
def test_llama_3_2_provided_system_prompt(self):
def test_llama_3_2_provided_system_prompt():
generator = PythonListCustomToolGenerator()
user_system_prompt = textwrap.dedent(
"""

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
import unittest
from unittest.mock import patch
import pytest
@ -15,28 +14,24 @@ from llama_stack.providers.remote.datasetio.nvidia.config import NvidiaDatasetIO
from llama_stack.providers.remote.datasetio.nvidia.datasetio import NvidiaDatasetIOAdapter
class TestNvidiaDatastore(unittest.TestCase):
def setUp(self):
@pytest.fixture
def nvidia_dataset_adapter():
"""Set up the NVIDIA dataset adapter with mocked dependencies"""
os.environ["NVIDIA_DATASETS_URL"] = "http://nemo.test/datasets"
config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"], dataset_namespace="default", project_id="default"
)
self.adapter = NvidiaDatasetIOAdapter(config)
self.make_request_patcher = patch(
adapter = NvidiaDatasetIOAdapter(config)
with patch(
"llama_stack.providers.remote.datasetio.nvidia.datasetio.NvidiaDatasetIOAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
) as mock_make_request:
yield adapter, mock_make_request
def tearDown(self):
self.make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_json=None):
"""Helper method to verify request details in mock calls."""
def assert_request(mock_call, expected_method, expected_path, expected_json=None):
"""Helper function to verify request details in mock calls."""
call_args = mock_call.call_args
assert call_args[0][0] == expected_method
@ -46,8 +41,11 @@ class TestNvidiaDatastore(unittest.TestCase):
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_register_dataset(self):
self.mock_make_request.return_value = {
async def test_register_dataset(nvidia_dataset_adapter):
adapter, mock_make_request = nvidia_dataset_adapter
mock_make_request.return_value = {
"id": "dataset-123456",
"name": "test-dataset",
"namespace": "default",
@ -63,11 +61,11 @@ class TestNvidiaDatastore(unittest.TestCase):
metadata={"provider_id": "nvidia", "format": "jsonl", "description": "Test dataset description"},
)
self.run_async(self.adapter.register_dataset(dataset_def))
await adapter.register_dataset(dataset_def)
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
mock_make_request.assert_called_once()
assert_request(
mock_make_request,
"POST",
"/v1/datasets",
expected_json={
@ -80,20 +78,26 @@ class TestNvidiaDatastore(unittest.TestCase):
},
)
def test_unregister_dataset(self):
self.mock_make_request.return_value = {
async def test_unregister_dataset(nvidia_dataset_adapter):
adapter, mock_make_request = nvidia_dataset_adapter
mock_make_request.return_value = {
"message": "Resource deleted successfully.",
"id": "dataset-81RSQp7FKX3rdBtKvF9Skn",
"deleted_at": None,
}
dataset_id = "test-dataset"
self.run_async(self.adapter.unregister_dataset(dataset_id))
await adapter.unregister_dataset(dataset_id)
self.mock_make_request.assert_called_once()
self._assert_request(self.mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
mock_make_request.assert_called_once()
assert_request(mock_make_request, "DELETE", "/v1/datasets/default/test-dataset")
async def test_register_dataset_with_custom_namespace_project(nvidia_dataset_adapter):
adapter, mock_make_request = nvidia_dataset_adapter
def test_register_dataset_with_custom_namespace_project(self):
custom_config = NvidiaDatasetIOConfig(
datasets_url=os.environ["NVIDIA_DATASETS_URL"],
dataset_namespace="custom-namespace",
@ -101,7 +105,7 @@ class TestNvidiaDatastore(unittest.TestCase):
)
custom_adapter = NvidiaDatasetIOAdapter(custom_config)
self.mock_make_request.return_value = {
mock_make_request.return_value = {
"id": "dataset-123456",
"name": "test-dataset",
"namespace": "custom-namespace",
@ -117,11 +121,11 @@ class TestNvidiaDatastore(unittest.TestCase):
metadata={"format": "jsonl"},
)
self.run_async(custom_adapter.register_dataset(dataset_def))
await custom_adapter.register_dataset(dataset_def)
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
mock_make_request.assert_called_once()
assert_request(
mock_make_request,
"POST",
"/v1/datasets",
expected_json={
@ -132,7 +136,3 @@ class TestNvidiaDatastore(unittest.TestCase):
"format": "jsonl",
},
)
if __name__ == "__main__":
unittest.main()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
import unittest
from unittest.mock import MagicMock, patch
import pytest
@ -21,49 +20,42 @@ MOCK_DATASET_ID = "default/test-dataset"
MOCK_BENCHMARK_ID = "test-benchmark"
class TestNVIDIAEvalImpl(unittest.TestCase):
def setUp(self):
@pytest.fixture
def nvidia_eval_impl():
"""Set up the NVIDIA eval implementation with mocked dependencies"""
os.environ["NVIDIA_EVALUATOR_URL"] = "http://nemo.test"
# Create mock APIs
self.datasetio_api = MagicMock()
self.datasets_api = MagicMock()
self.scoring_api = MagicMock()
self.inference_api = MagicMock()
self.agents_api = MagicMock()
datasetio_api = MagicMock()
datasets_api = MagicMock()
scoring_api = MagicMock()
inference_api = MagicMock()
agents_api = MagicMock()
self.config = NVIDIAEvalConfig(
config = NVIDIAEvalConfig(
evaluator_url=os.environ["NVIDIA_EVALUATOR_URL"],
)
self.eval_impl = NVIDIAEvalImpl(
config=self.config,
datasetio_api=self.datasetio_api,
datasets_api=self.datasets_api,
scoring_api=self.scoring_api,
inference_api=self.inference_api,
agents_api=self.agents_api,
eval_impl = NVIDIAEvalImpl(
config=config,
datasetio_api=datasetio_api,
datasets_api=datasets_api,
scoring_api=scoring_api,
inference_api=inference_api,
agents_api=agents_api,
)
# Mock the HTTP request methods
self.evaluator_get_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get"
)
self.evaluator_post_patcher = patch(
"llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post"
)
with (
patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_get") as mock_evaluator_get,
patch("llama_stack.providers.remote.eval.nvidia.eval.NVIDIAEvalImpl._evaluator_post") as mock_evaluator_post,
):
yield eval_impl, mock_evaluator_get, mock_evaluator_post
self.mock_evaluator_get = self.evaluator_get_patcher.start()
self.mock_evaluator_post = self.evaluator_post_patcher.start()
def tearDown(self):
"""Clean up after each test."""
self.evaluator_get_patcher.stop()
self.evaluator_post_patcher.stop()
def _assert_request_body(self, expected_json):
"""Helper method to verify request body in Evaluator POST request is correct"""
call_args = self.mock_evaluator_post.call_args
def assert_request_body(mock_evaluator_post, expected_json):
"""Helper function to verify request body in Evaluator POST request is correct"""
call_args = mock_evaluator_post.call_args
actual_json = call_args[0][1]
# Check that all expected keys contain the expected values in the actual JSON
@ -77,11 +69,10 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
else:
assert actual_json[key] == value, f"Value mismatch for '{key}'"
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_register_benchmark(self):
async def test_register_benchmark(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
eval_config = {
"type": "custom",
"params": {"parallelism": 8},
@ -106,16 +97,21 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
# Mock Evaluator API response
mock_evaluator_response = {"id": MOCK_BENCHMARK_ID, "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
mock_evaluator_post.return_value = mock_evaluator_response
# Register the benchmark
self.run_async(self.eval_impl.register_benchmark(benchmark))
await eval_impl.register_benchmark(benchmark)
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body({"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config})
mock_evaluator_post.assert_called_once()
assert_request_body(
mock_evaluator_post, {"namespace": benchmark.provider_id, "name": benchmark.identifier, **eval_config}
)
async def test_run_eval(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
def test_run_eval(self):
benchmark_config = BenchmarkConfig(
eval_candidate=ModelCandidate(
type="model",
@ -126,20 +122,19 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "created"}
self.mock_evaluator_post.return_value = mock_evaluator_response
mock_evaluator_post.return_value = mock_evaluator_response
# Run the Evaluation job
result = self.run_async(
self.eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
)
result = await eval_impl.run_eval(benchmark_id=MOCK_BENCHMARK_ID, benchmark_config=benchmark_config)
# Verify the Evaluator API was called correctly
self.mock_evaluator_post.assert_called_once()
self._assert_request_body(
mock_evaluator_post.assert_called_once()
assert_request_body(
mock_evaluator_post,
{
"config": f"nvidia/{MOCK_BENCHMARK_ID}",
"target": {"type": "model", "model": "meta/llama-3.1-8b-instruct"},
}
},
)
# Verify the result
@ -147,13 +142,16 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
assert result.job_id == "job-123"
assert result.status == JobStatus.in_progress
def test_job_status(self):
async def test_job_status(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "completed"}
self.mock_evaluator_get.return_value = mock_evaluator_response
mock_evaluator_get.return_value = mock_evaluator_response
# Get the Evaluation job
result = self.run_async(self.eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
result = await eval_impl.job_status(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
# Verify the result
assert isinstance(result, Job)
@ -161,20 +159,26 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
assert result.status == JobStatus.completed
# Verify the API was called correctly
self.mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
mock_evaluator_get.assert_called_once_with(f"/v1/evaluation/jobs/{result.job_id}")
async def test_job_cancel(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
def test_job_cancel(self):
# Mock Evaluator API response
mock_evaluator_response = {"id": "job-123", "status": "cancelled"}
self.mock_evaluator_post.return_value = mock_evaluator_response
mock_evaluator_post.return_value = mock_evaluator_response
# Cancel the Evaluation job
self.run_async(self.eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
await eval_impl.job_cancel(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
# Verify the API was called correctly
self.mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
mock_evaluator_post.assert_called_once_with("/v1/evaluation/jobs/job-123/cancel", {})
async def test_job_result(nvidia_eval_impl):
eval_impl, mock_evaluator_get, mock_evaluator_post = nvidia_eval_impl
def test_job_result(self):
# Mock Evaluator API responses
mock_job_status_response = {"id": "job-123", "status": "completed"}
mock_job_results_response = {
@ -182,13 +186,13 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
"status": "completed",
"results": {MOCK_BENCHMARK_ID: {"score": 0.85, "details": {"accuracy": 0.85, "f1": 0.84}}},
}
self.mock_evaluator_get.side_effect = [
mock_evaluator_get.side_effect = [
mock_job_status_response, # First call to retrieve job
mock_job_results_response, # Second call to retrieve job results
]
# Get the Evaluation job results
result = self.run_async(self.eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123"))
result = await eval_impl.job_result(benchmark_id=MOCK_BENCHMARK_ID, job_id="job-123")
# Verify the result
assert isinstance(result, EvaluateResponse)
@ -196,6 +200,6 @@ class TestNVIDIAEvalImpl(unittest.TestCase):
assert result.scores[MOCK_BENCHMARK_ID].aggregated_results["results"][MOCK_BENCHMARK_ID]["score"] == 0.85
# Verify the API was called correctly
assert self.mock_evaluator_get.call_count == 2
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
self.mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")
assert mock_evaluator_get.call_count == 2
mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123")
mock_evaluator_get.assert_any_call("/v1/evaluation/jobs/job-123/results")

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
@ -27,33 +26,33 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
)
class TestNvidiaParameters(unittest.TestCase):
def setUp(self):
@pytest.fixture
def nvidia_adapter():
"""Set up the NVIDIA adapter with mock configuration"""
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test"
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test"
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
# Mock the _make_request method
with patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
self.mock_make_request.return_value = {
) as mock_make_request:
mock_make_request.return_value = {
"id": "job-123",
"status": "created",
"created_at": "2025-03-04T13:07:47.543605",
"updated_at": "2025-03-04T13:07:47.543605",
}
yield adapter, mock_make_request
def tearDown(self):
self.make_request_patcher.stop()
def _assert_request_params(self, expected_json):
"""Helper method to verify parameters in the request JSON."""
call_args = self.mock_make_request.call_args
def assert_request_params(mock_make_request, expected_json):
"""Helper function to verify parameters in the request JSON."""
call_args = mock_make_request.call_args
actual_json = call_args[1]["json"]
for key, value in expected_json.items():
@ -63,12 +62,11 @@ class TestNvidiaParameters(unittest.TestCase):
else:
assert actual_json[key] == value
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def test_customizer_parameters_passed(self):
async def test_customizer_parameters_passed(nvidia_adapter):
"""Test scenario 1: When an optional parameter is passed and value is correctly set."""
adapter, mock_make_request = nvidia_adapter
algorithm_config = LoraFinetuningConfig(
type="LoRA",
apply_lora_to_mlp=True,
@ -96,8 +94,7 @@ class TestNvidiaParameters(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
await adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="",
@ -106,7 +103,6 @@ class TestNvidiaParameters(unittest.TestCase):
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
@ -118,7 +114,8 @@ class TestNvidiaParameters(unittest.TestCase):
for field in fields:
assert any(field in text for text in warning_texts)
self._assert_request_params(
assert_request_params(
mock_make_request,
{
"hyperparameters": {
"lora": {"alpha": 16},
@ -126,11 +123,14 @@ class TestNvidiaParameters(unittest.TestCase):
"learning_rate": 0.0002,
"batch_size": 16,
}
}
},
)
def test_required_parameters_passed(self):
async def test_required_parameters_passed(nvidia_adapter):
"""Test scenario 2: When required parameters are passed."""
adapter, mock_make_request = nvidia_adapter
required_model = "meta/llama-3.2-1b-instruct@v1.0.0+L40"
required_dataset_id = "required-dataset"
required_job_uuid = "required-job"
@ -164,8 +164,7 @@ class TestNvidiaParameters(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
await adapter.supervised_fine_tune(
job_uuid=required_job_uuid, # Required parameter
model=required_model, # Required parameter
checkpoint_dir="",
@ -174,7 +173,6 @@ class TestNvidiaParameters(unittest.TestCase):
logger_config={},
hyperparam_search_config={},
)
)
warning_texts = [str(warning.message) for warning in w]
@ -187,14 +185,17 @@ class TestNvidiaParameters(unittest.TestCase):
for field in fields:
assert any(field in text for text in warning_texts)
self.mock_make_request.assert_called_once()
call_args = self.mock_make_request.call_args
mock_make_request.assert_called_once()
call_args = mock_make_request.call_args
assert call_args[1]["json"]["config"] == required_model
assert call_args[1]["json"]["dataset"]["name"] == required_dataset_id
def test_unsupported_parameters_warning(self):
async def test_unsupported_parameters_warning(nvidia_adapter):
"""Test that warnings are raised for unsupported parameters."""
adapter, mock_make_request = nvidia_adapter
data_config = DataConfig(
dataset_id="test-dataset",
batch_size=8,
@ -232,8 +233,7 @@ class TestNvidiaParameters(unittest.TestCase):
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
self.run_async(
self.adapter.supervised_fine_tune(
await adapter.supervised_fine_tune(
job_uuid="test-job",
model="meta-llama/Llama-3.1-8B-Instruct",
checkpoint_dir="test-dir", # Unsupported parameter
@ -249,7 +249,6 @@ class TestNvidiaParameters(unittest.TestCase):
logger_config={"test": "value"}, # Unsupported parameter
hyperparam_search_config={"test": "value"}, # Unsupported parameter
)
)
assert len(w) >= 4
warning_texts = [str(warning.message) for warning in w]
@ -273,7 +272,3 @@ class TestNvidiaParameters(unittest.TestCase):
]
for field in fields:
assert any(field in text for text in warning_texts)
if __name__ == "__main__":
unittest.main()

View file

@ -5,7 +5,6 @@
# the root directory of this source tree.
import os
import unittest
from typing import Any
from unittest.mock import AsyncMock, MagicMock, patch
@ -18,42 +17,35 @@ from llama_stack.providers.remote.safety.nvidia.config import NVIDIASafetyConfig
from llama_stack.providers.remote.safety.nvidia.nvidia import NVIDIASafetyAdapter
class TestNVIDIASafetyAdapter(unittest.TestCase):
def setUp(self):
@pytest.fixture
def nvidia_safety_adapter():
"""Set up the NVIDIA safety adapter with mocked dependencies"""
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
# Initialize the adapter
self.config = NVIDIASafetyConfig(
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
)
self.adapter = NVIDIASafetyAdapter(config=self.config)
self.shield_store = AsyncMock()
self.adapter.shield_store = self.shield_store
adapter = NVIDIASafetyAdapter(config=config)
shield_store = AsyncMock()
adapter.shield_store = shield_store
# Mock the HTTP request methods
self.guardrails_post_patcher = patch(
with patch(
"llama_stack.providers.remote.safety.nvidia.nvidia.NeMoGuardrails._guardrails_post"
)
self.mock_guardrails_post = self.guardrails_post_patcher.start()
self.mock_guardrails_post.return_value = {"status": "allowed"}
) as mock_guardrails_post:
mock_guardrails_post.return_value = {"status": "allowed"}
yield adapter, shield_store, mock_guardrails_post
def tearDown(self):
"""Clean up after each test."""
self.guardrails_post_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(
self,
def assert_request(
mock_call: MagicMock,
expected_url: str,
expected_headers: dict[str, str] | None = None,
expected_json: dict[str, Any] | None = None,
) -> None:
"""
Helper method to verify request details in mock API calls.
Helper function to verify request details in mock API calls.
Args:
mock_call: The MagicMock object that was called
@ -80,7 +72,10 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
else:
assert call_args[1]["json"][key] == value
def test_register_shield_with_valid_id(self):
async def test_register_shield_with_valid_id(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
shield = Shield(
provider_id="nvidia",
type="shield",
@ -89,9 +84,12 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
)
# Register the shield
self.run_async(self.adapter.register_shield(shield))
await adapter.register_shield(shield)
async def test_register_shield_without_id(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
def test_register_shield_without_id(self):
shield = Shield(
provider_id="nvidia",
type="shield",
@ -100,10 +98,13 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
)
# Register the shield should raise a ValueError
with self.assertRaises(ValueError):
self.run_async(self.adapter.register_shield(shield))
with pytest.raises(ValueError):
await adapter.register_shield(shield)
async def test_run_shield_allowed(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
def test_run_shield_allowed(self):
# Set up the shield
shield_id = "test-shield"
shield = Shield(
@ -112,10 +113,10 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
shield_store.get_shield.return_value = shield
# Mock Guardrails API response
self.mock_guardrails_post.return_value = {"status": "allowed"}
mock_guardrails_post.return_value = {"status": "allowed"}
# Run the shield
messages = [
@ -127,13 +128,13 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
tool_calls=[],
),
]
result = self.run_async(self.adapter.run_shield(shield_id, messages))
result = await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with(
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
@ -157,7 +158,10 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
assert isinstance(result, RunShieldResponse)
assert result.violation is None
def test_run_shield_blocked(self):
async def test_run_shield_blocked(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
# Set up the shield
shield_id = "test-shield"
shield = Shield(
@ -166,10 +170,10 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
shield_store.get_shield.return_value = shield
# Mock Guardrails API response
self.mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
mock_guardrails_post.return_value = {"status": "blocked", "rails_status": {"reason": "harmful_content"}}
# Run the shield
messages = [
@ -181,13 +185,13 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
tool_calls=[],
),
]
result = self.run_async(self.adapter.run_shield(shield_id, messages))
result = await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with(
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
@ -214,25 +218,31 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
assert result.violation.violation_level == ViolationLevel.ERROR
assert result.violation.metadata == {"reason": "harmful_content"}
def test_run_shield_not_found(self):
async def test_run_shield_not_found(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
# Set up shield store to return None
shield_id = "non-existent-shield"
self.shield_store.get_shield.return_value = None
shield_store.get_shield.return_value = None
messages = [
UserMessage(role="user", content="Hello, how are you?"),
]
with self.assertRaises(ValueError):
self.run_async(self.adapter.run_shield(shield_id, messages))
with pytest.raises(ValueError):
await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was not called
self.mock_guardrails_post.assert_not_called()
mock_guardrails_post.assert_not_called()
async def test_run_shield_http_error(nvidia_safety_adapter):
adapter, shield_store, mock_guardrails_post = nvidia_safety_adapter
def test_run_shield_http_error(self):
shield_id = "test-shield"
shield = Shield(
provider_id="nvidia",
@ -240,11 +250,11 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
identifier=shield_id,
provider_resource_id="test-model",
)
self.shield_store.get_shield.return_value = shield
shield_store.get_shield.return_value = shield
# Mock Guardrails API to raise an exception
error_msg = "API Error: 500 Internal Server Error"
self.mock_guardrails_post.side_effect = Exception(error_msg)
mock_guardrails_post.side_effect = Exception(error_msg)
# Running the shield should raise an exception
messages = [
@ -256,14 +266,14 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
tool_calls=[],
),
]
with self.assertRaises(Exception) as context:
self.run_async(self.adapter.run_shield(shield_id, messages))
with pytest.raises(Exception) as excinfo:
await adapter.run_shield(shield_id, messages)
# Verify the shield store was called
self.shield_store.get_shield.assert_called_once_with(shield_id)
shield_store.get_shield.assert_called_once_with(shield_id)
# Verify the Guardrails API was called correctly
self.mock_guardrails_post.assert_called_once_with(
mock_guardrails_post.assert_called_once_with(
path="/v1/guardrail/checks",
data={
"model": shield_id,
@ -283,11 +293,14 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
},
)
# Verify the exception message
assert error_msg in str(context.exception)
assert error_msg in str(excinfo.value)
def test_init_nemo_guardrails(self):
def test_init_nemo_guardrails():
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
test_config_id = "test-custom-config-id"
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
@ -314,12 +327,15 @@ class TestNVIDIASafetyAdapter(unittest.TestCase):
assert guardrails.temperature == 0.7
assert guardrails.guardrails_service_url == os.environ["NVIDIA_GUARDRAILS_URL"]
def test_init_nemo_guardrails_invalid_temperature(self):
def test_init_nemo_guardrails_invalid_temperature():
from llama_stack.providers.remote.safety.nvidia.nvidia import NeMoGuardrails
os.environ["NVIDIA_GUARDRAILS_URL"] = "http://nemo.test"
config = NVIDIASafetyConfig(
guardrails_service_url=os.environ["NVIDIA_GUARDRAILS_URL"],
config_id="test-custom-config-id",
)
with self.assertRaises(ValueError):
with pytest.raises(ValueError):
NeMoGuardrails(config, "test-model", temperature=0)

View file

@ -5,9 +5,8 @@
# the root directory of this source tree.
import os
import unittest
import warnings
from unittest.mock import patch
from unittest.mock import AsyncMock, MagicMock, patch
import pytest
@ -32,43 +31,38 @@ from llama_stack.providers.remote.post_training.nvidia.post_training import (
)
class TestNvidiaPostTraining(unittest.TestCase):
def setUp(self):
@pytest.fixture
def nvidia_adapters():
"""Set up the NVIDIA adapters with mocked dependencies"""
os.environ["NVIDIA_BASE_URL"] = "http://nemo.test" # needed for llm inference
os.environ["NVIDIA_CUSTOMIZER_URL"] = "http://nemo.test" # needed for nemo customizer
config = NvidiaPostTrainingConfig(
base_url=os.environ["NVIDIA_BASE_URL"], customizer_url=os.environ["NVIDIA_CUSTOMIZER_URL"], api_key=None
)
self.adapter = NvidiaPostTrainingAdapter(config)
self.make_request_patcher = patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
)
self.mock_make_request = self.make_request_patcher.start()
adapter = NvidiaPostTrainingAdapter(config)
# Mock the inference client
inference_config = NVIDIAConfig(base_url=os.environ["NVIDIA_BASE_URL"], api_key=None)
self.inference_adapter = NVIDIAInferenceAdapter(inference_config)
inference_adapter = NVIDIAInferenceAdapter(inference_config)
self.mock_client = unittest.mock.MagicMock()
self.mock_client.chat.completions.create = unittest.mock.AsyncMock()
self.inference_mock_make_request = self.mock_client.chat.completions.create
self.inference_make_request_patcher = patch(
mock_client = MagicMock()
mock_client.chat.completions.create = AsyncMock()
with (
patch(
"llama_stack.providers.remote.post_training.nvidia.post_training.NvidiaPostTrainingAdapter._make_request"
) as mock_make_request,
patch(
"llama_stack.providers.remote.inference.nvidia.nvidia.NVIDIAInferenceAdapter._get_client",
return_value=self.mock_client,
)
self.inference_make_request_patcher.start()
return_value=mock_client,
),
):
yield adapter, inference_adapter, mock_make_request, mock_client
def tearDown(self):
self.make_request_patcher.stop()
self.inference_make_request_patcher.stop()
@pytest.fixture(autouse=True)
def inject_fixtures(self, run_async):
self.run_async = run_async
def _assert_request(self, mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper method to verify request details in mock calls."""
def assert_request(mock_call, expected_method, expected_path, expected_params=None, expected_json=None):
"""Helper function to verify request details in mock calls."""
call_args = mock_call.call_args
if expected_method and expected_path:
@ -85,9 +79,12 @@ class TestNvidiaPostTraining(unittest.TestCase):
for key, value in expected_json.items():
assert call_args[1]["json"][key] == value
def test_supervised_fine_tune(self):
async def test_supervised_fine_tune(nvidia_adapters):
"""Test the supervised fine-tuning API call."""
self.mock_make_request.return_value = {
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
mock_make_request.return_value = {
"id": "cust-JGTaMbJMdqjJU8WbQdN9Q2",
"created_at": "2024-12-09T04:06:28.542884",
"updated_at": "2024-12-09T04:06:28.542884",
@ -162,8 +159,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
with warnings.catch_warnings(record=True):
warnings.simplefilter("always")
training_job = self.run_async(
self.adapter.supervised_fine_tune(
training_job = await adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
@ -172,15 +168,14 @@ class TestNvidiaPostTraining(unittest.TestCase):
logger_config={},
hyperparam_search_config={},
)
)
# check the output is a PostTrainingJob
assert isinstance(training_job, NvidiaPostTrainingJob)
assert training_job.job_uuid == "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
mock_make_request.assert_called_once()
assert_request(
mock_make_request,
"POST",
"/v1/customization/jobs",
expected_json={
@ -198,7 +193,10 @@ class TestNvidiaPostTraining(unittest.TestCase):
},
)
def test_supervised_fine_tune_with_qat(self):
async def test_supervised_fine_tune_with_qat(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
algorithm_config = QATFinetuningConfig(type="QAT", quantizer_name="quantizer_name", group_size=1)
data_config = DataConfig(
dataset_id="sample-basic-test", batch_size=16, shuffle=False, data_format=DatasetFormat.instruct
@ -215,9 +213,8 @@ class TestNvidiaPostTraining(unittest.TestCase):
optimizer_config=optimizer_config,
)
# This will raise NotImplementedError since QAT is not supported
with self.assertRaises(NotImplementedError):
self.run_async(
self.adapter.supervised_fine_tune(
with pytest.raises(NotImplementedError):
await adapter.supervised_fine_tune(
job_uuid="1234",
model="meta/llama-3.2-1b-instruct@v1.0.0+L40",
checkpoint_dir="",
@ -226,21 +223,23 @@ class TestNvidiaPostTraining(unittest.TestCase):
logger_config={},
hyperparam_search_config={},
)
)
def test_get_training_job_status(self):
customizer_status_to_job_status = [
@pytest.mark.parametrize(
"customizer_status,expected_status",
[
("running", "in_progress"),
("completed", "completed"),
("failed", "failed"),
("cancelled", "cancelled"),
("pending", "scheduled"),
("unknown", "scheduled"),
]
],
)
async def test_get_training_job_status(nvidia_adapters, customizer_status, expected_status):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
for customizer_status, expected_status in customizer_status_to_job_status:
with self.subTest(customizer_status=customizer_status, expected_status=expected_status):
self.mock_make_request.return_value = {
mock_make_request.return_value = {
"created_at": "2024-12-09T04:06:28.580220",
"updated_at": "2024-12-09T04:21:19.852832",
"status": customizer_status,
@ -254,7 +253,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
status = self.run_async(self.adapter.get_training_job_status(job_uuid=job_id))
status = await adapter.get_training_job_status(job_uuid=job_id)
assert isinstance(status, NvidiaPostTrainingJobStatusResponse)
assert status.status.value == expected_status
@ -265,16 +264,19 @@ class TestNvidiaPostTraining(unittest.TestCase):
assert status.train_loss == 1.718016266822815
assert status.val_loss == 1.8661999702453613
self._assert_request(
self.mock_make_request,
assert_request(
mock_make_request,
"GET",
f"/v1/customization/jobs/{job_id}/status",
expected_params={"job_id": job_id},
)
def test_get_training_jobs(self):
async def test_get_training_jobs(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
self.mock_make_request.return_value = {
mock_make_request.return_value = {
"data": [
{
"id": job_id,
@ -300,7 +302,7 @@ class TestNvidiaPostTraining(unittest.TestCase):
]
}
jobs = self.run_async(self.adapter.get_training_jobs())
jobs = await adapter.get_training_jobs()
assert isinstance(jobs, ListNvidiaPostTrainingJobs)
assert len(jobs.data) == 1
@ -308,31 +310,37 @@ class TestNvidiaPostTraining(unittest.TestCase):
assert job.job_uuid == job_id
assert job.status.value == "completed"
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
mock_make_request.assert_called_once()
assert_request(
mock_make_request,
"GET",
"/v1/customization/jobs",
expected_params={"page": 1, "page_size": 10, "sort": "created_at"},
)
def test_cancel_training_job(self):
self.mock_make_request.return_value = {} # Empty response for successful cancellation
async def test_cancel_training_job(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
mock_make_request.return_value = {} # Empty response for successful cancellation
job_id = "cust-JGTaMbJMdqjJU8WbQdN9Q2"
result = self.run_async(self.adapter.cancel_training_job(job_uuid=job_id))
result = await adapter.cancel_training_job(job_uuid=job_id)
assert result is None
self.mock_make_request.assert_called_once()
self._assert_request(
self.mock_make_request,
mock_make_request.assert_called_once()
assert_request(
mock_make_request,
"POST",
f"/v1/customization/jobs/{job_id}/cancel",
expected_params={"job_id": job_id},
)
def test_inference_register_model(self):
async def test_inference_register_model(nvidia_adapters):
adapter, inference_adapter, mock_make_request, mock_client = nvidia_adapters
model_id = "default/job-1234"
model_type = ModelType.llm
model = Model(
@ -342,21 +350,15 @@ class TestNvidiaPostTraining(unittest.TestCase):
provider_resource_id=model_id,
model_type=model_type,
)
result = self.run_async(self.inference_adapter.register_model(model))
result = await inference_adapter.register_model(model)
assert result == model
assert len(self.inference_adapter.alias_to_provider_id_map) > 1
assert self.inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
assert len(inference_adapter.alias_to_provider_id_map) > 1
assert inference_adapter.get_provider_model_id(model.provider_model_id) == model_id
with patch.object(self.inference_adapter, "chat_completion") as mock_chat_completion:
self.run_async(
self.inference_adapter.chat_completion(
with patch.object(inference_adapter, "chat_completion") as mock_chat_completion:
await inference_adapter.chat_completion(
model_id=model_id,
messages=[{"role": "user", "content": "Hello, model"}],
)
)
mock_chat_completion.assert_called()
if __name__ == "__main__":
unittest.main()

View file

@ -5,15 +5,19 @@
# the root directory of this source tree.
import os
import unittest
import pytest
from llama_stack.distribution.stack import replace_env_vars
class TestReplaceEnvVars(unittest.TestCase):
def setUp(self):
# Clear any existing environment variables we'll use in tests
@pytest.fixture(autouse=True)
def setup_env_vars():
"""Set up test environment variables and clean up after test"""
# Store original values
original_vars = {}
for var in ["TEST_VAR", "EMPTY_VAR", "ZERO_VAR"]:
original_vars[var] = os.environ.get(var)
if var in os.environ:
del os.environ[var]
@ -22,56 +26,70 @@ class TestReplaceEnvVars(unittest.TestCase):
os.environ["EMPTY_VAR"] = ""
os.environ["ZERO_VAR"] = "0"
def test_simple_replacement(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR}"), "test_value")
yield
def test_default_value_when_not_set(self):
self.assertEqual(replace_env_vars("${env.NOT_SET:=default}"), "default")
# Cleanup: restore original values
for var, original_value in original_vars.items():
if original_value is not None:
os.environ[var] = original_value
elif var in os.environ:
del os.environ[var]
def test_default_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:=default}"), "test_value")
def test_default_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=default}"), "default")
def test_simple_replacement():
assert replace_env_vars("${env.TEST_VAR}") == "test_value"
def test_none_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:=}"), None)
def test_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:=}"), "test_value")
def test_default_value_when_not_set():
assert replace_env_vars("${env.NOT_SET:=default}") == "default"
def test_empty_var_no_default(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}"), None)
def test_conditional_value_when_set(self):
self.assertEqual(replace_env_vars("${env.TEST_VAR:+conditional}"), "conditional")
def test_default_value_when_set():
assert replace_env_vars("${env.TEST_VAR:=default}") == "test_value"
def test_conditional_value_when_not_set(self):
self.assertEqual(replace_env_vars("${env.NOT_SET:+conditional}"), None)
def test_conditional_value_when_empty(self):
self.assertEqual(replace_env_vars("${env.EMPTY_VAR:+conditional}"), None)
def test_default_value_when_empty():
assert replace_env_vars("${env.EMPTY_VAR:=default}") == "default"
def test_conditional_value_with_zero(self):
self.assertEqual(replace_env_vars("${env.ZERO_VAR:+conditional}"), "conditional")
def test_mixed_syntax(self):
self.assertEqual(
replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}"), "test_value and "
)
self.assertEqual(
replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}"), "default and conditional"
)
def test_none_value_when_empty():
assert replace_env_vars("${env.EMPTY_VAR:=}") is None
def test_nested_structures(self):
def test_value_when_set():
assert replace_env_vars("${env.TEST_VAR:=}") == "test_value"
def test_empty_var_no_default():
assert replace_env_vars("${env.EMPTY_VAR_NO_DEFAULT:+}") is None
def test_conditional_value_when_set():
assert replace_env_vars("${env.TEST_VAR:+conditional}") == "conditional"
def test_conditional_value_when_not_set():
assert replace_env_vars("${env.NOT_SET:+conditional}") is None
def test_conditional_value_when_empty():
assert replace_env_vars("${env.EMPTY_VAR:+conditional}") is None
def test_conditional_value_with_zero():
assert replace_env_vars("${env.ZERO_VAR:+conditional}") == "conditional"
def test_mixed_syntax():
assert replace_env_vars("${env.TEST_VAR:=default} and ${env.NOT_SET:+conditional}") == "test_value and "
assert replace_env_vars("${env.NOT_SET:=default} and ${env.TEST_VAR:+conditional}") == "default and conditional"
def test_nested_structures():
data = {
"key1": "${env.TEST_VAR:=default}",
"key2": ["${env.NOT_SET:=default}", "${env.TEST_VAR:+conditional}"],
"key3": {"nested": "${env.NOT_SET:+conditional}"},
}
expected = {"key1": "test_value", "key2": ["default", "conditional"], "key3": {"nested": None}}
self.assertEqual(replace_env_vars(data), expected)
if __name__ == "__main__":
unittest.main()
assert replace_env_vars(data) == expected